#' @title Prepare Data for Weighted CIF
#' @description Prepares merged competing-risks and longitudinal severity data for
#'   weighted restricted mean analyses. The routine removes patients with zero
#'   follow-up or missing baseline severity, handles discharge-to-die cases, merges
#'   the longitudinal trajectory, and computes user-specified weighted time summaries
#'   for death-focused and discharge-focused analyses.
#' 
#' @param data_main A data.frame with ID, TTR, TTD, RECCNSR, DTHCNSR, baseline score, trt, etc.
#' @param data_long A data.frame with repeated clinical scores over time
#'   (e.g. ADYC, ORDSCOR).
#' @param wID_main Name of the patient ID column in the main dataset (default "USUBJID").
#' @param wTimeToRecovery_main Name of the time-to-recovery column (default "TTRECOV").
#' @param wTimeToDeath_main Name of the time-to-death column (default "TTDEATH").
#' @param wRecov_Censoring_main Name of the recovery-censor column (default "RECCNSR").
#' @param wDeath_Censoring_main Name of the death-censor column (default "DTHCNSR").
#' @param wBaselineScore_main Name of the baseline ordinal column (default "ordscr_bs").
#' @param wTreatment_main Name of the treatment indicator column (0=control,1=treatment). Default "trt".
#'
#' @param wID_long Name of the patient ID column in the long dataset (default "USUBJID").
#' @param wADY_long Name of the day-since-treatment column in the long dataset (default "ADYC").
#' @param wScore_long Name of the ordinal score column in the long dataset (default "ORDSCOR").
#'
#' @param wStates_death Vector of ordinal states for death weighting (default c(4,5,6,7)).
#' @param wWeights_death Numeric weights, same length as wStates_death (default c(2,1.5,1,0.5)).
#' @param wStates_discharge Vector of states for discharge weighting (default c(4,5,6,7)).
#' @param wWeights_discharge Numeric weights, same length as wStates_discharge
#'   (default c(0.5,1,1.5,2)).
#' @importFrom survival survfit Surv
#' @importFrom stats var cov pnorm
#' @importFrom rlang sym
#' @return A list containing:
#'   \itemize{
#'     \item \code{data.ws.death} and \code{data.ws.discharge}: Full merged datasets
#'           with an added \code{wU} column for (death) or (discharge).
#'     \item \code{Treatment.death} and \code{Control.death}: Subsets for weighted WRMLT2 (death-focused).
#'     \item \code{Treatment.discharge} and \code{Control.discharge}: Subsets for weighted WRMLT1 (recovery-focused).
#'   }
#' @export
prep_data_weighted_cif2 <- function(
    data_main,
    data_long,

    wID_main              = "USUBJID",
    wTimeToRecovery_main  = "TTRECOV",
    wTimeToDeath_main     = "TTDEATH",
    wRecov_Censoring_main = "RECCNSR",
    wDeath_Censoring_main = "DTHCNSR",
    wBaselineScore_main   = "ordscr_bs",
    wTreatment_main       = "trt",

    wID_long              = "USUBJID",
    wADY_long             = "ADYC",
    wScore_long           = "ORDSCOR",

    wStates_death         = c(4,5,6,7),
    wWeights_death        = c(2,1.5,1,0.5),
    wStates_discharge     = c(4,5,6,7),
    wWeights_discharge    = c(0.5,1,1.5,2)
){
  if (length(wWeights_death) != length(wStates_death)) {
    stop("wWeights_death must be same length as wStates_death")
  }
  if (length(wWeights_discharge) != length(wStates_discharge)) {
    stop("wWeights_discharge must be same length as wStates_discharge")
  }

  cn.t0 <- which(
    data_main[[wTimeToRecovery_main]] == 0 |
      data_main[[wTimeToDeath_main]] == 0
  )

  cn.noob <- which(is.na(data_main[[wBaselineScore_main]]))

  cn.dtd <- which(
    data_main[[wRecov_Censoring_main]] == 0 &
      data_main[[wDeath_Censoring_main]] == 0
  )
  if(length(cn.dtd) > 0){
    id.dtd <- data_main[[wID_main]][cn.dtd]
    myMaxTime <- max(data_main[[wTimeToRecovery_main]], na.rm=TRUE)
    data_main[data_main[[wID_main]] %in% id.dtd, wRecov_Censoring_main] <- 1
    data_main[data_main[[wID_main]] %in% id.dtd, wTimeToRecovery_main]  <- myMaxTime
  }


  data_main$etime <- pmin(
    data_main[[wTimeToRecovery_main]],
    data_main[[wTimeToDeath_main]]
  )

  cn.discard <- c(cn.t0, cn.noob)
  if(length(cn.discard) > 0){
    data_main <- data_main[-cn.discard, ]
  }

  data_main$estatus <- 1 - (
    (data_main[[wRecov_Censoring_main]] == 1) *
      (data_main[[wDeath_Censoring_main]] == 1)
  )

  data_main$etype2 <-
    1 * (data_main[[wRecov_Censoring_main]]==0 &
           data_main[[wDeath_Censoring_main]]==1) +
    2 * (data_main[[wDeath_Censoring_main]]==0)

  data.w <- data_main %>%
    dplyr::select(
      all_of(wID_main),
      "etime",
      "estatus",
      "etype2",
      all_of(wTreatment_main)
    )
  colnames(data.w) <- c("USUBJID","etime","estatus","etype2","trt")
  data_long[[wID_long]] <- as.character(data_long[[wID_long]])

  data_long <- data_long[data_long[[wID_long]] %in% data.w$USUBJID, ]

  data_long <- data_long %>%
    dplyr::rename(USUBJID = all_of(wID_long))

  data.long <- dplyr::left_join(data_long, data.w, by="USUBJID")

  data.l <- data.long %>%
    dplyr::select("USUBJID","etime","estatus",
                  all_of(wADY_long),
                  all_of(wScore_long))

  data.l[[wADY_long]] <- suppressWarnings(as.numeric(data.l[[wADY_long]]))
  data.l[[wADY_long]][is.na(data.l[[wADY_long]])] <- 0

  data.l <- data.l %>%
    dplyr::arrange(USUBJID, !!rlang::sym(wADY_long))

  data.l <- data.l[data.l$etime >= data.l[[wADY_long]], ]

  colnames(data.l) <- c("USUBJID","D_time","D_status","resp_time","resp")


  idList <- unique(data.w$USUBJID)
  wU <- data.frame(
    USUBJID = idList,
    death_w = 0,
    disc_w  = 0,
    stringsAsFactors = FALSE
  )

  for(i in seq_along(idList)){
    data.id <- data.l[data.l$USUBJID == idList[i], , drop = FALSE]
    count_death <- numeric(length(wStates_death))
    count_discharge <- numeric(length(wStates_discharge))

    if(nrow(data.id) == 0){
      next
    } else if(nrow(data.id) == 1){
      current_score <- data.id$resp[1]

      death_idx <- match(current_score, wStates_death)
      if (!is.na(death_idx)) {
        count_death[death_idx] <- 1
      }

      discharge_idx <- match(current_score, wStates_discharge)
      if (!is.na(discharge_idx)) {
        count_discharge[discharge_idx] <- 1
      }

    } else {
      for(j in seq_len(nrow(data.id) - 1)){
        current_score <- data.id$resp[j]
        interval_length <- data.id$resp_time[j+1] - data.id$resp_time[j]
        interval_length <- max(interval_length, 0)

        death_idx <- match(current_score, wStates_death)
        if (!is.na(death_idx)) {
          count_death[death_idx] <- count_death[death_idx] + interval_length
        }

        discharge_idx <- match(current_score, wStates_discharge)
        if (!is.na(discharge_idx)) {
          count_discharge[discharge_idx] <- count_discharge[discharge_idx] + interval_length
        }
      }
    }

    wU$death_w[i] <- sum(wWeights_death * count_death)
    wU$disc_w[i] <- sum(wWeights_discharge * count_discharge)
  }


  data.ws.death <- dplyr::left_join(
    data.w, wU[, c("USUBJID","death_w")], by="USUBJID"
  ) %>%
    dplyr::rename(wU = death_w)
  data.ws.death$wU[is.na(data.ws.death$wU)] <- 0

  data.ws.discharge <- dplyr::left_join(
    data.w, wU[, c("USUBJID","disc_w")], by="USUBJID"
  ) %>%
    dplyr::rename(wU = disc_w)
  data.ws.discharge$wU[is.na(data.ws.discharge$wU)] <- 0

  Treatment.death <- data.ws.death[data.ws.death$trt == 1,
                                   c("USUBJID","etime","estatus","etype2","wU")]
  colnames(Treatment.death) <- c("cn","D_time","D_status","etype","wU")

  Control.death <- data.ws.death[data.ws.death$trt == 0,
                                 c("USUBJID","etime","estatus","etype2","wU")]
  colnames(Control.death) <- c("cn","D_time","D_status","etype","wU")

  Treatment.discharge <- data.ws.discharge[data.ws.discharge$trt == 1,
                                           c("USUBJID","etime","estatus","etype2","wU")]
  colnames(Treatment.discharge) <- c("cn","D_time","D_status","etype","wU")

  Control.discharge <- data.ws.discharge[data.ws.discharge$trt == 0,
                                         c("USUBJID","etime","estatus","etype2","wU")]
  colnames(Control.discharge) <- c("cn","D_time","D_status","etype","wU")

  list(
    data.ws.death     = data.ws.death,
    data.ws.discharge = data.ws.discharge,

    Treatment.death    = Treatment.death,
    Control.death      = Control.death,
    Treatment.discharge = Treatment.discharge,
    Control.discharge   = Control.discharge
  )
}

# Internal helper retained for backwards compatibility in analyses. Not exported.
do_weighted_cif_from_raw <- function(
    data_main,
    data_long,

    wID_main              = "USUBJID",
    wTimeToRecovery_main  = "TTRECOV",
    wTimeToDeath_main     = "TTDEATH",
    wRecov_Censoring_main = "RECCNSR",
    wDeath_Censoring_main = "DTHCNSR",
    wBaselineScore_main   = "ordscr_bs",
    wTreatment_main       = "trt",
    wID_long              = "USUBJID",
    wADY_long             = "ADYC",
    wScore_long           = "ORDSCOR",

    wStates_death         = c(4,5,6,7),
    wWeights_death        = c(2,1.5,1,0.5),
    wStates_discharge     = c(4,5,6,7),
    wWeights_discharge    = c(0.5,1,1.5,2),

    tau                   = c(15, 29)
){

  prep_result <- prep_data_weighted_cif2(
    data_main = data_main,
    data_long = data_long,
    wID_main = wID_main,
    wTimeToRecovery_main = wTimeToRecovery_main,
    wTimeToDeath_main = wTimeToDeath_main,
    wRecov_Censoring_main = wRecov_Censoring_main,
    wDeath_Censoring_main = wDeath_Censoring_main,
    wBaselineScore_main = wBaselineScore_main,
    wTreatment_main = wTreatment_main,
    wID_long = wID_long,
    wADY_long = wADY_long,
    wScore_long = wScore_long,
    wStates_death = wStates_death,
    wWeights_death = wWeights_death,
    wStates_discharge = wStates_discharge,
    wWeights_discharge = wWeights_discharge
  )

  results <- lapply(tau, function(tt) {
    list(
      WRMLT1 = table_weighted(prep_result$Treatment.discharge,
                              prep_result$Control.discharge,
                              eta = 1, tau = tt),   # eta=1 for recovery/discharge
      WRMLT2 = table_weighted(prep_result$Treatment.death,
                              prep_result$Control.death,
                              eta = 2, tau = tt)    # eta=2 for death
    )
  })

  names(results) <- paste0("tau_", tau)

  return(results)
}


create_weights <- function(states, pattern = "decreasing") {
  n <- length(states)
  weights <- switch(pattern,
                    "decreasing" = seq(2, 0.5, length.out = n),        # Worse states get higher weight
                    "increasing" = seq(0.5, 2, length.out = n),        # Better states get higher weight
                    "equal" = rep(1, n),                               # All states equal weight
                    "binary_high_low" = if(n == 2) c(2, 0.5) else stop("binary_high_low only for 2 states"),
                    "binary_low_high" = if(n == 2) c(0.5, 2) else stop("binary_low_high only for 2 states"),
                    stop("Unknown pattern. Use: decreasing, increasing, equal, binary_high_low, binary_low_high")
  )

  return(weights)
}

###
# prepped_w <- prep_data_weighted_cif2(
#   data_main = main_df,
#   data_long = long_df,
#
#   wID_main              = "ID",
#   wTimeToRecovery_main  = "TimeToRecovery",
#   wTimeToDeath_main     = "TimeToDeath",
#   wRecov_Censoring_main = "RecoveryCensoringIndicator",
#   wDeath_Censoring_main = "DeathCensoringIndicator",
#   wTreatment_main       = "Treatment",
#   wBaselineScore_main   = "BaselineScore",
#
#   wID_long              = "PersonID",
#   wADY_long             = "RelativeDay",
#   wScore_long           = "OrdinalScore",
#
#   wStates_death         = c(4,5,6,7),
#   wWeights_death        = c(2,1.5,1,0.5),
#   wStates_discharge     = c(4,5,6,7),
#   wWeights_discharge    = c(0.5,1,1.5,2)
# )
