###############################################################################
##  jmSurface Package — Complete Validation & Testing Script
##  
##  Run this script in R after installing the package:
##    install.packages("jmSurface_0.1.0.tar.gz", repos = NULL, type = "source")
##  
##  Or source the package files directly:
##    for (f in list.files("jmSurface/R", full.names = TRUE)) source(f)
##
##  This script tests ALL exported functions and reports PASS/FAIL for each.
###############################################################################

cat("\n")
cat("================================================================\n")
cat("  jmSurface Package — Comprehensive Validation Suite\n")
cat("================================================================\n\n")

## ── Helper ──
test_count <- 0
pass_count <- 0
fail_count <- 0
results <- list()

run_test <- function(name, expr) {
  test_count <<- test_count + 1
  cat(sprintf("  [%02d] %-50s ", test_count, name))
  t0 <- proc.time()["elapsed"]
  result <- tryCatch({
    eval(expr)
    elapsed <- round(proc.time()["elapsed"] - t0, 2)
    pass_count <<- pass_count + 1
    cat(sprintf("\u2705 PASS  (%.1fs)\n", elapsed))
    results[[name]] <<- "PASS"
    TRUE
  }, error = function(e) {
    elapsed <- round(proc.time()["elapsed"] - t0, 2)
    fail_count <<- fail_count + 1
    cat(sprintf("\u274C FAIL  (%.1fs)\n", elapsed))
    cat(sprintf("        Error: %s\n", e$message))
    results[[name]] <<- paste("FAIL:", e$message)
    FALSE
  })
  result
}


## ══════════════════════════════════════════════════════════════════
##  SECTION 1: Package Loading
## ══════════════════════════════════════════════════════════════════
cat("--- Section 1: Package Loading ---\n")

run_test("Load jmSurface package", {
  library(jmSurface)
})

run_test("Check required dependencies loaded", {
  stopifnot(requireNamespace("nlme", quietly = TRUE))
  stopifnot(requireNamespace("survival", quietly = TRUE))
  stopifnot(requireNamespace("mgcv", quietly = TRUE))
})

run_test("All exported functions exist", {
  fns <- c("jmSurf", "fit_longitudinal", "compute_blup_eta",
           "fit_gam_cox", "edf_diagnostics", "dynPred",
           "plot_surface", "contour_heatmap", "marginal_slices",
           "simulate_jmSurface", "run_shiny_app")
  for (fn in fns) {
    stopifnot(is.function(get(fn, envir = asNamespace("jmSurface"))))
  }
})


## ══════════════════════════════════════════════════════════════════
##  SECTION 2: Data Simulation
## ══════════════════════════════════════════════════════════════════
cat("\n--- Section 2: Data Simulation ---\n")

run_test("simulate_jmSurface() runs (N=200)", {
  sim <- simulate_jmSurface(n_patients = 200, seed = 42)
  stopifnot(is.list(sim))
  stopifnot(all(c("long_data", "surv_data") %in% names(sim)))
})

run_test("Simulated long_data has correct columns", {
  sim <- simulate_jmSurface(n_patients = 100, seed = 1)
  req_cols <- c("patient_id", "visit_time_years", "biomarker", "value", "unit")
  stopifnot(all(req_cols %in% names(sim$long_data)))
})

run_test("Simulated surv_data has correct columns", {
  sim <- simulate_jmSurface(n_patients = 100, seed = 1)
  req_cols <- c("patient_id", "start_time", "stop_time", "status",
                "state_from", "state_to", "transition")
  stopifnot(all(req_cols %in% names(sim$surv_data)))
})

run_test("Simulated data has 3 biomarkers", {
  sim <- simulate_jmSurface(n_patients = 100, seed = 1)
  bms <- unique(sim$long_data$biomarker)
  stopifnot(length(bms) == 3)
  stopifnot(all(c("eGFR", "BNP", "HbA1c") %in% bms))
})

run_test("Simulated data has transitions with events", {
  sim <- simulate_jmSurface(n_patients = 200, seed = 1)
  events <- sim$surv_data[sim$surv_data$status == 1, ]
  stopifnot(nrow(events) > 0)
  n_trans <- length(unique(events$transition))
  cat(sprintf("(%d transitions) ", n_trans))
  stopifnot(n_trans >= 2)
})

run_test("Simulated data seed reproducibility", {
  sim1 <- simulate_jmSurface(n_patients = 50, seed = 999)
  sim2 <- simulate_jmSurface(n_patients = 50, seed = 999)
  stopifnot(identical(sim1$long_data, sim2$long_data))
  stopifnot(identical(sim1$surv_data, sim2$surv_data))
})


## ══════════════════════════════════════════════════════════════════
##  SECTION 2B: Bundled Example Data
## ══════════════════════════════════════════════════════════════════
cat("\n--- Section 2B: Bundled Example Data ---\n")

run_test("load_example_data() loads successfully", {
  dat <- load_example_data()
  stopifnot(is.list(dat))
  stopifnot(all(c("long_data", "surv_data") %in% names(dat)))
})

run_test("Bundled long_data has 2000 patients", {
  dat <- load_example_data()
  n_pat <- length(unique(dat$long_data$patient_id))
  cat(sprintf("(%d patients, %d rows) ", n_pat, nrow(dat$long_data)))
  stopifnot(n_pat == 2000)
})

run_test("Bundled surv_data has correct transitions", {
  dat <- load_example_data()
  events <- dat$surv_data[dat$surv_data$status == 1, ]
  trans <- unique(events$transition)
  cat(sprintf("(%d transition types) ", length(trans)))
  stopifnot(length(trans) >= 6)
})

run_test("Bundled data has all 3 biomarkers", {
  dat <- load_example_data()
  bms <- unique(dat$long_data$biomarker)
  stopifnot(all(c("eGFR", "BNP", "HbA1c") %in% bms))
})

run_test("CSV files exist in inst/extdata", {
  lp <- system.file("extdata", "longitudinal_biomarkers.csv", package = "jmSurface")
  sp <- system.file("extdata", "survival_events.csv", package = "jmSurface")
  stopifnot(nchar(lp) > 0 && file.exists(lp))
  stopifnot(nchar(sp) > 0 && file.exists(sp))
})


## ══════════════════════════════════════════════════════════════════
##  SECTION 3: Stage 1 — Longitudinal Model Fitting
## ══════════════════════════════════════════════════════════════════
cat("\n--- Section 3: Stage 1 — Longitudinal Models ---\n")

# Use a shared simulated dataset for remaining tests
sim <- simulate_jmSurface(n_patients = 300, seed = 42)

run_test("fit_longitudinal() returns named list", {
  lme_fits <- fit_longitudinal(sim$long_data, verbose = FALSE)
  stopifnot(is.list(lme_fits))
  stopifnot(all(c("eGFR", "BNP", "HbA1c") %in% names(lme_fits)))
})

run_test("LME fits are valid nlme::lme objects", {
  lme_fits <- fit_longitudinal(sim$long_data, verbose = FALSE)
  for (mk in names(lme_fits)) {
    stopifnot(inherits(lme_fits[[mk]], "lme"))
  }
})

run_test("Fixed effects have expected signs", {
  lme_fits <- fit_longitudinal(sim$long_data, verbose = FALSE)
  fe_egfr <- nlme::fixef(lme_fits[["eGFR"]])
  fe_bnp <- nlme::fixef(lme_fits[["BNP"]])
  # eGFR should decline (negative slope), BNP should increase (positive slope)
  stopifnot(fe_egfr[2] < 0)  # eGFR slope negative
  stopifnot(fe_bnp[2] > 0)   # BNP slope positive
  cat(sprintf("(eGFR slope=%.2f, BNP slope=%.2f) ", fe_egfr[2], fe_bnp[2]))
})

run_test("compute_blup_eta() returns correct structure", {
  lme_fits <- fit_longitudinal(sim$long_data, verbose = FALSE)
  eta <- compute_blup_eta(lme_fits, patient_ids = 1:5, times = c(0, 2, 5))
  stopifnot(is.data.frame(eta))
  stopifnot(nrow(eta) == 15)  # 5 patients x 3 times
  eta_cols <- grep("^eta_", names(eta), value = TRUE)
  stopifnot(length(eta_cols) == 3)
})


## ══════════════════════════════════════════════════════════════════
##  SECTION 4: Full Model Fitting (jmSurf)
## ══════════════════════════════════════════════════════════════════
cat("\n--- Section 4: Full Model Fitting (jmSurf) ---\n")

run_test("jmSurf() runs successfully", {
  fit <- jmSurf(
    long_data = sim$long_data,
    surv_data = sim$surv_data,
    covariates = c("age_baseline", "sex"),
    verbose = FALSE
  )
  stopifnot(inherits(fit, "jmSurface"))
})

# Store fit for remaining tests
fit <- jmSurf(
  long_data = sim$long_data,
  surv_data = sim$surv_data,
  covariates = c("age_baseline", "sex"),
  verbose = FALSE
)

run_test("jmSurface object has required components", {
  required <- c("lme_fits", "gam_fits", "eta_data", "transitions",
                "biomarkers", "covariates", "edf", "deviance_explained",
                "n_patients", "call")
  stopifnot(all(required %in% names(fit)))
})

run_test("At least 1 transition fitted", {
  n_tr <- length(fit$transitions)
  cat(sprintf("(%d transitions) ", n_tr))
  stopifnot(n_tr >= 1)
})

run_test("GAM fits are valid mgcv::gam objects", {
  for (tr in names(fit$gam_fits)) {
    stopifnot(inherits(fit$gam_fits[[tr]], "gam"))
  }
})

run_test("EDF values are positive and finite", {
  stopifnot(all(is.finite(fit$edf)))
  stopifnot(all(fit$edf > 0))
  cat(sprintf("(range %.1f - %.1f) ", min(fit$edf), max(fit$edf)))
})

run_test("Deviance explained is between 0 and 1", {
  stopifnot(all(fit$deviance_explained >= 0))
  stopifnot(all(fit$deviance_explained <= 1))
})


## ══════════════════════════════════════════════════════════════════
##  SECTION 5: EDF Diagnostics
## ══════════════════════════════════════════════════════════════════
cat("\n--- Section 5: EDF Diagnostics ---\n")

run_test("edf_diagnostics() returns correct data frame", {
  edf_df <- edf_diagnostics(fit)
  stopifnot(is.data.frame(edf_df))
  req_cols <- c("transition", "edf", "deviance_explained", "n_obs",
                "n_events", "complexity", "p_value")
  stopifnot(all(req_cols %in% names(edf_df)))
})

run_test("Complexity labels are valid", {
  edf_df <- edf_diagnostics(fit)
  valid_labels <- c("Linear", "Moderate", "Nonlinear", "Unknown")
  stopifnot(all(edf_df$complexity %in% valid_labels))
  cat(sprintf("(%s) ", paste(edf_df$complexity, collapse = ", ")))
})

run_test("EDF diagnostics match fit$edf", {
  edf_df <- edf_diagnostics(fit)
  for (i in seq_len(nrow(edf_df))) {
    tr <- edf_df$transition[i]
    stopifnot(abs(edf_df$edf[i] - round(fit$edf[tr], 2)) < 0.01)
  }
})


## ══════════════════════════════════════════════════════════════════
##  SECTION 6: Dynamic Prediction
## ══════════════════════════════════════════════════════════════════
cat("\n--- Section 6: Dynamic Prediction ---\n")

# Pick a patient with events
event_patients <- unique(sim$surv_data$patient_id[sim$surv_data$status == 1])
test_pid <- event_patients[1]

run_test("dynPred() runs for a single patient", {
  pred <- dynPred(fit, patient_id = test_pid, landmark = 0, horizon = 3)
  stopifnot(is.data.frame(pred))
  stopifnot(nrow(pred) > 0)
})

run_test("dynPred() output has correct columns", {
  pred <- dynPred(fit, patient_id = test_pid, landmark = 0, horizon = 3)
  req_cols <- c("time", "risk", "hazard", "transition", "to_state",
                "patient_id", "landmark")
  stopifnot(all(req_cols %in% names(pred)))
})

run_test("Predicted risks are in [0, 1)", {
  pred <- dynPred(fit, patient_id = test_pid, landmark = 0, horizon = 5)
  stopifnot(all(pred$risk >= 0))
  stopifnot(all(pred$risk < 1))
})

run_test("Predicted risks are monotonically non-decreasing", {
  pred <- dynPred(fit, patient_id = test_pid, landmark = 0, horizon = 5)
  for (tr in unique(pred$transition)) {
    r <- pred$risk[pred$transition == tr]
    diffs <- diff(r)
    # Allow tiny numerical decreases
    stopifnot(all(diffs >= -1e-10))
  }
})

run_test("dynPred() with different landmark", {
  pred1 <- dynPred(fit, patient_id = test_pid, landmark = 0, horizon = 3)
  pred2 <- dynPred(fit, patient_id = test_pid, landmark = 2, horizon = 3)
  stopifnot(nrow(pred2) > 0)
  # Times should differ
  stopifnot(min(pred2$time) > min(pred1$time))
})

run_test("dynPred() with different horizons", {
  pred_short <- dynPred(fit, patient_id = test_pid, landmark = 0, horizon = 1)
  pred_long  <- dynPred(fit, patient_id = test_pid, landmark = 0, horizon = 5)
  # Longer horizon should have higher max risk
  max_short <- max(pred_short$risk)
  max_long  <- max(pred_long$risk)
  stopifnot(max_long >= max_short - 0.01)  # allow small tolerance
})

run_test("dynPred() errors for non-existent patient", {
  err <- tryCatch(dynPred(fit, patient_id = 999999, landmark = 0, horizon = 3),
                  error = function(e) e)
  stopifnot(inherits(err, "error"))
})


## ══════════════════════════════════════════════════════════════════
##  SECTION 7: Visualization Functions
## ══════════════════════════════════════════════════════════════════
cat("\n--- Section 7: Visualization Functions ---\n")

tr1 <- fit$transitions[1]

run_test("plot_surface() runs without error", {
  pdf(tempfile(fileext = ".pdf"))
  plot_surface(fit, transition = tr1)
  dev.off()
})

run_test("contour_heatmap() runs without error", {
  pdf(tempfile(fileext = ".pdf"))
  contour_heatmap(fit, transition = tr1)
  dev.off()
})

run_test("marginal_slices() runs without error", {
  pdf(tempfile(fileext = ".pdf"))
  marginal_slices(fit, transition = tr1)
  dev.off()
})

run_test("plot.jmSurface() dispatches correctly", {
  pdf(tempfile(fileext = ".pdf"))
  plot(fit, type = "surface")
  plot(fit, type = "heatmap")
  plot(fit, type = "slices")
  dev.off()
})

run_test("plot_surface() returns grid invisibly", {
  pdf(tempfile(fileext = ".pdf"))
  g <- plot_surface(fit, transition = tr1)
  dev.off()
  stopifnot(is.data.frame(g))
  stopifnot("z" %in% names(g))
})

run_test("Visualization errors on bad transition name", {
  err <- tryCatch(plot_surface(fit, transition = "FAKE -> FAKE"),
                  error = function(e) e)
  stopifnot(inherits(err, "error"))
})


## ══════════════════════════════════════════════════════════════════
##  SECTION 8: S3 Methods
## ══════════════════════════════════════════════════════════════════
cat("\n--- Section 8: S3 Methods ---\n")

run_test("print.jmSurface() works", {
  out <- capture.output(print(fit))
  stopifnot(length(out) > 0)
  stopifnot(any(grepl("jmSurface", out)))
})

run_test("summary.jmSurface() works", {
  out <- capture.output(sm <- summary(fit))
  stopifnot(length(out) > 5)
  stopifnot(is.list(sm))
  stopifnot("edf_diagnostics" %in% names(sm))
})


## ══════════════════════════════════════════════════════════════════
##  SECTION 9: Edge Cases & Robustness
## ══════════════════════════════════════════════════════════════════
cat("\n--- Section 9: Edge Cases & Robustness ---\n")

run_test("jmSurf() errors on bad column names", {
  bad_data <- data.frame(id = 1, time = 0, marker = "X", val = 1)
  err <- tryCatch(jmSurf(bad_data, sim$surv_data, verbose = FALSE),
                  error = function(e) e)
  stopifnot(inherits(err, "error"))
})

run_test("jmSurf() errors on too few shared IDs", {
  tiny_long <- sim$long_data[sim$long_data$patient_id <= 5, ]
  tiny_surv <- sim$surv_data[sim$surv_data$patient_id > 200, ]
  err <- tryCatch(jmSurf(tiny_long, tiny_surv, verbose = FALSE),
                  error = function(e) e)
  stopifnot(inherits(err, "error"))
})

run_test("jmSurf() with custom k_marginal", {
  fit2 <- jmSurf(
    long_data = sim$long_data,
    surv_data = sim$surv_data,
    k_marginal = c(4, 4),
    k_additive = 5,
    verbose = FALSE
  )
  stopifnot(inherits(fit2, "jmSurface"))
})

run_test("Small sample (N=50) still works", {
  sim_small <- simulate_jmSurface(n_patients = 80, seed = 7)
  fit_small <- tryCatch(
    jmSurf(sim_small$long_data, sim_small$surv_data, verbose = FALSE),
    error = function(e) NULL
  )
  # May fail due to few events, that's OK — just shouldn't crash ungracefully
  if (is.null(fit_small)) {
    cat("(too few events, graceful skip) ")
  } else {
    stopifnot(inherits(fit_small, "jmSurface"))
  }
})


## ══════════════════════════════════════════════════════════════════
##  SECTION 10: Shiny App File Existence
## ══════════════════════════════════════════════════════════════════
cat("\n--- Section 10: Shiny App ---\n")

run_test("Shiny app file exists in inst/shiny", {
  app_dir <- system.file("shiny", package = "jmSurface")
  stopifnot(nchar(app_dir) > 0)
  stopifnot(file.exists(file.path(app_dir, "app.R")))
})


## ══════════════════════════════════════════════════════════════════
##  SECTION 11: Full Pipeline with Bundled Data (2000 patients)
## ══════════════════════════════════════════════════════════════════
cat("\n--- Section 11: Full Pipeline with Bundled Data ---\n")

run_test("jmSurf() on bundled data (2000 patients)", {
  dat <- load_example_data()
  fit_real <- jmSurf(
    long_data = dat$long_data,
    surv_data = dat$surv_data,
    covariates = c("age_baseline", "sex"),
    verbose = FALSE
  )
  stopifnot(inherits(fit_real, "jmSurface"))
  n_tr <- length(fit_real$transitions)
  cat(sprintf("(%d transitions fitted) ", n_tr))
  stopifnot(n_tr >= 3)
})

run_test("EDF diagnostics on bundled data", {
  dat <- load_example_data()
  fit_real <- jmSurf(dat$long_data, dat$surv_data,
                     covariates = c("age_baseline", "sex"), verbose = FALSE)
  edf_df <- edf_diagnostics(fit_real)
  cat(sprintf("(EDF range: %.1f - %.1f) ", min(edf_df$edf), max(edf_df$edf)))
  stopifnot(nrow(edf_df) >= 3)
  stopifnot(all(edf_df$edf > 0))
})

run_test("dynPred on bundled data patient", {
  dat <- load_example_data()
  fit_real <- jmSurf(dat$long_data, dat$surv_data,
                     covariates = c("age_baseline", "sex"), verbose = FALSE)
  ev_pats <- unique(dat$surv_data$patient_id[dat$surv_data$status == 1])
  pred <- dynPred(fit_real, patient_id = ev_pats[1], landmark = 1, horizon = 3)
  stopifnot(nrow(pred) > 0)
  stopifnot(all(pred$risk >= 0 & pred$risk < 1))
  cat(sprintf("(patient %d, %d predictions) ", ev_pats[1], nrow(pred)))
})

run_test("All 3 visualizations on bundled data", {
  dat <- load_example_data()
  fit_real <- jmSurf(dat$long_data, dat$surv_data,
                     covariates = c("age_baseline", "sex"), verbose = FALSE)
  tr <- fit_real$transitions[1]
  pdf(tempfile(fileext = ".pdf"))
  plot_surface(fit_real, transition = tr)
  contour_heatmap(fit_real, transition = tr)
  marginal_slices(fit_real, transition = tr)
  dev.off()
  cat(sprintf("(transition: %s) ", tr))
})


## ══════════════════════════════════════════════════════════════════
##  FINAL REPORT
## ══════════════════════════════════════════════════════════════════
cat("\n")
cat("================================================================\n")
cat("  VALIDATION REPORT\n")
cat("================================================================\n")
cat(sprintf("  Total tests:  %d\n", test_count))
cat(sprintf("  Passed:       %d  \u2705\n", pass_count))
cat(sprintf("  Failed:       %d  %s\n", fail_count,
    ifelse(fail_count == 0, "\u2705", "\u274C")))
cat(sprintf("  Pass rate:    %.1f%%\n", 100 * pass_count / test_count))
cat("================================================================\n")

if (fail_count > 0) {
  cat("\nFailed tests:\n")
  for (nm in names(results)) {
    if (grepl("^FAIL", results[[nm]])) {
      cat(sprintf("  \u274C %s: %s\n", nm, results[[nm]]))
    }
  }
}

cat("\n")
if (fail_count == 0) {
  cat("\u2705 ALL TESTS PASSED — jmSurface is working correctly!\n\n")
} else {
  cat("\u26A0\uFE0F Some tests failed. Review errors above.\n\n")
}
