//
//  r_robust_utils.cc
//  pense
//
//  Created by David Kepplinger on 2019-04-03.
//  Copyright © 2019 David Kepplinger. All rights reserved.
//

#include "r_robust_utils.hpp"

#include "constants.hpp"
#include "rcpp_integration.hpp"
#include "r_interface_utils.hpp"
#include "rcpp_utils.hpp"
#include "robust_scale_location.hpp"

using Rcpp::as;
using pense::Mscale;

namespace {
constexpr int kDefaultMLocationMaxIt = 100;
}  // namespace

namespace pense {
namespace r_interface {
//! Evaluate the Rho Function
//!
//! @param x numeric values
//! @param deriv whether to evaluate the rho function (deriv=0), the psi function (deriv=1),
//               or the 2nd deriative rho'' (deriv=2)
//! @param std whether to evaluate the standardized rho function (sup rho(x) = 1) or not
//! @param scale the scale of the values in x
//! @param rho_opts a list of options for the rho function
SEXP RhoFun(SEXP r_x, SEXP r_deriv, SEXP r_std, SEXP r_scale, SEXP r_rho_opts) noexcept {
  BEGIN_RCPP
  auto x = MakeVectorView(r_x);
  int deriv = as<int>(r_deriv);
  double scale = as<double>(r_scale);
  bool std = as<bool>(r_std);
  auto rho_opts = as<Rcpp::List>(r_rho_opts);
  auto rho = RhoFactory(r_rho_opts);

  switch(deriv) {
  case 2:
    if (std) {
      return Rcpp::wrap(rho->SecondDerivativeStd(*x, scale));
    }
    return Rcpp::wrap(rho->SecondDerivative(*x, scale));
  case 1:
    if (std) {
      return Rcpp::wrap(rho->DerivativeStd(*x, scale));
    }
    return Rcpp::wrap(rho->Derivative(*x, scale));
  case 0:
  default:
    if (std) {
      return Rcpp::wrap(rho->EvaluateStd(*x, scale));
    }
    return Rcpp::wrap(rho->operator()(*x, scale));
  }

  END_RCPP;
}

//! Compute the tau-Scale of Centered Values
//!
//! @param x numeric values.
//! @return the tau-scale of the centered values.
SEXP TauSize(SEXP r_x) noexcept {
  BEGIN_RCPP
  auto x = MakeVectorView(r_x);
  return Rcpp::wrap(pense::TauSize(*x));
  END_RCPP;
}

//! Compute the M-scale of Centered Values
//!
//! @param x numeric values.
//! @param mscale_opts a list of options for the M-scale equation.
//! @return the M-scale of the centered values.
SEXP MScale(SEXP r_x, SEXP r_mscale_opts) noexcept {
  BEGIN_RCPP
  auto x = MakeVectorView(r_x);
  auto mscale_opts = as<Rcpp::List>(r_mscale_opts);
  return Rcpp::wrap(Mscale(mscale_opts)(*x));
  END_RCPP;
}

//! Compute the derivative M-scale function with respect to each coordinate.
//!
//! @param x numeric values.
//! @param mscale_opts a list of options for the M-scale equation.
//! @param order the order of the derivative to compute (1 or 2)
//! @return the derivative of the M-scale function.
SEXP MScaleDerivative(SEXP r_x, SEXP r_mscale_opts, SEXP r_order) noexcept {
  BEGIN_RCPP
  auto x = MakeVectorView(r_x);
  auto mscale_opts = as<Rcpp::List>(r_mscale_opts);
  auto order = as<int>(r_order);
  switch (order) {
  case 2:
    return Rcpp::wrap(Mscale(mscale_opts).GradientHessian(*x));
  case 1:
  default:
    return Rcpp::wrap(Mscale(mscale_opts).Derivative(*x));
  }
  END_RCPP;
}

//! Compute the maximum derivative of M-scale function over a grid of values
//!
//! @param x original numeric values.
//! @param grid grid of values to look for maximal derivative.
//! @param change number of elements in `x` to change.
//! @param mscale_opts a list of options for the M-scale equation.
//! @return the derivative of the M-scale function.
SEXP MaxMScaleDerivative(SEXP r_x, SEXP r_grid, SEXP r_change, SEXP r_mscale_opts) noexcept {
  BEGIN_RCPP
  auto x = as<arma::vec>(r_x);
  auto grid = MakeVectorView(r_grid);
  auto change = as<int>(r_change);
  auto mscale_opts = as<Rcpp::List>(r_mscale_opts);
  auto mscale = Mscale(mscale_opts);
  const auto derivatives = mscale.Derivative(x);
  double max_md = 0;
  if (derivatives.n_elem > 0) {
    max_md = arma::max(arma::abs(derivatives));
  }

  arma::uvec counters(change, arma::fill::zeros);
  int p = 0;
  do {
    for (int i = 0; i < change; ++i) {
      x[i] = grid->at(counters[i]);
    }
    const auto derivatives = mscale.Derivative(x);
    if (derivatives.n_elem > 0) {
      const double md = arma::max(arma::abs(derivatives));
      if (md > max_md) {
        max_md = md;
      }
    }

    p = change - 1;
    while (p >= 0) {
      ++counters[p];
      if (counters[p] >= grid->n_elem) {
        counters[p] = 0;
        --p;
      } else {
        p = -2;
      }
    }
  } while (p == -2);

  return Rcpp::wrap(max_md);
  END_RCPP;
}

//! Compute the maximum entry in the gradient and Hessian of the M-scale
//! function over a grid of values
//!
//! @param x original numeric values.
//! @param grid grid of values to look for maximal derivative.
//! @param change number of elements in `x` to change.
//! @param mscale_opts a list of options for the M-scale equation.
//! @return a vector with 4 elements:
//!   [0] the maximum element of the gradient,
//!   [1] the maximum element of the Hessian,
//!   [2] the M-scale associated with the maximum gradient,
//!   [3] the M-scale associated with the maximum Hessian.
SEXP MaxMScaleGradientHessian(SEXP r_x, SEXP r_grid, SEXP r_change,
                              SEXP r_mscale_opts) noexcept {
  BEGIN_RCPP
  auto x = as<arma::vec>(r_x);
  auto grid = MakeVectorView(r_grid);
  auto change = as<int>(r_change);
  auto mscale_opts = as<Rcpp::List>(r_mscale_opts);
  auto mscale = Mscale(mscale_opts);
  const auto tmp_maxima = mscale.MaxGradientHessian(x);

  if (tmp_maxima.n_elem < 1) {
    return Rcpp::wrap(tmp_maxima);
  }

  arma::vec::fixed<4> maxima = { tmp_maxima[1], tmp_maxima[2],
                                  tmp_maxima[0], tmp_maxima[0] };

  if (change < 1) {
    return Rcpp::wrap(maxima);
  }

  arma::uvec counters(change, arma::fill::zeros);
  int p = 0;
  do {
    for (int i = 0; i < change; ++i) {
      x[i] = grid->at(counters[i]);
    }
    const auto tmp_maxima = mscale.MaxGradientHessian(x);
    if (tmp_maxima.n_elem == 3) {
      if (tmp_maxima[1] > maxima[0]) {
        maxima[0] = tmp_maxima[1];
        maxima[2] = tmp_maxima[0];
      }
      if (tmp_maxima[2] > maxima[1]) {
        maxima[1] = tmp_maxima[2];
        maxima[3] = tmp_maxima[0];
      }
    }

    p = change - 1;
    while (p >= 0) {
      ++counters[p];
      if (counters[p] >= grid->n_elem) {
        counters[p] = 0;
        --p;
      } else {
        p = -2;
      }
    }
  } while (p == -2);

  return Rcpp::wrap(maxima);
  END_RCPP;
}


//! Compute the M-location
//!
//! @param x numeric values.
//! @param scale the scale of the values.
//! @param opts a list of options for the M-estimating equation.
//! @return the M-estimate of location.
SEXP MLocation(SEXP r_x, SEXP r_scale, SEXP r_opts) noexcept {
  BEGIN_RCPP
  auto x = MakeVectorView(r_x);
  auto opts = as<Rcpp::List>(r_opts);
  double const * const scale = REAL(r_scale);
  const int max_it = GetFallback(opts, "max_it", kDefaultMLocationMaxIt);
  const double convergence_tol = GetFallback(opts, "eps", kDefaultConvergenceTolerance);

  return Rcpp::wrap(pense::MLocation(*x, *RhoFactory(opts), *scale, convergence_tol, max_it));
  END_RCPP;
}

//! Compute the M-estimate of the Location and Scale
//!
//! @param x numeric values.
//! @param mscale_opts a list of options for the scale rho-function and the M-estimating equation.
//! @param location_opts a list of options for the location rho-function
//! @return a vector with 2 elements: the location and the scale estimate.
SEXP MLocationScale(SEXP r_x, SEXP r_mscale_opts, SEXP r_location_opts) noexcept {
  BEGIN_RCPP
  auto x = MakeVectorView(r_x);
  auto mscale_opts = as<Rcpp::List>(r_mscale_opts);
  auto location_opts = as<Rcpp::List>(r_location_opts);

  Mscale mscale(mscale_opts);
  auto m_loc_scale = MLocationScale(*x, mscale, *RhoFactory(location_opts));

  Rcpp::NumericVector ret_vec;
  ret_vec["location"] = m_loc_scale.location;
  ret_vec["scale"] = m_loc_scale.scale;
  return Rcpp::wrap(ret_vec);
  END_RCPP;
}
}  // namespace r_interface
}  // namespace pense
