##############################################
#### Functions for class("mb.predict") ####
##############################################



#' Plots predicted responses from a time-course MBNMA model
#'
#' @param x An object of class `"mb.predict"` generated by
#'   `predict("mbnma")`
#' @param disp.obs A boolean object to indicate whether to show shaded sections
#'   of the plot for where there is observed data (`TRUE`) or not (`FALSE`)
#' @param overlay.ref A boolean object indicating whether to overlay a line
#'   showing the median network reference treatment response over time on the
#'   plot (`TRUE`) or not (`FALSE`). The network reference treatment (treatment
#'   1) must be included in `predict` for this to display the network reference
#'   treatment properly.
#' @param overlay.nma Can be used to overlay the predicted results from a standard NMA model that
#'   "lumps" time-points together within the range specified in `overlay.nma`. Must be a numeric vector of length 2, or
#'   left as `NULL` (the default) to indicate no NMA should be performed. `overlay.nma` can only be specified if
#'   `overlay.ref==TRUE`. See Details for further information.
#' @param method Can take `"common"` or `"random"` to indicate the type of NMA model used to synthesise data points
#'   given in `overlay.nma`. The default is `"random"` since this assumes different
#'   time-points in `overlay.nma` have been lumped together to estimate the NMA.
#' @param col A character indicating the colour to use for shading if `disp.obs`
#'   is set to `TRUE`. Can be either `"blue"`, `"green"`, or `"red"`
#' @param max.col.scale Rarely requires adjustment. The maximum count of
#'   observations (therefore the darkest shaded color) only used if `disp.obs` is
#'   used. This allows consistency of shading between multiple plotted graphs.
#'   It should always be at least as high as the maximum count of observations
#'   plotted
#' @param ... Arguments for `ggplot()`
#' @inheritParams plot.mb.rank
#'
#' @details For the S3 method `plot()`, if `disp.obs` is set to `TRUE` it is
#'   advisable to ensure predictions in `predict` are estimated using an even
#'   sequence of time points to avoid misrepresentation of shaded densities.
#'   Shaded counts of observations will be relative to the treatment plotted in
#'   each panel rather than to the network reference treatment if `disp.obs` is
#'   set to `TRUE`.
#'
#'   `overlay.nma` can be useful to assess if the MBNMA predictions are in agreement with predictions from an NMA model
#'   for a specific range of time-points. This can be a general indicator of the fit of the time-course model. However, it
#'   is important to note that the wider the range specified in `overlay.nma`, the more likely it is that different time-points
#'   are included, and therefore that there is greater heterogeneity/inconsistency in the NMA model. If `overlay.nma` includes
#'   several follow-up times for any study then only a single time-point will be taken (the one closest to `mean(overlay.nma)`).
#'   The NMA predictions are plotted over the range specified in `overlay.nma` as a horizontal line, with the 95%CrI shown by a grey
#'   rectangle. The NMA predictions represent those for *any time-points within this range* since they lump together data at
#'   all these time-points. Predictions for treatments that are disconnected from
#'   the network reference treatment at data points specified within `overlay.nma` cannot be estimated so are not included.
#'
#'   It is important to note that the NMA model is not necessarily the "correct" model, since it "lump" different time-points
#'   together and ignores potential differences in treatment effects that may arise from this. The wider the range specified in
#'   `overlay.nma`, the greater the effect of "lumping" and the stronger the assumption of similarity between studies.
#'
#' @examples
#' \donttest{
#' # Create an mb.network object from a dataset
#' copdnet <- mb.network(copd)
#'
#' # Run an MBNMA model with a log-linear time-course
#' loglin <- mb.run(copdnet,
#'   fun=tloglin(pool.rate="rel", method.rate="common"),
#'   rho="dunif(0,1)", covar="varadj")
#'
#' # Predict responses using the original dataset to estimate the network reference
#' #treatment response
#' df.ref <- copd[copd$treatment=="Placebo",]
#' predict <- predict(loglin, times=c(0:20), E0=0, ref.resp=df.ref)
#'
#' # Plot the predicted responses with observations displayed on plot as green shading
#' plot(predict, disp.obs=TRUE, overlay.ref=FALSE, col="green")
#'
#' # Plot the predicted responses with the median network reference treatment response overlayed
#' #on the plot
#' plot(predict, disp.obs=FALSE, overlay.ref=TRUE)
#'
#' # Plot predictions from an NMA calculated between different time-points
#' plot(predict, overlay.nma=c(5,10), n.iter=20000)
#' plot(predict, overlay.nma=c(15,20), n.iter=20000)
#' # Time-course fit may be less good at 15-20 weeks follow-up
#' }
#'
#' @export
plot.mb.predict <- function(x, disp.obs=FALSE, overlay.ref=TRUE,
                            overlay.nma=NULL, method="random",
                            col="blue", max.col.scale=NULL, treat.labs=NULL, ...) {

  # Run checks
  argcheck <- checkmate::makeAssertCollection()
  checkmate::assertClass(x, "mb.predict", add=argcheck)
  checkmate::assertLogical(disp.obs, len=1, add=argcheck)
  checkmate::assertLogical(overlay.ref, len=1, add=argcheck)
  checkmate::assertChoice(method, choices = c("common", "random"), add=argcheck)
  checkmate::reportAssertions(argcheck)

  pred <- x[["summary"]]

  data <- pred[[1]]
  data[["treat"]] <- rep(0, nrow(data))
  data <- data[0,]
  for (i in seq_along(pred)) {
    cut <- pred[[i]]
    #cut[["treat"]] <- rep(as.numeric(names(pred)[i]), nrow(cut))
    cut[["treat"]] <- rep(names(pred)[i], nrow(cut))
    data <- rbind(data, cut)
  }

  # Keep only relevant columns
  data <- data[, which(names(data) %in%
                         c("time", "2.5%", "50%", "97.5%", "treat"))]


  # Add treatment labels
  if (!is.null(treat.labs)) {
    data$treat <- factor(data$treat, labels=treat.labs)
  } else if (is.null(treat.labs)) {
    treat.labs <- names(pred$summary)
  }

  # Required for overlaying ref treatment effect
  if (overlay.ref==TRUE) {
    ref.treat <- x$network$treatments[1]

    if (!(ref.treat %in% names(pred))) {
      stop(paste0("Reference treatment (", ref.treat, ") must be included in `x` in order for it to be plotted"))
    }

    #data[["ref.median"]] <- rep(pred[["1"]][[/"50%"]], length(pred))
    data[["ref.median"]] <- rep(pred[[ref.treat]][["50%"]], length(pred))
    #data <- data[data$treat!=1,]
    data <- data[data$treat!=ref.treat,]
    treat.labs <- treat.labs[treat.labs!=ref.treat]
    x[["summary"]][[ref.treat]] <- NULL
  }


  # Plot ggplot axes
  g <- ggplot2::ggplot(data, ggplot2::aes(x=time, y=`50%`, ymin=`2.5%`, ymax=`97.5%`), ...)

  # Add shaded regions for observations in original dataset
  if (disp.obs==TRUE) {
    #g <- disp.obs(g, network, x, col=col, max.col.scale=max.col.scale)
    g <- disp.obs(g, predict=x, col=col, max.col.scale=max.col.scale)
  }

  # Overlay reference treatment effect
  if (overlay.ref==TRUE) {
    g <- g + ggplot2::geom_line(ggplot2::aes(y=ref.median, colour="Predicted reference"), size=1)
    message(paste0("Reference treatment in plots is ", ref.treat))
  }
  colorvals <- c("Predicted reference"="red")

  # Add overlayed lines and legends
  g <- g + ggplot2::geom_line(ggplot2::aes(linetype="Predicted MBNMA")) +
    ggplot2::geom_line(ggplot2::aes(y=`2.5%`, linetype="MBNMA 95% CrI")) +
    ggplot2::geom_line(ggplot2::aes(y=`97.5%`, linetype="MBNMA 95% CrI"))


  if (!is.null(overlay.nma)) {
    # CHECKS
    checkmate::assertNumeric(overlay.nma, lower=0.0001, upper = max(x$times), len=2, sorted = TRUE)

    if (overlay.ref!=TRUE) {
      stop("'overlay.ref' must be TRUE if overlay.nma is used, to ensure prediction of reference treatment response is correct")
    }
    if ("classes" %in% names(x$network)) {
      if (sum(names(pred$summary) %in% x$network$classes)>2) {
        stop("'overlay.nma' cannot be used with predictions at class level")
      }
    }

    # Run split NMA
    nma <- overlay.nma(x, incl.range=overlay.nma, method=method, link=x$link, ...)

    predtrt <- nma$pred.df

    # Write caption
    capt <- paste0(" effects NMA model\nResDev = ", nma$totresdev,
                   "; Ndat = ", nma$ndat,
                   "; DIC = ", nma$dic)
    if (method=="common") {
      capt <- paste0("Common", capt)
    } else if (method=="random") {
      capt <- paste0("Random", capt, "\nBetween-study SD = ", nma$sd)
    }

    g <- g + ggplot2::geom_rect(ggplot2::aes(ymin=`2.5%`, ymax=`97.5%`, xmin=overlay.nma[1], xmax=overlay.nma[2],
                                             fill="NMA (95%CrI)"),
                                alpha=0.8, data=predtrt) +
      ggplot2::geom_segment(ggplot2::aes(y=`50%`, yend=`50%`, x=overlay.nma[1], xend=overlay.nma[2], color="Predicted NMA"),
                            data=predtrt, size=1) +
      ggplot2::labs(caption=capt) +
      ggplot2::scale_fill_manual(name="", values=c("NMA (95%CrI)"="grey"))

    colorvals <- c("Predicted reference"="red", "Predicted NMA"="gray0")

  }

  g <- g + ggplot2::facet_wrap(~factor(treat)) +
    ggplot2::labs(y="Predicted response", x="Time")

  linetypevals <- c("Predicted MBNMA"="solid",
                    "MBNMA 95% CrI"="dashed")
  g <- g + ggplot2::scale_linetype_manual(name="",
                                          values=linetypevals)

  g <- g + ggplot2::scale_color_manual(name="",
                                       values=colorvals) +
    theme_mbnma()

  return(g)
}





#' Print summary information from an mb.predict object
#'
#' @param x An object of `class("mb.predict")` generated by `predict.mbnma()`
#' @param ... further arguments passed to or from other methods
#'
#' @export
print.mb.predict <- function(x, ...) {

  sum.df <- summary(x)

  sumlist <- x[["summary"]]

  if (!(x$network$treatments[1] %in% names(sumlist))) {
    err <- "Responses have not been predicted for the network reference treatment\n"
    if ("classes" %in% names(x$network)) {
      if (!x$network$classes[1] %in% names(sumlist)) {
        cat(err)
      }
    } else {
      cat(err)
    }
  }

  msg <- paste0("Predicted responses at ", nrow(sum.df), " different follow-up times ",
                "for treatments: ", paste(names(sumlist), collapse=", "), "\n\n")
  cat(msg)

  print(sum.df, digits = max(3, getOption("digits")-3), max=ncol(sum.df)*10, ...)
}





#' Prints summary of mb.predict object
#'
#' Prints a summary table of the mean of MCMC iterations at each time point
#' for each treatment
#'
#' @param object An object of class `"mb.predict"`
#' @param ... further arguments passed to or from other methods
#'
#' @return A matrix containing times at which responses have been predicted (`time`)
#' and an additional column for each treatment for which responses have been predicted.
#' Each row represents mean MCMC predicted responses for each treatment at a particular
#' time.
#'
#' @examples
#' \donttest{
#' # Define network
#' network <- mb.network(obesityBW_CFB, reference="plac")
#'
#' # Run an MBNMA with a quadratic time-course function
#' quad <- mb.run(network,
#'   fun=tpoly(degree=2, pool.1="rel", method.1="common",
#'     pool.2="rel", method.2="common"),
#'   intercept=TRUE)
#'
#' # Predict responses
#' pred <- predict(quad, times=c(0:50), treats=c(1:5),
#'   ref.resp = network$data.ab[network$data.ab$treatment==1,],
#'   E0=10)
#'
#' # Generate summary of predictions
#' summary(pred)
#' }
#' @export
summary.mb.predict <- function(object, ...) {
  checkmate::assertClass(object, "mb.predict")

  sumlist <- object[["summary"]]

  time <- sumlist[[1]]$time
  treats <- names(sumlist)
  #treats <- unlist(lapply(treats, FUN=function(x) paste0("treat_", x)))
  sum.df <- time
  for (i in seq_along(sumlist)) {
    sum.df <- cbind(sum.df, sumlist[[i]]$mean)
  }
  #sum.df <- data.frame(sum.df)
  colnames(sum.df) <- c("time", treats)

  #print(sum.df, digits = max(3, getOption("digits")-3), max=ncol(sum.df)*10)

  #print(sum.df)
  #return(invisible(sum.df))
  return(sum.df)
}





#' Rank predictions at a specific time point
#'
#' @param x an object of `class("mb.predict")` that contains predictions from an MBNMA model
#' @param time a number indicating the time point at which predictions should be ranked. It must
#' be one of the time points for which predictions in `x` are available.
#' @param treats A character vector of treatment/class names for which responses have been predicted
#'   in `x` As default, rankings will be calculated for all treatments/classes in `x`.
#' @inheritParams rank.mbnma
#' @param ... Arguments to be passed to methods
#'
#' @return Returns an object of `class("mb.rank")` containing ranked predictions
#'
#' @examples
#' \donttest{
#' # Create an mb.network object from a dataset
#' network <- mb.network(osteopain)
#'
#' # Run an MBNMA model with an Emax time-course
#' emax <- mb.run(network,
#'   fun=temax(pool.emax="rel", method.emax="common",
#'     pool.et50="abs", method.et50="common"))
#'
#' # Predict responses using a stochastic baseline (E0) and a distribution for the
#' #network reference treatment
#' preds <- predict(emax, E0=7,
#'   ref.resp=list(emax=~rnorm(n, -0.5, 0.05)))
#'
#' # Rank predictions at latest predicted time-point
#' rank(preds, lower_better=TRUE)
#'
#'
#' #### Rank predictions at 5 weeks follow-up ####
#'
#' # First ensure responses are predicted at 5 weeks
#' preds <- predict(emax, E0=7,
#'   ref.resp=list(emax=~rnorm(n, -0.5, 0.05)),
#'   times=c(0,5,10))
#'
#' # Rank predictions at 5 weeks follow-up
#' ranks <- rank(preds, lower_better=TRUE, time=5)
#'
#' # Plot ranks
#' plot(ranks)
#'
#' }
#' @export
rank.mb.predict <- function(x, time=max(x$summary[[1]]$time), lower_better=FALSE,
                            treats=names(x$summary), ...) {

  # Checks
  argcheck <- checkmate::makeAssertCollection()
  checkmate::assertClass(x, "mb.predict", add=argcheck)
  checkmate::assertNumeric(time, len=1, lower=0, add=argcheck)
  checkmate::assertLogical(lower_better, null.ok=FALSE, len=1, add=argcheck)
  checkmate::assertCharacter(treats, null.ok=FALSE, add=argcheck)
  checkmate::reportAssertions(argcheck)

  if (!all(treats %in% names(x$summary))) {
    stop("'treats' includes treatments/classes not included in 'x'")
  }
  if (!time %in% x$summary[[1]]$time) {
    stop("'time' is not a time point given in 'x'")
  }

  # Subset x for treats
  index <- which(names(x$summary) %in% treats)
  x$summary <- x$summary[index]
  x$pred.mat <- x$pred.mat[index]


  #### Compute rankings ####

  # Get time point of interest from x
  index <- which(x$summary[[1]]$time %in% time)
  rank.mat <- lapply(x$pred.mat, FUN=function(k) {k[,index]})
  # rank.mat <- t(do.call(rbind, rank.mat))
  rank.mat <- do.call(cbind, rank.mat)

  # Compute rankings for each iteration
  rank.mat <- t(apply(rank.mat, MARGIN=1, FUN=function(k) {
    order(order(k, decreasing = !lower_better), decreasing=FALSE)
  }))
  colnames(rank.mat) <- treats

  # Store rankings
  rank.result <- list(temp=
                        list("summary"=sumrank(rank.mat),
                             "prob.matrix"=calcprob(rank.mat, treats=treats),
                             "rank.matrix"=rank.mat)
  )
  names(rank.result) <- paste0("Predictions at time = ", time)

  class(rank.result) <- "mb.rank"
  return(rank.result)
}
