/*
 * Decompiled with CFR 0.152.
 */
package keel.Algorithms.Decision_Trees.CART;

import java.util.ArrayList;
import java.util.Arrays;
import keel.Algorithms.Decision_Trees.CART.impurities.IImpurityFunction;
import keel.Algorithms.Decision_Trees.CART.tree.DecisionTree;
import keel.Algorithms.Decision_Trees.CART.tree.TreeNode;
import keel.Algorithms.Neural_Networks.NNEP_Common.data.DoubleTransposedDataSet;

public class CART {
    private DecisionTree tree;
    private int maxDepth;
    private boolean regression;
    private IImpurityFunction impurityFunction;
    private DoubleTransposedDataSet dataset;

    public CART(DoubleTransposedDataSet dataset) {
        this.dataset = dataset;
    }

    public CART(DoubleTransposedDataSet dataset, IImpurityFunction impurityFunction) {
        this.dataset = dataset;
        this.impurityFunction = impurityFunction;
        this.impurityFunction.setDataset(dataset);
    }

    public DecisionTree getTree() {
        return this.tree;
    }

    public IImpurityFunction getImpurityFunction() {
        return this.impurityFunction;
    }

    public void setImpurityFunction(IImpurityFunction impurityFunction) {
        this.impurityFunction = impurityFunction;
        this.impurityFunction.setDataset(this.dataset);
    }

    public int getMaxDepth() {
        return this.maxDepth;
    }

    public void setMaxDepth(int maxDepth) {
        this.maxDepth = maxDepth;
    }

    public boolean isRegression() {
        return this.regression;
    }

    public void setRegression(boolean regression) {
        this.regression = regression;
    }

    private double[][] splittingValues(int[] patterns) {
        int ninputs = this.dataset.getNofinputs();
        int npatterns = patterns.length;
        double[][] splittingValues = new double[ninputs][npatterns - 1];
        for (int j = 0; j < ninputs; ++j) {
            int i;
            double[] aux = this.dataset.getObservationsOf(j);
            double[] x_j = new double[npatterns];
            for (i = 0; i < npatterns; ++i) {
                int patternIndex = patterns[i];
                x_j[i] = aux[patternIndex];
            }
            Arrays.sort(x_j);
            for (i = 0; i < x_j.length - 1; ++i) {
                splittingValues[j][i] = (x_j[i] + x_j[i + 1]) / 2.0;
            }
        }
        return splittingValues;
    }

    public void build_tree() {
        this.tree = new DecisionTree();
        int[] patterns = new int[this.dataset.getNofobservations()];
        for (int i = 0; i < patterns.length; ++i) {
            patterns[i] = i;
        }
        TreeNode root = new TreeNode(null, patterns);
        this.tree.setRoot(root);
        this.grow(root);
    }

    private void grow(TreeNode node) {
        if (node == null) {
            return;
        }
        if (this.stopCriteria(node)) {
            if (this.regression) {
                this.assignMean(node);
            } else {
                this.assignClass(node);
            }
            return;
        }
        this.splitNode(node);
        if (this.regression) {
            this.assignMean(node);
        } else {
            this.assignClass(node);
        }
        this.grow(node.getLeftSon());
        this.grow(node.getRightSon());
    }

    private void splitNode(TreeNode node) {
        int ninputs = this.dataset.getNofinputs();
        int npatterns = node.getPatterns().length;
        double[][] gains = new double[ninputs][npatterns - 1];
        int bestSplit_i = 0;
        int bestSplit_j = 0;
        try {
            node.setImpurities(this.impurityFunction.impurities(node.getPatterns(), 1.0));
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        double[][] splittingValues = this.splittingValues(node.getPatterns());
        for (int j = 0; j < ninputs; ++j) {
            for (int i = 0; i < npatterns - 1; ++i) {
                gains[j][i] = this.computeImpuritiesGain(node, j, splittingValues[j][i]);
                if (!(gains[j][i] >= gains[bestSplit_j][bestSplit_i])) continue;
                bestSplit_i = i;
                bestSplit_j = j;
            }
        }
        node.setVariable(bestSplit_j);
        node.setValue(splittingValues[bestSplit_j][bestSplit_i]);
        ArrayList<int[]> arrays = this.dividePatterns(node);
        int[] toLeft = arrays.get(0);
        int[] toRight = arrays.get(1);
        TreeNode leftSon = new TreeNode(node, toLeft);
        TreeNode rightSon = new TreeNode(node, toRight);
        node.setLeftSon(leftSon);
        node.setRightSon(rightSon);
    }

    private void assignClass(TreeNode node) {
        int[] patterns = node.getPatterns();
        double[][] outputs = this.dataset.getAllOutputs();
        int[] patternsInClass = new int[outputs.length];
        for (int i = 0; i < outputs.length; ++i) {
            for (int j = 0; j < patterns.length; ++j) {
                int patternIndex = patterns[j];
                if (outputs[i][patternIndex] != 1.0) continue;
                int n = i;
                patternsInClass[n] = patternsInClass[n] + 1;
            }
        }
        int majorityClass = 0;
        for (int i = 1; i < patternsInClass.length; ++i) {
            if (patternsInClass[i] <= patternsInClass[majorityClass]) continue;
            majorityClass = i;
        }
        node.setOutputClass(majorityClass);
    }

    private void assignMean(TreeNode node) {
        int[] patterns = node.getPatterns();
        double[] outputs = this.dataset.getOutput(0);
        double mean = 0.0;
        for (int i = 0; i < patterns.length; ++i) {
            int patternIndex = patterns[i];
            mean += outputs[patternIndex];
        }
        node.setOutputValue(mean /= (double)patterns.length);
    }

    private ArrayList<int[]> dividePatterns(TreeNode from) {
        int[] patterns = from.getPatterns();
        int variable = from.getVariable();
        double limitValue = from.getValue();
        ArrayList<Integer> leftBranch = new ArrayList<Integer>();
        ArrayList<Integer> rightBranch = new ArrayList<Integer>();
        for (int j = 0; j < patterns.length; ++j) {
            int patternIndex = patterns[j];
            double patternValue = this.dataset.getAllInputs()[variable][patternIndex];
            if (patternValue <= limitValue) {
                leftBranch.add(patternIndex);
                continue;
            }
            rightBranch.add(patternIndex);
        }
        int[] toLeft = new int[leftBranch.size()];
        for (int i = 0; i < toLeft.length; ++i) {
            toLeft[i] = (Integer)leftBranch.get(i);
        }
        int[] toRight = new int[rightBranch.size()];
        for (int i = 0; i < toRight.length; ++i) {
            toRight[i] = (Integer)rightBranch.get(i);
        }
        ArrayList<int[]> result = new ArrayList<int[]>();
        result.add(toLeft);
        result.add(toRight);
        return result;
    }

    private double computeImpuritiesGain(TreeNode node, int inputvar, double limitValue) {
        ArrayList<Integer> leftBranch = new ArrayList<Integer>();
        ArrayList<Integer> rightBranch = new ArrayList<Integer>();
        int[] patterns = node.getPatterns();
        for (int j = 0; j < patterns.length; ++j) {
            int patternIndex = patterns[j];
            double patternValue = this.dataset.getAllInputs()[inputvar][patterns[j]];
            if (patternValue <= limitValue) {
                leftBranch.add(patternIndex);
                continue;
            }
            rightBranch.add(patternIndex);
        }
        int[] leftPatterns = new int[leftBranch.size()];
        int[] rightPatterns = new int[rightBranch.size()];
        double parentImpurities = 0.0;
        double leftImpurities = 0.0;
        double rightImpurities = 0.0;
        try {
            int i;
            for (i = 0; i < leftPatterns.length; ++i) {
                leftPatterns[i] = (Integer)leftBranch.get(i);
            }
            leftImpurities = this.impurityFunction.impurities(leftPatterns, 1.0);
            for (i = 0; i < rightPatterns.length; ++i) {
                rightPatterns[i] = (Integer)rightBranch.get(i);
            }
            rightImpurities = this.impurityFunction.impurities(rightPatterns, 1.0);
            parentImpurities = node.getImpurities();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        double P_l = (double)leftPatterns.length / (double)patterns.length;
        double P_r = (double)rightPatterns.length / (double)patterns.length;
        return parentImpurities - P_l * leftImpurities - P_r * rightImpurities;
    }

    public void prune_tree() {
    }

    public boolean stopCriteria(TreeNode node) {
        int[] patterns = node.getPatterns();
        if (patterns.length < 2) {
            return true;
        }
        if (this.tree.depth() >= this.maxDepth) {
            return true;
        }
        boolean equalDependant = true;
        for (int i = 0; i < patterns.length - 1; ++i) {
            double[] next_output;
            int patternIndex = patterns[i];
            int nextPatternIndex = patterns[i + 1];
            double[] prev_output = this.dataset.getOutputs(patternIndex);
            equalDependant = Arrays.equals(prev_output, next_output = this.dataset.getOutputs(nextPatternIndex));
            if (!equalDependant) break;
        }
        return equalDependant;
    }

    public byte[][] getClassificationResults(DoubleTransposedDataSet dataset) {
        double[][] inputs = this.transposedMatrix(dataset.getAllInputs());
        int noutputs = dataset.getNofoutputs();
        int npatterns = dataset.getNofobservations();
        byte[][] predicted = new byte[noutputs][npatterns];
        TreeNode root = this.tree.getRoot();
        for (int i = 0; i < npatterns; ++i) {
            double[] pattern = inputs[i];
            int predictedClass = (int)root.evaluate(pattern, this.regression);
            for (int j = 0; j < noutputs; ++j) {
                predicted[j][i] = 0;
            }
            predicted[predictedClass][i] = 1;
        }
        return predicted;
    }

    public double[] getRegressionResults(DoubleTransposedDataSet dataset) {
        double[][] inputs = this.transposedMatrix(dataset.getAllInputs());
        int npatterns = dataset.getNofobservations();
        double[] predicted = new double[npatterns];
        TreeNode root = this.tree.getRoot();
        for (int i = 0; i < npatterns; ++i) {
            double predictedValue;
            double[] pattern = inputs[i];
            predicted[i] = predictedValue = root.evaluate(pattern, this.regression);
        }
        return predicted;
    }

    private double[][] transposedMatrix(double[][] a) {
        int rows = a.length;
        int cols = a[rows - 1].length;
        double[][] b = new double[cols][rows];
        for (int i = 0; i < rows; ++i) {
            for (int j = 0; j < cols; ++j) {
                b[j][i] = a[i][j];
            }
        }
        return b;
    }
}

