package com.geoway.atlas.web.api.v2.job;

import com.geoway.atlas.data.vector.spark.common.rpc.common.AtlasRpcServerException;
import com.geoway.atlas.web.api.v2.component.bean.AtlasGisToolkitBeanFactory;
import com.geoway.atlas.web.api.v2.component.rpc.RpcClientProxy;
import com.geoway.atlas.web.api.v2.utils.ResponseBuilder;
import lombok.Getter;
import lombok.SneakyThrows;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import javax.servlet.http.HttpServletRequest;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * 异步任务管理
 *
 * @author zhaotong 2022/9/6 16:59
 */

@Component
public class JobManager {

    public static final String JOB_ID = "jobid";

    public static final String TASK_ID = "taskid";

    private final AtomicInteger autoId = new AtomicInteger(0);

    @Autowired
    private RpcClientProxy rpcClientProxy;

    @Getter
    private ConcurrentMap<String, ConcurrentMap<String, Future<?>>> futureConcurrentMap = new ConcurrentHashMap<>();

    /**
     * 获取jobId
     *
     * @param request 请求
     * @return 返回Job id
     */
    public String getJobId(HttpServletRequest request) {
        String jobId = request.getParameter(JOB_ID);
        if (StringUtils.isBlank(jobId)) {
            jobId = Integer.toString(autoId.getAndIncrement());
        }

        return jobId;
    }

    /**
     * 获取jobId
     *
     * @param request 请求
     * @return 返回Job id
     */
    public String getTaskId(HttpServletRequest request) {
        String taskId = request.getParameter(TASK_ID);
        if (StringUtils.isBlank(taskId)) {
            throw new RuntimeException("无法获取请求的taskid!");
        }

        return taskId;
    }

    @SneakyThrows
    public String toJSON(Map<String, Object> map) {
        return AtlasGisToolkitBeanFactory.getObjectMapper().writeValueAsString(map);
    }

    @SneakyThrows
    public Map<String, Object> fromJSON(String jsonString) {
        return (Map<String, Object>) AtlasGisToolkitBeanFactory.getObjectMapper().readValue(jsonString, Map.class);
    }

    /**
     * 放入任务信息
     *
     * @param rawTaskId   任务Id
     * @param jobId    工作项Id
     * @param response 响应信息
     */
    public String startJob(String rawTaskId, String jobId, Map<String, Object> response) {
        return rpcClientProxy.getSparkRpcClientApi().startJob(rawTaskId, jobId, toJSON(response));
    }

    public synchronized void putFuture(String rawTaskId, String jobId, Future<?> future) {
        ConcurrentMap<String, Future<?>> futureMap;
        if (!futureConcurrentMap.containsKey(rawTaskId)) {
            futureMap = new ConcurrentHashMap<>();
            futureConcurrentMap.put(rawTaskId, futureMap);
        } else {
            futureMap = futureConcurrentMap.get(rawTaskId);
        }

        futureMap.put(jobId, future);
    }

    public synchronized void finishJob(String rawTaskId, String jobId, Map<String, Object> response) {
        int WAIT_FOR_CHECK = 50;
        int MAX_TRY = 2000;

        int initTaskTry = 0;
        /* 确保任务完成前已经被加到futureMap中 */
        while (!futureConcurrentMap.containsKey(rawTaskId) && initTaskTry < MAX_TRY) {
            try {
                Thread.sleep(WAIT_FOR_CHECK);
                initTaskTry++;
            } catch (InterruptedException ignored) {
            }
        }
        // 如果达到了重试上限则直接将结果加到finishjob中
        if(initTaskTry == MAX_TRY){
            if(!futureConcurrentMap.containsKey(rawTaskId)) {
                rpcClientProxy.getSparkRpcClientApi().finishJob(rawTaskId, jobId, toJSON(response));
                return;
            }
        }

        int initJobTry = 0;
        while (!futureConcurrentMap.get(rawTaskId).containsKey(jobId) && initJobTry < MAX_TRY) {
            try {
                Thread.sleep(WAIT_FOR_CHECK);
                initJobTry++;
            } catch (InterruptedException ignored) {
            }
        }
        // 如果达到了重试上限则直接将结果加到finishjob中
        if(initJobTry == MAX_TRY){
            if(!futureConcurrentMap.get(rawTaskId).containsKey(jobId)) {
                rpcClientProxy.getSparkRpcClientApi().finishJob(rawTaskId, jobId, toJSON(response));
                return;
            }
        }

        futureConcurrentMap.get(rawTaskId).remove(jobId);
        rpcClientProxy.getSparkRpcClientApi().finishJob(rawTaskId, jobId, toJSON(response));
    }

    public synchronized void stopFutures(String rawTaskId) {
        if (futureConcurrentMap.containsKey(rawTaskId)) {
            for (Future<?> future : futureConcurrentMap.get(rawTaskId).values()) {
                try {
                    future.cancel(true);
                } catch (Throwable ignored) {
                }
            }
//            futureConcurrentMap.remove(rawTaskId);
        }
    }

    public synchronized void cancelJob(String taskId, String jobId) {
        Map<String, Object> failedObj = ResponseBuilder.buildFailed(new AtlasRpcServerException("任务" + taskId + "已被停止!"));
        rpcClientProxy.getSparkRpcClientApi().cancelJob(taskId, jobId, toJSON(failedObj));
    }

    /**
     * 获取job详细信息
     *
     * @param jobId 任务id
     * @return 返回任务详细信息
     */
    public Map<String, Object> getJobInfo(String jobId) {
        String response = rpcClientProxy.getSparkRpcClientApi().getJobReponse("", jobId);
        if (StringUtils.isNotBlank(response)) {
            return fromJSON(response);
        } else {
            return ResponseBuilder.buildFailed(new RuntimeException("无法找到任务id"));
        }
    }
}
