#' Plot causal distillation tree object
#'
#' @description Visualize the subgroups (i.e., the student tree) from a causal
#'   distillation tree object.
#'
#' @param cdt A causal distillation tree object, typically the output of
#'   \code{\link{causalDT}}.
#' @param show_digits Number of digits to show in the plot labels. Default is 2.
#'
#' @return A plot of the causal distillation tree.
#'
#' @examples
#' \donttest{
#' n <- 200
#' p <- 10
#' X <- matrix(rnorm(n * p), nrow = n, ncol = p)
#' Z <- rbinom(n, 1, 0.5)
#' Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1)
#'
#' cdt <- causalDT(X, Y, Z)
#' plot_cdt(cdt)
#' }
#'
#' @export
plot_cdt <- function(cdt, show_digits = 2) {
  # to avoid "no visible binding" error from R CMD check
  breaks_label <- NULL
  splitvar <- NULL
  estimate <- NULL
  label <- NULL
  info <- NULL

  party_obj <- partykit::as.party(cdt$student_fit$fit)
  plt <- ggparty::ggparty(party_obj) +
    ggparty::geom_edge() +
    ggparty::geom_edge_label(
      ggplot2::aes(
        label = substr(breaks_label, start = 1, stop = 12 + show_digits)
      ),
    ) +
    ggparty::geom_node_label(ggplot2::aes(label = splitvar), ids = "inner")
  subgroup_ates <- data.frame(id = plt$data$id) |>
    dplyr::left_join(cdt$estimate, by = c("id" = "leaf_id")) |>
    dplyr::mutate(
      label = sprintf("Subgroup ATE\n= %.3f", estimate)
    ) |>
    dplyr::pull(label)
  plt$data$info <- subgroup_ates
  plt <- plt +
    ggparty::geom_node_label(ggplot2::aes(label = info), ids = "terminal")
  return(plt)
}


#' Plot Jaccard subgroup similarity index (SSI) for causal distillation tree objects
#'
#' @description The Jaccard subgroup similiarity index (SSI) is a measure of the
#'   similarity between two candidate partitions of subgroups. To select an
#'   appropriate teacher model in CDT, the Jaccard SSI can be used to select the
#'   teacher model that recovers the most stable subgroups.
#'
#' @param ... Two or more causal distillation tree objects, each is typically
#'   the output of \code{\link{causalDT}}. Arguments should be named (so that
#'   they are properly labeled in the resulting plot).
#'
#' @return A plot of the Jaccard SSI for each tree depth.
#'
#' @examples
#' \donttest{
#' n <- 50
#' p <- 2
#' X <- matrix(rnorm(n * p), nrow = n, ncol = p)
#' Z <- rbinom(n, 1, 0.5)
#' Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1)
#'
#' # number of bootstraps for stability diagnostics (setting to small value for faster example)
#' B <- 10
#'
#' # run CDT with default causal forest teacher model
#' cdt1 <- causalDT(X, Y, Z, B_stability = B)
#'
#' # run CDT with custom BCF teacher model
#' cdt2 <- causalDT(
#'   X, Y, Z,
#'   # set BCF training parameters to be small for faster example
#'   teacher_model = purrr::partial(bcf, nsim = 100, nburn = 10),
#'   teacher_predict = predict_bcf,
#'   # set number of cross-fitting replications to be small for faster example
#'   nreps_crossfit = 5,
#'   B_stability = B
#' )
#' plot_jaccard(`Causal Forest` = cdt1, `BCF` = cdt2)
#' }
#'
#' @export
plot_jaccard <- function(...) {
  dots_ls <- rlang::dots_list(...)
  # to avoid "no visible binding" error from R CMD check
  tree_depth <- NULL
  jaccard_ssi <- NULL
  teacher_model <- NULL

  default_names <- paste0("Model", 1:length(dots_ls))
  if (is.null(names(dots_ls))) {
    names(dots_ls) <- default_names
  } else {
    names(dots_ls)[names(dots_ls) == ""] <- default_names[names(dots_ls) == ""]
  }

  ssi_df <- purrr::map(
    dots_ls,
    function(cdt) {
      tibble::tibble(
        tree_depth = 1:length(cdt$stability_diagnostics$jaccard_mean),
        jaccard_ssi = cdt$stability_diagnostics$jaccard_mean
      )
    }
  ) |>
    dplyr::bind_rows(.id = "teacher_model")

  plt <- ggplot2::ggplot(ssi_df) +
    ggplot2::aes(x = tree_depth, y = jaccard_ssi, color = teacher_model) +
    ggplot2::geom_line() +
    ggplot2::geom_point() +
    ggplot2::labs(
      x = "Tree Depth", y = "Jaccard SSI", color = "Teacher Model"
    ) +
    ggplot2::theme_classic()
  return(plt)
}
