package com.geoway.atlas.function.parser.common;

import com.geoway.atlas.function.parser.FunctionDSLBaseBaseVisitor;
import com.geoway.atlas.function.parser.FunctionDSLBaseParser;
import org.antlr.v4.runtime.Token;
import org.apache.commons.lang3.StringUtils;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

/**
 * @author zhaotong 2024/12/19 15:55
 */
public class FunctionDSLAstBuilder extends FunctionDSLBaseBaseVisitor<Object> {


    @Override
    public FunctionPlan visitSingleStatement(FunctionDSLBaseParser.SingleStatementContext ctx) {
        return (FunctionPlan) visit(ctx.functionExpression());
    }

    @Override
    public QualifiedName visitQualifiedNameVal(FunctionDSLBaseParser.QualifiedNameValContext ctx) {
        QualifiedName qualifiedName = new QualifiedName();
        for(FunctionDSLBaseParser.IdentifierContext identifierContext: ctx.qualifiedName().identifier()){
            qualifiedName.add(((String) visit(identifierContext)).toLowerCase());
        }
        return qualifiedName;
    }

    @Override
    public FunctionPlan visitFunctionCall(FunctionDSLBaseParser.FunctionCallContext ctx) {
        String functionName = (String) visit(ctx.functionName());
        Object[] args = new Object[ctx.argument.size()];
        for(int i = 0; i < ctx.argument.size(); i++){
            args[i] = visit(ctx.argument.get(i));
        }
        return FunctionPlanFactory.getFunctionPlan(functionName, args);
    }

    @Override
    public String visitFunctionName(FunctionDSLBaseParser.FunctionNameContext ctx) {
        if(ctx.identFunc != null && StringUtils.isNotBlank(ctx.identFunc.getText())){
            return ctx.identFunc.getText();
        }else {
            return String.join(".", visitQualifiedName(ctx.qualifiedName()));
        }
    }

    @Override
    public QualifiedName visitQualifiedName(FunctionDSLBaseParser.QualifiedNameContext ctx) {
        QualifiedName qualifiedName = new QualifiedName();
        for(FunctionDSLBaseParser.IdentifierContext identifierContext: ctx.identifier()){
            qualifiedName.add((String) visit(identifierContext));
        }
        return qualifiedName;
    }

    @Override
    public String visitUnquotedIdentifier(FunctionDSLBaseParser.UnquotedIdentifierContext ctx) {
        if (ctx != null) {
            if (ctx.IDENTIFIER() != null) {
                return ctx.IDENTIFIER().getSymbol().getText();
            }
        }

        return null;
    }

    @Override
    public Object visitQuotedIdentifierAlternative(FunctionDSLBaseParser.QuotedIdentifierAlternativeContext ctx) {
        return visit(ctx.quotedIdentifier());
    }

    @Override
    public String visitQuotedIdentifier(FunctionDSLBaseParser.QuotedIdentifierContext ctx) {
        if (ctx != null) {
            if (ctx.BACKQUOTED_IDENTIFIER() != null) {
                return ctx.BACKQUOTED_IDENTIFIER().getSymbol().getText();
            }
        }

        return null;
    }

    @Override
    public Object visitNullLiteral(FunctionDSLBaseParser.NullLiteralContext ctx) {
        return null;
    }

    @Override
    public String visitStringLiteral(FunctionDSLBaseParser.StringLiteralContext ctx) {
        return createString(ctx);
    }

    private String createString(FunctionDSLBaseParser.StringLiteralContext ctx) {
        return ctx.stringLit().stream()
                .map(this::visitStringLit)
                .map(FunctionParserUtils::string)
                .collect(Collectors.joining());
    }

    @Override
    public Token visitStringLit(FunctionDSLBaseParser.StringLitContext ctx) {
        if (ctx != null) {
            if (ctx.STRING_LITERAL() != null) {
                return ctx.STRING_LITERAL().getSymbol();
            }
        }

        return null;
    }

    @Override
    public Boolean visitBooleanValue(FunctionDSLBaseParser.BooleanValueContext ctx) {
        return Boolean.valueOf(ctx.getText());
    }

    @Override
    public Double visitExponentLiteral(FunctionDSLBaseParser.ExponentLiteralContext ctx) {
        return Double.valueOf(numericLiteral(ctx, ctx.getText(), /* exponent values don't have a suffix */
                new BigDecimal(Double.MIN_VALUE), new BigDecimal(Double.MAX_VALUE), "double"));
    }

    @Override
    public BigDecimal visitDecimalLiteral(FunctionDSLBaseParser.DecimalLiteralContext ctx) {
        return new BigDecimal(ctx.getText());
    }

    @Override
    public Object visitIntegerLiteral(FunctionDSLBaseParser.IntegerLiteralContext ctx) {
        BigDecimal bd = new BigDecimal(ctx.getText());
        if(bd.compareTo(new BigDecimal(Integer.MIN_VALUE)) >= 0 &&
                bd.compareTo(new BigDecimal(Integer.MAX_VALUE)) <=0){
            return (Integer) bd.intValue();
        }else {
            return (Long) bd.longValue();
        }
    }

    @Override
    public Long visitBigIntLiteral(FunctionDSLBaseParser.BigIntLiteralContext ctx) {
        String rawStrippedQualifier = ctx.getText().substring(0, ctx.getText().length() - 1);
        return Long.valueOf(numericLiteral(ctx, rawStrippedQualifier,
                new BigDecimal(Long.MIN_VALUE),
                new BigDecimal(Long.MAX_VALUE),
                "long"));
    }

    @Override
    public Double visitDoubleLiteral(FunctionDSLBaseParser.DoubleLiteralContext ctx) {
        String rawStrippedQualifier = ctx.getText().substring(0, ctx.getText().length() - 1);
        return Double.valueOf(numericLiteral(ctx, rawStrippedQualifier,
                new BigDecimal(Double.MIN_VALUE),
                new BigDecimal(Double.MAX_VALUE),
                "double"));
    }

    @Override
    public Float visitFloatLiteral(FunctionDSLBaseParser.FloatLiteralContext ctx) {
        String rawStrippedQualifier = ctx.getText().substring(0, ctx.getText().length() - 1);
        return Float.valueOf(numericLiteral(ctx, rawStrippedQualifier,
                new BigDecimal(Float.MIN_VALUE),
                new BigDecimal(Float.MAX_VALUE),
                "float"));
    }


    /** Create a numeric literal expression. */
    private String numericLiteral(
            FunctionDSLBaseParser.NumberContext ctx,
            String rawStrippedQualifier,
            BigDecimal minValue,
            BigDecimal maxValue,
            String typeName){
        BigDecimal rawBigDecimal = new BigDecimal(rawStrippedQualifier);
        if (rawBigDecimal.compareTo(minValue) < 0 || rawBigDecimal.compareTo(maxValue) > 0) {
            throw new RuntimeException(String.format("值 %s 越界，必须在数据类型 %s 的值域[%s, %s]范围内",
                    rawStrippedQualifier, typeName,minValue.toString(), maxValue.toString()));
        }
        return rawStrippedQualifier;
    }
}
