/*
 * Decompiled with CFR 0.152.
 */
package tech.amikos.chromadb.embeddings;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxTensorLike;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URL;
import java.nio.LongBuffer;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.nio.file.attribute.FileAttribute;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.zip.GZIPInputStream;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.primitives.Floats;
import tech.amikos.chromadb.EFException;
import tech.amikos.chromadb.Embedding;
import tech.amikos.chromadb.embeddings.EmbeddingFunction;

public class DefaultEmbeddingFunction
implements EmbeddingFunction {
    public static final String MODEL_NAME = "all-MiniLM-L6-v2";
    private static final String ARCHIVE_FILENAME = "onnx.tar.gz";
    private static final String MODEL_DOWNLOAD_URL = "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz";
    private static final String MODEL_SHA256_CHECKSUM = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3";
    public static final Path MODEL_CACHE_DIR = Paths.get(System.getProperty("user.home"), ".cache", "chroma", "onnx_models", "all-MiniLM-L6-v2");
    private static final Path modelPath = Paths.get(MODEL_CACHE_DIR.toString(), "onnx");
    private static final Path modelFile = Paths.get(modelPath.toString(), "model.onnx");
    private final HuggingFaceTokenizer tokenizer;
    private final OrtEnvironment env;
    final OrtSession session;

    public static float[][] normalize(float[][] v) {
        int j;
        int i;
        int rows = v.length;
        int cols = v[0].length;
        float[] norm = new float[rows];
        for (i = 0; i < rows; ++i) {
            float sum = 0.0f;
            for (j = 0; j < cols; ++j) {
                sum += v[i][j] * v[i][j];
            }
            norm[i] = (float)Math.sqrt(sum);
        }
        for (i = 0; i < rows; ++i) {
            if (norm[i] != 0.0f) continue;
            norm[i] = 1.0E-12f;
        }
        float[][] normalized = new float[rows][cols];
        for (int i2 = 0; i2 < rows; ++i2) {
            for (j = 0; j < cols; ++j) {
                normalized[i2][j] = v[i2][j] / norm[i2];
            }
        }
        return normalized;
    }

    public DefaultEmbeddingFunction() throws EFException {
        if (!this.validateModel()) {
            this.downloadAndSetupModel();
        }
        Map<String, String> tokenizerConfig = Collections.unmodifiableMap(new HashMap<String, String>(){
            {
                this.put("padding", "MAX_LENGTH");
                this.put("maxLength", "256");
            }
        });
        try {
            this.tokenizer = HuggingFaceTokenizer.newInstance((Path)modelPath, tokenizerConfig);
            this.env = OrtEnvironment.getEnvironment();
            OrtSession.SessionOptions options = new OrtSession.SessionOptions();
            this.session = this.env.createSession(modelFile.toString(), options);
        }
        catch (OrtException | IOException e) {
            throw new EFException(e);
        }
    }

    public List<List<Float>> forward(List<String> documents) throws OrtException {
        Encoding[] e = this.tokenizer.batchEncode(documents, true, false);
        ArrayList inputIds = new ArrayList();
        ArrayList attentionMask = new ArrayList();
        ArrayList tokenIdtypes = new ArrayList();
        int maxIds = 0;
        for (Encoding encoding : e) {
            maxIds = Math.max(maxIds, encoding.getIds().length);
            inputIds.addAll(Arrays.asList(Arrays.stream(encoding.getIds()).boxed().toArray(Long[]::new)));
            attentionMask.addAll(Arrays.asList(Arrays.stream(encoding.getAttentionMask()).boxed().toArray(Long[]::new)));
            tokenIdtypes.addAll(Arrays.asList(Arrays.stream(encoding.getTypeIds()).boxed().toArray(Long[]::new)));
        }
        long[] inputShape = new long[]{e.length, maxIds};
        final OnnxTensor inputTensor = OnnxTensor.createTensor((OrtEnvironment)this.env, (LongBuffer)LongBuffer.wrap(inputIds.stream().mapToLong(i -> i).toArray()), (long[])inputShape);
        final OnnxTensor attentionTensor = OnnxTensor.createTensor((OrtEnvironment)this.env, (LongBuffer)LongBuffer.wrap(attentionMask.stream().mapToLong(i -> i).toArray()), (long[])inputShape);
        final OnnxTensor _tokenIdtypes = OnnxTensor.createTensor((OrtEnvironment)this.env, (LongBuffer)LongBuffer.wrap(tokenIdtypes.stream().mapToLong(i -> i).toArray()), (long[])inputShape);
        Map<String, OnnxTensorLike> inputs = Collections.unmodifiableMap(new HashMap<String, OnnxTensorLike>(){
            {
                this.put("input_ids", inputTensor);
                this.put("attention_mask", attentionTensor);
                this.put("token_type_ids", _tokenIdtypes);
            }
        });
        INDArray lastHiddenState = null;
        try (OrtSession.Result results = this.session.run(inputs);){
            lastHiddenState = Nd4j.create((float[][][])((float[][][])results.get(0).getValue()));
        }
        INDArray attMask = Nd4j.create((double[])attentionMask.stream().mapToDouble(i -> i.longValue()).toArray(), (long[])inputShape, (char)'c');
        INDArray expandedMask = Nd4j.expandDims((INDArray)attMask, (int)2).broadcast(lastHiddenState.shape());
        INDArray summed = lastHiddenState.mul(expandedMask).sum(new int[]{1});
        INDArray[] clippedSumMask = Nd4j.getExecutioner().exec((CustomOp)new ClipByValue(expandedMask.sum(new int[]{1}), 1.0E-9, Double.MAX_VALUE));
        INDArray embeddings = summed.div(clippedSumMask[0]);
        float[][] embeddingsArray = DefaultEmbeddingFunction.normalize(embeddings.toFloatMatrix());
        ArrayList<List<Float>> embeddingsList = new ArrayList<List<Float>>();
        for (float[] embedding : embeddingsArray) {
            embeddingsList.add(Floats.asList((float[])embedding));
        }
        return embeddingsList;
    }

    private static String getSHA256Checksum(String filePath) throws IOException, NoSuchAlgorithmException {
        MessageDigest digest = MessageDigest.getInstance("SHA-256");
        try (FileInputStream fis = new FileInputStream(filePath);){
            int bytesCount;
            byte[] byteArray = new byte[1024];
            while ((bytesCount = fis.read(byteArray)) != -1) {
                digest.update(byteArray, 0, bytesCount);
            }
        }
        byte[] bytes = digest.digest();
        StringBuilder sb = new StringBuilder();
        for (byte b : bytes) {
            sb.append(String.format("%02x", b));
        }
        return sb.toString();
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private static void extractTarGz(Path tarGzPath, Path extractDir) throws IOException {
        try (InputStream fileIn = Files.newInputStream(tarGzPath, new OpenOption[0]);
             GZIPInputStream gzipIn = new GZIPInputStream(fileIn);
             TarArchiveInputStream tarIn = new TarArchiveInputStream((InputStream)gzipIn);){
            TarArchiveEntry entry;
            while ((entry = tarIn.getNextTarEntry()) != null) {
                OutputStream out;
                block48: {
                    Path entryPath = extractDir.resolve(entry.getName());
                    if (entry.isDirectory()) {
                        Files.createDirectories(entryPath, new FileAttribute[0]);
                        continue;
                    }
                    Files.createDirectories(entryPath.getParent(), new FileAttribute[0]);
                    out = Files.newOutputStream(entryPath, new OpenOption[0]);
                    Throwable throwable = null;
                    try {
                        int len;
                        byte[] buffer = new byte[1024];
                        while ((len = tarIn.read(buffer)) != -1) {
                            out.write(buffer, 0, len);
                        }
                        if (out == null) continue;
                        if (throwable == null) break block48;
                    }
                    catch (Throwable throwable2) {
                        try {
                            throwable = throwable2;
                            throw throwable2;
                        }
                        catch (Throwable throwable3) {
                            if (out == null) throw throwable3;
                            if (throwable != null) {
                                try {
                                    out.close();
                                    throw throwable3;
                                }
                                catch (Throwable throwable4) {
                                    throwable.addSuppressed(throwable4);
                                    throw throwable3;
                                }
                            }
                            out.close();
                            throw throwable3;
                        }
                    }
                    try {
                        out.close();
                        continue;
                    }
                    catch (Throwable throwable5) {
                        throwable.addSuppressed(throwable5);
                        continue;
                    }
                }
                out.close();
            }
            return;
        }
    }

    private void downloadAndSetupModel() throws EFException {
        try (InputStream in = new URL(MODEL_DOWNLOAD_URL).openStream();){
            Path archivePath;
            if (!Files.exists(MODEL_CACHE_DIR, new LinkOption[0])) {
                Files.createDirectories(MODEL_CACHE_DIR, new FileAttribute[0]);
            }
            if (!(archivePath = Paths.get(MODEL_CACHE_DIR.toString(), ARCHIVE_FILENAME)).toFile().exists()) {
                System.err.println("Model not found under " + archivePath + ". Downloading...");
                Files.copy(in, archivePath, StandardCopyOption.REPLACE_EXISTING);
            }
            if (!MODEL_SHA256_CHECKSUM.equals(DefaultEmbeddingFunction.getSHA256Checksum(archivePath.toString()))) {
                throw new RuntimeException("Checksum does not match. Delete the whole directory " + MODEL_CACHE_DIR + " and try again.");
            }
            DefaultEmbeddingFunction.extractTarGz(archivePath, MODEL_CACHE_DIR);
            archivePath.toFile().delete();
        }
        catch (IOException | NoSuchAlgorithmException e) {
            throw new EFException(e);
        }
    }

    private boolean validateModel() {
        return modelFile.toFile().exists() && modelFile.toFile().isFile();
    }

    @Override
    public Embedding embedQuery(String query) throws EFException {
        try {
            return Embedding.fromList(this.forward(Collections.singletonList(query)).get(0));
        }
        catch (OrtException e) {
            throw new EFException(e);
        }
    }

    @Override
    public List<Embedding> embedDocuments(List<String> documents) throws EFException {
        try {
            return this.forward(documents).stream().map(Embedding::new).collect(Collectors.toList());
        }
        catch (OrtException e) {
            throw new EFException(e);
        }
    }

    @Override
    public List<Embedding> embedDocuments(String[] documents) throws EFException {
        return this.embedDocuments(Arrays.asList(documents));
    }
}

