/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.checkpointing;

import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.state.CheckpointListener;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.configuration.TaskManagerOptions;
import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
import org.apache.flink.runtime.state.AbstractStateBackend;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.filesystem.FsStateBackend;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.util.TestLogger;
import org.junit.Assert;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

public class KeyedStateCheckpointingITCase
extends TestLogger {
    protected static final int MAX_MEM_STATE_SIZE = 0xA00000;
    protected static final int NUM_STRINGS = 10000;
    protected static final int NUM_KEYS = 40;
    protected static final int NUM_TASK_MANAGERS = 2;
    protected static final int NUM_TASK_SLOTS = 2;
    protected static final int PARALLELISM = 4;
    @ClassRule
    public static final MiniClusterWithClientResource MINI_CLUSTER_RESOURCE = new MiniClusterWithClientResource(new MiniClusterResourceConfiguration.Builder().setConfiguration(KeyedStateCheckpointingITCase.getConfiguration()).setNumberTaskManagers(2).setNumberSlotsPerTaskManager(2).build());
    @Rule
    public final TemporaryFolder tmpFolder = new TemporaryFolder();

    private static Configuration getConfiguration() {
        Configuration config = new Configuration();
        config.set(TaskManagerOptions.MANAGED_MEMORY_SIZE, (Object)MemorySize.parse((String)"12m"));
        return config;
    }

    @Test
    public void testWithMemoryBackendSync() throws Exception {
        MemoryStateBackend syncMemBackend = new MemoryStateBackend(0xA00000, false);
        this.testProgramWithBackend((AbstractStateBackend)syncMemBackend);
    }

    @Test
    public void testWithMemoryBackendAsync() throws Exception {
        MemoryStateBackend asyncMemBackend = new MemoryStateBackend(0xA00000, true);
        this.testProgramWithBackend((AbstractStateBackend)asyncMemBackend);
    }

    @Test
    public void testWithFsBackendSync() throws Exception {
        FsStateBackend syncFsBackend = new FsStateBackend(this.tmpFolder.newFolder().toURI().toString(), false);
        this.testProgramWithBackend((AbstractStateBackend)syncFsBackend);
    }

    @Test
    public void testWithFsBackendAsync() throws Exception {
        FsStateBackend asyncFsBackend = new FsStateBackend(this.tmpFolder.newFolder().toURI().toString(), true);
        this.testProgramWithBackend((AbstractStateBackend)asyncFsBackend);
    }

    @Test
    public void testWithRocksDbBackendFull() throws Exception {
        RocksDBStateBackend fullRocksDbBackend = new RocksDBStateBackend((AbstractStateBackend)new MemoryStateBackend(0xA00000), false);
        fullRocksDbBackend.setDbStoragePath(this.tmpFolder.newFolder().getAbsolutePath());
        this.testProgramWithBackend((AbstractStateBackend)fullRocksDbBackend);
    }

    @Test
    public void testWithRocksDbBackendIncremental() throws Exception {
        RocksDBStateBackend incRocksDbBackend = new RocksDBStateBackend((AbstractStateBackend)new MemoryStateBackend(0xA00000), true);
        incRocksDbBackend.setDbStoragePath(this.tmpFolder.newFolder().getAbsolutePath());
        this.testProgramWithBackend((AbstractStateBackend)incRocksDbBackend);
    }

    protected void testProgramWithBackend(AbstractStateBackend stateBackend) throws Exception {
        Assert.assertEquals((String)"Broken test setup", (long)0L, (long)0L);
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(4);
        env.enableCheckpointing(500L);
        env.setRestartStrategy(RestartStrategies.fixedDelayRestart((int)Integer.MAX_VALUE, (long)0L));
        env.setStateBackend((StateBackend)stateBackend);
        int failurePosMin = 1500;
        int failurePosMax = 2000;
        int failurePos = new Random().nextInt(500) + 1500;
        DataStreamSource stream1 = env.addSource((SourceFunction)new IntGeneratingSourceFunction(5000, 2500));
        DataStreamSource stream2 = env.addSource((SourceFunction)new IntGeneratingSourceFunction(5000, 2500));
        stream1.union(new DataStream[]{stream2}).keyBy(new IdentityKeySelector()).map((MapFunction)new OnceFailingPartitionedSum(failurePos)).keyBy(new int[]{0}).addSink((SinkFunction)new CounterSink());
        env.execute();
        Assert.assertEquals((long)40L, (long)CounterSink.ALL_COUNTS.size());
        Assert.assertEquals((long)40L, (long)OnceFailingPartitionedSum.ALL_SUMS.size());
        for (Map.Entry sum : OnceFailingPartitionedSum.ALL_SUMS.entrySet()) {
            Assert.assertEquals((long)((long)((Integer)sum.getKey()).intValue() * 10000L / 40L), (long)((Long)sum.getValue()));
        }
        Iterator iterator = CounterSink.ALL_COUNTS.values().iterator();
        while (iterator.hasNext()) {
            long count = (Long)((Object)iterator.next());
            Assert.assertEquals((long)250L, (long)count);
        }
    }

    public static class NonSerializableLong {
        public long value;

        private NonSerializableLong(long value) {
            this.value = value;
        }

        public static NonSerializableLong of(long value) {
            return new NonSerializableLong(value);
        }

        public boolean equals(Object obj) {
            return this == obj || obj != null && obj.getClass() == this.getClass() && ((NonSerializableLong)obj).value == this.value;
        }

        public int hashCode() {
            return (int)(this.value ^ this.value >>> 32);
        }
    }

    private static class IdentityKeySelector<T>
    implements KeySelector<T, T> {
        private IdentityKeySelector() {
        }

        public T getKey(T value) throws Exception {
            return value;
        }
    }

    private static class CounterSink
    extends RichSinkFunction<Tuple2<Integer, Long>> {
        private static final Map<Integer, Long> ALL_COUNTS = new ConcurrentHashMap<Integer, Long>();
        private transient ValueState<NonSerializableLong> aCounts;
        private transient ValueState<Long> bCounts;

        private CounterSink() {
        }

        public void open(Configuration parameters) throws IOException {
            this.aCounts = this.getRuntimeContext().getState(new ValueStateDescriptor("a", NonSerializableLong.class));
            this.bCounts = this.getRuntimeContext().getState(new ValueStateDescriptor("b", Long.class));
        }

        public void invoke(Tuple2<Integer, Long> value) throws Exception {
            NonSerializableLong acRaw = (NonSerializableLong)this.aCounts.value();
            Long bcRaw = (Long)this.bCounts.value();
            long ac = acRaw == null ? 0L : acRaw.value;
            long bc = bcRaw == null ? 0L : bcRaw;
            Assert.assertEquals((long)ac, (long)bc);
            long currentCount = ac + 1L;
            this.aCounts.update((Object)NonSerializableLong.of(currentCount));
            this.bCounts.update((Object)currentCount);
            ALL_COUNTS.put((Integer)value.f0, currentCount);
        }
    }

    private static class OnceFailingPartitionedSum
    extends RichMapFunction<Integer, Tuple2<Integer, Long>>
    implements ListCheckpointed<Integer> {
        private static final Map<Integer, Long> ALL_SUMS = new ConcurrentHashMap<Integer, Long>();
        private final int failurePos;
        private int count;
        private boolean shouldFail = true;
        private transient ValueState<Long> sum;

        OnceFailingPartitionedSum(int failurePos) {
            this.failurePos = failurePos;
        }

        public void open(Configuration parameters) throws IOException {
            this.sum = this.getRuntimeContext().getState(new ValueStateDescriptor("my_state", Long.class));
        }

        public Tuple2<Integer, Long> map(Integer value) throws Exception {
            if (this.shouldFail && this.count++ >= this.failurePos) {
                this.shouldFail = false;
                throw new Exception("Test Failure");
            }
            Long oldSum = (Long)this.sum.value();
            long currentSum = (oldSum == null ? 0L : oldSum) + (long)value.intValue();
            this.sum.update((Object)currentSum);
            ALL_SUMS.put(value, currentSum);
            return new Tuple2((Object)value, (Object)currentSum);
        }

        public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception {
            return Collections.singletonList(this.count);
        }

        public void restoreState(List<Integer> state) throws Exception {
            Assert.assertEquals((String)"Test failed due to unexpected recovered state size", (long)1L, (long)state.size());
            this.count = state.get(0);
            this.shouldFail = false;
        }

        public void close() throws Exception {
            if (this.shouldFail) {
                Assert.fail((String)"Test ineffective: Function cleanly finished without ever failing.");
            }
        }
    }

    private static class IntGeneratingSourceFunction
    extends RichParallelSourceFunction<Integer>
    implements ListCheckpointed<Integer>,
    CheckpointListener {
        private final int numElements;
        private final int checkpointLatestAt;
        private int lastEmitted = -1;
        private boolean checkpointHappened;
        private volatile boolean isRunning = true;

        IntGeneratingSourceFunction(int numElements, int checkpointLatestAt) {
            this.numElements = numElements;
            this.checkpointLatestAt = checkpointLatestAt;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void run(SourceFunction.SourceContext<Integer> ctx) throws Exception {
            int nextElement;
            Object lockingObject = ctx.getCheckpointLock();
            int step = this.getRuntimeContext().getNumberOfParallelSubtasks();
            int n = nextElement = this.lastEmitted >= 0 ? this.lastEmitted + step : this.getRuntimeContext().getIndexOfThisSubtask();
            while (this.isRunning && nextElement < this.numElements) {
                Object object;
                if (!this.checkpointHappened) {
                    if (nextElement < this.checkpointLatestAt) {
                        Thread.sleep(1L);
                    } else {
                        object = this;
                        synchronized (object) {
                            while (!this.checkpointHappened) {
                                ((Object)((Object)this)).wait();
                            }
                        }
                    }
                }
                object = lockingObject;
                synchronized (object) {
                    ctx.collect((Object)(nextElement % 40));
                    this.lastEmitted = nextElement;
                }
                nextElement += step;
            }
        }

        public void cancel() {
            this.isRunning = false;
        }

        public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception {
            return Collections.singletonList(this.lastEmitted);
        }

        public void restoreState(List<Integer> state) throws Exception {
            Assert.assertEquals((String)"Test failed due to unexpected recovered state size", (long)1L, (long)state.size());
            this.lastEmitted = state.get(0);
            this.checkpointHappened = true;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void notifyCheckpointComplete(long checkpointId) throws Exception {
            IntGeneratingSourceFunction intGeneratingSourceFunction = this;
            synchronized (intGeneratingSourceFunction) {
                this.checkpointHappened = true;
                ((Object)((Object)this)).notifyAll();
            }
        }

        public void notifyCheckpointAborted(long checkpointId) {
        }
    }
}

