/*-------------------------------------------------------------------------------
 This file is part of unityForest.

 Copyright (c) [2014-2018] [Marvin N. Wright]
 Modifications and extensions by Roman Hornung

 This software may be modified and distributed under the terms of the MIT license.

 Please note that the C++ core of divfor is distributed under MIT license and the
 R package "unityForest" under GPL3 license.
 #-------------------------------------------------------------------------------*/

#ifndef FORESTCLASSIFICATION_H_
#define FORESTCLASSIFICATION_H_

#include <iostream>
#include <map>
#include <utility>
#include <vector>

#include "globals.h"
#include "Forest.h"
#include "TreeClassification.h"

namespace unityForest
{

  class ForestClassification : public Forest
  {
  public:
    ForestClassification() = default;

    ForestClassification(const ForestClassification &) = delete;
    ForestClassification &operator=(const ForestClassification &) = delete;

    virtual ~ForestClassification() override = default;

    void loadForest(size_t dependent_varID, size_t num_trees,
                    std::vector<std::vector<std::vector<size_t>>> &forest_child_nodeIDs,
                    std::vector<std::vector<size_t>> &forest_split_varIDs, std::vector<std::vector<double>> &forest_split_values,
                    std::vector<double> &class_values, std::vector<bool> &is_ordered_variable);

    void loadForestRepr(size_t dependent_varID, size_t num_trees,
                        std::vector<std::vector<std::vector<size_t>>> &forest_child_nodeIDs,
                        std::vector<std::vector<size_t>> &forest_split_varIDs, std::vector<std::vector<double>> &forest_split_values,
                        std::vector<double> &class_values, std::vector<double> &class_weights, std::vector<std::vector<size_t>> &forest_nodeID_in_root,
                        std::vector<std::vector<size_t>> &forest_inbag_counts,
                        std::vector<bool> &is_ordered_variable);

    const std::vector<double> &getClassValues() const
    {
      return class_values;
    }

    const std::vector<double> &getClassWeights() const
    {
      return class_weights;
    }

    void setClassWeights(std::vector<double> &class_weights)
    {
      this->class_weights = class_weights;
    }

  protected:
    void initInternal(std::string status_variable_name) override;
    void growInternal() override;
    void allocatePredictMemory() override;
    void predictInternal(size_t sample_idx) override;
    void computePredictionErrorInternal() override;

    // Classes of the dependent variable and classIDs for responses
    std::vector<double> class_values;
    std::vector<uint> response_classIDs;
    std::vector<std::vector<size_t>> sampleIDs_per_class;

    // Splitting weights
    std::vector<double> class_weights;

    // Table with classifications and true classes
    std::map<std::pair<double, double>, size_t> classification_table;

  private:
    double getTreePrediction(size_t tree_idx, size_t sample_idx) const;
    size_t getTreePredictionTerminalNodeID(size_t tree_idx, size_t sample_idx) const;
  };

} // namespace unityForest

#endif /* FORESTCLASSIFICATION_H_ */
