package com.geoway.atlas.web.api.v2.service.pkg.impl.assigin;

import com.geoway.atlas.function.parser.common.QualifiedName;
import com.geoway.atlas.web.api.v2.exception.AtlasException;
import org.apache.commons.lang3.ArrayUtils;

import java.util.*;
import java.util.function.Function;

/**
 * 支持值聚合操作
 * @author zhaotong 2024/12/19 18:26
 */
public class AggValueFunctionPlan extends AssignFunctionPlan {

    /**
     * 用户输入自定义的聚合表达式
     */
    private List<String> expressions = new ArrayList<>();

    /**
     * 保存比例字段名称
     */
    private QualifiedName userDefineRateName = null;


    private boolean isGroupByAssign = false;

    @Override
    public String functionName() {
        return "agg_value";
    }


    @Override
    public boolean isNeedShapeArea() {
        return userDefineRateName != null;
    }

    @Override
    public boolean needRepair() {
        return true;
    }

    /**
     * 支持聚合表达式
     * 聚合表达式数量和赋值字段要一一对应
     * agg_value(col, 聚合表达式1, 聚合表达式2， 聚合表达式3, 重算字段1，重算字段2。。。。) agg_value(col, string, string, ...string, col1, col2, .....)
     * agg_value(聚合表达式1, 聚合表达式2， 聚合表达式3, 重算字段1，重算字段2。。。。) agg_value(string, string, ...string, col1, col2, .....)
     * @param args 设置参数
     */
    @Override
    public void setArgs(Object[] args) {
        if(args == null || args.length <= 1 ){
            throw new AtlasException("请输入聚合表达式和重算面积字段名称!");
        }else {
            int startIndex = 0;
            // 判断当前参数是否为字段名
            if(args[startIndex] instanceof QualifiedName){
                userDefineRateName = (QualifiedName) args[startIndex];
                startIndex++;
            }

            while (startIndex < args.length && args[startIndex] instanceof String){
                expressions.add((String) args[startIndex]);
                startIndex++;
            }

            if(expressions.isEmpty()){
                throw new AtlasException("请检查是否存在表达式输入参数！");
            }

            while (startIndex < args.length && args[startIndex] instanceof QualifiedName){
                computeFields.add((QualifiedName) args[startIndex]);
                startIndex++;
            }

            if(startIndex != args.length){
                if(args[startIndex] instanceof Boolean){
                    isGroupByAssign = (boolean) args[startIndex];
                }else {
                    throw new AtlasException(String.format("请检查输入参数，当前期望输入参数: %d,实际输入参数: %d", startIndex, args.length));
                }
            }
        }
    }


    @Override
    public String getStatisticSql(String middleLayer, Collection<String> assignFields, String oidField) {
        if(isGroupByAssign){
            return super.getStatisticSql(middleLayer, assignFields, oidField);
        }else {
            if (isNeedShapeArea()) {
                return String.format("select *, %s / %s as %s  from %s", intersectShapeAreaName, baseShapeAreaName, RATE_NAME, middleLayer);
            } else {
                return middleLayer;
            }
        }
    }

    @Override
    public List<String> unusedFields(List<String> fields, Map<String, String> assignFieldMap) {
        if(isGroupByAssign){
            return super.unusedFields(fields, assignFieldMap);
        }else {
            return new ArrayList<>();
        }
    }

    @Override
    protected String getAssignSelectSql(String statisticViewAlias, Map<String, String> assignFieldMaps) {
        if(assignFieldMaps.keySet().size() != expressions.size()){
            throw new AtlasException("请检查输入参数，表达式数量和赋值字段数量不匹配！");
        }

        Set<String> keySet = assignFieldMaps.keySet();

        Map<String, Function<String, String>> fieldNameMap = getFieldNameMap();

        int i = 0;
        List<String> sqls = new ArrayList<>();
        for(String key: keySet){
            sqls.add(replaceSqlFields(expressions.get(i), statisticViewAlias, fieldNameMap)
                    + " as " + assignFieldMaps.get(key));
            i++;
        }

        return String.join(", ", sqls);
    }

    @Override
    protected String getDefaultValExpr(String viewName, String fieldName) {
        return String.format("case when %s.%s is null then %s else %s.%s end as %s",
                viewName, fieldName, defaultValueMap.get(fieldName), viewName, fieldName, fieldName);
    }

    @Override
    protected String getUndefineValExpr(String viewName, String fieldName) {
        return String.format("%s.%s", viewName, fieldName);
    }

    /**
     * 获取指定字段在sql表达式里替换的名称
     * @return 返回映射
     */
    private Map<String, Function<String, String>> getFieldNameMap(){
        Map<String, Function<String, String>> fieldNameMap = new HashMap<>();
        if(userDefineRateName != null){
            fieldNameMap.put(userDefineRateName.getSimpleLastName(),
                    alias -> String.format("%s.%s", alias, RATE_NAME));
        }
        return fieldNameMap;
    }
}
