/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.internal.memory;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import lombok.NonNull;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.autodiff.samediff.internal.memory.AbstractMemoryMgr;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;

public class ArrayCacheMemoryMgr
extends AbstractMemoryMgr {
    private final double maxMemFrac;
    private final long smallArrayThreshold;
    private final double largerArrayMaxMultiple;
    private final long maxCacheBytes;
    private final long totalMemBytes;
    private long currentCacheSize = 0L;
    private Map<DataType, ArrayStore> arrayStores = new HashMap<DataType, ArrayStore>();
    private LinkedHashSet<Long> lruCache = new LinkedHashSet();
    private Map<Long, INDArray> lruCacheValues = new HashMap<Long, INDArray>();

    public ArrayCacheMemoryMgr() {
        this(0.25, 1024L, 2.0);
    }

    public ArrayCacheMemoryMgr(double maxMemFrac, long smallArrayThreshold, double largerArrayMaxMultiple) {
        Preconditions.checkArgument((maxMemFrac > 0.0 && maxMemFrac < 1.0 ? 1 : 0) != 0, (String)"Maximum memory fraction for cache must be between 0.0 and 1.0, got %s", (double)maxMemFrac);
        Preconditions.checkArgument((smallArrayThreshold >= 0L ? 1 : 0) != 0, (String)"Small array threshold must be >= 0, got %s", (long)smallArrayThreshold);
        Preconditions.checkArgument((largerArrayMaxMultiple >= 1.0 ? 1 : 0) != 0, (String)"Larger array max multiple must be >= 1.0, got %s", (double)largerArrayMaxMultiple);
        this.maxMemFrac = maxMemFrac;
        this.smallArrayThreshold = smallArrayThreshold;
        this.largerArrayMaxMultiple = largerArrayMaxMultiple;
        if (this.isCpu()) {
            this.totalMemBytes = Pointer.maxBytes();
        } else {
            Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
            List devList = (List)p.get("cuda.devicesInformation");
            Map m = (Map)devList.get(0);
            this.totalMemBytes = (Long)m.get("cuda.totalMemory");
        }
        this.maxCacheBytes = (long)(maxMemFrac * (double)this.totalMemBytes);
    }

    private boolean isCpu() {
        String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
        return !"CUDA".equalsIgnoreCase(backend);
    }

    @Override
    public INDArray allocate(boolean detached, DataType dataType, long ... shape) {
        INDArray arr;
        if (this.arrayStores.containsKey((Object)dataType) && (arr = this.arrayStores.get((Object)dataType).get(shape)) != null) {
            this.currentCacheSize -= (long)dataType.width() * arr.data().length();
            return arr;
        }
        return Nd4j.createUninitializedDetached(dataType, shape);
    }

    @Override
    public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
        if (descriptor.isEmpty()) {
            INDArray ret = Nd4j.create(descriptor);
            if (detached) {
                ret = ret.detach();
            }
            return ret;
        }
        DataType dataType = descriptor.dataType();
        long[] shape = descriptor.getShape();
        if (this.arrayStores.containsKey((Object)dataType)) {
            INDArray arr = this.arrayStores.get((Object)dataType).get(shape);
            if (arr != null && arr.ordering() != descriptor.getOrder()) {
                arr.setOrder(descriptor.getOrder());
            }
            if (arr != null) {
                this.currentCacheSize -= (long)dataType.width() * arr.data().length();
                return arr;
            }
        }
        return Nd4j.createUninitializedDetached(dataType, shape);
    }

    @Override
    public void release(@NonNull INDArray array) {
        if (array == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        long id = array.getId();
        Preconditions.checkState((!this.lruCache.contains(id) ? 1 : 0) != 0, (String)"Array was released multiple times: id=%s, shape=%ndShape", (Object)id, (Object)array);
        DataType dt = array.dataType();
        if (array.data() == null && array.closeable()) {
            array.close();
            return;
        }
        long thisBytes = array.data().length() * (long)dt.width();
        if (array.dataType() == DataType.UTF8) {
            if (array.closeable()) {
                array.close();
            }
        } else if (this.currentCacheSize + thisBytes > this.maxCacheBytes) {
            if (thisBytes > this.maxCacheBytes) {
                if (array.closeable()) {
                    array.close();
                }
                return;
            }
            Iterator iter = this.lruCache.iterator();
            while (this.currentCacheSize + thisBytes > this.maxCacheBytes) {
                long next = (Long)iter.next();
                iter.remove();
                INDArray nextOldest = this.lruCacheValues.remove(next);
                DataType ndt = nextOldest.dataType();
                long nextBytes = (long)ndt.width() * nextOldest.data().length();
                this.arrayStores.get((Object)ndt).removeObject(nextOldest);
                this.currentCacheSize -= nextBytes;
                if (!nextOldest.closeable()) continue;
                nextOldest.close();
            }
            this.cacheArray(array);
        } else {
            this.cacheArray(array);
        }
        this.lruCache.add(array.getId());
        this.lruCacheValues.put(array.getId(), array);
    }

    private void cacheArray(INDArray array) {
        DataType dt = array.dataType();
        if (!this.arrayStores.containsKey((Object)dt)) {
            this.arrayStores.put(dt, new ArrayStore());
        }
        this.arrayStores.get((Object)dt).add(array);
        this.currentCacheSize += array.data().length() * (long)dt.width();
        this.lruCache.add(array.getId());
        this.lruCacheValues.put(array.getId(), array);
    }

    @Override
    public void close() {
        for (ArrayStore as : this.arrayStores.values()) {
            as.close();
        }
    }

    public double getMaxMemFrac() {
        return this.maxMemFrac;
    }

    public long getSmallArrayThreshold() {
        return this.smallArrayThreshold;
    }

    public double getLargerArrayMaxMultiple() {
        return this.largerArrayMaxMultiple;
    }

    public long getMaxCacheBytes() {
        return this.maxCacheBytes;
    }

    public long getTotalMemBytes() {
        return this.totalMemBytes;
    }

    public long getCurrentCacheSize() {
        return this.currentCacheSize;
    }

    public Map<DataType, ArrayStore> getArrayStores() {
        return this.arrayStores;
    }

    public LinkedHashSet<Long> getLruCache() {
        return this.lruCache;
    }

    public Map<Long, INDArray> getLruCacheValues() {
        return this.lruCacheValues;
    }

    public class ArrayStore {
        private INDArray[] sorted = new INDArray[1000];
        private long[] lengths = new long[1000];
        private long lengthSum;
        private long bytesSum;
        private int size;

        private void add(@NonNull INDArray array) {
            long length;
            int idx;
            if (array == null) {
                throw new NullPointerException("array is marked non-null but is null");
            }
            if (this.size == this.sorted.length) {
                this.sorted = Arrays.copyOf(this.sorted, 2 * this.sorted.length);
                this.lengths = Arrays.copyOf(this.lengths, 2 * this.lengths.length);
            }
            if ((idx = Arrays.binarySearch(this.lengths, 0, this.size, length = array.data().length())) < 0) {
                idx = -idx - 1;
            }
            for (int i = this.size - 1; i >= idx; --i) {
                this.sorted[i + 1] = this.sorted[i];
                this.lengths[i + 1] = this.lengths[i];
            }
            this.sorted[idx] = array;
            this.lengths[idx] = length;
            ++this.size;
            this.lengthSum += length;
            this.bytesSum += length * (long)array.dataType().width();
        }

        private INDArray get(long[] shape) {
            if (this.size == 0) {
                return null;
            }
            long length = shape.length == 0 ? 1L : (long)ArrayUtil.prod((long[])shape);
            int idx = Arrays.binarySearch(this.lengths, 0, this.size, length);
            if (idx < 0) {
                boolean tooLarge;
                if ((idx = -idx - 1) >= this.size) {
                    return null;
                }
                INDArray nextSmallest = this.sorted[idx];
                long nextSmallestLength = nextSmallest.data().length();
                long nextSmallestLengthBytes = nextSmallestLength * (long)nextSmallest.dataType().width();
                boolean bl = tooLarge = length > (long)((double)nextSmallestLength * ArrayCacheMemoryMgr.this.largerArrayMaxMultiple);
                if (nextSmallestLengthBytes > ArrayCacheMemoryMgr.this.smallArrayThreshold && tooLarge) {
                    return null;
                }
            }
            INDArray arr = this.removeIdx(idx);
            ArrayCacheMemoryMgr.this.lruCache.remove(arr.getId());
            ArrayCacheMemoryMgr.this.lruCacheValues.remove(arr.getId());
            return Nd4j.create(arr.data(), shape);
        }

        private void removeObject(INDArray array) {
            int i;
            long length = array.data().length();
            int idx = Arrays.binarySearch(this.lengths, 0, this.size, length);
            Preconditions.checkState((idx >= 0 ? 1 : 0) != 0, (String)"Cannot remove array from ArrayStore: no array with this length exists in the cache");
            boolean found = false;
            for (i = 0; !found && i < this.size; ++i) {
                found = this.sorted[i] == array && this.lengths[i] == length;
            }
            Preconditions.checkState((boolean)found, (String)"Cannot remove array: not found in ArrayCache");
            this.removeIdx(i - 1);
        }

        private INDArray removeIdx(int idx) {
            INDArray arr = this.sorted[idx];
            for (int i = idx; i < this.size; ++i) {
                this.sorted[i] = this.sorted[i + 1];
                this.lengths[i] = this.lengths[i + 1];
            }
            this.sorted[this.size] = null;
            this.lengths[this.size] = 0L;
            --this.size;
            this.bytesSum -= arr.data().length() * (long)arr.dataType().width();
            this.lengthSum -= arr.data().length();
            return arr;
        }

        private void close() {
            for (int i = 0; i < this.size; ++i) {
                if (this.sorted[i].closeable()) {
                    this.sorted[i].close();
                }
                this.lengths[i] = 0L;
            }
            this.lengthSum = 0L;
            this.bytesSum = 0L;
            this.size = 0;
        }

        public INDArray[] getSorted() {
            return this.sorted;
        }

        public long[] getLengths() {
            return this.lengths;
        }

        public long getLengthSum() {
            return this.lengthSum;
        }

        public long getBytesSum() {
            return this.bytesSum;
        }

        public int getSize() {
            return this.size;
        }
    }
}

