# diagcounts: Core implementation

#' Derive Unreported Diagnostic Test Counts
#'
#' Recovers unreported true positive (TP), false negative (FN),
#' false positive (FP), and true negative (TN) counts using a
#' system of linear equations.
#'
#' @param n Total sample size.
#' @param sensitivity Test sensitivity.
#' @param specificity Test specificity.
#' @param ppv Positive predictive value.
#' @param npv Negative predictive value.
#' @param prevalence Pretest probability.
#' @param tol Numerical tolerance for validation.
#'
#' @return An object of class \code{diagcounts} with elements TP, FN, FP, TN.
#'
#' @references
#' Xie X, Wang M, Antony J, Vandersluis S, Kabali CB (2025).
#' System of Linear Equations to Derive Unreported Test Accuracy Counts.
#' Statistics in Medicine. https://doi.org/10.1002/sim.70336
#'
#' @examples
#' # Recover unreported diagnostic counts from published accuracy measures
#' derive_counts(
#' n = 105,
#' sensitivity = 0.6,
#' specificity = 0.893,
#' prevalence = 0.733
#' )
#'
#' # Recover counts using predictive values
#' derive_counts(
#'   n = 160,
#'   sensitivity = 0.75,
#'   ppv = 0.75,
#'   npv = 0.75
#' )
#' @importFrom utils combn
#' @export

derive_counts <- function(
    n,
    sensitivity = NULL,
    specificity = NULL,
    ppv = NULL,
    npv = NULL,
    prevalence = NULL,
    tol = 1e-6
) {
  if (missing(n) || length(n) != 1 || n <= 0) {
    stop("`n` must be a positive scalar.")
  }

  params <- list(
    sensitivity = sensitivity,
    specificity = specificity,
    ppv = ppv,
    npv = npv,
    prevalence = prevalence
  )

  provided <- !vapply(params, is.null, logical(1))
  if (sum(provided) < 3) {
    stop("At least three accuracy measures must be provided.")
  }

  # Build equations A %*% x = b
  # x = (TP, FN, FP, TN)
  A <- matrix(0, nrow = 0, ncol = 4)
  b <- numeric(0)

  # Total count
  A <- rbind(A, c(1, 1, 1, 1))
  b <- c(b, n)

  if (!is.null(sensitivity)) {
    # TP / (TP + FN) = sens
    A <- rbind(A, c(1 - sensitivity, -sensitivity, 0, 0))
    b <- c(b, 0)
  }

  if (!is.null(specificity)) {
    # TN / (FP + TN) = spec
    A <- rbind(A, c(0, 0, -specificity, 1 - specificity))
    b <- c(b, 0)
  }

  if (!is.null(ppv)) {
    # TP / (TP + FP) = ppv
    A <- rbind(A, c(1 - ppv, 0, -ppv, 0))
    b <- c(b, 0)
  }

  if (!is.null(npv)) {
    # TN / (FN + TN) = npv
    A <- rbind(A, c(0, -npv, 0, 1 - npv))
    b <- c(b, 0)
  }

  if (!is.null(prevalence)) {
    # (TP + FN) / n = prev
    A <- rbind(A, c(1, 1, 0, 0))
    b <- c(b, prevalence * n)
  }

  # Select 4 independent equations
  if (nrow(A) < 4) {
    stop("Infeasible system: insufficient equations.")
  }

  # Use first 4 equations that yield full rank
  idx <- combn(seq_len(nrow(A)), 4, simplify = FALSE)
  solved <- FALSE

  for (i in idx) {
    Ai <- A[i, , drop = FALSE]
    if (abs(det(Ai)) > tol) {
      bi <- b[i]
      x <- solve(Ai, bi)
      solved <- TRUE
      break
    }
  }

  if (!solved) {
    stop("Infeasible system: equations are linearly dependent.")
  }

  x_round <- round(x)

  if (any(x_round < 0)) {
    stop("Infeasible system: negative counts after rounding.")
  }

  if (sum(x_round) != n) {
    stop("Infeasible system: rounded counts do not sum to n.")
  }

  res <- structure(
    list(TP = as.integer(x_round[1]),
         FN = as.integer(x_round[2]),
         FP = as.integer(x_round[3]),
         TN = as.integer(x_round[4])),
    class = "diagcounts"
  )

  res
}

#' @export
print.diagcounts <- function(x, ...) {
  cat("Derived diagnostic counts:\n")
  cat(sprintf("TP: %d  FN: %d  FP: %d  TN: %d\n",
              x$TP, x$FN, x$FP, x$TN))
  invisible(x)
}

#' @export
as.table.diagcounts <- function(x, ...) {
  matrix(c(x$TP, x$FN, x$FP, x$TN),
         nrow = 2,
         byrow = TRUE,
         dimnames = list(
           Test = c("Positive", "Negative"),
           Truth = c("Disease", "No disease")
         ))
}

