// poLCAParallel
// Copyright (C) 2024 Sherman Lo

// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License along
// with this program; if not, write to the Free Software Foundation, Inc.,
// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

#ifndef POLCAPARALLEL_INCLUDE_EM_ALGORITHM_NAN_H_
#define POLCAPARALLEL_INCLUDE_EM_ALGORITHM_NAN_H_

#include <cstddef>
#include <span>
#include <vector>

#include "arma.h"
#include "em_algorithm.h"
#include "em_algorithm_regress.h"
#include "util.h"

namespace polca_parallel {

/**
 * Template class for EmAlgorithmNan and EmAlgorithmNanRegress
 *
 * This template class is a direct subclass of either EmAlgorithm and
 * EmAlgorithmRegress and reimplement methods for NaN handling. Provide either
 * of the superclasses EmAlgorithm and EmAlgorithmRegress via the template
 * parameter
 *
 * For NaN handling, those values (also known as missing values) in the
 * responses are encoded as zeros
 *
 * @tparam Either EmAlgorithm or EmAlgorithmRegress
 */
template <typename T>
class EmAlgorithmNanTemplate : public T {
 protected:
  /** Temporary variable for summing posteriors over categories */
  std::vector<double> posterior_sum_;

 public:
  /**
   * EM algorithm with NaN handling. NaN are encoded as zeros in reponses
   *
   * Please see and use EmAlgorithmNan and EmAlgorithmNanRegress rather than
   * this template class instead
   *
   * @copydoc EmAlgorithm::EmAlgorithm
   */
  EmAlgorithmNanTemplate(std::span<const double> features,
                         std::span<const int> responses,
                         std::span<const double> initial_prob,
                         std::size_t n_data, std::size_t n_feature,
                         NOutcomes n_outcomes, std::size_t n_cluster,
                         unsigned int max_iter, double tolerance,
                         std::span<double> posterior, std::span<double> prior,
                         std::span<double> estimated_prob,
                         std::span<double> regress_coeff);

  ~EmAlgorithmNanTemplate() override = default;

 protected:
  /**
   * Overridden to handle and ignore reponse zero
   *
   * Overridden to handle and ignore reponse zero and modifies the member
   * variable EmAlgorithmNanTemplate::posterior_sum_
   *
   * @copydoc EmAlgorithm::WeightedSumProb
   */
  void WeightedSumProb(const std::size_t cluster_index) override;

  /**
   * Overridden to estimate probabilities using <code>posterior_sum</code>
   *
   * Overridden to estimate probabilities using
   * EmAlgorithmNanTemplate::posterior_sum_
   *
   * @copydoc EmAlgorithm::NormalWeightedSumProb
   */
  void NormalWeightedSumProb(const std::size_t cluster_index) override;

  [[nodiscard]] double Likelihood(
      std::span<const int> responses_i,
      const arma::Col<double>& estimated_prob) const override;
};

/**
 * EM algorithm with NaN handling
 *
 * EM algorithm with NaN handling. NaN are encoded as zeros in reponses. The
 * methods responsible for probability estimation are overriden.
 *
 */
class EmAlgorithmNan : public EmAlgorithmNanTemplate<EmAlgorithm> {
 public:
  /**
   * EM algorithm with NaN handling. NaN are encoded as zeros in reponses
   *
   * @copydoc EmAlgorithm::EmAlgorithm(std::span<const double> features,
                 std::span<const int> responses,
                 std::span<const double> initial_prob, std::size_t n_data,
                 std::size_t n_feature, NOutcomes n_outcomes,
                 std::size_t n_cluster, unsigned int max_iter, double tolerance,
                 std::span<double> posterior, std::span<double> prior,
                 std::span<double> estimated_prob,
                 std::span<double> regress_coeff)
   */
  EmAlgorithmNan(std::span<const double> features,
                 std::span<const int> responses,
                 std::span<const double> initial_prob, std::size_t n_data,
                 std::size_t n_feature, NOutcomes n_outcomes,
                 std::size_t n_cluster, unsigned int max_iter, double tolerance,
                 std::span<double> posterior, std::span<double> prior,
                 std::span<double> estimated_prob,
                 std::span<double> regress_coeff);

  /**
   * EM algorithm with NaN handling. NaN are encoded as zeros in reponses
   *
   * @copydoc EmAlgorithm::EmAlgorithm(std::span<const int> responses,
                 std::span<const double> initial_prob, std::size_t n_data,
                 NOutcomes n_outcomes, std::size_t n_cluster,
                 unsigned int max_iter, double tolerance,
                 std::span<double> posterior, std::span<double> prior,
                 std::span<double> estimated_prob)
   */
  EmAlgorithmNan(std::span<const int> responses,
                 std::span<const double> initial_prob, std::size_t n_data,
                 NOutcomes n_outcomes, std::size_t n_cluster,
                 unsigned int max_iter, double tolerance,
                 std::span<double> posterior, std::span<double> prior,
                 std::span<double> estimated_prob);

  ~EmAlgorithmNan() override = default;
};

/**
 * EM algorithm for regression with NaN handling
 *
 * EM algorithm for regression with NaN handling. NaN are encoded as zeros in
 * reponses. The methods responsible for probability estimation are overriden.
 *
 */
class EmAlgorithmNanRegress
    : public EmAlgorithmNanTemplate<EmAlgorithmRegress> {
 public:
  /**
   * EM algorithm for regression with NaN handling, encoded as zero in responses
   *
   * @copydoc EmAlgorithmRegress::EmAlgorithmRegress
   */
  EmAlgorithmNanRegress(std::span<const double> features,
                        std::span<const int> responses,
                        std::span<const double> initial_prob,
                        std::size_t n_data, std::size_t n_feature,
                        NOutcomes n_outcomes, std::size_t n_cluster,
                        unsigned int max_iter, double tolerance,
                        std::span<double> posterior, std::span<double> prior,
                        std::span<double> estimated_prob,
                        std::span<double> regress_coeff);

  ~EmAlgorithmNanRegress() override = default;
};

/**
 * Static version EmAlgorithmNanTemplate::WeightedSumProb()
 *
 * Static version EmAlgorithmNanTemplate::WeightedSumProb() and the NaN version
 * of EmAlgorithm::WeightedSumProb(). This is then used to override
 * EmAlgorithm::WeightedSumProb() in EmAlgorithmNanTemplate::WeightedSumProb()
 *
 * It ignores response zero in the cumulative sum of posteriors over categories
 * when calculating the <code>posterior_sum</code>. This is used for
 * EmAlgorithmNanTemplate::posterior_sum_
 *
 * @param cluster_index Which cluster to consider
 * @param responses Design matrix <b>transposed</b> of responses, matrix
 * containing outcomes/responses for each category as integers 1, 2, 3, ....
 * Missing values may be encoded as 0. The matrix has dimensions
 * <ul>
 *   <li>dim 0: for each category</li>
 *   <li>dim 1: for each data point</li>
 * </ul>
 * @param n_outcomes Vector of number of outcomes for each category
 * @param posterior Design matrix of posterior probabilities (also called
 * responsibility). It's the probability a data point is in cluster m given
 * responses. The matrix has the following dimensions
 * <ul>
 *   <li>dim 0: for each data</li>
 *   <li>dim 1: for each cluster</li>
 * </ul>
 * @param estimated_prob <b>Modified</b> To contain the sum of posteriors for
 * each outcome, conditioned on the category and cluster. A flattened list in
 * the following order
 * <ul>
 *   <li>
 *     dim 0: for each outcome | category (inner), for each category (outer)
 *   </li>
 *   <li>dim 1: for each cluster</li>
 * </ul>
 * @param posterior_sum <b>Modified</b> To store the cumulative posterior sum
 * over categories
 */
void NanWeightedSumProb(const std::size_t cluster_index,
                        std::span<const int> responses,
                        std::span<const std::size_t> n_outcomes,
                        const arma::Mat<double>& posterior,
                        arma::Mat<double>& estimated_prob,
                        std::vector<double>& posterior_sum);

/**
 * Static version of EmAlgorithmNanTemplate::NormalWeightedSumProb()
 *
 * Static version of EmAlgorithmNanTemplate::NormalWeightedSumProb() and the NaN
 * version of EmAlgorithm::NormalWeightedSumProb(). Used to override
 * EmAlgorithm::NormalWeightedSumProb() in
 * EmAlgorithmNanTemplate::NormalWeightedSumProb()
 *
 * It estimate probabilities using
 * <code>EmAlgorithmNanTemplate::posterior_sum_</code>, calculated from
 * NanWeightedSumProb() or EmAlgorithmNanTemplate::WeightedSumProb()
 *
 * @param cluster_index Which cluster to consider
 * @param n_outcomes Vector of number of outcomes for each category
 * @param posterior_sum Vector which stores the resulting cumulative posterior
 * sum over categories
 * @param estimated_prob <b>Modified</b> to contain the estimated response
 * probabilities for each outcome, conditioned on the category and cluster. A
 * flattened list in the following order
 * <ul>
 *   <li>
 *     dim 0: for each outcome | category (inner), for each category (outer)
 *   </li>
 *   <li>dim 1: for each cluster</li>
 * </ul>
 */
void NanNormalWeightedSumProb(const std::size_t cluster_index,
                              std::span<const std::size_t> n_outcomes,
                              std::vector<double>& posterior_sum,
                              arma::Mat<double>& estimated_prob);

}  // namespace polca_parallel

#endif  // POLCAPARALLEL_INCLUDE_EM_ALGORITHM_NAN_H_
