#' Fit the Topic Testlet Model (TTM)
#'
#' Calibrates the TTM using score data and pre-computed topic proportions.
#' Uses a Variational Expectation-Maximization (VEM) approach to estimate
#' student ability (theta), topic penalties (lambda), and item parameters (b).
#'
#' @param scores An N x J numeric matrix of item scores (0, 1, ...).
#' @param delta An N x K numeric matrix of topic proportions (from ttm_lda).
#' @param max_iter Maximum number of EM iterations.
#' @param tol Convergence tolerance.
#'
#' @return A list containing:
#' \item{theta}{Vector of estimated student abilities.}
#' \item{lambda}{Matrix of estimated topic penalties.}
#' \item{gamma}{Vector of person-specific testlet effects.}
#' \item{item_params}{List of step difficulties for each item.}
#' \item{AIC}{Akaike Information Criterion.}
#' \item{BIC}{Bayesian Information Criterion.}
#' @importFrom stats optim dnorm
#' @export
ttm_est <- function(scores, delta, max_iter = 100, tol = 1e-4) {
  
  N <- nrow(scores)
  J <- ncol(scores)
  K <- ncol(delta)
  
  # --- 1. Initialization ---
  
  # Initialize Theta (using raw score approximation standardized)
  raw_scores <- rowSums(scores, na.rm = TRUE)
  theta <- as.vector(scale(raw_scores))
  
  # Initialize Lambda (Topic Penalties) to zeros
  lambda <- matrix(0, nrow = N, ncol = K)
  
  # Initialize Item Parameters (Step difficulties)
  # Structure: List where beta[[j]] contains thresholds for item j
  # Simple initialization: PCM with no testlet effect
  max_cats <- apply(scores, 2, max, na.rm = TRUE)
  beta <- lapply(1:J, function(j) {
    seq(-1, 1, length.out = max_cats[j]) 
  })
  
  # Helper: Partial Credit Model Probability
  # Computes P(X=x | theta, beta, gamma)
  get_prob_pcm <- function(th, b_vec, gam) {
    # th: scalar theta
    # b_vec: vector of step difficulties (h=1 to M)
    # gam: scalar testlet effect
    
    steps <- c(0, b_vec) # b_0 is 0 by definition in cumulative sum math usually, 
                         # but paper Eq 1 implies sum(theta - b_ih - gamma).
                         # Standard PCM: Num = exp(sum_{h=0}^j (theta - b_h)).
                         # Here: Num = exp(sum_{h=0}^j (theta - b_h - gamma)).
                         # We treat b_0 = 0.
    
    # Calculate numerator terms for categories 0 to M
    # Term k = sum_{h=0}^k (theta - b_h - gamma)
    #        = k*(theta - gamma) - sum_{h=0}^k b_h
    
    M <- length(b_vec)
    logits <- numeric(M + 1)
    
    cum_b <- cumsum(c(0, b_vec))
    
    for(k in 0:M) {
      logits[k+1] <- k * (th - gam) - cum_b[k+1]
    }
    
    probs <- exp(logits)
    probs <- probs / sum(probs)
    return(probs)
  }
  
  # --- VEM Loop ---
  
  prev_loglik <- -Inf
  
  for (iter in 1:max_iter) {
    
    # --- E-Step: Update Person Parameters (Theta and Lambda) ---
    # In VEM, we maximize the ELBO with respect to variational parameters.
    # Here, we simplify by finding the MAP estimate for Theta and Lambda
    # for each person, given the current item parameters.
    
    gamma_vec <- numeric(N)
    
    # Optimization wrapper for a single student
    # We optimize Theta (1 dim) and Lambda (K dims) jointly
    
    for (n in 1:N) {
      
      # Objective function for student n: Log-Likelihood + Priors
      # Prior Theta ~ N(0, 1)
      # Prior Lambda ~ N(0, sigma_lambda) - fixed small regularization
      
      fn_student <- function(params) {
        cur_theta <- params[1]
        cur_lambda <- params[2:(K+1)]
        cur_gamma <- sum(cur_lambda * delta[n, ])
        
        log_lik <- 0
        
        for (j in 1:J) {
          x_nj <- scores[n, j]
          if (!is.na(x_nj)) {
            probs <- get_prob_pcm(cur_theta, beta[[j]], cur_gamma)
            # x_nj is 0-indexed in paper (0...M), R usually 1-indexed.
            # Assuming input scores are 0, 1, 2...
            # Prob vector index is x_nj + 1
            log_lik <- log_lik + log(probs[x_nj + 1])
          }
        }
        
        # Add Priors (Regularization)
        log_prior_theta <- dnorm(cur_theta, 0, 1, log = TRUE)
        # Ridge penalty on lambda to ensure stability
        log_prior_lambda <- sum(dnorm(cur_lambda, 0, 1, log = TRUE)) 
        
        return(-(log_lik + log_prior_theta + log_prior_lambda)) # Minimize negative LL
      }
      
      # Start values
      init_vals <- c(theta[n], lambda[n, ])
      
      opt <- optim(init_vals, fn_student, method = "BFGS")
      
      theta[n] <- opt$par[1]
      lambda[n, ] <- opt$par[2:(K+1)]
      gamma_vec[n] <- sum(lambda[n, ] * delta[n, ])
    }
    
    # Identify/Center Gamma (Eq 5: sum Gamma = 0)
    # We enforce this by adjusting lambda intercept implicitly or post-hoc centering
    mean_gamma <- mean(gamma_vec)
    gamma_vec <- gamma_vec - mean_gamma
    # Note: Lambda technically should update to reflect this shift, 
    # but for prediction Gamma is the key.
    
    # --- M-Step: Update Item Parameters (Beta) ---
    # Maximize Marginal Likelihood given current Theta and Gamma
    
    total_loglik <- 0
    
    for (j in 1:J) {
      
      fn_item <- function(b_params) {
        ll_item <- 0
        for (n in 1:N) {
          x_nj <- scores[n, j]
          if (!is.na(x_nj)) {
             probs <- get_prob_pcm(theta[n], b_params, gamma_vec[n])
             ll_item <- ll_item + log(probs[x_nj + 1])
          }
        }
        return(-ll_item)
      }
      
      # optimize step parameters for item j
      opt_item <- optim(beta[[j]], fn_item, method = "BFGS")
      beta[[j]] <- opt_item$par
      total_loglik <- total_loglik - opt_item$value
    }
    
    # --- Convergence Check ---
    diff <- abs(total_loglik - prev_loglik)
    message(sprintf("Iter: %d | LogLik: %.4f | Diff: %.4f", iter, total_loglik, diff))
    
    if (diff < tol) break
    prev_loglik <- total_loglik
  }
  
  # --- Calculate Fit Statistics ---
  # Parameters: N_theta (N) + N_lambda (N*K) + ItemParams (sum categories)
  # Note: Usually in marginal estimation we count item params. 
  # In JML context, person params count. 
  # Using simple N_items * Avg_cats for AIC/BIC approximation here.
  n_params <- sum(sapply(beta, length))
  aic <- 2 * n_params - 2 * total_loglik
  bic <- n_params * log(N) - 2 * total_loglik
  
  return(list(
    theta = theta,
    lambda = lambda,
    gamma = gamma_vec,
    item_params = beta,
    AIC = aic,
    BIC = bic,
    loglik = total_loglik
  ))
}