/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import java.util.Collections;
import java.util.concurrent.Future;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.SpoofCPInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.spark.SpoofSPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;

public class SpoofFEDInstruction
extends FEDInstruction {
    private final SpoofOperator _op;
    private final CPOperand[] _inputs;
    private final CPOperand _output;

    private SpoofFEDInstruction(SpoofOperator op, CPOperand[] in, CPOperand out, String opcode, String instStr) {
        super(FEDInstruction.FEDType.SpoofFused, opcode, instStr);
        this._op = op;
        this._inputs = in;
        this._output = out;
    }

    public static SpoofFEDInstruction parseInstruction(SpoofCPInstruction inst, ExecutionContext ec) {
        Class<?> scla = inst.getOperatorClass().getSuperclass();
        if ((scla == SpoofCellwise.class || scla == SpoofMultiAggregate.class || scla == SpoofOuterProduct.class) && SpoofFEDInstruction.isFederated(ec, inst.getInputs(), scla) || scla == SpoofRowwise.class && SpoofFEDInstruction.isFederated(ec, FTypes.FType.ROW, inst.getInputs(), scla)) {
            return SpoofFEDInstruction.parseInstruction(inst);
        }
        return null;
    }

    public static SpoofFEDInstruction parseInstruction(SpoofSPInstruction inst, ExecutionContext ec) {
        Class<?> scla = inst.getOperatorClass().getSuperclass();
        if ((scla == SpoofCellwise.class || scla == SpoofMultiAggregate.class || scla == SpoofOuterProduct.class) && SpoofFEDInstruction.isFederated(ec, inst.getInputs(), scla) || scla == SpoofRowwise.class && SpoofFEDInstruction.isFederated(ec, FTypes.FType.ROW, inst.getInputs(), scla)) {
            return SpoofFEDInstruction.parseInstruction(inst);
        }
        return null;
    }

    private static SpoofFEDInstruction parseInstruction(SpoofCPInstruction instr) {
        return new SpoofFEDInstruction(instr.getSpoofOperator(), instr.getInputs(), instr.getOutput(), instr.getOpcode(), instr.getInstructionString());
    }

    private static SpoofFEDInstruction parseInstruction(SpoofSPInstruction instr) {
        SpoofOperator op = CodegenUtils.createInstance(instr.getOperatorClass());
        return new SpoofFEDInstruction(op, instr.getInputs(), instr.getOutput(), instr.getOpcode(), instr.getInstructionString());
    }

    public static SpoofFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        CPOperand[] inputCpo = new CPOperand[parts.length - 3 - 2];
        Class<?> cla = CodegenUtils.getClass(parts[2]);
        SpoofOperator op = CodegenUtils.createInstance(cla);
        String opcode = parts[0] + op.getSpoofType();
        for (int counter = 3; counter < parts.length - 2; ++counter) {
            inputCpo[counter - 3] = new CPOperand(parts[counter]);
        }
        CPOperand out = new CPOperand(parts[parts.length - 2]);
        return new SpoofFEDInstruction(op, inputCpo, out, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        FederationMap fedMap = null;
        for (CPOperand cpo : this._inputs) {
            Data tmpData = ec.getVariable(cpo);
            if (!(tmpData instanceof MatrixObject) || !((MatrixObject)tmpData).isFederatedExcept(FTypes.FType.BROADCAST)) continue;
            fedMap = ((MatrixObject)tmpData).getFedMapping();
            break;
        }
        Class<?> scla = this._op.getClass().getSuperclass();
        SpoofFEDType spoofType = null;
        if (scla == SpoofCellwise.class) {
            spoofType = new SpoofFEDCellwise(this._op, this._output, fedMap.getType());
        } else if (scla == SpoofRowwise.class) {
            spoofType = new SpoofFEDRowwise(this._op, this._output, fedMap.getType());
        } else if (scla == SpoofMultiAggregate.class) {
            spoofType = new SpoofFEDMultiAgg(this._op, this._output, fedMap.getType());
        } else if (scla == SpoofOuterProduct.class) {
            spoofType = new SpoofFEDOuterProduct(this._op, this._output, fedMap.getType(), this._inputs);
        } else {
            throw new DMLRuntimeException("Federated code generation only supported for cellwise, rowwise, multiaggregate, and outerproduct templates.");
        }
        this.processRequest(ec, fedMap, spoofType);
    }

    private void processRequest(ExecutionContext ec, FederationMap fedMap, SpoofFEDType spoofType) {
        ArrayList<FederatedRequest[]> frBroadcast = new ArrayList<FederatedRequest[]>();
        ArrayList<FederatedRequest[]> frBroadcastSliced = new ArrayList<FederatedRequest[]>();
        long[] frIds = new long[this._inputs.length];
        int index = 0;
        for (CPOperand cpo : this._inputs) {
            FederatedRequest[] tmpFr;
            Data tmpData = ec.getVariable(cpo);
            if (tmpData instanceof MatrixObject) {
                MatrixLineagePair mo = MatrixLineagePair.of((MatrixObject)tmpData, DMLScript.LINEAGE ? ec.getLineageItem(cpo) : null);
                if (mo.isFederatedExcept(FTypes.FType.BROADCAST)) {
                    frIds[index++] = mo.getFedMapping().getID();
                    continue;
                }
                if (spoofType.needsBroadcastSliced(fedMap, mo.getNumRows(), mo.getNumColumns(), index)) {
                    tmpFr = spoofType.broadcastSliced(mo, fedMap);
                    frIds[index++] = tmpFr[0].getID();
                    frBroadcastSliced.add(tmpFr);
                    continue;
                }
                tmpFr = fedMap.broadcast(mo);
                frIds[index++] = tmpFr.getID();
                frBroadcast.add(tmpFr);
                continue;
            }
            if (!(tmpData instanceof ScalarObject)) continue;
            ScalarObject so = (ScalarObject)tmpData;
            tmpFr = fedMap.broadcast(so);
            frIds[index++] = tmpFr.getID();
            frBroadcast.add(tmpFr);
        }
        this.instString = this.instString.replace("true", "false");
        FederatedRequest frCompute = FederationUtils.callInstruction(this.instString, this._output, this._inputs, frIds);
        FederatedRequest frGet = null;
        FederatedRequest frCleanup = null;
        if (!spoofType.isFedOutput()) {
            frGet = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, frCompute.getID());
            frCleanup = fedMap.cleanup(this.getTID(), frCompute.getID());
        }
        FederatedRequest[] frAll = frGet == null ? (FederatedRequest[])ArrayUtils.addAll((Object[])frBroadcast.toArray(new FederatedRequest[0]), (Object[])new FederatedRequest[]{frCompute}) : (FederatedRequest[])ArrayUtils.addAll((Object[])frBroadcast.toArray(new FederatedRequest[0]), (Object[])new FederatedRequest[]{frCompute, frGet, frCleanup});
        Future<FederatedResponse>[] response = fedMap.executeMultipleSlices(this.getTID(), true, (FederatedRequest[][])frBroadcastSliced.toArray((T[])new FederatedRequest[0][]), frAll);
        spoofType.setOutput(ec, response, fedMap, frCompute.getID());
    }

    public static boolean isFederated(ExecutionContext ec, CPOperand[] inputs, Class<?> scla) {
        return SpoofFEDInstruction.isFederated(ec, null, inputs, scla);
    }

    public static boolean isFederated(ExecutionContext ec, FTypes.FType type, CPOperand[] inputs, Class<?> scla) {
        FederationMap fedMap = null;
        boolean retVal = false;
        ArrayList<FTypes.AlignType> alignmentTypes = new ArrayList<FTypes.AlignType>();
        for (CPOperand input : inputs) {
            Data data = ec.getVariable(input);
            if (!(data instanceof MatrixObject) || !((MatrixObject)data).isFederated(type) || ((MatrixObject)data).isFederated(FTypes.FType.BROADCAST)) continue;
            MatrixObject mo = (MatrixObject)data;
            if (fedMap == null) {
                fedMap = mo.getFedMapping();
                retVal = true;
                alignmentTypes.add(mo.isFederated(FTypes.FType.ROW) ? FTypes.AlignType.ROW : FTypes.AlignType.COL);
                if (scla != SpoofOuterProduct.class) continue;
                Collections.addAll(alignmentTypes, FTypes.AlignType.ROW_T, FTypes.AlignType.COL_T);
                continue;
            }
            if (fedMap.isAligned(mo.getFedMapping(), alignmentTypes.toArray(new FTypes.AlignType[0]))) continue;
            retVal = false;
        }
        return retVal;
    }

    private static class SpoofFEDOuterProduct
    extends SpoofFEDType {
        private final SpoofOuterProduct _op;
        private final SpoofOuterProduct.OutProdType _outProdType;
        private CPOperand[] _inputs;

        SpoofFEDOuterProduct(SpoofOperator op, CPOperand out, FTypes.FType fedType, CPOperand[] inputs) {
            super(out, fedType);
            this._op = (SpoofOuterProduct)op;
            this._outProdType = this._op.getOuterProdType();
            this._inputs = inputs;
        }

        @Override
        protected FederatedRequest[] broadcastSliced(MatrixLineagePair mo, FederationMap fedMap) {
            return fedMap.broadcastSliced(mo, this._fedType == FTypes.FType.COL);
        }

        @Override
        protected boolean needsBroadcastSliced(FederationMap fedMap, long rowNum, long colNum, int inputIndex) {
            boolean retVal = false;
            retVal |= rowNum == fedMap.getMaxIndexInRange(0) && colNum == fedMap.getMaxIndexInRange(1);
            if (this._fedType == FTypes.FType.ROW) {
                retVal |= rowNum == fedMap.getMaxIndexInRange(0) && inputIndex != 2;
            } else if (this._fedType == FTypes.FType.COL) {
                retVal |= rowNum == fedMap.getMaxIndexInRange(1) && inputIndex != 1;
            } else {
                throw new DMLRuntimeException("Only row partitioned or column partitioned federated input supported yet.");
            }
            return retVal;
        }

        @Override
        protected boolean isFedOutput() {
            boolean retVal = false;
            retVal |= this._outProdType == SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT && this._fedType == FTypes.FType.COL;
            retVal |= this._outProdType == SpoofOuterProduct.OutProdType.RIGHT_OUTER_PRODUCT && this._fedType == FTypes.FType.ROW;
            return retVal |= this._outProdType == SpoofOuterProduct.OutProdType.CELLWISE_OUTER_PRODUCT;
        }

        @Override
        protected void setFedOutput(ExecutionContext ec, FederationMap fedMap, long frComputeID) {
            FederationMap newFedMap = fedMap.copyWithNewID(frComputeID);
            long[] outDims = new long[2];
            MatrixObject X = ec.getMatrixObject(this._inputs[0]);
            switch (this._outProdType) {
                case LEFT_OUTER_PRODUCT: {
                    newFedMap = newFedMap.transpose();
                    outDims[0] = X.getNumColumns();
                    outDims[1] = ec.getMatrixObject(this._inputs[1]).getNumColumns();
                    break;
                }
                case RIGHT_OUTER_PRODUCT: {
                    outDims[0] = X.getNumRows();
                    outDims[1] = ec.getMatrixObject(this._inputs[2]).getNumColumns();
                    break;
                }
                case CELLWISE_OUTER_PRODUCT: {
                    outDims[0] = X.getNumRows();
                    outDims[1] = X.getNumColumns();
                    break;
                }
                default: {
                    throw new DMLRuntimeException("Outer Product Type " + this._outProdType + " not supported yet.");
                }
            }
            MatrixObject out = ec.getMatrixObject(this._output);
            int dim = newFedMap.getType() == FTypes.FType.ROW ? 1 : 0;
            newFedMap.modifyFedRanges(outDims[dim], dim);
            out.setFedMapping(newFedMap);
        }

        @Override
        protected void aggResult(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
            AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
            switch (this._outProdType) {
                case LEFT_OUTER_PRODUCT: 
                case RIGHT_OUTER_PRODUCT: {
                    ec.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
                    break;
                }
                case AGG_OUTER_PRODUCT: {
                    ec.setVariable(this._output.getName(), FederationUtils.aggScalar(aop, response));
                    break;
                }
                default: {
                    throw new DMLRuntimeException("Outer Product Type " + this._outProdType + " not supported yet.");
                }
            }
        }
    }

    private static class SpoofFEDMultiAgg
    extends SpoofFEDType {
        private final SpoofMultiAggregate _op;

        SpoofFEDMultiAgg(SpoofOperator op, CPOperand out, FTypes.FType fedType) {
            super(out, fedType);
            this._op = (SpoofMultiAggregate)op;
        }

        @Override
        protected boolean isFedOutput() {
            return false;
        }

        @Override
        protected void setFedOutput(ExecutionContext ec, FederationMap fedMap, long frComputeID) {
            throw new DMLRuntimeException("SpoofFEDMultiAgg cannot create a federated output.");
        }

        @Override
        protected void aggResult(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
            MatrixBlock[] partRes = FederationUtils.getResults(response);
            SpoofCellwise.AggOp[] aggOps = this._op.getAggOps();
            for (int counter = 1; counter < partRes.length; ++counter) {
                SpoofMultiAggregate.aggregatePartialResults(aggOps, partRes[0], partRes[counter]);
            }
            ec.setMatrixOutput(this._output.getName(), partRes[0]);
        }
    }

    private static class SpoofFEDRowwise
    extends SpoofFEDType {
        private final SpoofRowwise _op;
        private final SpoofRowwise.RowType _rowType;

        SpoofFEDRowwise(SpoofOperator op, CPOperand out, FTypes.FType fedType) {
            super(out, fedType);
            this._op = (SpoofRowwise)op;
            this._rowType = this._op.getRowType();
        }

        @Override
        protected boolean isFedOutput() {
            boolean retVal = false;
            retVal |= this._rowType == SpoofRowwise.RowType.NO_AGG;
            retVal |= this._rowType == SpoofRowwise.RowType.NO_AGG_B1;
            retVal |= this._rowType == SpoofRowwise.RowType.NO_AGG_CONST;
            return retVal &= this._fedType == FTypes.FType.ROW;
        }

        @Override
        protected void setFedOutput(ExecutionContext ec, FederationMap fedMap, long frComputeID) {
            MatrixObject out = ec.getMatrixObject(this._output);
            FederationMap newFedMap = fedMap.copyWithNewID(frComputeID).modifyFedRanges(out.getNumColumns(), 1);
            out.setFedMapping(newFedMap);
        }

        @Override
        protected void aggResult(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
            if (this._fedType != FTypes.FType.ROW) {
                throw new DMLRuntimeException("Only row partitioned federated matrices supported yet.");
            }
            Object aggInst = "ua";
            if (this._rowType == SpoofRowwise.RowType.FULL_AGG) {
                aggInst = (String)aggInst + "k+";
            } else if (this._rowType == SpoofRowwise.RowType.ROW_AGG) {
                aggInst = (String)aggInst + "rk+";
            } else if (this._rowType.isColumnAgg()) {
                aggInst = (String)aggInst + "ck+";
            } else {
                throw new DMLRuntimeException("AggregationType not supported yet.");
            }
            AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator((String)aggInst);
            if (this._rowType == SpoofRowwise.RowType.FULL_AGG) {
                ec.setVariable(this._output.getName(), FederationUtils.aggScalar(aop, response));
            } else {
                ec.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
            }
        }
    }

    private static class SpoofFEDCellwise
    extends SpoofFEDType {
        private final SpoofCellwise _op;
        private final SpoofCellwise.CellType _cellType;

        SpoofFEDCellwise(SpoofOperator op, CPOperand out, FTypes.FType fedType) {
            super(out, fedType);
            this._op = (SpoofCellwise)op;
            this._cellType = this._op.getCellType();
        }

        @Override
        protected boolean isFedOutput() {
            boolean retVal = false;
            retVal |= this._cellType == SpoofCellwise.CellType.ROW_AGG && this._fedType == FTypes.FType.ROW;
            retVal |= this._cellType == SpoofCellwise.CellType.COL_AGG && this._fedType == FTypes.FType.COL;
            return retVal |= this._cellType == SpoofCellwise.CellType.NO_AGG;
        }

        @Override
        protected void setFedOutput(ExecutionContext ec, FederationMap fedMap, long frComputeID) {
            MatrixObject out = ec.getMatrixObject(this._output);
            FederationMap newFedMap = this.modifyFedRanges(fedMap.copyWithNewID(frComputeID));
            out.setFedMapping(newFedMap);
        }

        private FederationMap modifyFedRanges(FederationMap fedMap) {
            if (this._cellType == SpoofCellwise.CellType.ROW_AGG || this._cellType == SpoofCellwise.CellType.COL_AGG) {
                int dim = this._cellType == SpoofCellwise.CellType.COL_AGG ? 0 : 1;
                fedMap.modifyFedRanges(1L, dim);
            }
            return fedMap;
        }

        @Override
        protected void aggResult(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
            SpoofCellwise.AggOp aggOp = this._op.getAggOp();
            Object aggInst = "ua";
            switch (this._cellType) {
                case FULL_AGG: {
                    break;
                }
                case ROW_AGG: {
                    aggInst = (String)aggInst + "r";
                    break;
                }
                case COL_AGG: {
                    aggInst = (String)aggInst + "c";
                    break;
                }
                default: {
                    throw new DMLRuntimeException("Aggregation type not supported yet.");
                }
            }
            switch (aggOp) {
                case SUM: 
                case SUM_SQ: {
                    aggInst = (String)aggInst + "k+";
                    break;
                }
                case MIN: {
                    aggInst = (String)aggInst + "min";
                    break;
                }
                case MAX: {
                    aggInst = (String)aggInst + "max";
                    break;
                }
                default: {
                    throw new DMLRuntimeException("Aggregation operation not supported yet.");
                }
            }
            AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator((String)aggInst);
            if (this._cellType == SpoofCellwise.CellType.FULL_AGG) {
                ec.setVariable(this._output.getName(), FederationUtils.aggScalar(aop, response));
            } else {
                ec.setMatrixOutput(this._output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
            }
        }
    }

    private static abstract class SpoofFEDType {
        CPOperand _output;
        FTypes.FType _fedType;

        protected SpoofFEDType(CPOperand out, FTypes.FType fedType) {
            this._output = out;
            this._fedType = fedType;
        }

        protected FederatedRequest[] broadcastSliced(MatrixLineagePair mo, FederationMap fedMap) {
            return fedMap.broadcastSliced(mo, false);
        }

        protected boolean needsBroadcastSliced(FederationMap fedMap, long rowNum, long colNum, int inputIndex) {
            boolean retVal;
            boolean bl = retVal = rowNum == fedMap.getMaxIndexInRange(0) && colNum == fedMap.getMaxIndexInRange(1);
            if (this._fedType == FTypes.FType.ROW) {
                retVal |= rowNum == fedMap.getMaxIndexInRange(0) && (colNum == 1L || fedMap.getMaxIndexInRange(1) == 1L);
            } else if (this._fedType == FTypes.FType.COL) {
                retVal |= colNum == fedMap.getMaxIndexInRange(1) && (rowNum == 1L || fedMap.getMaxIndexInRange(0) == 1L);
            } else {
                throw new DMLRuntimeException("Only row partitioned or column partitioned federated input supported yet.");
            }
            return retVal;
        }

        protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap, long frComputeID) {
            if (this.isFedOutput()) {
                this.setFedOutput(ec, fedMap, frComputeID);
            } else {
                this.aggResult(ec, response, fedMap);
            }
        }

        protected abstract boolean isFedOutput();

        protected abstract void setFedOutput(ExecutionContext var1, FederationMap var2, long var3);

        protected abstract void aggResult(ExecutionContext var1, Future<FederatedResponse>[] var2, FederationMap var3);
    }
}

