package com.geoway.atlas.function.parser.common;

import org.antlr.v4.runtime.Token;
import org.apache.commons.lang3.StringUtils;

import java.math.BigDecimal;
import java.nio.CharBuffer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * Utility methods for parsing Spark-related strings.
 */
public class FunctionParserUtils {

    public static final Pattern U16_CHAR_PATTERN = Pattern.compile("\\\\u([a-fA-F0-9]{4})(?s).*");
    public static final Pattern U32_CHAR_PATTERN = Pattern.compile("\\\\U([a-fA-F0-9]{8})(?s).*");
    public static final Pattern OCTAL_CHAR_PATTERN = Pattern.compile("\\\\([01][0-7]{2})(?s).*");
    public static final Pattern ESCAPED_CHAR_PATTERN = Pattern.compile("\\\\((?s).)(?s).*");

    /**
     * Unescape backslash-escaped string enclosed by quotes.
     *
     * @param b the input SQL string
     * @return the unescaped SQL string
     */
    private static String unescapeSQLString(String b) {
        StringBuilder sb = new StringBuilder(b.length());

        if (b.startsWith("r") || b.startsWith("R")) {
            return b.substring(2, b.length() - 1);
        } else {
            // Skip the first and last quotations enclosing the string literal.
            CharBuffer charBuffer = CharBuffer.wrap(b, 1, b.length() - 1);

            while (charBuffer.hasRemaining()) {

                String charString = charBuffer.subSequence(0, charBuffer.remaining()).toString();

                Matcher u16Matcher = U16_CHAR_PATTERN.matcher(charString);
                Matcher u32Matcher = U32_CHAR_PATTERN.matcher(charString);
                Matcher octalMatcher = OCTAL_CHAR_PATTERN.matcher(charString);
                Matcher escapedMatcher = ESCAPED_CHAR_PATTERN.matcher(charString);

                if (u16Matcher.matches()) {
                    // \u0000 style 16-bit unicode character literals.
                    int cp = Integer.parseInt(u16Matcher.group(1), 16);
                    sb.append((char) cp);
                    charBuffer.position(charBuffer.position() + 6);
                } else if (u32Matcher.matches()) {
                    // \U00000000 style 32-bit unicode character literals.
                    long codePoint = Long.parseLong(u32Matcher.group(1), 16);
                    if (codePoint < 0x10000) {
                        sb.append((char) (codePoint & 0xFFFF));
                    } else {
                        int highSurrogate = (int) ((codePoint - 0x10000) / 0x400 + 0xD800);
                        int lowSurrogate = (int) ((codePoint - 0x10000) % 0x400 + 0xDC00);
                        sb.append((char) highSurrogate);
                        sb.append((char) lowSurrogate);
                    }
                    charBuffer.position(charBuffer.position() + 10);
                } else if (octalMatcher.matches()) {
                    // \000 style character literals.
                    int cp = Integer.parseInt(octalMatcher.group(1), 8);
                    sb.append((char) cp);
                    charBuffer.position(charBuffer.position() + 4);
                } else if (escapedMatcher.matches()) {
                    // escaped character literals.
                    FunctionParserUtils.appendEscapedChar(sb, escapedMatcher.group(1).charAt(0));
                    charBuffer.position(charBuffer.position() + 2);
                } else {
                    // non-escaped character literals.
                    sb.append(charBuffer.get());
                }
            }
        }
        return sb.toString();
    }

    private static void appendEscapedChar(StringBuilder sb, char n) {
        switch (n) {
            case '0':
                sb.append('\u0000');
                break;
            case '\'':
                sb.append('\'');
                break;
            case '"':
                sb.append('\"');
                break;
            case 'b':
                sb.append('\b');
                break;
            case 'n':
                sb.append('\n');
                break;
            case 'r':
                sb.append('\r');
                break;
            case 't':
                sb.append('\t');
                break;
            case 'Z':
                sb.append('\u001A');
                break;
            case '\\':
                sb.append('\\');
                break;
            // The following 2 lines are exactly what MySQL does TODO: why do we do this?
            case '%':
                sb.append("\\%");
                break;
            case '_':
                sb.append("\\_");
                break;
            default:
                sb.append(n);
                break;
        }
    }

    public static String string(Token token) {
        return unescapeSQLString(token.getText());
    }

    public static double getDouble(Object val, String errMsg){
        double resultValue;
        if(val instanceof Double){
            resultValue = (double) val;
        }else if(val instanceof Float){
            resultValue = (float) val;
        }else if(val instanceof Integer){
            resultValue = (int) val;
        }else if(val instanceof Long){
            resultValue = (long) val;
        }else if(val instanceof BigDecimal){
            resultValue = ((BigDecimal) val).doubleValue();
        }else {
            String msg = "请检查输入参数，仅支持数值类型！";
            if(StringUtils.isNotBlank(errMsg)){
                msg = errMsg;
            }
            throw new RuntimeException(msg);
        }

        return resultValue;
    }

    public static String getStrExprWithNull(Object obj){
        if(obj == null){
            return null;
        }else if(obj instanceof String){
            return "'" + obj + "'";
        }else
            return obj.toString();
    }
}
