/*-------------------------------------------------------------------------------
 This file is part of unityForest.

 Copyright (c) [2014-2018] [Marvin N. Wright]
 Modifications and extensions by Roman Hornung

 This software may be modified and distributed under the terms of the MIT license.

 Please note that the C++ core of divfor is distributed under MIT license and the
 R package "unityForest" under GPL3 license.
 #-------------------------------------------------------------------------------*/

#include <Rcpp.h>

#include <algorithm>
#include <iostream>
#include <iterator>

#include <ctime>

#include "utility.h"
#include "TreeRegression.h"
#include "Data.h"

namespace unityForest
{

  TreeRegression::TreeRegression(std::vector<std::vector<size_t>> &child_nodeIDs, std::vector<size_t> &split_varIDs,
                                 std::vector<double> &split_values) : Tree(child_nodeIDs, split_varIDs, split_values), counter(0), sums(0)
  {
  }

  // Constructor for repr_tree_mode:
  TreeRegression::TreeRegression(std::vector<std::vector<size_t>> &child_nodeIDs, std::vector<size_t> &split_varIDs,
                                 std::vector<double> &split_values,
                                 std::vector<size_t> &nodeID_in_root, std::vector<size_t> &inbag_counts, std::vector<size_t> &repr_vars, const Data *data_ptr) : Tree(child_nodeIDs, split_varIDs, split_values, data_ptr), counter(0), sums(0)
  {
    this->nodeID_in_root = nodeID_in_root;
    this->repr_vars = repr_vars;
    this->inbag_counts = inbag_counts;
  }

  std::unique_ptr<Tree> TreeRegression::clone() const
  {
    return std::make_unique<TreeRegression>(*this);
  }

  void TreeRegression::allocateMemory()
  {
    // Init counters if not in memory efficient mode
    if (!memory_saving_splitting)
    {
      size_t max_num_splits = data->getMaxNumUniqueValues();

      // Use number of random splits for extratrees
      if (splitrule == EXTRATREES && num_random_splits > max_num_splits)
      {
        max_num_splits = num_random_splits;
      }

      counter.resize(max_num_splits);
      sums.resize(max_num_splits);
    }
  }

  double TreeRegression::estimate(size_t nodeID)
  {

    // Mean of responses of samples in node
    double sum_responses_in_node = 0;
    size_t num_samples_in_node = end_pos[nodeID] - start_pos[nodeID];
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      sum_responses_in_node += data->get(sampleID, dependent_varID);
    }
    return (sum_responses_in_node / (double)num_samples_in_node);
  }

  // Compute the reduction in variance for a split in the unity VIM computation.
  double TreeRegression::computeSplitCriterion(std::vector<size_t> sampleIDs_left_child,
                                               std::vector<size_t> sampleIDs_right_child)
  {
    // Combine the sample IDs of the left and right child nodes:
    std::vector<size_t> sampleIDs_left_right_child;
    sampleIDs_left_right_child.reserve(sampleIDs_left_child.size() + sampleIDs_right_child.size());
    sampleIDs_left_right_child.insert(sampleIDs_left_right_child.end(),
                                      sampleIDs_left_child.begin(), sampleIDs_left_child.end());
    sampleIDs_left_right_child.insert(sampleIDs_left_right_child.end(),
                                      sampleIDs_right_child.begin(), sampleIDs_right_child.end());

    // Variance of the parent node:
    double var_parent = 0.0;
    if (!sampleIDs_left_right_child.empty())
    {
      var_parent = computeVariance(sampleIDs_left_right_child);
    }

    // Variance of left child:
    double var_left = 0.0;
    if (!sampleIDs_left_child.empty())
    {
      var_left = computeVariance(sampleIDs_left_child);
    }

    // Variance of right child:
    double var_right = 0.0;
    if (!sampleIDs_right_child.empty())
    {
      var_right = computeVariance(sampleIDs_right_child);
    }

    const double n_parent = static_cast<double>(sampleIDs_left_right_child.size());
    if (n_parent == 0.0)
    {
      return 0.0;
    }

    // Variance reduction:
    const double w_left = static_cast<double>(sampleIDs_left_child.size()) / n_parent;
    const double w_right = static_cast<double>(sampleIDs_right_child.size()) / n_parent;

    const double var_reduction = var_parent - w_left * var_left - w_right * var_right;
    return var_reduction;
  }

  // Compute the outcome variance in a node (needed for the unity VIM computation and the CRTR analysis).
  double TreeRegression::computeVariance(std::vector<size_t> sampleIDs_node)
  {
    // Compute the number of samples in the current node:
    size_t num_samples_node = sampleIDs_node.size();

    // Compute mean
    double sum = 0.0;
    for (size_t i = 0; i < num_samples_node; ++i)
    {
      const size_t sampleID = sampleIDs_node[i];
      const double y = data->get(sampleID, dependent_varID);
      sum += y;
    }
    const double mean = sum / static_cast<double>(num_samples_node);

    // Compute variance
    double ssq = 0.0;
    for (size_t i = 0; i < num_samples_node; ++i)
    {
      const size_t sampleID = sampleIDs_node[i];
      const double y = data->get(sampleID, dependent_varID);
      const double d = y - mean;
      ssq += d * d;
    }

    const double var = ssq / static_cast<double>(num_samples_node);
    return var;
  }

  // Compute the variance reduction for a split in the CRTR analysis.
  double TreeRegression::computeOOBSplitCriterionValue(size_t nodeID,
                                                       std::vector<size_t> oob_sampleIDs_nodeID)
  {
    // Compute the variance of the parent node:
    double var_parent = computeVariance(oob_sampleIDs_nodeID);

    // Determine the OOB observations that are assigned to the left and right child nodes:
    std::vector<size_t> oob_sampleIDs_left_child;
    oob_sampleIDs_left_child.reserve(oob_sampleIDs_nodeID.size());

    std::vector<size_t> oob_sampleIDs_right_child;
    oob_sampleIDs_right_child.reserve(oob_sampleIDs_nodeID.size());

    for (size_t i = 0; i < oob_sampleIDs_nodeID.size(); ++i)
    {
      size_t sampleID = oob_sampleIDs_nodeID[i];
      double value = data->get(sampleID, split_varIDs[nodeID]);

      if (value <= split_values[nodeID])
      {
        oob_sampleIDs_left_child.push_back(sampleID);
      }
      else
      {
        oob_sampleIDs_right_child.push_back(sampleID);
      }
    }

    // Variance of left child
    double var_left = 0.0;
    if (!oob_sampleIDs_left_child.empty())
    {
      var_left = computeVariance(oob_sampleIDs_left_child);
    }

    // Variance of right child
    double var_right = 0.0;
    if (!oob_sampleIDs_right_child.empty())
    {
      var_right = computeVariance(oob_sampleIDs_right_child);
    }

    // Compute the variance reduction
    const double n_parent = static_cast<double>(oob_sampleIDs_nodeID.size());

    double var_reduction =
        var_parent - (static_cast<double>(oob_sampleIDs_left_child.size()) / n_parent) * var_left - (static_cast<double>(oob_sampleIDs_right_child.size()) / n_parent) * var_right;

    return var_reduction;
  }

  // Compute the OOB split criterion value for the node after permuting the OOB observations (unity VIM).
  double TreeRegression::computeOOBSplitCriterionValuePermuted(size_t nodeID,
                                                               std::vector<size_t> oob_sampleIDs_nodeID,
                                                               std::vector<size_t> permutations)
  {
    // Compute the variance of the parent node:
    double var_parent = computeVariance(oob_sampleIDs_nodeID);

    // Determine the OOB observations that are assigned to the left and right child nodes
    // after permuting the values of the split variable:
    std::vector<size_t> oob_sampleIDs_left_child;
    oob_sampleIDs_left_child.reserve(oob_sampleIDs_nodeID.size());
    std::vector<size_t> oob_sampleIDs_right_child;
    oob_sampleIDs_right_child.reserve(oob_sampleIDs_nodeID.size());

    for (size_t i = 0; i < oob_sampleIDs_nodeID.size(); ++i)
    {
      // Use permuted sampleID for the split variable value,
      // but assign the *original* OOB sample (oob_sampleIDs_nodeID[i]) to left/right.
      size_t sampleID_perm = permutations[i];
      double value = data->get(sampleID_perm, split_varIDs[nodeID]);

      if (value <= split_values[nodeID])
      {
        oob_sampleIDs_left_child.push_back(oob_sampleIDs_nodeID[i]);
      }
      else
      {
        oob_sampleIDs_right_child.push_back(oob_sampleIDs_nodeID[i]);
      }
    }

    // Variance of left child:
    double var_left = 0.0;
    if (!oob_sampleIDs_left_child.empty())
    {
      var_left = computeVariance(oob_sampleIDs_left_child);
    }

    // Variance of right child:
    double var_right = 0.0;
    if (!oob_sampleIDs_right_child.empty())
    {
      var_right = computeVariance(oob_sampleIDs_right_child);
    }

    // Compute the variance reduction:
    const double n_parent = static_cast<double>(oob_sampleIDs_nodeID.size());

    double var_reduction =
        var_parent - (static_cast<double>(oob_sampleIDs_left_child.size()) / n_parent) * var_left - (static_cast<double>(oob_sampleIDs_right_child.size()) / n_parent) * var_right;

    return var_reduction;
  }

  // Evaluate a random candidate tree root.
  double TreeRegression::evaluateRandomTree(const std::vector<size_t> &terminal_nodes)
  {

    // We compute the partition criterion in the "equivalent" form:
    // score = (1/N) * sum_{m=1..M} ( (sum_{i in leaf m} y_i)^2 / N_m )
    // Maximizing this score is equivalent (up to an additive constant) to maximizing
    // V - sum_m (N_m/N) * V_m.

    double score = 0.0;
    size_t total_num_samples = 0;

    // Cache dependent variable ID if you store it as a member; otherwise keep your existing access pattern.
    // Here we assume you have a member `dependent_varID` and `data` (like in your split code).
    // If your implementation stores responses differently, adapt the "response" line accordingly.

    for (size_t nodeID : terminal_nodes)
    {

      const size_t start_pos = start_pos_loop[nodeID];
      const size_t end_pos = end_pos_loop[nodeID];

      const size_t node_sample_size = end_pos - start_pos;
      if (node_sample_size == 0)
      {
        continue; // should not happen, but avoids division by zero
      }

      // Sum of responses in the terminal node
      double sum_y = 0.0;
      for (size_t pos = start_pos; pos < end_pos; ++pos)
      {
        const size_t sampleID = sampleIDs[pos];
        const double y = data->get(sampleID, dependent_varID);
        sum_y += y;
      }

      // Add leaf contribution
      score += (sum_y * sum_y) / static_cast<double>(node_sample_size);

      // Total samples
      total_num_samples += node_sample_size;
    }

    if (total_num_samples == 0)
    {
      return 0.0;
    }

    // Normalize by total sample size (not strictly necessary)
    score /= static_cast<double>(total_num_samples);

    return score;
  }

  // Split in a tree sprout.
  bool TreeRegression::splitNodeInternal(size_t nodeID, std::vector<size_t> &possible_split_varIDs)
  {

    size_t num_samples_node = end_pos[nodeID] - start_pos[nodeID];

    // Stop if maximum node size or depth reached
    if (num_samples_node <= min_node_size || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth))
    {
      split_values[nodeID] = estimate(nodeID);
      return true;
    }

    // Check if node is pure and set split_value to estimate and stop if pure
    bool pure = true;
    double pure_value = 0;
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      double value = data->get(sampleID, dependent_varID);
      if (pos != start_pos[nodeID] && value != pure_value)
      {
        pure = false;
        break;
      }
      pure_value = value;
    }
    if (pure)
    {
      split_values[nodeID] = pure_value;
      return true;
    }

    // Find best split, stop if no decrease of impurity
    bool stop;
    if (splitrule == MAXSTAT)
    {
      stop = findBestSplitMaxstat(nodeID, possible_split_varIDs);
    }
    else if (splitrule == EXTRATREES)
    {
      stop = findBestSplitExtraTrees(nodeID, possible_split_varIDs);
    }
    else
    {
      stop = findBestSplit(nodeID, possible_split_varIDs);
    }

    if (stop)
    {
      split_values[nodeID] = estimate(nodeID);
      return true;
    }

    return false;
  }

  // Check whether the current node in a random candidate tree root is final.
  bool TreeRegression::checkWhetherFinalRandom(size_t nodeID)
  {

    // Stop if maximum node size or depth reached
    size_t num_samples_node = end_pos_loop[nodeID] - start_pos_loop[nodeID];
    if (num_samples_node <= min_node_size_root || (nodeID >= last_left_nodeID_loop && max_depth_root > 0 && depth >= max_depth_root))
    {
      return true;
    }

    // Check if node is pure and set split_value to estimate and stop if pure
    bool pure = true;
    size_t pos = start_pos_loop[nodeID];
    size_t sampleID = sampleIDs[pos];
    double pure_value = data->get(sampleID, dependent_varID);
    ++pos;
    for (; pos < end_pos_loop[nodeID]; ++pos)
    {
      sampleID = sampleIDs[pos];
      double value = data->get(sampleID, dependent_varID);
      if (value != pure_value)
      {
        pure = false;
        break;
      }
    }

    if (pure)
    {
      return true;
    }

    return false;
  }

  // Create an empty node in a random candidate tree root.
  void TreeRegression::createEmptyNodeRandomTreeInternal()
  {
    // Empty on purpose
  }

  // Create an empty node in a tree sprout.
  void TreeRegression::createEmptyNodeFullTreeInternal()
  {
    // Empty on purpose
  }

  // Function used to clear some objects from the candidate random tree roots.
  void TreeRegression::clearRandomTreeInternal()
  {
    // Empty on purpose
  }

  // Find the best split for a node in the tree sprout. 
  bool TreeRegression::findBestSplit(size_t nodeID, std::vector<size_t> &possible_split_varIDs)
  {

    size_t num_samples_node = end_pos[nodeID] - start_pos[nodeID];
    double best_decrease = -1;
    size_t best_varID = 0;
    double best_value = 0;

    // Compute sum of responses in node
    double sum_node = 0;
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      sum_node += data->get(sampleID, dependent_varID);
    }

    // For all possible split variables
    for (auto &varID : possible_split_varIDs)
    {

      // Find best split value, if ordered consider all values as split values, else all 2-partitions
      if (data->isOrderedVariable(varID))
      {

        // Use memory saving method if option set
        if (memory_saving_splitting)
        {
          findBestSplitValueSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease);
        }
        else
        {
          // Use faster method for both cases
          double q = (double)num_samples_node / (double)data->getNumUniqueDataValues(varID);
          if (q < Q_THRESHOLD)
          {
            findBestSplitValueSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease);
          }
          else
          {
            findBestSplitValueLargeQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease);
          }
        }
      }
      else
      {
        findBestSplitValueUnordered(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease);
      }
    }

    // Stop if no good split found
    if (best_decrease < 0)
    {
      return true;
    }

    // Save best values
    split_varIDs[nodeID] = best_varID;
    split_values[nodeID] = best_value;

    return false;
  }

  void TreeRegression::findBestSplitValueSmallQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node,
                                                double &best_value, size_t &best_varID, double &best_decrease)
  {

    // Create possible split values
    std::vector<double> possible_split_values;
    data->getAllValues(possible_split_values, sampleIDs, varID, start_pos[nodeID], end_pos[nodeID]);

    // Try next variable if all equal for this
    if (possible_split_values.size() < 2)
    {
      return;
    }

    // -1 because no split possible at largest value
    const size_t num_splits = possible_split_values.size() - 1;
    if (memory_saving_splitting)
    {
      std::vector<double> sums_right(num_splits);
      std::vector<size_t> n_right(num_splits);
      findBestSplitValueSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease,
                               possible_split_values, sums_right, n_right);
    }
    else
    {
      std::fill_n(sums.begin(), num_splits, 0);
      std::fill_n(counter.begin(), num_splits, 0);
      findBestSplitValueSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease,
                               possible_split_values, sums, counter);
    }
  }

  void TreeRegression::findBestSplitValueSmallQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node,
                                                double &best_value, size_t &best_varID, double &best_decrease, std::vector<double> possible_split_values,
                                                std::vector<double> &sums_right, std::vector<size_t> &n_right)
  {
    // -1 because no split possible at largest value
    const size_t num_splits = possible_split_values.size() - 1;

    // Sum in right child and possbile split
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      double value = data->get(sampleID, varID);
      double response = data->get(sampleID, dependent_varID);

      // Count samples until split_value reached
      for (size_t i = 0; i < num_splits; ++i)
      {
        if (value > possible_split_values[i])
        {
          ++n_right[i];
          sums_right[i] += response;
        }
        else
        {
          break;
        }
      }
    }

    // Compute decrease of impurity for each possible split
    for (size_t i = 0; i < num_splits; ++i)
    {

      // Stop if one child empty
      size_t n_left = num_samples_node - n_right[i];
      if (n_left == 0 || n_right[i] == 0)
      {
        continue;
      }

      double sum_right = sums_right[i];
      double sum_left = sum_node - sum_right;
      double decrease = sum_left * sum_left / (double)n_left + sum_right * sum_right / (double)n_right[i];

      // If better than before, use this
      if (decrease > best_decrease)
      {
        best_value = (possible_split_values[i] + possible_split_values[i + 1]) / 2;
        best_varID = varID;
        best_decrease = decrease;

        // Use smaller value if average is numerically the same as the larger value
        if (best_value == possible_split_values[i + 1])
        {
          best_value = possible_split_values[i];
        }
      }
    }
  }

  void TreeRegression::findBestSplitValueLargeQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node,
                                                double &best_value, size_t &best_varID, double &best_decrease)
  {

    // Set counters to 0
    size_t num_unique = data->getNumUniqueDataValues(varID);
    std::fill_n(counter.begin(), num_unique, 0);
    std::fill_n(sums.begin(), num_unique, 0);

    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      size_t index = data->getIndex(sampleID, varID);

      sums[index] += data->get(sampleID, dependent_varID);
      ++counter[index];
    }

    size_t n_left = 0;
    double sum_left = 0;

    // Compute decrease of impurity for each split
    for (size_t i = 0; i < num_unique - 1; ++i)
    {

      // Stop if nothing here
      if (counter[i] == 0)
      {
        continue;
      }

      n_left += counter[i];
      sum_left += sums[i];

      // Stop if right child empty
      size_t n_right = num_samples_node - n_left;
      if (n_right == 0)
      {
        break;
      }

      double sum_right = sum_node - sum_left;
      double decrease = sum_left * sum_left / (double)n_left + sum_right * sum_right / (double)n_right;

      // If better than before, use this
      if (decrease > best_decrease)
      {
        // Find next value in this node
        size_t j = i + 1;
        while (j < num_unique && counter[j] == 0)
        {
          ++j;
        }

        // Use mid-point split
        best_value = (data->getUniqueDataValue(varID, i) + data->getUniqueDataValue(varID, j)) / 2;
        best_varID = varID;
        best_decrease = decrease;

        // Use smaller value if average is numerically the same as the larger value
        if (best_value == data->getUniqueDataValue(varID, j))
        {
          best_value = data->getUniqueDataValue(varID, i);
        }
      }
    }
  }

  void TreeRegression::findBestSplitValueUnordered(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node,
                                                   double &best_value, size_t &best_varID, double &best_decrease)
  {

    // Create possible split values
    std::vector<double> factor_levels;
    data->getAllValues(factor_levels, sampleIDs, varID, start_pos[nodeID], end_pos[nodeID]);

    // Try next variable if all equal for this
    if (factor_levels.size() < 2)
    {
      return;
    }

    // Number of possible splits is 2^num_levels
    size_t num_splits = (1 << factor_levels.size());

    // Compute decrease of impurity for each possible split
    // Split where all left (0) or all right (1) are excluded
    // The second half of numbers is just left/right switched the first half -> Exclude second half
    for (size_t local_splitID = 1; local_splitID < num_splits / 2; ++local_splitID)
    {

      // Compute overall splitID by shifting local factorIDs to global positions
      size_t splitID = 0;
      for (size_t j = 0; j < factor_levels.size(); ++j)
      {
        if ((local_splitID & (1 << j)))
        {
          double level = factor_levels[j];
          size_t factorID = floor(level) - 1;
          splitID = splitID | (1 << factorID);
        }
      }

      // Initialize
      double sum_right = 0;
      size_t n_right = 0;

      // Sum in right child
      for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
      {
        size_t sampleID = sampleIDs[pos];
        double response = data->get(sampleID, dependent_varID);
        double value = data->get(sampleID, varID);
        size_t factorID = floor(value) - 1;

        // If in right child, count
        // In right child, if bitwise splitID at position factorID is 1
        if ((splitID & (1 << factorID)))
        {
          ++n_right;
          sum_right += response;
        }
      }
      size_t n_left = num_samples_node - n_right;

      // Sum of squares
      double sum_left = sum_node - sum_right;
      double decrease = sum_left * sum_left / (double)n_left + sum_right * sum_right / (double)n_right;

      // If better than before, use this
      if (decrease > best_decrease)
      {
        best_value = splitID;
        best_varID = varID;
        best_decrease = decrease;
      }
    }
  }

  bool TreeRegression::findBestSplitMaxstat(size_t nodeID, std::vector<size_t> &possible_split_varIDs)
  {

    size_t num_samples_node = end_pos[nodeID] - start_pos[nodeID];

    // Compute ranks
    std::vector<double> response;
    response.reserve(num_samples_node);
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      response.push_back(data->get(sampleID, dependent_varID));
    }
    std::vector<double> ranks = rank(response);

    // Save split stats
    std::vector<double> pvalues;
    pvalues.reserve(possible_split_varIDs.size());
    std::vector<double> values;
    values.reserve(possible_split_varIDs.size());
    std::vector<double> candidate_varIDs;
    candidate_varIDs.reserve(possible_split_varIDs.size());
    std::vector<double> test_statistics;
    test_statistics.reserve(possible_split_varIDs.size());

    // Compute p-values
    for (auto &varID : possible_split_varIDs)
    {

      // Get all observations
      std::vector<double> x;
      x.reserve(num_samples_node);
      for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
      {
        size_t sampleID = sampleIDs[pos];
        x.push_back(data->get(sampleID, varID));
      }

      // Order by x
      std::vector<size_t> indices = order(x, false);
      // std::vector<size_t> indices = orderInData(data, sampleIDs[nodeID], varID, false);

      // Compute maximally selected rank statistics
      double best_maxstat;
      double best_split_value;
      maxstat(ranks, x, indices, best_maxstat, best_split_value, minprop, 1 - minprop);
      // maxstatInData(scores, data, sampleIDs[nodeID], varID, indices, best_maxstat, best_split_value, minprop, 1 - minprop);

      if (best_maxstat > -1)
      {
        // Compute number of samples left of cutpoints
        std::vector<size_t> num_samples_left = numSamplesLeftOfCutpoint(x, indices);
        // std::vector<size_t> num_samples_left = numSamplesLeftOfCutpointInData(data, sampleIDs[nodeID], varID, indices);

        // Compute p-values
        double pvalue_lau92 = maxstatPValueLau92(best_maxstat, minprop, 1 - minprop);
        double pvalue_lau94 = maxstatPValueLau94(best_maxstat, minprop, 1 - minprop, num_samples_node, num_samples_left);

        // Use minimum of Lau92 and Lau94
        double pvalue = std::min(pvalue_lau92, pvalue_lau94);

        // Save split stats
        pvalues.push_back(pvalue);
        values.push_back(best_split_value);
        candidate_varIDs.push_back(varID);
        test_statistics.push_back(best_maxstat);
      }
    }

    double adjusted_best_pvalue = std::numeric_limits<double>::max();
    size_t best_varID = 0;
    double best_value = 0;
    double best_maxstat = 0;

    if (pvalues.size() > 0)
    {
      // Adjust p-values with Benjamini/Hochberg
      std::vector<double> adjusted_pvalues = adjustPvalues(pvalues);

      // Use smallest p-value
      double min_pvalue = std::numeric_limits<double>::max();
      for (size_t i = 0; i < pvalues.size(); ++i)
      {
        if (pvalues[i] < min_pvalue)
        {
          min_pvalue = pvalues[i];
          best_varID = candidate_varIDs[i];
          best_value = values[i];
          adjusted_best_pvalue = adjusted_pvalues[i];
          best_maxstat = test_statistics[i];
        }
      }
    }

    // Stop if no good split found (this is terminal node).
    if (adjusted_best_pvalue > alpha)
    {
      return true;
    }
    else
    {
      // If not terminal node save best values
      split_varIDs[nodeID] = best_varID;
      split_values[nodeID] = best_value;

      return false;
    }
  }

  bool TreeRegression::findBestSplitExtraTrees(size_t nodeID, std::vector<size_t> &possible_split_varIDs)
  {

    size_t num_samples_node = end_pos[nodeID] - start_pos[nodeID];
    double best_decrease = -1;
    size_t best_varID = 0;
    double best_value = 0;

    // Compute sum of responses in node
    double sum_node = 0;
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      sum_node += data->get(sampleID, dependent_varID);
    }

    // For all possible split variables
    for (auto &varID : possible_split_varIDs)
    {

      // Find best split value, if ordered consider all values as split values, else all 2-partitions
      if (data->isOrderedVariable(varID))
      {
        findBestSplitValueExtraTrees(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease);
      }
      else
      {
        findBestSplitValueExtraTreesUnordered(nodeID, varID, sum_node, num_samples_node, best_value, best_varID,
                                              best_decrease);
      }
    }

    // Stop if no good split found
    if (best_decrease < 0)
    {
      return true;
    }

    // Save best values
    split_varIDs[nodeID] = best_varID;
    split_values[nodeID] = best_value;

    return false;
  }

  void TreeRegression::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node,
                                                    double &best_value, size_t &best_varID, double &best_decrease)
  {

    // Get min/max values of covariate in node
    double min;
    double max;
    data->getMinMaxValues(min, max, sampleIDs, varID, start_pos[nodeID], end_pos[nodeID]);

    // Try next variable if all equal for this
    if (min == max)
    {
      return;
    }

    // Create possible split values: Draw randomly between min and max
    std::vector<double> possible_split_values;
    std::uniform_real_distribution<double> udist(min, max);
    possible_split_values.reserve(num_random_splits);
    for (size_t i = 0; i < num_random_splits; ++i)
    {
      possible_split_values.push_back(udist(random_number_generator));
    }

    const size_t num_splits = possible_split_values.size();
    if (memory_saving_splitting)
    {
      std::vector<double> sums_right(num_splits);
      std::vector<size_t> n_right(num_splits);
      findBestSplitValueExtraTrees(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease,
                                   possible_split_values, sums_right, n_right);
    }
    else
    {
      std::fill_n(sums.begin(), num_splits, 0);
      std::fill_n(counter.begin(), num_splits, 0);
      findBestSplitValueExtraTrees(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease,
                                   possible_split_values, sums, counter);
    }
  }

  void TreeRegression::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node,
                                                    double &best_value, size_t &best_varID, double &best_decrease, std::vector<double> possible_split_values,
                                                    std::vector<double> &sums_right, std::vector<size_t> &n_right)
  {
    const size_t num_splits = possible_split_values.size();

    // Sum in right child and possbile split
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      double value = data->get(sampleID, varID);
      double response = data->get(sampleID, dependent_varID);

      // Count samples until split_value reached
      for (size_t i = 0; i < num_splits; ++i)
      {
        if (value > possible_split_values[i])
        {
          ++n_right[i];
          sums_right[i] += response;
        }
        else
        {
          break;
        }
      }
    }

    // Compute decrease of impurity for each possible split
    for (size_t i = 0; i < num_splits; ++i)
    {

      // Stop if one child empty
      size_t n_left = num_samples_node - n_right[i];
      if (n_left == 0 || n_right[i] == 0)
      {
        continue;
      }

      double sum_right = sums_right[i];
      double sum_left = sum_node - sum_right;
      double decrease = sum_left * sum_left / (double)n_left + sum_right * sum_right / (double)n_right[i];

      // If better than before, use this
      if (decrease > best_decrease)
      {
        best_value = possible_split_values[i];
        best_varID = varID;
        best_decrease = decrease;
      }
    }
  }

  void TreeRegression::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t varID, double sum_node,
                                                             size_t num_samples_node, double &best_value, size_t &best_varID, double &best_decrease)
  {

    size_t num_unique_values = data->getNumUniqueDataValues(varID);

    // Get all factor indices in node
    std::vector<bool> factor_in_node(num_unique_values, false);
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      size_t index = data->getIndex(sampleID, varID);
      factor_in_node[index] = true;
    }

    // Vector of indices in and out of node
    std::vector<size_t> indices_in_node;
    std::vector<size_t> indices_out_node;
    indices_in_node.reserve(num_unique_values);
    indices_out_node.reserve(num_unique_values);
    for (size_t i = 0; i < num_unique_values; ++i)
    {
      if (factor_in_node[i])
      {
        indices_in_node.push_back(i);
      }
      else
      {
        indices_out_node.push_back(i);
      }
    }

    // Generate num_random_splits splits
    for (size_t i = 0; i < num_random_splits; ++i)
    {
      std::vector<size_t> split_subset;
      split_subset.reserve(num_unique_values);

      // Draw random subsets, sample all partitions with equal probability
      if (indices_in_node.size() > 1)
      {
        size_t num_partitions = (2 << (indices_in_node.size() - 1)) - 2; // 2^n-2 (don't allow full or empty)
        std::uniform_int_distribution<size_t> udist(1, num_partitions);
        size_t splitID_in_node = udist(random_number_generator);
        for (size_t j = 0; j < indices_in_node.size(); ++j)
        {
          if ((splitID_in_node & (1 << j)) > 0)
          {
            split_subset.push_back(indices_in_node[j]);
          }
        }
      }
      if (indices_out_node.size() > 1)
      {
        size_t num_partitions = (2 << (indices_out_node.size() - 1)) - 1; // 2^n-1 (allow full or empty)
        std::uniform_int_distribution<size_t> udist(0, num_partitions);
        size_t splitID_out_node = udist(random_number_generator);
        for (size_t j = 0; j < indices_out_node.size(); ++j)
        {
          if ((splitID_out_node & (1 << j)) > 0)
          {
            split_subset.push_back(indices_out_node[j]);
          }
        }
      }

      // Assign union of the two subsets to right child
      size_t splitID = 0;
      for (auto &idx : split_subset)
      {
        splitID |= 1 << idx;
      }

      // Initialize
      double sum_right = 0;
      size_t n_right = 0;

      // Sum in right child
      for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
      {
        size_t sampleID = sampleIDs[pos];
        double response = data->get(sampleID, dependent_varID);
        double value = data->get(sampleID, varID);
        size_t factorID = floor(value) - 1;

        // If in right child, count
        // In right child, if bitwise splitID at position factorID is 1
        if ((splitID & (1 << factorID)))
        {
          ++n_right;
          sum_right += response;
        }
      }
      size_t n_left = num_samples_node - n_right;

      // Sum of squares
      double sum_left = sum_node - sum_right;
      double decrease = sum_left * sum_left / (double)n_left + sum_right * sum_right / (double)n_right;

      // If better than before, use this
      if (decrease > best_decrease)
      {
        best_value = splitID;
        best_varID = varID;
        best_decrease = decrease;
      }
    }
  }

} // namespace unityForest
