package com.geoway.atlas.web.api.v2.service.pkg.impl.assigin;

import com.geoway.atlas.web.api.v2.exception.AtlasException;

import java.util.*;

/**
 * @author zhaotong 2024/12/23 16:24
 */
public class SumRateFunctionPlan extends AssignFunctionPlan {

    @Override
    public String functionName() {
        return "sum_rate";
    }

    private int precision = -1;

    @Override
    public boolean needRepair() {
        return true;
    }

    @Override
    public List<String> unusedFields(List<String> fields, Map<String, String> assignFieldMap) {
        List<String> removeFields = new ArrayList<>(fields);
        removeFields.remove(baseShapeAreaName);
        return removeFields;
    }

    public String getStatisticSql(String middleLayer, Collection<String> assignFields, String oidField) {
        return String.format("select *, %s / %s as %s  from %s", intersectShapeAreaName, baseShapeAreaName, RATE_NAME, middleLayer);
    }

    @Override
    public void setArgs(Object[] args) {
        if(args != null && args.length == 1){
            if(args[0] != null) {
                try {
                    precision = (int) args[0];
                } catch (ClassCastException cce) {
                    throw new AtlasException("请检查输入的精度类型，仅支持数字类型！");
                }
            }
        }
    }

    @Override
    protected String getAssignSelectSql(String statisticViewAlias, Map<String, String> assignFieldMaps) {
        if(assignFieldMaps.size() != 1){
            throw new AtlasException("请检查输入赋值字段，仅支持1个赋值字段作为结果图层比例字段");
        }
        String nRateName = assignFieldMaps.values().stream().findFirst().get();
        if(precision < 0) {
            return String.format("case when sum(%s.%s) > 1.0 then 1.0 else sum(%s.%s) end as %s",
                    statisticViewAlias, RATE_NAME, statisticViewAlias, RATE_NAME, nRateName);
        }else {
            return String.format("case when round(sum(%s.%s), %d) > 1.0 then 1.0 else round(sum(%s.%s), %d) end as %s",
                    statisticViewAlias, RATE_NAME, precision, statisticViewAlias, RATE_NAME, precision, nRateName);
        }
    }

    @Override
    protected String getDefaultValExpr(String viewName, String fieldName) {
        String defaultValue = defaultValueMap.get(fieldName);
        return String.format("case when %s.%s is null then %s else %s.%s end as %s", viewName, fieldName, defaultValue, viewName, fieldName, fieldName);
    }

    @Override
    protected String getUndefineValExpr(String viewName, String fieldName) {
        // 默认值为0.0
        String defaultValue = "0.0";
        return String.format("case when %s.%s is null then %s else %s.%s end as %s", viewName, fieldName, defaultValue, viewName, fieldName, fieldName);
    }
}
