/*-------------------------------------------------------------------------------
  This file is part of generalized-random-forest.

  drf is free software: you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation, either version 3 of the License, or
  (at your option) any later version.

  drf is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with drf. If not, see <http://www.gnu.org/licenses/>.
 #-------------------------------------------------------------------------------*/

#include <Rcpp.h>
#include <queue>
#include <vector>

#include "analysis/SplitFrequencyComputer.h"
#include "commons/globals.h"
#include "forest/Forest.h"
#include "prediction/collector/SampleWeightComputer.h"
#include "prediction/collector/TreeTraverser.h"

#include "RcppUtilities.h"

using namespace drf;

// [[Rcpp::export]]
Rcpp::NumericMatrix compute_split_frequencies(Rcpp::List forest_object,
                                              size_t max_depth) {
  Forest forest = RcppUtilities::deserialize_forest(forest_object);

  SplitFrequencyComputer computer;
  std::vector<std::vector<size_t>> split_frequencies = computer.compute(forest, max_depth);

  size_t num_variables = forest.get_num_variables();
  Rcpp::NumericMatrix result(max_depth, num_variables);
  for (size_t depth = 0; depth < split_frequencies.size(); depth++) {
    const std::vector<size_t>& frequencies = split_frequencies.at(depth);
    for (size_t var = 0; var < num_variables; var++) {
      double frequency = frequencies[var];
        result(depth, var) = frequency;
      }
    }
  return result;
}

Eigen::SparseMatrix<double> compute_sample_weights(Rcpp::List forest_object,
                                                   Rcpp::NumericMatrix train_matrix,
                                                   Eigen::SparseMatrix<double> sparse_train_matrix,
                                                   Rcpp::NumericMatrix test_matrix,
                                                   Eigen::SparseMatrix<double> sparse_test_matrix,
                                                   unsigned int num_threads,
                                                   bool oob_prediction) {
  std::unique_ptr<Data> train_data = RcppUtilities::convert_data(train_matrix, sparse_train_matrix);
  std::unique_ptr<Data> data = RcppUtilities::convert_data(test_matrix, sparse_test_matrix);
  Forest forest = RcppUtilities::deserialize_forest(forest_object);
  num_threads = ForestOptions::validate_num_threads(num_threads);

  TreeTraverser tree_traverser(num_threads);
  SampleWeightComputer weight_computer;

  std::vector<std::vector<size_t>> leaf_nodes_by_tree = tree_traverser.get_leaf_nodes(forest, *data, oob_prediction);
  std::vector<std::vector<bool>> trees_by_sample = tree_traverser.get_valid_trees_by_sample(forest, *data, oob_prediction);

  size_t num_samples = data->get_num_rows();
  size_t num_neighbors = train_data->get_num_rows();

  // From http://eigen.tuxfamily.org/dox/group__TutorialSparse.html:
  // Filling a sparse matrix effectively
  std::vector<Eigen::Triplet<double>> triplet_list;
  triplet_list.reserve(num_neighbors);
  Eigen::SparseMatrix<double> result(num_samples, num_neighbors);

  for (size_t sample = 0; sample < num_samples; sample++) {
    // Calculate weights over full forest (all trees)
    std::unordered_map<size_t, double> weights = weight_computer.compute_weights(
        sample, forest, leaf_nodes_by_tree, trees_by_sample, 
        0, forest.get_trees().size());
    for (auto it = weights.begin(); it != weights.end(); it++) {
      size_t neighbor = it->first;
      double weight = it->second;
      triplet_list.emplace_back(sample, neighbor, weight);
    }
  }
  result.setFromTriplets(triplet_list.begin(), triplet_list.end());

  return result;
}

// [[Rcpp::export]]
Eigen::SparseMatrix<double> compute_weights(Rcpp::List forest_object,
                                            Rcpp::NumericMatrix train_matrix,
                                            Eigen::SparseMatrix<double> sparse_train_matrix,
                                            Rcpp::NumericMatrix test_matrix,
                                            Eigen::SparseMatrix<double> sparse_test_matrix,
                                            unsigned int num_threads) {
  return compute_sample_weights(forest_object, train_matrix, sparse_train_matrix,
                                test_matrix, sparse_test_matrix, num_threads, false);
}

// [[Rcpp::export]]
Eigen::SparseMatrix<double> compute_weights_oob(Rcpp::List forest_object,
                                                Rcpp::NumericMatrix test_matrix,
                                                Eigen::SparseMatrix<double> sparse_test_matrix,
                                                unsigned int num_threads) {
  return compute_sample_weights(forest_object, test_matrix, sparse_test_matrix,
                                test_matrix, sparse_test_matrix, num_threads, true);
}


std::vector<Eigen::SparseMatrix<double>> compute_sample_weights_uncertainty(Rcpp::List forest_object,
                                                                            Rcpp::NumericMatrix train_matrix,
                                                                            Eigen::SparseMatrix<double> sparse_train_matrix,
                                                                            Rcpp::NumericMatrix test_matrix,
                                                                            Eigen::SparseMatrix<double> sparse_test_matrix,
                                                                            unsigned int num_threads) {
  
  // This is largely copy-pasta of compute_sample_weights(), only the for loop has been adapted 
  // to loop over each CI group
  
  
  const bool oob_prediction = false; // uncertainty weights not defined for OOB
  
  std::unique_ptr<Data> train_data = RcppUtilities::convert_data(train_matrix, sparse_train_matrix);
  std::unique_ptr<Data> data = RcppUtilities::convert_data(test_matrix, sparse_test_matrix);
  Forest forest = RcppUtilities::deserialize_forest(forest_object);
  num_threads = ForestOptions::validate_num_threads(num_threads);
  const size_t ci_group_size = forest.get_ci_group_size();
  if (ci_group_size <= 1) {
    throw std::runtime_error("To estimate uncertainty, the forest must be trained with ci.group.size greater than 1.");
  }
  
  TreeTraverser tree_traverser(num_threads);
  SampleWeightComputer weight_computer;
  
  std::vector<std::vector<size_t>> leaf_nodes_by_tree = tree_traverser.get_leaf_nodes(forest, *data, oob_prediction);
  std::vector<std::vector<bool>> trees_by_sample = tree_traverser.get_valid_trees_by_sample(forest, *data, oob_prediction);
  
  size_t num_samples = data->get_num_rows();
  size_t num_neighbors = train_data->get_num_rows();
  
  
  // ceil(n / k): integer ceiling division trick for (n=num_trees) / (k=ci_group_size)
  size_t num_ci_groups = (forest.get_trees().size() + ci_group_size - 1) / ci_group_size;
  Eigen::SparseMatrix<double> res_i(num_ci_groups, num_neighbors);
  
  std::vector<Eigen::SparseMatrix<double>> results;
  
  for (size_t sample = 0; sample < num_samples; sample++) {
    
    // From http://eigen.tuxfamily.org/dox/group__TutorialSparse.html:
    // Filling a sparse matrix effectively
    
    // Need to have fresh triplet list for each sample!
    std::vector<Eigen::Triplet<double>> triplet_list;
    triplet_list.reserve(num_neighbors);
    
    size_t matrix_row_i = 0; // to write result row for each CI group
    
    // Loop over each CI group: Calculate weights on tree groups of size `ci_group_size`
    for(size_t tree_index = 0; tree_index < forest.get_trees().size(); tree_index += ci_group_size) {
      
      // Catch if writing beyond nrow of results matrix (should not happen)
      // `>=` because comparing 0-vs-1-indexed 
      if(matrix_row_i >= num_ci_groups){
        throw std::runtime_error("Attempting to write beyond last row of result matrix (matrix_row_i >= num_ci_groups).");
      }
    
      // Calculate weights only for this CI group: trees[tree_index, tree_index + ci_group_size)
      std::unordered_map<size_t, double> weights = weight_computer.compute_weights(
        sample, forest, leaf_nodes_by_tree, trees_by_sample, tree_index, tree_index + ci_group_size);
        
      
      for (auto it = weights.begin(); it != weights.end(); it++) {
        size_t neighbor = it->first;
        double weight = it->second;
        triplet_list.emplace_back(matrix_row_i, neighbor, weight);
      }
      
      // Weights of next CI group are on a new row in the matrix
      matrix_row_i++;
    }
    
    // Create matrix 
    res_i.setFromTriplets(triplet_list.begin(), triplet_list.end()); 
    
    // Add matrix of current sample to list of all samples
    results.push_back(res_i);
  }
  
  return results;
}

// [[Rcpp::export]]
std::vector<Eigen::SparseMatrix<double>> compute_weights_uncertainty(Rcpp::List forest_object,
                                                                     Rcpp::NumericMatrix train_matrix,
                                                                     Eigen::SparseMatrix<double> sparse_train_matrix,
                                                                     Rcpp::NumericMatrix test_matrix,
                                                                     Eigen::SparseMatrix<double> sparse_test_matrix,
                                                                     unsigned int num_threads) {
  // For consistency, follow the sampe pattern as for `compute_weights`, 
  // even if no OOB is possible
  
  return compute_sample_weights_uncertainty(forest_object, 
                                            train_matrix, 
                                            sparse_train_matrix,
                                            test_matrix, 
                                            sparse_test_matrix,
                                            num_threads);
}

// [[Rcpp::export]]
Rcpp::List merge(const Rcpp::List forest_objects) {
 std::vector<Forest> forests;

 for (auto& forest_obj : forest_objects) {
   Forest deserialized_forest = RcppUtilities::deserialize_forest(forest_obj);
   forests.push_back(std::move(deserialized_forest));
 }

  Forest big_forest = Forest::merge(forests);
  return RcppUtilities::serialize_forest(big_forest);
}
