/*
 * Decompiled with CFR 0.152.
 */
package org.extratrees;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import org.extratrees.AbstractTrees;
import org.extratrees.Aggregator;
import org.extratrees.FactorBinaryTree;
import org.extratrees.TaskCutResult;
import org.extratrees.data.Array2D;

public class FactorExtraTrees
extends AbstractTrees<FactorBinaryTree, Integer>
implements Serializable {
    private static final long serialVersionUID = 3625952360819832098L;
    transient int[] output;
    int nFactors;

    public FactorExtraTrees(int n) {
        this.nFactors = n;
    }

    public FactorExtraTrees(Array2D array2D, int[] nArray) {
        this(array2D, nArray, null);
    }

    public FactorExtraTrees(Array2D array2D, int[] nArray, int[] nArray2) {
        if (array2D.nrows() != nArray.length) {
            throw new IllegalArgumentException("Input and output do not have same length.");
        }
        if (nArray2 != null && array2D.nrows() != nArray2.length) {
            throw new IllegalArgumentException("Input and tasks do not have the same number of data points.");
        }
        this.setInput(array2D);
        this.output = nArray;
        this.nFactors = 1;
        for (int i = 0; i < nArray.length; ++i) {
            if (nArray[i] < 0) {
                throw new RuntimeException("Bug: negative output (factor) values.");
            }
            if (this.nFactors > nArray[i]) continue;
            this.nFactors = nArray[i] + 1;
        }
        this.setTasks(nArray2);
    }

    public int getnFactors() {
        return this.nFactors;
    }

    public void setnFactors(int n) {
        this.nFactors = n;
    }

    public FactorExtraTrees selectTrees(boolean[] blArray) {
        FactorExtraTrees factorExtraTrees = new FactorExtraTrees(this.nFactors);
        factorExtraTrees.trees = new ArrayList();
        for (int i = 0; i < blArray.length; ++i) {
            if (!blArray[i]) continue;
            factorExtraTrees.trees.add(this.trees.get(i));
        }
        return factorExtraTrees;
    }

    @Override
    Aggregator<Integer> getNewAggregator() {
        return new MajorityVote();
    }

    @Override
    double convertToDouble(Integer n) {
        return n >= 0 ? (double)n.intValue() : Double.NaN;
    }

    public static int getMaxIndex(int[] nArray) {
        int n = -1;
        int n2 = 0;
        for (int i = 0; i < nArray.length; ++i) {
            if (nArray[i] <= n2) continue;
            n2 = nArray[i];
            n = i;
        }
        return n;
    }

    public static int getMaxIndex(double[] dArray) {
        int n = -1;
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < dArray.length; ++i) {
            if (!(dArray[i] > d)) continue;
            d = dArray[i];
            n = i;
        }
        return n;
    }

    private static int[] list2array(ArrayList<Integer> arrayList) {
        int[] nArray = new int[arrayList.size()];
        for (int i = 0; i < nArray.length; ++i) {
            nArray[i] = arrayList.get(i);
        }
        return nArray;
    }

    public int[] getValues(Array2D array2D) {
        return FactorExtraTrees.list2array(this.getValuesD(array2D));
    }

    public int[] getValuesMT(Array2D array2D, int[] nArray) {
        return FactorExtraTrees.list2array(this.getValuesMTD(array2D, nArray));
    }

    public static double getGiniIndex(double[] dArray) {
        double d = 0.0;
        double d2 = 0.0;
        for (int i = 0; i < dArray.length; ++i) {
            d += dArray[i] * dArray[i];
            d2 += dArray[i];
        }
        return 1.0 - d / (d2 * d2);
    }

    public static double getGiniIndex(int[] nArray) {
        long l = 0L;
        long l2 = 0L;
        for (int i = 0; i < nArray.length; ++i) {
            l += (long)(nArray[i] * nArray[i]);
            l2 += (long)nArray[i];
        }
        return 1.0 - (double)l / (double)(l2 * l2);
    }

    @Override
    protected FactorBinaryTree makeFilledTree(FactorBinaryTree factorBinaryTree, FactorBinaryTree factorBinaryTree2, int n, double d, int n2) {
        FactorBinaryTree factorBinaryTree3 = new FactorBinaryTree();
        factorBinaryTree3.column = n;
        factorBinaryTree3.threshold = d;
        factorBinaryTree3.nSuccessors = n2;
        factorBinaryTree3.left = factorBinaryTree;
        factorBinaryTree3.right = factorBinaryTree2;
        return factorBinaryTree3;
    }

    @Override
    protected TaskCutResult getTaskCut(int[] nArray, Set<Integer> set, double d, int n) {
        if (this.nFactors > 2) {
            throw new RuntimeException("Multitask learning is not implemented 3 or more factors (classes).");
        }
        if (set.size() <= 1) {
            return null;
        }
        int[][] nArray2 = this.getFactorTaskTable(nArray);
        double[] dArray = this.getTaskScores(nArray2);
        if (!this.hasAtLeast2Tasks(set, nArray2)) {
            return null;
        }
        double[] dArray2 = FactorExtraTrees.getRange(dArray);
        TaskCutResult taskCutResult = null;
        for (int i = 0; i < this.numRandomTaskCuts; ++i) {
            double d2 = this.getRandom(dArray2[0], dArray2[1], n);
            TaskCutResult taskCutResult2 = new TaskCutResult();
            this.calculateTaskCutScore(dArray, nArray2, d2, taskCutResult2);
            if (!(taskCutResult2.score < d)) continue;
            taskCutResult = taskCutResult2;
            d = taskCutResult2.score;
        }
        return taskCutResult;
    }

    protected boolean hasAtLeast2Tasks(Set<Integer> set, int[][] nArray) {
        boolean bl = false;
        for (int n : set) {
            if (nArray[0][n] <= 0 && nArray[1][n] <= 0) continue;
            if (bl) {
                return true;
            }
            bl = true;
        }
        return false;
    }

    private int[][] getFactorTaskTable(int[] nArray) {
        int n;
        int[][] nArray2 = new int[this.nFactors][this.nTasks];
        for (n = 0; n < this.nFactors; ++n) {
            nArray2[n] = new int[this.nTasks];
        }
        for (n = 0; n < nArray.length; ++n) {
            int n2 = nArray[n];
            int[] nArray3 = nArray2[this.output[n2]];
            int n3 = this.tasks[n2];
            nArray3[n3] = nArray3[n3] + 1;
        }
        return nArray2;
    }

    private double[] getTaskScores(int[][] nArray) {
        int[][] nArray2 = nArray;
        double d = 1.0;
        double[] dArray = FactorExtraTrees.sumAlong2nd(nArray2);
        double d2 = (dArray[0] + 1.0) / (dArray[0] + dArray[1] + 2.0) * d;
        double[] dArray2 = new double[this.nTasks];
        for (int i = 0; i < this.nTasks; ++i) {
            dArray2[i] = ((double)nArray2[0][i] + d2) / ((double)(nArray2[0][i] + nArray2[1][i]) + d);
        }
        return dArray2;
    }

    public static double[] sumAlong2nd(int[][] nArray) {
        double[] dArray = new double[2];
        for (int i = 0; i < nArray[0].length; ++i) {
            dArray[0] = dArray[0] + (double)nArray[0][i];
            dArray[1] = dArray[1] + (double)nArray[1][i];
        }
        return dArray;
    }

    @Override
    protected double get1NaNScore(int[] nArray) {
        double[] dArray = new double[this.nFactors];
        for (int i = 0; i < nArray.length; ++i) {
            int n = nArray[i];
            int n2 = this.output[n];
            dArray[n2] = dArray[n2] + (this.useWeights ? this.weights[n] : 1.0);
        }
        return FactorExtraTrees.getGiniIndex(dArray);
    }

    @Override
    protected void calculateCutScore(int[] nArray, int n, double d, AbstractTrees.CutResult cutResult) {
        if (!this.useWeights && !this.hasNaN) {
            int[][] nArray2 = new int[2][this.nFactors];
            for (int i = 0; i < nArray.length; ++i) {
                int[] nArray3 = nArray2[this.input.get(nArray[i], n) < d ? 0 : 1];
                int n2 = this.output[nArray[i]];
                nArray3[n2] = nArray3[n2] + 1;
            }
            cutResult.countLeft = FactorExtraTrees.sum(nArray2[0]);
            cutResult.countRight = FactorExtraTrees.sum(nArray2[1]);
            double d2 = FactorExtraTrees.getGiniIndex(nArray2[0]);
            double d3 = FactorExtraTrees.getGiniIndex(nArray2[1]);
            cutResult.score = (d2 * (double)cutResult.countLeft + d3 * (double)cutResult.countRight) / (double)(cutResult.countLeft + cutResult.countRight);
            cutResult.leftConst = d2 < 9.999999999999998E-15;
            cutResult.rightConst = d3 < 9.999999999999998E-15;
        } else {
            double d4;
            double d5;
            double[][] dArray = new double[2][this.nFactors];
            int[] nArray4 = new int[2];
            for (int i = 0; i < nArray.length; ++i) {
                int n3 = nArray[i];
                d5 = this.input.get(n3, n);
                double d6 = d4 = this.useWeights ? this.weights[n3] : 1.0;
                if (this.hasNaN && Double.isNaN(d5)) {
                    cutResult.nanWeigth += d4;
                    continue;
                }
                int n4 = d5 < d ? 0 : 1;
                double[] dArray2 = dArray[n4];
                int n5 = this.output[n3];
                dArray2[n5] = dArray2[n5] + d4;
                int n6 = n4;
                nArray4[n6] = nArray4[n6] + 1;
            }
            cutResult.countLeft = nArray4[0];
            cutResult.countRight = nArray4[1];
            double d7 = FactorExtraTrees.getGiniIndex(dArray[0]);
            d5 = FactorExtraTrees.getGiniIndex(dArray[1]);
            d4 = FactorExtraTrees.sum(dArray[0]);
            double d8 = FactorExtraTrees.sum(dArray[1]);
            cutResult.score = (d7 * d4 + d5 * d8) / (d4 + d8);
            cutResult.leftConst = d7 < 9.999999999999998E-15;
            cutResult.rightConst = d5 < 9.999999999999998E-15;
        }
    }

    private void calculateTaskCutScore(double[] dArray, int[][] nArray, double d, TaskCutResult taskCutResult) {
        double[] dArray2 = new double[this.nFactors];
        double[] dArray3 = new double[this.nFactors];
        taskCutResult.leftTasks = new HashSet<Integer>();
        taskCutResult.rightTasks = new HashSet<Integer>();
        for (int i = 0; i < nArray[0].length; ++i) {
            int n;
            if (dArray[i] < d) {
                for (n = 0; n < this.nFactors; ++n) {
                    int n2 = n;
                    dArray2[n2] = dArray2[n2] + (double)nArray[n][i];
                }
                taskCutResult.leftTasks.add(i);
                continue;
            }
            for (n = 0; n < this.nFactors; ++n) {
                int n3 = n;
                dArray3[n3] = dArray3[n3] + (double)nArray[n][i];
            }
            taskCutResult.rightTasks.add(i);
        }
        taskCutResult.countLeft = (int)FactorExtraTrees.sum(dArray2);
        taskCutResult.countRight = (int)FactorExtraTrees.sum(dArray3);
        this.cutResultFromCounts(taskCutResult, dArray2, dArray3);
    }

    private void cutResultFromCounts(AbstractTrees.CutResult cutResult, double[] dArray, double[] dArray2) {
        double d = FactorExtraTrees.getGiniIndex(dArray);
        double d2 = FactorExtraTrees.getGiniIndex(dArray2);
        cutResult.score = (d * (double)cutResult.countLeft + d2 * (double)cutResult.countRight) / (double)(cutResult.countLeft + cutResult.countRight);
        cutResult.leftConst = d < 9.999999999999998E-15;
        cutResult.rightConst = d2 < 9.999999999999998E-15;
    }

    @Override
    public FactorBinaryTree makeLeaf(int[] nArray, Set<Integer> set) {
        FactorBinaryTree factorBinaryTree = new FactorBinaryTree();
        factorBinaryTree.value = 0;
        factorBinaryTree.nSuccessors = nArray.length;
        factorBinaryTree.tasks = set;
        if (!this.useWeights) {
            int[] nArray2 = new int[this.nFactors];
            for (int i = 0; i < nArray.length; ++i) {
                int n = this.output[nArray[i]];
                nArray2[n] = nArray2[n] + 1;
            }
            factorBinaryTree.value = FactorExtraTrees.getMaxIndex(nArray2);
        } else {
            double[] dArray = new double[this.nFactors];
            for (int i = 0; i < nArray.length; ++i) {
                int n = this.output[nArray[i]];
                dArray[n] = dArray[n] + this.weights[nArray[i]];
            }
            factorBinaryTree.value = FactorExtraTrees.getMaxIndex(dArray);
        }
        return factorBinaryTree;
    }

    public class MajorityVote
    implements Aggregator<Integer> {
        int[] counts;

        public MajorityVote() {
            this.counts = new int[FactorExtraTrees.this.nFactors];
        }

        @Override
        public void processLeaf(Integer n) {
            int n2 = n;
            this.counts[n2] = this.counts[n2] + 1;
        }

        @Override
        public Integer getResult() {
            return FactorExtraTrees.getMaxIndex(this.counts);
        }
    }
}

