#' Plot predictions from several xhaz models (two-panel display)
#'
#' This method plots a \code{predxhaz_list} object, i.e. multiple prediction
#' objects (typically one per model and per subgroup), using two panels
#' (left/right) for \code{"survival"} and/or \code{"hazard"}.
#'
#' @param x A \code{predxhaz_list}.
#' @param left,right Which metrics to display: "hazard" or "survival".
#' @param col Named colors for models (names = model names). Auto if NULL.
#' @param lty_by_group Line types for subgroups (recycled). Auto if NULL.
#' @param legend_model_pos "top" (outside, shared) or "none".
#' @param legend_group_pos Inside-panel legend position (e.g., "topright") or "none".
#' @param ylim_left,ylim_right Numeric y-limits (auto if NULL).
#' @param grid Logical; add \code{graphics::grid()} in panels.
#' @param step NULL = auto (step for hazard when baseline == "constant"),
#'   TRUE/FALSE to force.
#' @param ... Passed to the initial \code{graphics::plot()} in each panel
#'   (e.g., xlab, ylab, xlim).
#' @return Invisibly returns \code{x}.
#'
#' @method plot predxhaz_list
#'
#' @examples
#' \dontrun{
#' ## --- Single-model predictions (predxhaz) ---
#' data("simuData", package = "xhaz")
#' library(survival)
#' fit1 <- xhaz(Surv(time_year, status) ~ agec + race,
#'              data = simuData,
#'              ratetable = survexp.us,
#'              interval = c(0, NA, NA, NA, NA, NA, max(simuData$time_year)),
#'              rmap = list(age = "age", sex = "sex", year = "date"),
#'              baseline = "constant", pophaz = "classic")
#'
#' fit2 <- xhaz(Surv(time_year, status) ~ agec + race,
#'              data = simuData,
#'              ratetable = survexp.us,
#'              interval = c(0, NA, NA, max(simuData$time_year)),
#'              rmap = list(age = "age", sex = "sex", year = "date"),
#'              baseline = "bsplines", pophaz = "classic")
#'
#' tgrid <- seq(0, 4, 0.1)
#' d1 <- simuData[1,]
#' px1 <- predict(fit1, new.data = d1, times.pts = tgrid, baseline = FALSE)
#' px2 <- predict(fit2, new.data = d1, times.pts = tgrid, baseline = FALSE)
#'
# Build a predxhaz_list (recommended constructor for plot())
#' pl <- pred_list(constant = px1, bsplines = px2)
#'
#' # Two-panel plot (left/right)
#' plot(pl,
#'      left = "survival", right = "hazard",
#'      xlab = "Time (years)",
#'      ylim_left = c(0, 1))
#' }
#' @export

plot.predxhaz_list <- function(x,
                               left = c("hazard", "survival"),
                               right = c("survival", "hazard"),
                               col = NULL,
                               lty_by_group = NULL,
                               legend_model_pos = c("top", "none"),
                               legend_group_pos = "topright",
                               ylim_left = NULL, ylim_right = NULL,
                               grid = TRUE,
                               step = NULL,
                               ...) {

  stopifnot(inherits(x, "predxhaz_list"))

  left  <- match.arg(left)
  right <- match.arg(right)
  legend_model_pos <- match.arg(legend_model_pos)

  items        <- x$items
  model_names  <- x$model_names
  group_ids    <- x$group_ids
  group_labels <- x$group_labels

  dots <- list(...)
  dots$what <- NULL
  dots$left <- NULL
  dots$right <- NULL
  .check_px <- function(px, metric) {
    if (is.null(px$time)) stop("Each predxhaz must contain $time.")
    if (is.null(px[[metric]])) stop("Each predxhaz must contain $", metric, ".")
    invisible(TRUE)
  }

  if (is.null(col)) {
    col <- grDevices::hcl.colors(length(model_names), palette = "Dark 3")
    names(col) <- model_names
  } else {
    if (is.null(names(col))) names(col) <- model_names
    if (length(col) < length(model_names)) col <- rep(col, length.out = length(model_names))
  }

  if (is.null(lty_by_group)) lty_by_group <- c(1, 2, 3, 4)
  lty_map <- setNames(rep(lty_by_group, length.out = length(group_ids)), group_ids)

  rng <- function(metric) {
    vals <- c()
    for (nm in model_names) {
      for (g in names(items[[nm]])) {
        px <- items[[nm]][[g]]
        .check_px(px, metric)
        vals <- c(vals, px[[metric]])
      }
    }
    range(vals, finite = TRUE)
  }

  if (is.null(ylim_left))  ylim_left  <- rng(if (left  == "hazard") "hazard" else "survival")
  if (is.null(ylim_right)) ylim_right <- rng(if (right == "hazard") "hazard" else "survival")

  nmod <- length(model_names)
  nrow_legend <- ceiling(nmod / min(6, nmod)) # for legend area height

  old.par <- graphics::par(no.readonly = TRUE)
  on.exit({
    graphics::layout(1)
    graphics::par(old.par)
  }, add = TRUE)

  graphics::layout(matrix(c(1, 1, 2, 3), nrow = 2, byrow = TRUE),
                   heights = c(0.9 + 0.35 * nrow_legend, 8))

  graphics::par(mar = c(0, 0, 0, 0))
  graphics::plot.new()

  if (legend_model_pos == "top") {
    # Avoid warning: do not set ncol when horiz=TRUE
    graphics::legend("center", bty = "n", horiz = TRUE, title = "Model",
                     legend = model_names, lty = 1, lwd = 2,
                     col = unname(col[model_names]),
                     cex = if (nmod > 6) 0.9 else 1)
  }

  draw_panel <- function(metric, ylim, default_ylab) {
    graphics::par(mar = c(4, 4, 2, 1))

    # init frame
    first_px <- NULL
    for (nm in model_names) {
      if (length(items[[nm]]) > 0L) {
        first_px <- items[[nm]][[1]]
        break
      }
    }
    if (is.null(first_px)) stop("No data to plot.")
    .check_px(first_px, metric)

    if (is.null(dots$xlab)) dots$xlab <- "Time (years)"
    if (is.null(dots$ylab)) dots$ylab <- default_ylab

    do.call(graphics::plot, c(list(x = first_px$time,
                                   y = first_px[[metric]],
                                   type = "n",
                                   ylim = ylim),
                              dots))

    if (isTRUE(grid)) graphics::grid()

    # draw all series
    for (nm in model_names) {
      for (g in names(items[[nm]])) {
        px <- items[[nm]][[g]]
        .check_px(px, metric)

        step_here <- if (!is.null(step)) {
          isTRUE(step)
        } else {
          (metric == "hazard" && identical(attr(px, "baseline", TRUE), "constant"))
        }

        graphics::lines(px$time, px[[metric]],
                        type = if (step_here) "s" else "l",
                        col = col[[nm]], lwd = 2, lty = lty_map[[g]])
      }
    }

    if (!identical(legend_group_pos, "none") && length(group_ids) > 1) {
      graphics::legend(legend_group_pos, bty = "n", title = "Subgroup",
                       legend = unname(group_labels),
                       lty = unname(lty_map[group_ids]),
                       lwd = 2, col = "black", cex = 0.9)
    }
  }

  draw_panel(left,  ylim_left,  if (left  == "hazard") "Excess hazard" else "Net survival")
  draw_panel(right, ylim_right, if (right == "hazard") "Excess hazard" else "Net survival")

  invisible(x)
}
