/*-------------------------------------------------------------------------------
 This file is part of unityForest.

 Copyright (c) [2014-2018] [Marvin N. Wright]
 Modifications and extensions by Roman Hornung

 This software may be modified and distributed under the terms of the MIT license.

 Please note that the C++ core of divfor is distributed under MIT license and the
 R package "unityForest" under GPL3 license.
 #-------------------------------------------------------------------------------*/

#include <stdexcept>

#include "utility.h"
#include "ForestProbability.h"
#include "TreeProbability.h"
#include "Data.h"

namespace unityForest
{

  void ForestProbability::loadForest(size_t dependent_varID, size_t num_trees,
                                     std::vector<std::vector<std::vector<size_t>>> &forest_child_nodeIDs,
                                     std::vector<std::vector<size_t>> &forest_split_varIDs, std::vector<std::vector<double>> &forest_split_values,
                                     std::vector<double> &class_values, std::vector<std::vector<std::vector<double>>> &forest_terminal_class_counts,
                                     std::vector<bool> &is_ordered_variable)
  {

    this->dependent_varID = dependent_varID;
    this->num_trees = num_trees;
    this->class_values = class_values;
    data->setIsOrderedVariable(is_ordered_variable);

    // Create trees
    trees.reserve(num_trees);
    for (size_t i = 0; i < num_trees; ++i)
    {
      trees.push_back(
          std::make_unique<TreeProbability>(forest_child_nodeIDs[i], forest_split_varIDs[i], forest_split_values[i],
                                            &this->class_values, &response_classIDs, forest_terminal_class_counts[i]));
    }

    // Create thread ranges
    equalSplit(thread_ranges, 0, num_trees - 1, num_threads);
  }

  // Function for loading a saved forest for the CRTR analysis 
  void ForestProbability::loadForestRepr(size_t dependent_varID, size_t num_trees,
                                         std::vector<std::vector<std::vector<size_t>>> &forest_child_nodeIDs,
                                         std::vector<std::vector<size_t>> &forest_split_varIDs, std::vector<std::vector<double>> &forest_split_values,
                                         std::vector<double> &class_values, std::vector<double> &class_weights, std::vector<std::vector<std::vector<double>>> &forest_terminal_class_counts, std::vector<std::vector<size_t>> &forest_nodeID_in_root,
                                         std::vector<std::vector<size_t>> &forest_inbag_counts,
                                         std::vector<bool> &is_ordered_variable)
  {

    this->dependent_varID = dependent_varID;
    this->num_trees = num_trees;
    this->class_values = class_values;
    this->class_weights = class_weights;
    data->setIsOrderedVariable(is_ordered_variable);

    /* build response_classIDs exactly like initInternal ------------------- */
    response_classIDs.clear();
    response_classIDs.reserve(data->getNumRows());

    for (size_t row = 0; row < data->getNumRows(); ++row)
    {
      double y = data->get(row, dependent_varID);

      // find() returns iterator; subtracting gives the index
      uint classID = std::find(class_values.begin(),
                               class_values.end(),
                               y) -
                     class_values.begin();

      // if not found, append and classID becomes new last index
      if (classID == class_values.size())
        class_values.push_back(y);

      response_classIDs.push_back(classID);
    }
    /* -------------------------------------------------------------------- */

    // Create trees
    trees.reserve(num_trees);
    for (size_t i = 0; i < num_trees; ++i)
    {
      trees.push_back(
          std::make_unique<TreeProbability>(forest_child_nodeIDs[i], forest_split_varIDs[i], forest_split_values[i],
                                            &this->class_values, &this->class_weights, &response_classIDs, forest_terminal_class_counts[i], forest_nodeID_in_root[i], forest_inbag_counts[i], this->repr_vars, data.get()));
    }

    // Create thread ranges
    equalSplit(thread_ranges, 0, num_trees - 1, num_threads);
  }

  std::vector<std::vector<std::vector<double>>> ForestProbability::getTerminalClassCounts() const
  {
    std::vector<std::vector<std::vector<double>>> result;
    result.reserve(num_trees);
    for (const auto &tree : trees)
    {
      const auto &temp = dynamic_cast<const TreeProbability &>(*tree);
      result.push_back(temp.getTerminalClassCounts());
    }
    return result;
  }

  void ForestProbability::initInternal(std::string status_variable_name)
  {

    // If mtry not set, use floored square root of number of independent variables.
    if (mtry == 0)
    {
      unsigned long temp = sqrt((double)(num_variables - 1));
      mtry = std::max((unsigned long)1, temp);
    }

    // Set minimal node size
    if (min_node_size == 0)
    {
      min_node_size = DEFAULT_MIN_NODE_SIZE_PROBABILITY;
    }

    // Create class_values and response_classIDs
    if (!prediction_mode)
    {
      for (size_t i = 0; i < num_samples; ++i)
      {
        double value = data->get(i, dependent_varID);

        // If classID is already in class_values, use ID. Else create a new one.
        uint classID = find(class_values.begin(), class_values.end(), value) - class_values.begin();
        if (classID == class_values.size())
        {
          class_values.push_back(value);
        }
        response_classIDs.push_back(classID);
      }
    }

    // Create sampleIDs_per_class if required
    if (sample_fraction.size() > 1)
    {
      sampleIDs_per_class.resize(sample_fraction.size());
      for (auto &v : sampleIDs_per_class)
      {
        v.reserve(num_samples);
      }
      for (size_t i = 0; i < num_samples; ++i)
      {
        size_t classID = response_classIDs[i];
        sampleIDs_per_class[classID].push_back(i);
      }
    }

    // Set class weights all to 1
    class_weights = std::vector<double>(class_values.size(), 1.0);

    // Sort data if memory saving mode
    if (!memory_saving_splitting)
    {
      data->sort();
    }
  }

  void ForestProbability::growInternal()
  {
    trees.reserve(num_trees);
    for (size_t i = 0; i < num_trees; ++i)
    {
      trees.push_back(
          std::make_unique<TreeProbability>(&class_values, &response_classIDs, &sampleIDs_per_class, &class_weights));
      // -----------------------------------------------------------
      // give the freshly-built tree the pre-computed variable list
      // -----------------------------------------------------------
      trees.back()->setAllowedVarIDs(&allowedVarIDs_);
    }
  }

  void ForestProbability::allocatePredictMemory()
  {
    size_t num_prediction_samples = data->getNumRows();
    if (predict_all)
    {
      predictions = std::vector<std::vector<std::vector<double>>>(num_prediction_samples,
                                                                  std::vector<std::vector<double>>(class_values.size(), std::vector<double>(num_trees, 0)));
    }
    else if (prediction_type == TERMINALNODES)
    {
      predictions = std::vector<std::vector<std::vector<double>>>(1,
                                                                  std::vector<std::vector<double>>(num_prediction_samples, std::vector<double>(num_trees, 0)));
    }
    else
    {
      predictions = std::vector<std::vector<std::vector<double>>>(1,
                                                                  std::vector<std::vector<double>>(num_prediction_samples, std::vector<double>(class_values.size(), 0)));
    }
  }

  void ForestProbability::predictInternal(size_t sample_idx)
  {
    // For each sample compute proportions in each tree
    for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx)
    {
      if (predict_all)
      {
        std::vector<double> counts = getTreePrediction(tree_idx, sample_idx);

        for (size_t class_idx = 0; class_idx < counts.size(); ++class_idx)
        {
          predictions[sample_idx][class_idx][tree_idx] += counts[class_idx];
        }
      }
      else if (prediction_type == TERMINALNODES)
      {
        predictions[0][sample_idx][tree_idx] = getTreePredictionTerminalNodeID(tree_idx, sample_idx);
      }
      else
      {
        std::vector<double> counts = getTreePrediction(tree_idx, sample_idx);

        for (size_t class_idx = 0; class_idx < counts.size(); ++class_idx)
        {
          predictions[0][sample_idx][class_idx] += counts[class_idx];
        }
      }
    }

    // Average over trees
    if (!predict_all && prediction_type != TERMINALNODES)
    {
      for (size_t class_idx = 0; class_idx < predictions[0][sample_idx].size(); ++class_idx)
      {
        predictions[0][sample_idx][class_idx] /= num_trees;
      }
    }
  }

  void ForestProbability::computePredictionErrorInternal()
  {

    // CP();

    // For each sample sum over trees where sample is OOB
    std::vector<size_t> samples_oob_count;
    samples_oob_count.resize(num_samples, 0);
    predictions = std::vector<std::vector<std::vector<double>>>(1,
                                                                std::vector<std::vector<double>>(num_samples, std::vector<double>(class_values.size(), 0)));

    for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx)
    {
      for (size_t sample_idx = 0; sample_idx < trees[tree_idx]->getNumSamplesOob(); ++sample_idx)
      {
        size_t sampleID = trees[tree_idx]->getOobSampleIDs()[sample_idx];
        std::vector<double> counts = getTreePrediction(tree_idx, sample_idx);

        for (size_t class_idx = 0; class_idx < counts.size(); ++class_idx)
        {
          predictions[0][sampleID][class_idx] += counts[class_idx];
        }
        ++samples_oob_count[sampleID];
      }
    }

    // MSE with predicted probability and true data
    size_t num_predictions = 0;
    overall_prediction_error = 0;
    for (size_t i = 0; i < predictions[0].size(); ++i)
    {
      if (samples_oob_count[i] > 0)
      {
        ++num_predictions;
        for (size_t j = 0; j < predictions[0][i].size(); ++j)
        {
          predictions[0][i][j] /= (double)samples_oob_count[i];
        }
        size_t real_classID = response_classIDs[i];
        double predicted_value = predictions[0][i][real_classID];
        overall_prediction_error += (1 - predicted_value) * (1 - predicted_value);
      }
      else
      {
        for (size_t j = 0; j < predictions[0][i].size(); ++j)
        {
          predictions[0][i][j] = NAN;
        }
      }
    }

    overall_prediction_error /= (double)num_predictions;
  }

  // #nocov start
  const std::vector<double> &ForestProbability::getTreePrediction(size_t tree_idx, size_t sample_idx) const
  {
    const auto &tree = dynamic_cast<const TreeProbability &>(*trees[tree_idx]);
    return tree.getPrediction(sample_idx);
  }

  size_t ForestProbability::getTreePredictionTerminalNodeID(size_t tree_idx, size_t sample_idx) const
  {
    const auto &tree = dynamic_cast<const TreeProbability &>(*trees[tree_idx]);
    return tree.getPredictionTerminalNodeID(sample_idx);
  }

  // #nocov end

} // namespace unityForest
