/*
 * Decompiled with CFR 0.152.
 */
package org.apache.uniffle.server.block;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.apache.uniffle.common.exception.InvalidRequestException;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.server.ShuffleTaskInfo;
import org.apache.uniffle.server.block.ShuffleBlockIdManager;
import org.apache.uniffle.shaded.guava.annotations.VisibleForTesting;
import org.apache.uniffle.shaded.guava.collect.Maps;
import org.apache.uniffle.shaded.guava.collect.Sets;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultShuffleBlockIdManager
implements ShuffleBlockIdManager {
    private static final Logger LOG = LoggerFactory.getLogger(DefaultShuffleBlockIdManager.class);
    private Map<String, Map<Integer, Roaring64NavigableMap[]>> partitionsToBlockIds = JavaUtils.newConcurrentMap();

    @VisibleForTesting
    public static Roaring64NavigableMap getBlockIdsByPartitionId(Set<Integer> requestPartitions, Roaring64NavigableMap bitmap, Roaring64NavigableMap resultBitmap, BlockIdLayout blockIdLayout) {
        bitmap.forEach(blockId -> {
            int partitionId = blockIdLayout.getPartitionId(blockId);
            if (requestPartitions.contains(partitionId)) {
                resultBitmap.addLong(blockId);
            }
        });
        return resultBitmap;
    }

    @Override
    public void registerAppId(String appId) {
        this.partitionsToBlockIds.computeIfAbsent(appId, key -> JavaUtils.newConcurrentMap());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public int addFinishedBlockIds(ShuffleTaskInfo taskInfo, String appId, Integer shuffleId, Map<Integer, long[]> partitionToBlockIds, int bitmapNum) {
        Map<Integer, Roaring64NavigableMap[]> shuffleIdToPartitions = this.partitionsToBlockIds.get(appId);
        if (shuffleIdToPartitions == null) {
            throw new RssException("appId[" + appId + "] is expired!");
        }
        shuffleIdToPartitions.computeIfAbsent(shuffleId, key -> {
            Roaring64NavigableMap[] blockIds = new Roaring64NavigableMap[bitmapNum];
            for (int i = 0; i < bitmapNum; ++i) {
                blockIds[i] = Roaring64NavigableMap.bitmapOf((long[])new long[0]);
            }
            return blockIds;
        });
        Roaring64NavigableMap[] blockIds = shuffleIdToPartitions.get(shuffleId);
        if (blockIds.length != bitmapNum) {
            throw new InvalidRequestException("Request expects " + bitmapNum + " bitmaps, but there are " + blockIds.length + " bitmaps!");
        }
        int totalUpdatedBlockCount = 0;
        for (Map.Entry<Integer, long[]> entry : partitionToBlockIds.entrySet()) {
            Integer partitionId = entry.getKey();
            Roaring64NavigableMap bitmap = blockIds[partitionId % bitmapNum];
            int updatedBlockCount = 0;
            Roaring64NavigableMap roaring64NavigableMap = bitmap;
            synchronized (roaring64NavigableMap) {
                for (long blockId : entry.getValue()) {
                    if (bitmap.contains(blockId)) continue;
                    bitmap.addLong(blockId);
                    ++updatedBlockCount;
                    ++totalUpdatedBlockCount;
                }
            }
            taskInfo.incBlockNumber(shuffleId, partitionId, updatedBlockCount);
        }
        return totalUpdatedBlockCount;
    }

    @Override
    public byte[] getFinishedBlockIds(ShuffleTaskInfo taskInfo, String appId, Integer shuffleId, Set<Integer> partitions, BlockIdLayout blockIdLayout) throws IOException {
        Map<Integer, Roaring64NavigableMap[]> shuffleIdToPartitions = this.partitionsToBlockIds.get(appId);
        if (shuffleIdToPartitions == null) {
            LOG.warn("Empty blockIds for app: {}. This should not happen", (Object)appId);
            return null;
        }
        Roaring64NavigableMap[] blockIds = shuffleIdToPartitions.get(shuffleId);
        if (blockIds == null) {
            LOG.warn("Empty blockIds for app: {}, shuffleId: {}", (Object)appId, (Object)shuffleId);
            return new byte[0];
        }
        long expectedBlockNumber = 0L;
        HashMap<Integer, HashSet<Integer>> bitmapIndexToPartitions = Maps.newHashMap();
        for (int partitionId : partitions) {
            int bitmapIndex = partitionId % blockIds.length;
            if (bitmapIndexToPartitions.containsKey(bitmapIndex)) {
                ((Set)bitmapIndexToPartitions.get(bitmapIndex)).add(partitionId);
            } else {
                HashSet<Integer> newHashSet = Sets.newHashSet(partitionId);
                bitmapIndexToPartitions.put(bitmapIndex, newHashSet);
            }
            expectedBlockNumber += taskInfo.getBlockNumber(shuffleId, partitionId);
        }
        Roaring64NavigableMap res = Roaring64NavigableMap.bitmapOf((long[])new long[0]);
        for (Map.Entry entry : bitmapIndexToPartitions.entrySet()) {
            Set requestPartitions = (Set)entry.getValue();
            Roaring64NavigableMap bitmap = blockIds[(Integer)entry.getKey()];
            DefaultShuffleBlockIdManager.getBlockIdsByPartitionId(requestPartitions, bitmap, res, blockIdLayout);
        }
        if (res.getLongCardinality() != expectedBlockNumber) {
            throw new RssException("Inconsistent block number for partitions: " + partitions + ". Excepted: " + expectedBlockNumber + ", actual: " + res.getLongCardinality());
        }
        return RssUtils.serializeBitMap((Roaring64NavigableMap)res);
    }

    @Override
    public void removeBlockIdByShuffleId(String appId, List<Integer> shuffleIds) {
        Optional.ofNullable(this.partitionsToBlockIds.get(appId)).ifPresent(x -> {
            for (Integer shuffleId : shuffleIds) {
                x.remove(shuffleId);
            }
        });
    }

    @Override
    public void removeBlockIdByAppId(String appId) {
        this.partitionsToBlockIds.remove(appId);
    }

    @Override
    public long getTotalBlockCount() {
        return this.partitionsToBlockIds.values().stream().flatMap(innerMap -> innerMap.values().stream()).flatMapToLong(arr -> Arrays.stream(arr).mapToLong(Roaring64NavigableMap::getLongCardinality)).sum();
    }

    @Override
    public long getBlockCountByShuffleId(String appId, List<Integer> shuffleIds) {
        return this.partitionsToBlockIds.values().stream().filter(k -> shuffleIds.contains(k.keySet())).flatMap(innerMap -> innerMap.values().stream()).flatMapToLong(arr -> Arrays.stream(arr).mapToLong(Roaring64NavigableMap::getLongCardinality)).sum();
    }

    @Override
    public boolean contains(String appId) {
        return this.partitionsToBlockIds.containsKey(appId);
    }

    @Override
    public long getBitmapNum(String appId, int shuffleId) {
        return this.partitionsToBlockIds.get(appId).get(shuffleId).length;
    }
}

