/*-------------------------------------------------------------------------------
  Copyright (c) 2024 GRF Contributors.

  This file is part of generalized random forest (grf).

  grf is free software: you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation, either version 3 of the License, or
  (at your option) any later version.

  grf is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with grf. If not, see <http://www.gnu.org/licenses/>.
 #-------------------------------------------------------------------------------*/

#include "TreeTraverser.h"
#include "commons/utility.h"

#include <future>
#include <thread>

namespace grf {

TreeTraverser::TreeTraverser(uint num_threads) :
    num_threads(num_threads) {}

std::vector<std::vector<size_t>> TreeTraverser::get_leaf_nodes(
    const Forest& forest,
    const Data& data,
    bool oob_prediction) const {
  std::atomic<bool> user_interrupt_flag {false};

  size_t num_trees = forest.get_trees().size();
  ProgressBar progress_bar(num_trees, "prediction [traversal]: ");

  std::vector<std::vector<size_t>> leaf_nodes_by_tree;
  leaf_nodes_by_tree.reserve(num_trees);

  std::vector<uint> thread_ranges;
  split_sequence(thread_ranges, 0, static_cast<uint>(num_trees - 1), num_threads);

  std::vector<std::future<
      std::vector<std::vector<size_t>>>> futures;
  futures.reserve(thread_ranges.size());

  for (uint i = 0; i < thread_ranges.size() - 1; ++i) {
    size_t start_index = thread_ranges[i];
    size_t num_trees_batch = thread_ranges[i + 1] - start_index;
    futures.push_back(std::async(std::launch::async,
                                 &TreeTraverser::get_leaf_node_batch,
                                 this,
                                 start_index,
                                 num_trees_batch,
                                 std::ref(forest),
                                 std::ref(data),
                                 oob_prediction,
                                 std::ref(progress_bar),
                                 std::ref(user_interrupt_flag)));
  }

  // Periodically check for user interrupts + update progress bar while threads are working.
  bool working = true;
  while (working) {
    try {
      grf::runtime_context.interrupt_handler();
      progress_bar.update();
    } catch (...) {
      user_interrupt_flag = true;
      // Adhere to good C++ hygiene and clean up the futures before rethrowing
      for (auto& future : futures) {
        if (future.valid()) {
          try { future.get(); } catch (...) {}
        }
      }
      throw;
    }
    // Check if we can stop working
    working = false;
    for (const auto& future : futures) {
      if (future.wait_for(std::chrono::milliseconds(0)) != std::future_status::ready) {
        working = true;
        break;
      }
    }
    if (working) {
      std::this_thread::sleep_for(std::chrono::milliseconds(1));
    }
  }

  // Collect the final results
  for (auto& future : futures) {
    std::vector<std::vector<size_t>> leaf_nodes = future.get();
    leaf_nodes_by_tree.insert(leaf_nodes_by_tree.end(),
                              leaf_nodes.begin(),
                              leaf_nodes.end());
  }
  progress_bar.final_update();

  return leaf_nodes_by_tree;
};

std::vector<std::vector<bool>> TreeTraverser::get_valid_trees_by_sample(const Forest& forest,
                                                                        const Data& data,
                                                                        bool oob_prediction) const {
  size_t num_trees = forest.get_trees().size();
  size_t num_samples = data.get_num_rows();

  std::vector<std::vector<bool>> result(num_samples, std::vector<bool>(num_trees, true));
  if (oob_prediction) {
    for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx) {
      for (size_t sample : forest.get_trees()[tree_idx]->get_drawn_samples()) {
        result[sample][tree_idx] = false;
      }
    }
  }
  return result;
}

std::vector<std::vector<size_t>> TreeTraverser::get_leaf_node_batch(
    size_t start,
    size_t num_trees,
    const Forest& forest,
    const Data& data,
    bool oob_prediction,
    ProgressBar& progress_bar,
    std::atomic<bool>& user_interrupt_flag) const {

  size_t num_samples = data.get_num_rows();
  std::vector<std::vector<size_t>> all_leaf_nodes(num_trees);

  for (size_t i = 0; i < num_trees; ++i) {
    if (user_interrupt_flag) {
      return std::vector<std::vector<size_t>>();
    }
    const std::unique_ptr<Tree>& tree = forest.get_trees()[start + i];

    std::vector<bool> valid_samples = get_valid_samples(num_samples, tree, oob_prediction);
    std::vector<size_t> leaf_nodes = tree->find_leaf_nodes(data, valid_samples);
    all_leaf_nodes[i] = leaf_nodes;
    progress_bar.increment(1);
  }

  return all_leaf_nodes;
}

std::vector<bool> TreeTraverser::get_valid_samples(size_t num_samples,
                                                   const std::unique_ptr<Tree>& tree,
                                                   bool oob_prediction) const {
  std::vector<bool> valid_samples(num_samples, true);
  if (oob_prediction) {
    for (size_t sample : tree->get_drawn_samples()) {
      valid_samples[sample] = false;
    }
  }
  return valid_samples;
}

} // namespace grf
