/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.runtime.library.common.shuffle.orderedgrouped;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.tez.common.InputContextUtils;
import org.apache.tez.common.RssTezConfig;
import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.TezUtilsInternal;
import org.apache.tez.common.UmbilicalUtils;
import org.apache.tez.common.counters.TaskCounter;
import org.apache.tez.common.counters.TezCounter;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezTaskID;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.InputContext;
import org.apache.tez.runtime.library.common.ConfigUtils;
import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.ExceptionReporter;
import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.RMRssShuffleScheduler;
import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.ShuffleInputEventHandlerOrderedGrouped;
import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.ShuffleScheduler;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.record.reader.KeyValuesReader;
import org.apache.uniffle.client.record.reader.RMRecordsReader;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.uniffle.shaded.org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InterfaceAudience.Private
@InterfaceStability.Unstable
public class RMRssShuffle
implements ExceptionReporter {
    private static final Logger LOG = LoggerFactory.getLogger(RMRssShuffle.class);
    private final Configuration conf;
    private final RssConf rssConf;
    private final InputContext inputContext;
    private final int numInputs;
    private final int shuffleId;
    private final ApplicationAttemptId applicationAttemptId;
    private final String appId;
    private ShuffleInputEventHandlerOrderedGrouped eventHandler;
    private final TezTaskAttemptID tezTaskAttemptID;
    private final String srcNameTrimmed;
    private final String clientType;
    private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
    private AtomicBoolean isShutDown = new AtomicBoolean(false);
    final TezCounter skippedInputCounter;
    final TezCounter inputRecordCounter;
    final Map<Integer, Set<InputAttemptIdentifier>> partitionIdToSuccessMapTaskAttempts = new HashMap<Integer, Set<InputAttemptIdentifier>>();
    final Map<Integer, Set<TezTaskID>> partitionIdToSuccessTezTasks = new HashMap<Integer, Set<TezTaskID>>();
    final Set<Integer> partitionIds = new HashSet<Integer>();
    private RMRecordsReader reader = null;
    private RMRssShuffleScheduler scheduler;

    public RMRssShuffle(InputContext inputContext, Configuration conf, int numInputs, int shuffleId, ApplicationAttemptId applicationAttemptId) throws IOException {
        this.inputContext = inputContext;
        this.conf = conf;
        this.rssConf = RssTezConfig.toRssConf(conf);
        this.numInputs = numInputs;
        this.shuffleId = shuffleId;
        this.applicationAttemptId = applicationAttemptId;
        this.clientType = conf.get("tez.rss.client.type", "GRPC_NETTY");
        this.appId = this.applicationAttemptId.toString();
        this.srcNameTrimmed = TezUtilsInternal.cleanVertexName((String)inputContext.getSourceVertexName());
        LOG.info(this.srcNameTrimmed + ": Shuffle assigned with " + numInputs + " inputs.");
        this.skippedInputCounter = inputContext.getCounters().findCounter((Enum)TaskCounter.NUM_SKIPPED_INPUTS);
        this.inputRecordCounter = inputContext.getCounters().findCounter((Enum)TaskCounter.INPUT_RECORDS_PROCESSED);
        this.scheduler = new RMRssShuffleScheduler(this.inputContext, this.conf, numInputs, this, null, null, System.currentTimeMillis(), null, false, 0, this.srcNameTrimmed, this);
        this.eventHandler = new ShuffleInputEventHandlerOrderedGrouped(inputContext, (ShuffleScheduler)this.scheduler, ShuffleUtils.isTezShuffleHandler((Configuration)conf));
        this.tezTaskAttemptID = InputContextUtils.getTezTaskAttemptID(this.inputContext);
        inputContext.inputIsReady();
    }

    public void handleEvents(List<Event> events) throws IOException {
        if (!this.isShutDown.get()) {
            this.eventHandler.handleEvents(events);
        } else {
            LOG.info(this.srcNameTrimmed + ": Ignoring events since already shutdown. EventCount: " + events.size());
        }
    }

    public void run() throws IOException {
        this.partitionToServers = UmbilicalUtils.requestShuffleServer(this.inputContext.getApplicationId(), this.conf, this.tezTaskAttemptID, this.shuffleId);
    }

    public void shutdown() {
        if (!this.isShutDown.getAndSet(true)) {
            if (this.reader != null) {
                this.reader.close();
            }
            LOG.info("Shutting down Shuffle for source: " + this.srcNameTrimmed);
        }
    }

    public void waitForEvents() throws InterruptedException {
        while (!this.allInputTaskAttemptDone()) {
            Thread.sleep(100L);
        }
        this.reportUniqueBlockIds();
        if (this.partitionIds.size() > 0) {
            this.reader = this.createRMRecordsReader(this.partitionIds);
            this.reader.start();
        }
    }

    private boolean allInputTaskAttemptDone() {
        return (long)this.partitionIdToSuccessTezTasks.values().stream().mapToInt(s -> s.size()).sum() + this.skippedInputCounter.getValue() == (long)this.numInputs;
    }

    public void reportUniqueBlockIds() {
        ShuffleWriteClient writeClient = RssTezUtils.createShuffleClient(this.conf);
        for (int partitionId : this.partitionIds) {
            Roaring64NavigableMap blockIdBitmap = writeClient.getShuffleResult(null, new HashSet<ShuffleServerInfo>((Collection)this.partitionToServers.get(partitionId)), this.appId, this.shuffleId, partitionId);
            Roaring64NavigableMap taskIdBitmap = RssTezUtils.fetchAllRssTaskIds(this.partitionIdToSuccessMapTaskAttempts.get(partitionId), this.numInputs, this.applicationAttemptId.getAttemptId(), RssTezUtils.getMaxAttemptNo(this.conf));
            Roaring64NavigableMap uniqueBlockIdBitMap = Roaring64NavigableMap.bitmapOf(new long[0]);
            blockIdBitmap.forEach(blockId -> {
                long taId = RssTezUtils.getTaskAttemptId(blockId);
                if (taskIdBitmap.contains(taId)) {
                    uniqueBlockIdBitMap.add(blockId);
                }
            });
            writeClient.startSortMerge(new HashSet<ShuffleServerInfo>((Collection)this.partitionToServers.get(partitionId)), this.appId, this.shuffleId, partitionId, uniqueBlockIdBitMap);
        }
    }

    public KeyValuesReader getKeyValuesReader() {
        if (this.reader == null) {
            return new KeyValuesReader(){

                @Override
                public boolean next() {
                    return false;
                }

                public Object getCurrentKey() throws IOException {
                    throw new IOException("No data available");
                }

                public Iterable getCurrentValues() throws IOException {
                    throw new IOException("No data available");
                }
            };
        }
        return this.reader.keyValuesReader();
    }

    @VisibleForTesting
    public RMRecordsReader createRMRecordsReader(Set partitionIds) {
        Class keyClass = ConfigUtils.getIntermediateInputKeyClass((Configuration)this.conf);
        Class valueClass = ConfigUtils.getIntermediateInputValueClass((Configuration)this.conf);
        WritableComparator rawComparator = WritableComparator.get((Class)keyClass);
        return new RMRecordsReader(this.appId, this.shuffleId, partitionIds, this.partitionToServers, this.rssConf, keyClass, valueClass, rawComparator, true, null, false, inc -> this.inputRecordCounter.increment(inc), this.clientType);
    }

    public void reportException(Throwable t) {
        throw new RssException("should never happen!");
    }

    public void killSelf(Exception exception, String message) {
        throw new RssException("should never happen!");
    }
}

