/*
 * Decompiled with CFR 0.152.
 */
package org.apache.uniffle.client.impl.grpc;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.uniffle.client.api.ClientInfo;
import org.apache.uniffle.client.common.ShuffleServerPushCostTracker;
import org.apache.uniffle.client.impl.grpc.ShuffleServerGrpcClient;
import org.apache.uniffle.client.request.RssGetInMemoryShuffleDataRequest;
import org.apache.uniffle.client.request.RssGetShuffleDataRequest;
import org.apache.uniffle.client.request.RssGetShuffleIndexRequest;
import org.apache.uniffle.client.request.RssGetSortedShuffleDataRequest;
import org.apache.uniffle.client.request.RssSendShuffleDataRequest;
import org.apache.uniffle.client.response.RssGetInMemoryShuffleDataResponse;
import org.apache.uniffle.client.response.RssGetShuffleDataResponse;
import org.apache.uniffle.client.response.RssGetShuffleIndexResponse;
import org.apache.uniffle.client.response.RssGetSortedShuffleDataResponse;
import org.apache.uniffle.client.response.RssSendShuffleDataResponse;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ReadSegment;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.NotRetryException;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.netty.client.TransportClient;
import org.apache.uniffle.common.netty.client.TransportClientFactory;
import org.apache.uniffle.common.netty.client.TransportConf;
import org.apache.uniffle.common.netty.client.TransportContext;
import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataRequest;
import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataResponse;
import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataV2Request;
import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataV3Request;
import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexRequest;
import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexResponse;
import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataRequest;
import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataResponse;
import org.apache.uniffle.common.netty.protocol.GetSortedShuffleDataRequest;
import org.apache.uniffle.common.netty.protocol.GetSortedShuffleDataResponse;
import org.apache.uniffle.common.netty.protocol.RpcResponse;
import org.apache.uniffle.common.netty.protocol.SendShuffleDataRequest;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.uniffle.shaded.org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ShuffleServerGrpcNettyClient
extends ShuffleServerGrpcClient {
    private static final Logger LOG = LoggerFactory.getLogger(ShuffleServerGrpcNettyClient.class);
    private int nettyPort;
    private TransportClientFactory clientFactory;
    private static final AtomicLong counter = new AtomicLong();

    @VisibleForTesting
    public ShuffleServerGrpcNettyClient(String host, int grpcPort, int nettyPort) {
        this(new RssConf(), host, grpcPort, nettyPort);
    }

    public ShuffleServerGrpcNettyClient(RssConf rssConf, String host, int grpcPort, int nettyPort) {
        this(rssConf == null ? new RssConf() : rssConf, host, grpcPort, nettyPort, rssConf == null ? RssClientConf.RPC_MAX_ATTEMPTS.defaultValue().intValue() : rssConf.getInteger(RssClientConf.RPC_MAX_ATTEMPTS), rssConf == null ? RssClientConf.RPC_TIMEOUT_MS.defaultValue().longValue() : rssConf.getLong(RssClientConf.RPC_TIMEOUT_MS), rssConf == null ? RssClientConf.RPC_NETTY_PAGE_SIZE.defaultValue().intValue() : rssConf.getInteger(RssClientConf.RPC_NETTY_PAGE_SIZE), rssConf == null ? RssClientConf.RPC_NETTY_MAX_ORDER.defaultValue().intValue() : rssConf.getInteger(RssClientConf.RPC_NETTY_MAX_ORDER), rssConf == null ? RssClientConf.RPC_NETTY_SMALL_CACHE_SIZE.defaultValue().intValue() : rssConf.getInteger(RssClientConf.RPC_NETTY_SMALL_CACHE_SIZE));
    }

    public ShuffleServerGrpcNettyClient(RssConf rssConf, String host, int grpcPort, int nettyPort, int maxRetryAttempts, long rpcTimeoutMs, int pageSize, int maxOrder, int smallCacheSize) {
        super(host, grpcPort, maxRetryAttempts, rpcTimeoutMs, true, pageSize, maxOrder, smallCacheSize, rssConf.get(RssClientConf.RSS_CLIENT_GRPC_EVENT_LOOP_THREADS));
        this.nettyPort = nettyPort;
        TransportContext transportContext = new TransportContext(new TransportConf(rssConf));
        this.clientFactory = new TransportClientFactory(transportContext);
    }

    @Override
    public ClientInfo getClientInfo() {
        return new ClientInfo(ClientType.GRPC_NETTY, new ShuffleServerInfo(this.host, this.port, this.nettyPort));
    }

    @Override
    public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest request) {
        Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks = request.getShuffleIdToBlocks();
        int stageAttemptNumber = request.getStageAttemptNumber();
        boolean isSuccessful = true;
        AtomicReference<StatusCode> failedStatusCode = new AtomicReference<StatusCode>(StatusCode.INTERNAL_ERROR);
        HashSet<Integer> needSplitPartitionIds = new HashSet<Integer>();
        for (Map.Entry<Integer, Map<Integer, List<ShuffleBlockInfo>>> stb : shuffleIdToBlocks.entrySet()) {
            int shuffleId = stb.getKey();
            int size = 0;
            int blockNum = 0;
            ArrayList<Integer> partitionIds = new ArrayList<Integer>();
            ArrayList<Integer> partitionRequireSizes = new ArrayList<Integer>();
            for (Map.Entry<Integer, List<ShuffleBlockInfo>> ptb : stb.getValue().entrySet()) {
                int partitionRequireSize = 0;
                for (ShuffleBlockInfo sbi : ptb.getValue()) {
                    partitionRequireSize += sbi.getSize();
                    ++blockNum;
                }
                size += partitionRequireSize;
                partitionIds.add(ptb.getKey());
                partitionRequireSizes.add(partitionRequireSize);
            }
            ShuffleServerPushCostTracker costTracker = request.getCostTracker();
            SendShuffleDataRequest sendShuffleDataRequest = new SendShuffleDataRequest(ShuffleServerGrpcNettyClient.requestId(), request.getAppId(), shuffleId, stageAttemptNumber, 0L, stb.getValue(), System.currentTimeMillis());
            int allocateSize = size + sendShuffleDataRequest.encodedLength();
            int finalBlockNum = blockNum;
            try {
                RetryUtils.retryWithCondition(() -> {
                    TransportClient transportClient = this.getTransportClient();
                    Pair<Long, List<Integer>> result = this.requirePreAllocation(request.getAppId(), shuffleId, partitionIds, partitionRequireSizes, allocateSize, request.getRetryMax(), request.getRetryIntervalMax(), failedStatusCode, costTracker);
                    long requireId = result.getLeft();
                    needSplitPartitionIds.addAll((Collection)result.getRight());
                    if (requireId == -1L) {
                        ClientInfo clientInfo = this.getClientInfo();
                        if (clientInfo != null && costTracker != null) {
                            costTracker.recordRequireBufferFailure(clientInfo.getShuffleServerInfo().getId());
                        }
                        throw new RssException(String.format("requirePreAllocation failed! size[%s], host[%s], port[%s]", allocateSize, this.host, this.port));
                    }
                    sendShuffleDataRequest.setRequireId(requireId);
                    sendShuffleDataRequest.setTimestamp(System.currentTimeMillis());
                    long start = System.currentTimeMillis();
                    RpcResponse rpcResponse = transportClient.sendRpcSync(sendShuffleDataRequest, this.rpcTimeout);
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("Do sendShuffleData to {}:{} rpc cost:" + (System.currentTimeMillis() - start) + " ms for " + allocateSize + " bytes with " + finalBlockNum + " blocks", (Object)this.host, (Object)this.port);
                    }
                    if (rpcResponse.getStatusCode() != StatusCode.SUCCESS) {
                        failedStatusCode.set(StatusCode.fromCode(rpcResponse.getStatusCode().statusCode()));
                        String msg = "Can't send shuffle data with " + finalBlockNum + " blocks to " + this.host + ":" + this.port + ", statusCode=" + (Object)((Object)rpcResponse.getStatusCode()) + ", errorMsg:" + rpcResponse.getRetMessage();
                        if (NOT_RETRY_STATUS_CODES.contains((Object)rpcResponse.getStatusCode())) {
                            throw new NotRetryException(msg);
                        }
                        throw new RssException(msg);
                    }
                    return rpcResponse;
                }, null, request.getRetryIntervalMax(), this.maxRetryAttempts, t2 -> !(t2 instanceof OutOfMemoryError) && !(t2 instanceof NotRetryException));
            }
            catch (Throwable throwable) {
                LOG.warn("Failed to send shuffle data due to ", throwable);
                isSuccessful = false;
                break;
            }
        }
        RssSendShuffleDataResponse response = isSuccessful ? new RssSendShuffleDataResponse(StatusCode.SUCCESS) : new RssSendShuffleDataResponse(failedStatusCode.get());
        response.setNeedSplitPartitionIds(needSplitPartitionIds);
        return response;
    }

    @Override
    public RssGetInMemoryShuffleDataResponse getInMemoryShuffleData(RssGetInMemoryShuffleDataRequest request) {
        GetMemoryShuffleDataResponse getMemoryShuffleDataResponse;
        RpcResponse rpcResponse;
        TransportClient transportClient = this.getTransportClient();
        GetMemoryShuffleDataRequest getMemoryShuffleDataRequest = new GetMemoryShuffleDataRequest(ShuffleServerGrpcNettyClient.requestId(), request.getAppId(), request.getShuffleId(), request.getPartitionId(), request.getLastBlockId(), request.getReadBufferSize(), System.currentTimeMillis(), request.getExpectedTaskIds());
        String requestInfo = "appId[" + request.getAppId() + "], shuffleId[" + request.getShuffleId() + "], partitionId[" + request.getPartitionId() + "], lastBlockId[" + request.getLastBlockId() + "]";
        long start = System.currentTimeMillis();
        int retry = 0;
        while (true) {
            rpcResponse = transportClient.sendRpcSync(getMemoryShuffleDataRequest, this.rpcTimeout);
            getMemoryShuffleDataResponse = (GetMemoryShuffleDataResponse)rpcResponse;
            if (rpcResponse.getStatusCode() != StatusCode.NO_BUFFER) break;
            this.waitOrThrow(request, retry, requestInfo, rpcResponse.getStatusCode(), start);
            ++retry;
        }
        switch (rpcResponse.getStatusCode()) {
            case SUCCESS: {
                LOG.info("GetInMemoryShuffleData size:{}(bytes) from {}:{} for {} cost:{}(ms)", new Object[]{getMemoryShuffleDataResponse.body().size(), this.host, this.nettyPort, requestInfo, System.currentTimeMillis() - start});
                return new RssGetInMemoryShuffleDataResponse(StatusCode.SUCCESS, getMemoryShuffleDataResponse.body(), getMemoryShuffleDataResponse.getBufferSegments());
            }
        }
        String msg = "Can't get shuffle in memory data from " + this.host + ":" + this.nettyPort + " for " + requestInfo + ", errorMsg:" + getMemoryShuffleDataResponse.getRetMessage();
        LOG.error(msg);
        throw new RssFetchFailedException(msg);
    }

    @Override
    public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest request) {
        GetLocalShuffleIndexResponse getLocalShuffleIndexResponse;
        RpcResponse rpcResponse;
        TransportClient transportClient = this.getTransportClient();
        GetLocalShuffleIndexRequest getLocalShuffleIndexRequest = new GetLocalShuffleIndexRequest(ShuffleServerGrpcNettyClient.requestId(), request.getAppId(), request.getShuffleId(), request.getPartitionId(), request.getPartitionNumPerRange(), request.getPartitionNum());
        String requestInfo = "appId[" + request.getAppId() + "], shuffleId[" + request.getShuffleId() + "], partitionId[" + request.getPartitionId() + "]";
        long start = System.currentTimeMillis();
        int retry = 0;
        while (true) {
            rpcResponse = transportClient.sendRpcSync(getLocalShuffleIndexRequest, this.rpcTimeout);
            getLocalShuffleIndexResponse = (GetLocalShuffleIndexResponse)rpcResponse;
            if (rpcResponse.getStatusCode() != StatusCode.NO_BUFFER) break;
            this.waitOrThrow(request, retry, requestInfo, rpcResponse.getStatusCode(), start);
            ++retry;
        }
        switch (rpcResponse.getStatusCode()) {
            case SUCCESS: {
                LOG.info("GetShuffleIndex size:{}(bytes) from {}:{} for {} cost:{}(ms)", new Object[]{getLocalShuffleIndexResponse.body().size(), this.host, this.nettyPort, requestInfo, System.currentTimeMillis() - start});
                return new RssGetShuffleIndexResponse(StatusCode.SUCCESS, getLocalShuffleIndexResponse.body(), getLocalShuffleIndexResponse.getFileLength(), getLocalShuffleIndexResponse.getStorageIds());
            }
        }
        String msg = "Can't get shuffle index from " + this.host + ":" + this.nettyPort + " for " + requestInfo + ", errorMsg:" + getLocalShuffleIndexResponse.getRetMessage();
        LOG.error(msg);
        throw new RssFetchFailedException(msg);
    }

    @Override
    public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request) {
        GetLocalShuffleDataResponse getLocalShuffleDataResponse;
        RpcResponse rpcResponse;
        TransportClient transportClient = this.getTransportClient();
        GetLocalShuffleDataRequest getLocalShuffleDataRequest = null;
        getLocalShuffleDataRequest = request.storageIdSpecified() ? (request.isNextReadSegmentsReportEnabled() ? new GetLocalShuffleDataV3Request(ShuffleServerGrpcNettyClient.requestId(), request.getAppId(), request.getShuffleId(), request.getPartitionId(), request.getPartitionNumPerRange(), request.getPartitionNum(), request.getOffset(), request.getLength(), request.getStorageId(), ReadSegment.from(request.getNextReadSegments()), System.currentTimeMillis(), request.getTaskAttemptId()) : new GetLocalShuffleDataV2Request(ShuffleServerGrpcNettyClient.requestId(), request.getAppId(), request.getShuffleId(), request.getPartitionId(), request.getPartitionNumPerRange(), request.getPartitionNum(), request.getOffset(), request.getLength(), request.getStorageId(), System.currentTimeMillis())) : new GetLocalShuffleDataRequest(ShuffleServerGrpcNettyClient.requestId(), request.getAppId(), request.getShuffleId(), request.getPartitionId(), request.getPartitionNumPerRange(), request.getPartitionNum(), request.getOffset(), request.getLength(), System.currentTimeMillis());
        String requestInfo = "appId[" + request.getAppId() + "], shuffleId[" + request.getShuffleId() + "], partitionId[" + request.getPartitionId() + "]";
        long start = System.currentTimeMillis();
        int retry = 0;
        while (true) {
            rpcResponse = transportClient.sendRpcSync(getLocalShuffleDataRequest, this.rpcTimeout);
            getLocalShuffleDataResponse = (GetLocalShuffleDataResponse)rpcResponse;
            if (rpcResponse.getStatusCode() != StatusCode.NO_BUFFER) break;
            this.waitOrThrow(request, retry, requestInfo, rpcResponse.getStatusCode(), start);
            ++retry;
        }
        switch (rpcResponse.getStatusCode()) {
            case SUCCESS: {
                LOG.info("GetShuffleData size:{}(bytes) from {}:{} for {} cost:{}(ms)", new Object[]{getLocalShuffleDataResponse.body().size(), this.host, this.nettyPort, requestInfo, System.currentTimeMillis() - start});
                return new RssGetShuffleDataResponse(StatusCode.SUCCESS, getLocalShuffleDataResponse.body());
            }
        }
        String msg = "Can't get shuffle data from " + this.host + ":" + this.nettyPort + " for " + requestInfo + ", errorMsg:" + getLocalShuffleDataResponse.getRetMessage();
        LOG.error(msg);
        throw new RssFetchFailedException(msg);
    }

    @Override
    public RssGetSortedShuffleDataResponse getSortedShuffleData(RssGetSortedShuffleDataRequest request) {
        GetSortedShuffleDataResponse getSortedShuffleDataResponse;
        RpcResponse rpcResponse;
        TransportClient transportClient = this.getTransportClient();
        GetSortedShuffleDataRequest getSortedShuffleDataRequest = new GetSortedShuffleDataRequest(ShuffleServerGrpcNettyClient.requestId(), request.getAppId(), request.getShuffleId(), request.getPartitionId(), request.getBlockId(), 0, System.currentTimeMillis());
        String requestInfo = String.format("appId[%s], shuffleId[%d], partitionId[%d], blockId[%d]", request.getAppId(), request.getShuffleId(), request.getPartitionId(), request.getBlockId());
        long start = System.currentTimeMillis();
        int retry = 0;
        while (true) {
            rpcResponse = transportClient.sendRpcSync(getSortedShuffleDataRequest, this.rpcTimeout);
            getSortedShuffleDataResponse = (GetSortedShuffleDataResponse)rpcResponse;
            if (rpcResponse.getStatusCode() != StatusCode.NO_BUFFER) break;
            this.waitOrThrow(request, retry, requestInfo, rpcResponse.getStatusCode(), start);
            ++retry;
        }
        switch (rpcResponse.getStatusCode()) {
            case SUCCESS: {
                LOG.info("GetSortedShuffleData from {}:{} for {} cost {} ms", new Object[]{this.host, this.nettyPort, requestInfo, System.currentTimeMillis() - start});
                return new RssGetSortedShuffleDataResponse(StatusCode.SUCCESS, getSortedShuffleDataResponse.getRetMessage(), getSortedShuffleDataResponse.body(), getSortedShuffleDataResponse.getNextBlockId(), getSortedShuffleDataResponse.getMergeState());
            }
        }
        String msg = String.format("Can't get sorted shuffle data from %s:%d for %s, errorMsg: %s", this.host, this.nettyPort, requestInfo, getSortedShuffleDataResponse.getRetMessage());
        LOG.error(msg);
        throw new RssFetchFailedException(msg);
    }

    public static long requestId() {
        return counter.getAndIncrement();
    }

    private TransportClient getTransportClient() {
        TransportClient transportClient;
        try {
            transportClient = this.clientFactory.createClient(this.host, this.nettyPort);
        }
        catch (Exception e) {
            throw new RssException("create transport client failed", e);
        }
        return transportClient;
    }
}

