/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.paramserv.dp;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionSparkScheme;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import scala.Tuple2;

public class ORSparkScheme
extends DataPartitionSparkScheme {
    private static final long serialVersionUID = 6867567406403580311L;

    protected ORSparkScheme() {
    }

    @Override
    public DataPartitionSparkScheme.Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) {
        List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = this.partition(numWorkers, rblkID, features);
        List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = this.partition(numWorkers, rblkID, labels);
        return new DataPartitionSparkScheme.Result(pfs, pls);
    }

    private List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> partition(int numWorkers, int rblkID, MatrixBlock mb) {
        return IntStream.range(0, numWorkers).boxed().flatMap(workerID -> {
            MatrixBlock partialPerm = (MatrixBlock)((PartitionedBroadcast)this._globalPerms.get((int)workerID)).getBlock(rblkID, 1);
            return IntStream.range(0, mb.getNumRows()).mapToObj(r -> {
                MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1);
                long shiftedPosition = (long)partialPerm.get(r, 0);
                return new Tuple2(workerID, (Object)new Tuple2((Object)shiftedPosition, (Object)rowMB));
            });
        }).collect(Collectors.toList());
    }
}

