#' Estimate Parameters in a Two-Component Gaussian Mixture Using Study-Level Summaries
#' 
#' Estimates group-specific means and standard deviations \eqn{(\mu_1, \mu_0, \sigma_1, \sigma_0)} in a two-component
#' normal mixture model based on aggregate data across multiple studies. The continuous variable \eqn{X}
#' is assumed to follow a Gaussian mixture conditional on a binary group indicator \eqn{Y \in \{0,1\}},
#' with each study reporting only summary-level statistics.
#'
#'#' Two estimation methods are available:
#' \itemize{
#'   \item \strong{"naive"}: Likelihood-based estimator using only sample means.
#'   \item \strong{"gmm"}: Generalized method of moments (GMM) estimator using sample means and variances.
#' }
#'
#' @param ni Integer vector of sample sizes per study.
#' @param xbar Numeric vector of sample means per study.
#' @param s2 Numeric vector of sample variances per study. Required if \code{method = "gmm"}.
#' @param mi Integer vector of group 1 counts per study.
#' @param method Estimation method to use. One of \code{"naive"} or \code{"gmm"}. Default is \code{"gmm"}.
#'
#' @return A named list containing:
#' \describe{
#'   \item{\code{mu1_hat, mu0_hat}}{Estimated means of the two groups.}
#'   \item{\code{sigma1_hat, sigma0_hat}}{Estimated standard deviations.}
#'   \item{\code{se}}{Standard errors of the parameter estimates (NA if \code{method = "naive"}).}
#'   \item{\code{ci}}{List of 95\% confidence intervals for each parameter (NULL if \code{method = "naive"}).}
#'   \item{\code{method}}{A character string indicating the method used.}
#' }
#'
#' @examples
#' # Load example dataset included in the package
#' data(mixture_example)
#'
#' # Estimate using GMM (recommended) with full summary statistics
#' est_mixture(
#'   ni = mixture_example$ni,
#'   xbar = mixture_example$xbar,
#'   s2 = mixture_example$s2,
#'   mi = mixture_example$mi,
#'   method = "gmm"
#' )
#'
#' # Estimate using naive likelihood method (only means used)
#' est_mixture(
#'   ni = mixture_example$ni,
#'   xbar = mixture_example$xbar,
#'   mi = mixture_example$mi,
#'   method = "naive"
#' )
#'
#' @export
#' @importFrom stats optim qnorm sd cov dnorm




est_mixture <- function(ni, xbar, mi, s2=NULL, method = c("gmm", "naive")) {
  
  method    <- match.arg(method)
  
  if (missing(ni) || missing(xbar) || missing(mi))
    stop("`ni`, `xbar`, and `mi` are required.")
  
  if (!is.numeric(ni) || !is.numeric(xbar) || !is.numeric(mi))
    stop("All inputs must be numeric vectors.")
  
  if (length(ni) != length(xbar) || length(ni) != length(mi))
    stop("`ni`, `xbar`, and `mi` must have the same length.")
  
  if (any(mi < 0 | mi > ni))
    stop("All `mi` must satisfy 0 <= mi <= n.")
  
  if (method == "gmm") {
    if (is.null(s2) || any(is.na(s2)))
      stop("`s2` must be provided for method = 'gmm'.")
    
    inputdata <- data.frame(ni = ni, xbar = xbar, s2 = s2, mi = mi, pi_est_individual = mi / ni)
    
  } else {
    inputdata <- data.frame(ni = ni, xbar = xbar, mi = mi, pi_est_individual = mi / ni)
  }
  
  ##############################################################################
  ## Naive estimator (plug-in based)
  ##############################################################################
  baseline_neg_loglik <- function(params, data) {
    mu1 <- params[1]
    mu0 <- params[2]
    log_sigma1 <- params[3]  # work on log scale
    log_sigma0 <- params[4]
    sigma1 <- exp(log_sigma1)
    sigma0 <- exp(log_sigma0)
    
    with(data, {
      pi_i <- pi_est_individual
      mean_i <- pi_i * mu1 + (1 - pi_i) * mu0
      var_i <- (mi / ni^2) * sigma1^2 + ((ni - mi) / ni^2) * sigma0^2
      ll_i <- dnorm(xbar, mean = mean_i, sd = sqrt(var_i), log = TRUE)
      return(-sum(ll_i))
    })
  }
  
  ##############################################################################
  ## GMM Scaled Estimator
  ##############################################################################
  moment_conditions_scale <- function(theta, data) {
    mu1 <- theta[1]
    mu0 <- theta[2]
    sigma1_sq <- theta[3]
    sigma0_sq <- theta[4]
    
    g1_vals <- c()
    g2_vals <- c()
    g3_vals <- c()
    g4_vals <- c()
    g5_vals <- c()
    
    for (i in 1:nrow(data)) {
      ni <- data$ni[i]
      mi <- data$mi[i]
      pi <- mi / ni
      xbar <- data$xbar[i]
      s2 <- data$s2[i]
      
      if (mi <= 1 || mi >= ni - 1) next
      # E[Xbar]
      mu_bar <- pi * mu1 + (1 - pi) * mu0
      g1 <- xbar - mu_bar
      # Var[Xbar]
      var_bar <- sigma1_sq * mi / ni^2 + sigma0_sq * (ni - mi) / ni^2
      if (var_bar <= 0) next
      g3 <- (xbar - mu_bar)^2 - var_bar
      # E[S2]
      E_s2 <- (mi * sigma1_sq + (ni - mi) * sigma0_sq) / ni +
        (mi * (ni - mi)) / (ni * (ni - 1)) * (mu1 - mu0)^2
      # Var[S2]
      nu <- mu1 - mu0
      tau2 <- sigma1_sq / mi + sigma0_sq / (ni - mi)
      var1 <- 2 * (mi - 1) * sigma1_sq^2
      var2 <- 2 * (ni - mi - 1) * sigma0_sq^2
      var3 <- ((mi * (ni - mi)) / ni)^2 * (2 * tau2^2 + 4 * tau2 * nu^2)
      Var_s2 <- (var1 + var2 + var3) / (ni - 1)^2
      if (Var_s2 <= 0) next
      
      g2 <- s2 - E_s2
      g4 <- (s2 - E_s2)^2 - Var_s2
      # Cov[Xbar, S2]
      model_cov <- (2 * mi * (ni - mi)) / (ni^2 * (ni - 1)) * (mu1 - mu0) * (sigma1_sq - sigma0_sq)
      g5 <- (xbar - mu_bar) * (s2 - E_s2) - model_cov
      
      g1_vals <- c(g1_vals, g1)
      g2_vals <- c(g2_vals, g2)
      g3_vals <- c(g3_vals, g3)
      g4_vals <- c(g4_vals, g4)
      g5_vals <- c(g5_vals, g5)
    }
    g_bar <- c(mean(g1_vals), mean(g2_vals), mean(g3_vals), mean(g4_vals), mean(g5_vals))
    sd_g <- apply(cbind(g1_vals, g2_vals, g3_vals, g4_vals, g5_vals), 2, sd, na.rm = TRUE)
    return(list(g_bar = g_bar / sd_g, raw_bar = g_bar, scales = sd_g))
  }
  
  fit_gmm_scale <- function(data, mu1_init, mu0_init, sigma1_init, sigma0_init) {
    gmm_objective <- function(theta, data) {
      moment_result <- moment_conditions_scale(theta, data)
      g <- moment_result$g_bar
      return(sum(g^2))
    }
    theta_init <- c(mu1_init, mu0_init, sigma1_init^2, sigma0_init^2)
    fit <- optim(
      par = theta_init,
      fn = function(p) gmm_objective(p, data),
      method = "L-BFGS-B",
      lower = c(-Inf, -Inf, 1e-6, 1e-6)
    )
    final_moments <- moment_conditions_scale(fit$par, data)
    return(list(
      theta_hat = fit$par,
      mu1_hat = fit$par[1],
      mu0_hat = fit$par[2],
      sigma1_hat = sqrt(fit$par[3]),
      sigma0_hat = sqrt(fit$par[4]),
      converged = fit$convergence == 0,
      objective_value = fit$value,
      scales = final_moments$scales
    ))
  }
  
  ################################################################################
  # SE
  ################################################################################
  compute_moment_matrix <- function(theta, data) {
    mu1 <- theta[1]
    mu0 <- theta[2]
    sigma1_sq <- theta[3]
    sigma0_sq <- theta[4]
    
    moment_matrix <- matrix(NA, nrow = 0, ncol = 5)
    
    for (i in 1:nrow(data)) {
      ni <- data$ni[i]
      mi <- data$mi[i]
      pi <- mi / ni
      xbar <- data$xbar[i]
      s2 <- data$s2[i]
      
      if (mi <= 1 || mi >= ni - 1) next
      
      mu_bar <- pi * mu1 + (1 - pi) * mu0
      g1 <- xbar - mu_bar
      
      var_bar <- sigma1_sq * mi / ni^2 + sigma0_sq * (ni - mi) / ni^2
      if (var_bar <= 0) next
      g3 <- (xbar - mu_bar)^2 - var_bar
      
      # E[S2]
      E_s2 <- (mi * sigma1_sq + (ni - mi) * sigma0_sq) / ni +
        (mi * (ni - mi)) / (ni * (ni - 1)) * (mu1 - mu0)^2
      
      nu <- mu1 - mu0
      tau2 <- sigma1_sq / mi + sigma0_sq / (ni - mi)
      var1 <- 2 * (mi - 1) * sigma1_sq^2
      var2 <- 2 * (ni - mi - 1) * sigma0_sq^2
      var3 <- ((mi * (ni - mi)) / ni)^2 * (2 * tau2^2 + 4 * tau2 * nu^2)
      Var_s2 <- (var1 + var2 + var3) / (ni - 1)^2
      if (Var_s2 <= 0) next
      g2 <- s2 - E_s2
      g4 <- (s2 - E_s2)^2 - Var_s2
      
      model_cov <- (2 * mi * (ni - mi)) / (ni^2 * (ni - 1)) * (mu1 - mu0) * (sigma1_sq - sigma0_sq)
      g5 <- (xbar - mu_bar) * (s2 - E_s2) - model_cov
      
      moment_matrix <- rbind(moment_matrix, c(g1, g2, g3, g4, g5))
      
    }
    
    return(moment_matrix)
  }
  
  compute_jacobian <- function(theta, data, moment_fn, eps = 1e-6) {
    base      <- moment_fn(theta, data)
    scales    <- base$scales                 # 冻住
    g0_scaled <- base$raw_bar / scales
    p <- length(theta); q <- length(g0_scaled)
    J <- matrix(NA_real_, q, p)
    for (j in 1:p) {
      theta_eps      <- theta; theta_eps[j] <- theta_eps[j] + eps
      bump           <- moment_fn(theta_eps, data)
      g_eps_scaled   <- bump$raw_bar / scales  # <- 用同一个 scales
      J[, j]         <- (g_eps_scaled - g0_scaled) / eps
    }
    return(J)
  }
  
  compute_se_gmm <- function(theta_hat, data, moment_fn) {
    
    # 1. 缩放矩条件用“当前 θ̂ 产生的 scales”
    mc_hat        <- moment_fn(theta_hat, data)
    scales_hat    <- mc_hat$scales
    g_mat_raw     <- compute_moment_matrix(theta_hat, data)
    g_mat_scaled  <- sweep(g_mat_raw, 2, scales_hat, "/")
    
    k <- nrow(g_mat_scaled)
    S <- cov(g_mat_scaled, use = "complete.obs") * (k-1)/k
    W <- diag(1, ncol(S))
    
    # 2. Jacobian
    G <- compute_jacobian(theta_hat, data, moment_fn)
    
    A <- t(G) %*% W %*% G
    B <- t(G) %*% W %*% S %*% W %*% G
    V <- solve(A, B) %*% solve(A)     # solve(A) %*% B %*% solve(A)
    
    se <- sqrt(diag(V) / k)
    return(se)
  }
  
  ################################################################################
  # Estimation
  ################################################################################
  
  init_params <- c(mu1 = mean(inputdata$xbar) + 1,
                   mu0 = mean(inputdata$xbar) - 1,
                   log_sigma1 = log(3),
                   log_sigma0 = log(3))
    
  fit_baseline <- tryCatch({
    optim(
      par = init_params,
      fn = baseline_neg_loglik,
      data = inputdata,
      method = "L-BFGS-B",
      lower = c(-Inf, -Inf, log(1e-3), log(1e-3)),
      upper   = c( Inf,  Inf, log(1e+20), log(1e+20))
    )
  }, error = function(e) {
    message("Naive Estimator's optimization failed")
    return(NULL)
  })
  
  # 提取baseline估计值
  par_bl <- fit_baseline$par
  mu1_bl <- par_bl[1]
  mu0_bl <- par_bl[2]
  sigma1_bl <- exp(par_bl[3])
  sigma0_bl <- exp(par_bl[4])
  
  
  
  if (method == "naive") {
    return(list(
      mu1_hat = as.numeric(mu1_bl),
      mu0_hat = as.numeric(mu0_bl),
      sigma1_hat = as.numeric(sigma1_bl),
      sigma0_hat = as.numeric(sigma0_bl),
      se = rep(NA, 4),
      ci = NULL,
      method = "Naive Estimator"
    ))
  }
  
  # 否则继续用GMM估计
  # GMM-scaled
  fit_gmm_scaled <- fit_gmm_scale(
    data = inputdata,
    mu1_init = mu1_bl,
    mu0_init = mu0_bl,
    sigma1_init = sigma1_bl,
    sigma0_init = sigma0_bl
  )
  
  # SE
  se_result <- tryCatch(
    {
      compute_se_gmm(
        theta_hat = fit_gmm_scaled$theta_hat,
        data = inputdata,
        moment_fn = moment_conditions_scale
      )
    },
    error = function(e) {
      message("SE computation failed",": ", e$message)
      return(rep(NA, 4))
    }
  )
  
  
  # GMM
  mu1_hat_scale <- fit_gmm_scaled$theta_hat[1]
  mu0_hat_scale <- fit_gmm_scaled$theta_hat[2]
  sigma1_hat_scale <- sqrt(fit_gmm_scaled$theta_hat[3])
  sigma0_hat_scale <- sqrt(fit_gmm_scaled$theta_hat[4])
  #SE
  mu1_se_scale <- se_result[1]
  mu0_se_scale <- se_result[2]
  sigma1_se_scale <- 0.5 / sigma1_hat_scale * se_result[3]
  sigma0_se_scale <- 0.5 / sigma0_hat_scale * se_result[4]
  se_result <- c(mu1   = mu1_se_scale, mu0   = mu0_se_scale, sigma1= as.numeric(sigma1_se_scale), sigma0= as.numeric(sigma0_se_scale))
  # CI
  ci <- list(
    mu1 = mu1_hat_scale + c(-1, 1) * 1.96 * mu1_se_scale,
    mu0 = mu0_hat_scale + c(-1, 1) * 1.96 * mu0_se_scale,
    sigma1 = sigma1_hat_scale + c(-1, 1) * 1.96 * sigma1_se_scale,
    sigma0 = sigma0_hat_scale + c(-1, 1) * 1.96 * sigma0_se_scale
  )
  
  list(
    mu1_hat = as.numeric(mu1_hat_scale),
    mu0_hat = as.numeric(mu0_hat_scale),
    sigma1_hat = as.numeric(sigma1_hat_scale),
    sigma0_hat = as.numeric(sigma0_hat_scale),
    se = se_result,
    ci = ci,
    method = "GMM Estimator"
  )
}
