/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.matrix.data;

import java.util.ArrayList;
import java.util.concurrent.Callable;
import org.apache.commons.math3.util.FastMath;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.data.DnnParameters;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.runtime.util.UtilFunctions;

public class LibMatrixDNNLSTM {
    private static final int row_tile_size = 4;
    private static final boolean kahan = false;
    private static final boolean optimized = true;

    public static ArrayList<Callable<Long>> getLSTMWorkers(DnnParameters params) {
        ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int taskSize = (int)Math.ceil((double)params.N / (double)k);
        if (taskSize < 4 && (params.D + params.M) * params.T < 6400) {
            taskSize = 4;
        }
        int i = 0;
        while (i * taskSize < params.N) {
            ret.add(new LSTMExecutor(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            ++i;
        }
        return ret;
    }

    public static void lstmTile(int n, int d, int T, int m, int start, int end, MatrixBlock x, MatrixBlock w, MatrixBlock bias, MatrixBlock out0, MatrixBlock c0, boolean return_sequences, MatrixBlock out, MatrixBlock cout, MatrixBlock cache_out, MatrixBlock cache_c, MatrixBlock cache_ifog) {
        double[] c_0_values = c0.getDenseBlockValues();
        double[] bias_values = bias.getDenseBlockValues();
        double[] out0_values = out0.getDenseBlockValues();
        double[] w_values = w.getDenseBlockValues();
        double[] x_values = x.getDenseBlockValues();
        double[] out_values = out.getDenseBlockValues();
        double[] cout_values = cout.getDenseBlockValues();
        double[] cache_out_values = cache_out.getDenseBlockValues();
        double[] cache_c_values = cache_c.getDenseBlockValues();
        double[] cache_ifog_values = cache_ifog.getDenseBlockValues();
        boolean biasAllocated = bias.isAllocated();
        boolean xAllocated = x.isAllocated();
        boolean wAllocated = w.isAllocated();
        int tile_size_i = 4;
        int tile_size_j = 32;
        int tile_size_k = 1024;
        int m_4 = 4 * m;
        int m_T = T * m;
        int[] pos_in_x = new int[4];
        double[] ifog = new double[16 * m];
        Object kbuff = null;
        Object kplus = null;
        double[] out_prev_values = null;
        double[] c_prev_values = null;
        for (int bi = start; bi < end; bi += 4) {
            int j;
            int c_prev_pointer;
            int i_internal;
            int i;
            int bimin = Math.min(end, bi + 4);
            if (out0_values != null) {
                if (out_prev_values == null) {
                    out_prev_values = new double[m * 4];
                }
                i = bi;
                i_internal = 0;
                while (i < bimin) {
                    c_prev_pointer = i * m;
                    for (j = 0; j < m; ++j) {
                        out_prev_values[j + i_internal * m] = out0_values[c_prev_pointer + j];
                    }
                    ++i;
                    ++i_internal;
                }
            } else {
                out_prev_values = new double[m * 4];
            }
            if (c_0_values != null) {
                if (c_prev_values == null) {
                    c_prev_values = new double[m * 4];
                }
                i = bi;
                i_internal = 0;
                while (i < bimin) {
                    c_prev_pointer = i * m;
                    for (j = 0; j < m; ++j) {
                        c_prev_values[j + i_internal * m] = c_0_values[c_prev_pointer + j];
                    }
                    ++i;
                    ++i_internal;
                }
            } else {
                c_prev_values = new double[m * 4];
            }
            i = bi;
            i_internal = 0;
            while (i < bimin) {
                pos_in_x[i_internal] = i * x.getNumColumns();
                ++i;
                ++i_internal;
            }
            for (int t = 0; t < T; ++t) {
                int j2;
                int i_internal2;
                int i2;
                int bkmin;
                int bk;
                int bjmin;
                int bj;
                int pos_in_sequence = t * d;
                int offset_t_internal = t * m;
                int offset_t = offset_t_internal * n;
                int offset_t2 = offset_t * 4;
                for (int j3 = 0; j3 < 4 * m; ++j3) {
                    int i3 = bi;
                    int i_internal3 = 0;
                    while (i3 < bimin) {
                        ifog[j3 + i_internal3 * m_4] = biasAllocated ? bias_values[j3] : 0.0;
                        ++i3;
                        ++i_internal3;
                    }
                }
                if (xAllocated) {
                    for (bj = 0; bj < d; bj += 32) {
                        if (!wAllocated) continue;
                        bjmin = Math.min(d, bj + 32);
                        for (bk = 0; bk < m_4; bk += 1024) {
                            bkmin = Math.min(m_4, bk + 1024);
                            i2 = bi;
                            i_internal2 = 0;
                            while (i2 < bimin) {
                                int pos_internal_ifog_i = i_internal2 * m_4;
                                int pos = pos_in_x[i_internal2] + pos_in_sequence;
                                for (j2 = bj; j2 < bjmin; ++j2) {
                                    int offset_w = j2 * 4 * m;
                                    int offset_x = pos + j2;
                                    for (int k = bk; k < bkmin; ++k) {
                                        int n2 = pos_internal_ifog_i + k;
                                        ifog[n2] = ifog[n2] + x_values[offset_x] * w_values[k + offset_w];
                                    }
                                }
                                ++i2;
                                ++i_internal2;
                            }
                        }
                    }
                }
                for (bj = 0; bj < m; bj += 32) {
                    if (!wAllocated) continue;
                    bjmin = Math.min(m, bj + 32);
                    for (bk = 0; bk < 4 * m; bk += 1024) {
                        bkmin = Math.min(4 * m, bk + 1024);
                        i2 = bi;
                        i_internal2 = 0;
                        while (i2 < bimin) {
                            int offset_out_prev = i_internal2 * m;
                            int offset_internal = offset_out_prev * 4;
                            for (j2 = bj; j2 < bjmin; ++j2) {
                                int offset_tmp = (j2 + d) * m_4;
                                for (int k = bk; k < bkmin; ++k) {
                                    int offset_w = k + offset_tmp;
                                    int n3 = offset_internal + k;
                                    ifog[n3] = ifog[n3] + out_prev_values[offset_out_prev + j2] * w_values[offset_w];
                                }
                            }
                            ++i2;
                            ++i_internal2;
                        }
                    }
                }
                int i4 = bi;
                int i_internal4 = 0;
                while (i4 < bimin) {
                    int offset_internal_i = i_internal4 * 4 * m;
                    int offset_internal_f = offset_internal_i + m;
                    int offset_internal_o = offset_internal_f + m;
                    int offset_internal_g = offset_internal_o + m;
                    int offset_c_internal = i_internal4 * m;
                    int offset_out = i4 * m_T + offset_t_internal;
                    int offset_i = i4 * m;
                    int offset_cache = offset_t + offset_i;
                    int offset_cache_i = offset_t2 + offset_i * 4;
                    int offset_cache_f = offset_cache_i + m;
                    int offset_cache_o = offset_cache_f + m;
                    int offset_cache_g = offset_cache_o + m;
                    for (int j4 = 0; j4 < m; ++j4) {
                        double ig = 1.0 / (FastMath.exp((double)(-ifog[offset_internal_i + j4])) + 1.0);
                        double fg = 1.0 / (FastMath.exp((double)(-ifog[offset_internal_f + j4])) + 1.0);
                        double og = 1.0 / (FastMath.exp((double)(-ifog[offset_internal_o + j4])) + 1.0);
                        double gg = FastMath.tanh((double)ifog[offset_internal_g + j4]);
                        double c = c_prev_values[offset_c_internal + j4] * fg + ig * gg;
                        double o = FastMath.tanh((double)c) * og;
                        if (return_sequences) {
                            out_values[offset_out + j4] = o;
                        }
                        cache_out_values[offset_cache + j4] = o;
                        cache_c_values[offset_cache + j4] = c;
                        cache_ifog_values[offset_cache_i + j4] = ig;
                        cache_ifog_values[offset_cache_f + j4] = fg;
                        cache_ifog_values[offset_cache_o + j4] = og;
                        cache_ifog_values[offset_cache_g + j4] = gg;
                        c_prev_values[offset_c_internal + j4] = c;
                        out_prev_values[offset_c_internal + j4] = o;
                    }
                    ++i4;
                    ++i_internal4;
                }
            }
            i = bi;
            i_internal = 0;
            while (i < bimin) {
                int offset_i = i * m;
                for (int j5 = 0; j5 < m; ++j5) {
                    cout_values[offset_i + j5] = c_prev_values[i_internal * m + j5];
                    if (return_sequences) continue;
                    out_values[offset_i + j5] = out_prev_values[i_internal * m + j5];
                }
                ++i;
                ++i_internal;
            }
        }
    }

    public static long lstmGeneric(DnnParameters params) {
        MatrixBlock x = params.input1;
        MatrixBlock w = params.input2;
        MatrixBlock bias = params.bias;
        MatrixBlock out = params.input3;
        MatrixBlock c = params.input4;
        MatrixBlock cache_out = params.output3;
        MatrixBlock cache_c = params.output4;
        MatrixBlock cache_ifog = params.output5;
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int M = params.M;
        BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString(), k);
        BinaryOperator emult = InstructionUtils.parseBinaryOperator(Opcodes.MULT.toString(), k);
        UnaryOperator tanh = InstructionUtils.parseUnaryOperator(Opcodes.TANH.toString(), k);
        UnaryOperator sigmoid = InstructionUtils.parseUnaryOperator(Opcodes.SIGMOID.toString(), k);
        AggregateBinaryOperator mmult = InstructionUtils.getMatMultOperator(k);
        for (int t = 0; t < params.T; ++t) {
            MatrixBlock x_t = x.slice(0, x.rlen - 1, t * params.D, (t + 1) * params.D - 1);
            MatrixBlock ifog = x_t.append(out, true);
            ifog = ifog.aggregateBinaryOperations(ifog, w, mmult);
            ifog = ifog.binaryOperations(plus, bias);
            MatrixBlock ifo = ifog.slice(0, ifog.rlen - 1, 0, 3 * M - 1).unaryOperations(sigmoid);
            MatrixBlock i = ifo.slice(0, ifog.rlen - 1, 0, M - 1);
            MatrixBlock f = ifo.slice(0, ifog.rlen - 1, M, 2 * M - 1);
            MatrixBlock o = ifo.slice(0, ifog.rlen - 1, 2 * M, 3 * M - 1);
            MatrixBlock g = ifog.slice(0, ifog.rlen - 1, 3 * M, 4 * M - 1).unaryOperations(tanh);
            MatrixBlock tmp = i.binaryOperations(emult, g);
            c = f.binaryOperations(emult, c).binaryOperations(plus, tmp, t == params.T - 1 ? params.output2 : null);
            tmp = c.unaryOperations(tanh);
            if (params.return_sequences) {
                out = o.binaryOperations(emult, tmp);
                params.output.leftIndexingOperations(out, 0, out.rlen - 1, t * M, (t + 1) * M - 1, null, MatrixObject.UpdateType.INPLACE);
            } else {
                out = o.binaryOperations(emult, tmp, t == params.T - 1 ? params.output : null);
            }
            ifog = ifo.append(g, true);
            MatrixBlock cache_out_t = out.reshape(1, cache_out.clen, true);
            cache_out.leftIndexingOperations(cache_out_t, t, t, 0, cache_out.clen - 1, null, MatrixObject.UpdateType.INPLACE);
            MatrixBlock cache_c_t = c.reshape(1, cache_c.clen, true);
            cache_c.leftIndexingOperations(cache_c_t, t, t, 0, cache_c.clen - 1, null, MatrixObject.UpdateType.INPLACE);
            MatrixBlock cache_ifog_t = ifog.reshape(1, cache_ifog.clen, true);
            cache_ifog.leftIndexingOperations(cache_ifog_t, t, t, 0, cache_ifog.clen - 1, null, MatrixObject.UpdateType.INPLACE);
        }
        return params.output.recomputeNonZeros();
    }

    public static long lstmBackwardGeneric(DnnParameters params) {
        MatrixBlock x = params.input1;
        MatrixBlock w = params.input2;
        MatrixBlock bias = params.bias;
        MatrixBlock out0 = params.input3;
        MatrixBlock c0 = params.input4;
        MatrixBlock dout = params.input5;
        MatrixBlock dc = params.input6;
        MatrixBlock cache_out = params.input7;
        MatrixBlock cache_c = params.input8;
        MatrixBlock cache_ifog = params.input9;
        MatrixBlock dX = params.output;
        MatrixBlock dW = null;
        MatrixBlock db = null;
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int M = params.M;
        BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString(), k);
        BinaryOperator emult = InstructionUtils.parseBinaryOperator(Opcodes.MULT.toString(), k);
        ScalarOperator exp2 = InstructionUtils.parseScalarBinaryOperator(Opcodes.POW2.toString(), false, 0.0, k);
        ScalarOperator minus = InstructionUtils.parseScalarBinaryOperator(Opcodes.MINUS.toString(), true, 1.0, k);
        UnaryOperator tanh = InstructionUtils.parseUnaryOperator(Opcodes.TANH.toString(), k);
        UnaryOperator sprop = InstructionUtils.parseUnaryOperator(Opcodes.SPROP.toString(), k);
        AggregateUnaryOperator colsum = InstructionUtils.parseBasicAggregateUnaryOperator(Opcodes.UACKP.toString(), k);
        ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k);
        AggregateBinaryOperator mmult = InstructionUtils.getMatMultOperator(k);
        MatrixBlock dout_prev = params.return_sequences ? null : dout;
        w = w.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
        for (int t = params.T - 1; t >= 0; --t) {
            if (params.return_sequences) {
                dout_prev = t == params.T - 1 ? dout.slice(0, dout.rlen - 1, t * M, (t + 1) * M - 1) : dout.slice(0, dout.rlen - 1, t * M, (t + 1) * M - 1).binaryOperations(plus, dout_prev);
            }
            MatrixBlock c_t = cache_c.slice(t, t, 0, cache_c.clen - 1).reshape(params.N, M, true);
            MatrixBlock c_prev = t == 0 ? c0 : cache_c.slice(t - 1, t - 1, 0, cache_c.clen - 1).reshape(params.N, M, true);
            MatrixBlock ifog = cache_ifog.slice(t, t, 0, cache_ifog.clen - 1).reshape(params.N, 4 * M, true);
            MatrixBlock i = ifog.slice(0, ifog.rlen - 1, 0, M - 1);
            MatrixBlock f = ifog.slice(0, ifog.rlen - 1, M, 2 * M - 1);
            MatrixBlock o = ifog.slice(0, ifog.rlen - 1, 2 * M, 3 * M - 1);
            MatrixBlock g = ifog.slice(0, ifog.rlen - 1, 3 * M, ifog.clen - 1);
            MatrixBlock tanh_forward = c_t.unaryOperations(tanh);
            MatrixBlock tanh_back = tanh_forward.scalarOperations(exp2, new MatrixBlock()).scalarOperations(minus, new MatrixBlock());
            tanh_back = tanh_back.binaryOperations(emult, dout_prev);
            MatrixBlock tmp = o.binaryOperations(emult, tanh_back);
            dc = dc.binaryOperations(plus, tmp);
            MatrixBlock d_o = tanh_forward.binaryOperations(emult, dout_prev);
            MatrixBlock d_f = c_prev.binaryOperations(emult, dc);
            MatrixBlock d_i = g.binaryOperations(emult, dc);
            MatrixBlock d_g = i.binaryOperations(emult, dc);
            MatrixBlock difog_raw = new MatrixBlock(params.N, 4 * M, false);
            MatrixBlock di_raw = i.unaryOperations(sprop, new MatrixBlock()).binaryOperations(emult, d_i);
            difog_raw.leftIndexingOperations(di_raw, 0, difog_raw.rlen - 1, 0, M - 1, null, MatrixObject.UpdateType.INPLACE);
            MatrixBlock df_raw = f.unaryOperations(sprop, new MatrixBlock()).binaryOperations(emult, d_f);
            difog_raw.leftIndexingOperations(df_raw, 0, difog_raw.rlen - 1, M, 2 * M - 1, null, MatrixObject.UpdateType.INPLACE);
            MatrixBlock do_raw = o.unaryOperations(sprop, new MatrixBlock()).binaryOperations(emult, d_o);
            difog_raw.leftIndexingOperations(do_raw, 0, difog_raw.rlen - 1, 2 * M, 3 * M - 1, null, MatrixObject.UpdateType.INPLACE);
            MatrixBlock dg_raw = g.scalarOperations(exp2, new MatrixBlock()).scalarOperations(minus, new MatrixBlock()).binaryOperations(emult, d_g);
            difog_raw.leftIndexingOperations(dg_raw, 0, difog_raw.rlen - 1, 3 * M, 4 * M - 1, null, MatrixObject.UpdateType.INPLACE);
            MatrixBlock x_t = x.slice(0, x.rlen - 1, t * params.D, (t + 1) * params.D - 1);
            MatrixBlock out_prev = t == 0 ? out0 : cache_out.slice(t - 1, t - 1, 0, cache_out.clen - 1).reshape(params.N, M, true);
            MatrixBlock in_t = x_t.append(out_prev, true).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
            tmp = in_t.aggregateBinaryOperations(in_t, difog_raw, params.T == 1 ? params.output2 : null, mmult);
            dW = t == params.T - 1 ? tmp : dW.binaryOperations(plus, tmp, t == 0 ? params.output2 : null);
            tmp = difog_raw.aggregateUnaryOperations(colsum, params.T == 1 ? params.output3 : null, difog_raw.rlen, new MatrixIndexes(1L, 1L), true);
            db = t == params.T - 1 ? tmp : db.binaryOperations(plus, tmp, t == 0 ? params.output3 : null);
            MatrixBlock dinput = difog_raw.aggregateBinaryOperations(difog_raw, w, mmult);
            dX.leftIndexingOperations(dinput.slice(0, dinput.rlen - 1, 0, params.D - 1), 0, dX.rlen - 1, t * params.D, (t + 1) * params.D - 1, null, MatrixObject.UpdateType.INPLACE);
            dout_prev = dinput.slice(0, dinput.rlen - 1, params.D, dinput.clen - 1, t == 0 ? params.output4 : null);
            dc = f.binaryOperations(emult, dc, t == 0 ? params.output5 : null);
        }
        return params.output.recomputeNonZeros();
    }

    public static boolean checkLSTMInputForOptimisation(DnnParameters params) {
        boolean fits_FP64 = UtilFunctions.prod(new int[]{params.T, params.N, params.M}) < Integer.MAX_VALUE;
        return (!params.input1.isAllocated() || !params.input1.sparse && params.input1.denseBlock.numBlocks() == 1) && (!params.input2.isAllocated() || !params.input2.sparse && params.input2.denseBlock.numBlocks() == 1) && (!params.bias.isAllocated() || !params.bias.sparse && params.bias.denseBlock.numBlocks() == 1) && (!params.input4.isAllocated() || !params.input4.sparse && params.input4.denseBlock.numBlocks() == 1) && (!params.input3.isAllocated() || !params.input3.sparse && params.input3.denseBlock.numBlocks() == 1) && fits_FP64;
    }

    public static boolean checkLSTMBackwardInputForOptimisation(DnnParameters params) {
        return false;
    }

    private static class LSTMExecutor
    implements Callable<Long> {
        protected final int _rl;
        protected final int _ru;
        protected final DnnParameters _params;

        public LSTMExecutor(int rl, int ru, DnnParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            LibMatrixDNNLSTM.lstmTile(this._params.N, this._params.D, this._params.T, this._params.M, this._rl, this._ru, this._params.input1, this._params.input2, this._params.bias, this._params.input3, this._params.input4, this._params.return_sequences, this._params.output, this._params.output2, this._params.output3, this._params.output4, this._params.output5);
            return this._params.output.recomputeNonZeros(this._rl, this._ru - 1);
        }
    }
}

