#include <RcppEigen.h>
using namespace Rcpp;

// [[Rcpp::depends(RcppEigen)]]
// [[Rcpp::export(rng = false)]]
List process_matched_pairs_cpp(const Eigen::VectorXi &strata,
                               const Eigen::VectorXd &y,
                               const Eigen::MatrixXd &X,
                               const Eigen::VectorXd &treatment) {
  const int n = strata.size();
  const int p = X.cols();

  Eigen::VectorXd treat_vec = treatment;

  // Pre-count to pre-allocate (single pass through strata)
  int n_reservoir = 0;
  int n_discordant = 0;
  int max_strata = 0;

  for (int i = 0; i < n; i++) {
    if (strata[i] == 0) {
      n_reservoir++;
    }
    if (strata[i] > max_strata) {
      max_strata = strata[i];
    }
  }

  // Pre-allocate with upper bounds (concordant pairs will add to reservoir)

  Eigen::MatrixXd reservoir_X(n_reservoir + 2 * max_strata, p);
  Eigen::VectorXd reservoir_y(n_reservoir + 2 * max_strata);
  Eigen::VectorXd reservoir_treat(n_reservoir + 2 * max_strata);
  Eigen::VectorXi reservoir_strata(n_reservoir + 2 * max_strata);

  Eigen::MatrixXd diffs_X(max_strata, p); // at most max_strata discordant pairs
  Eigen::VectorXd diffs_y(max_strata);
  Eigen::VectorXd diffs_treat(max_strata);
  std::vector<int> discordant_idx;

  int res_idx = 0;
  int diff_idx = 0;

  // Handle reservoir (strata == 0) - fast vectorized operation
  for (int i = 0; i < n; i++) {
    if (strata[i] == 0) {
      reservoir_X.row(res_idx) = X.row(i);
      reservoir_y[res_idx] = y[i];
      reservoir_treat[res_idx] = treat_vec[i];
      reservoir_strata[res_idx] = strata[i];
      res_idx++;
    }
  }

  // Build index for fast pair lookup - O(n) instead of O(n*max_strata)
  std::vector<std::vector<int>> pair_indices(max_strata + 1);
  for (int i = 0; i < n; i++) {
    if (strata[i] > 0) {
      pair_indices[strata[i]].push_back(i);
    }
  }

  // Process all pairs
  for (int pair_num = 1; pair_num <= max_strata; pair_num++) {
    const std::vector<int> &pair = pair_indices[pair_num];

    if (pair.size() == 0) {
      continue;
    }
    if (pair.size() != 2) {
      stop("Each nonzero stratum must have exactly 2 rows.");
    }

    const int i = pair[0];
    const int j = pair[1];
    const double yi = y[i];
    const double yj = y[j];
    int current_stratum = 1;

    if (yi == yj) {
      // Concordant pair → add to reservoir
      reservoir_X.row(res_idx) = X.row(i);
      reservoir_y[res_idx] = yi;
      reservoir_treat[res_idx] = treat_vec[i];
      reservoir_strata[res_idx] = current_stratum;
      res_idx++;

      reservoir_X.row(res_idx) = X.row(j);
      reservoir_y[res_idx] = yj;
      reservoir_treat[res_idx] = treat_vec[j];
      reservoir_strata[res_idx] = current_stratum;
      current_stratum++;
      res_idx++;
    } else {
      // Discordant → compute diff (case - control)
      if (treat_vec[i] == 0) {
        discordant_idx.push_back(j);
        discordant_idx.push_back(i);
        diffs_X.row(diff_idx) = X.row(j) - X.row(i);
        diffs_y[diff_idx] = yj - yi;
        diffs_treat[diff_idx] = treat_vec[j] - treat_vec[i];
      } else {
        discordant_idx.push_back(i);
        discordant_idx.push_back(j);
        diffs_X.row(diff_idx) = X.row(i) - X.row(j);
        diffs_y[diff_idx] = yi - yj;
        diffs_treat[diff_idx] = treat_vec[i] - treat_vec[j];
      }
      diff_idx++;
    }
  }

  // Resize to actual sizes (cheap operation, just changes dimensions)
  if (res_idx > 0) {
    reservoir_X.conservativeResize(res_idx, p);
    reservoir_y.conservativeResize(res_idx);
    reservoir_treat.conservativeResize(res_idx);
    reservoir_strata.conservativeResize(res_idx);
  }

  if (diff_idx > 0) {
    diffs_X.conservativeResize(diff_idx, p);
    diffs_y.conservativeResize(diff_idx);
    diffs_treat.conservativeResize(diff_idx);
  }

  return List::create(
      _["X_concordant"] = res_idx > 0 ? wrap(reservoir_X) : R_NilValue,
      _["y_concordant"] = res_idx > 0 ? wrap(reservoir_y) : R_NilValue,
      _["treatment_concordant"] =
          (res_idx > 0) ? wrap(reservoir_treat) : R_NilValue,
      _["strata_concordant"] =
          res_idx > 0 ? wrap(reservoir_strata) : R_NilValue,
      _["X_diffs_discordant"] = diff_idx > 0 ? wrap(diffs_X) : R_NilValue,
      _["y_diffs_discordant"] = diff_idx > 0 ? wrap(diffs_y) : R_NilValue,
      _["treatment_diffs_discordant"] =
          (diff_idx > 0) ? wrap(diffs_treat) : R_NilValue,
      _["discordant_idx"] = discordant_idx);
}
