#' Analytic variance calculation
#'
#' Helper for `fui`. Bootstrapped variance calculation.
#'
#' @param fmm Object of class "fastFMM".
#' @param mum Massively univariate model output of class "massmm"
#' @param nknots_min Integer passed from `fui`.
#' @param nknots_min_cov Integer passed from `fui`.
#' @param nknots_fpca Integer passed from `fui`.
#' @param betaHat Numeric matrix of smoothed coefficients
#' @param data Data frame of values to fit
#' @param L integer, number of points on functional domain
#' @param n_boots Integer, number of bootstrap replications.
#' @param boot_type Character, bootstrapping protocol.
#' @param seed Integer, random seed for reproducibility.
#' @param parallel Logical, whether to use parallel processing
#' @param n_cores Integer, number of cores for parallelization.
#' @param smooth_method Character, passed from `fui`
#' @param splines Character, passed from `fui`
#' @param silent Logical, suppresses messages when `TRUE`. Passed from `fui`.
#'
#' @return List of final outputs of `fui`
#'
#' @import Matrix
#' @importFrom stats smooth.spline quantile
#' @importFrom Rfast spdinv rowMaxs
#' @importFrom parallel mclapply
#' @importFrom methods new
#' @importFrom mvtnorm rmvnorm
#' @importFrom stringr str_remove
#' @keywords internal
var_bootstrap <- function(
  fmm,
  mum,
  nknots_min,
  nknots_min_cov,
  nknots_fpca,
  betaHat,
  data,
  L,
  n_boots,
  boot_type,
  seed,
  parallel,
  n_cores,
  smooth_method,
  splines,
  silent
) {

  if (!silent) message("Step 3: Inference (Bootstrap)")

  # 0 Warnings #################################################################

  if (fmm$concurrent) {
    if (!silent) {
      warning(
        "Bootstrap CI coverage has not been tested with concurrent models.", "\n",
        "Proceed with caution for variance/CI estimates."
      )
    }

    # Reset bootstrap type to cluster
    if (boot_type != "cluster") {
      if (!silent) {
        warning(
          "Estimation is only compatible with `boot_type == 'cluster`.", "\n",
          "Calculation will proceed with `cluster`."
        )
      }

      boot_type <- "cluster"
    }
  }

  # 1 Generate resamplings #####################################################

  # Check to see if group contains ":" which indicates hierarchical structure
  # and group needs to be specified
  group <- mum$group
  subj_id <- mum$subj_id
  argvals <- fmm$argvals

  if (grepl(":", group, fixed = TRUE)) {
    if (is.null(subj_id)) {
      # assumes the ID name is to the right of the ":"
      group <- str_remove(group, ".*:")
    } else if (!is.null(subj_id)) {
      group <- subj_id # use user specified if it exists
    } else {
      message("You must specify the argument: ID")
    }
  }

  ids <- unique(data[, group])
  id_perms <- t(
    replicate(
      n_boots,
      sample.int(length(ids), length(ids), replace = TRUE)
    )
  )
  B <- n_boots
  betaHat_boot <- array(NA, dim = c(nrow(betaHat), ncol(betaHat), B))

  family <- fmm$family
  if (is.null(boot_type)) {
    # default bootstrap type if not specified
    if (family == "gaussian") {
      boot_type <- ifelse(length(ids) <= 10, "reb", "wild")
    } else {
      boot_type <- "cluster"
    }
  }

  if (family != "gaussian" & boot_type %in% c("wild", "reb")) {
    stop(
      'Non-gaussian outcomes only supported for some bootstrap procedures.',
      '\n',
      'Set argument `boot_type` to one of the following:', '\n',
      '"parametric", "semiparametric", "cluster", "case", "residual"'
    )
  }

  message(
    "Bootstrapping procedure: ",
    as.character(boot_type)
  )

  # 2 Bootstrap resampling #####################################################

  # original way
  if (!silent)
    message("Step 3.1: Bootstrap resampling")

  if (boot_type == "cluster") {

    # 2A Cluster bootstrap =====================================================

    # Do bootstrap
    pb <- progress_bar$new(total = B)
    for (b in 1:B) {

      pb$tick()
      # take one of the randomly sampled (and unique) combinations
      id_perm <- id_perms[b,]
      dat_idx <- new_ids <- vector(length = length(id_perm), "list")

      for(ii in 1:length(id_perm)) {
        dat_idx[[ii]] <- which(data[, group] == ids[id_perm[ii]])
        # subj_b is now the pseudo_id
        new_ids[[ii]] <- rep(ii, length(dat_idx[[ii]]))
      }

      dat_idx <- do.call(c, dat_idx)
      new_ids <- do.call(c, new_ids)
      df2 <- data[dat_idx, ] # copy dataset with subset of rows we want
      df2[, subj_id] <- new_ids # replace old IDs with new IDs

      # Fastest fit is analytic = F, var = F
      fit_boot <- fui(
        formula = fmm$formula,
        data = df2,
        family = fmm$family,
        argvals = argvals,
        var = FALSE,
        analytic = FALSE,
        parallel = parallel,
        silent = TRUE,
        nknots_min = nknots_min,
        nknots_min_cov = nknots_min_cov,
        smooth_method = smooth_method,
        splines = splines,
        residuals = fmm$residuals,
        subj_id = mum$subj_id,
        n_cores = n_cores,
        concurrent = fmm$concurrent
      )
      # Save fixed coefficients
      betaHat_boot[, , b] <- fit_boot$betaHat
    }
    rm(fit_boot, df2, dat_idx, new_ids)

    if (!silent)
      message('Step 3.2: (Smoothing skipped for boot_type = "cluster"')

  } else {

    # 2B lmeresampler bootstrap ================================================

    # lmeresampler() way
    # Use original amount. Do not constrain by number of unique resampled
    # types here because we cannot construct rows to resample.

    B <- n_boots
    betaHat_boot <- betaTilde_boot <- array(
      NA, dim = c(nrow(betaHat), ncol(betaHat), B)
    )

    model_formula <- as.character(fmm$formula)

    pb <- progress_bar$new(total = L)
    for(l in 1:L) {
      pb$tick()
      data$Yl <- unclass(data[, fmm$out_index][, fmm$argvals[l]])
      fit_uni <- suppressMessages(
        lmer(
          formula = stats::as.formula(paste0("Yl ~ ", model_formula[3])),
          data = data,
          control = lmerControl(
            optimizer = "bobyqa", optCtrl = list(maxfun = 5000))
        )
      )

      # set seed to make sure bootstrap replicate (draws) are correlated
      # across functional domains
      set.seed(seed)

      if (boot_type == "residual") {
        # for residual bootstrap to avoid singularity problems
        boot_sample <- lmeresampler::bootstrap(
          model = fit_uni,
          B = B,
          type = boot_type,
          rbootnoise = 0.0001
        )$replicates
        betaTilde_boot[, l, ] <- t(as.matrix(boot_sample[,1:nrow(betaHat)]))

      } else if (boot_type %in% c("wild", "reb", "case")) {
        # for case
        flist <- lme4::getME(fit_uni, "flist")
        re_names <- names(flist)
        clusters_vec <- c(rev(re_names), ".id")

        # for case bootstrap, we only resample at first (subject level)
        # because doesn't make sense to resample within-cluster for
        # longitudinal data
        resample_vec <- c(TRUE, rep(FALSE, length(clusters_vec) - 1))
        boot_sample <- lmeresampler::bootstrap(
          model = fit_uni,
          B = B,
          type = boot_type,
          resample = resample_vec, # only matters for type = "case"
          hccme = "hc2", # wild bootstrap
          aux.dist = "mammen", # wild bootstrap
          reb_type = 0
        )$replicates # for reb bootstrap only
        betaTilde_boot[, l, ] <- t(as.matrix(boot_sample[, 1:nrow(betaHat)]))
      } else {
        use.u <- ifelse(boot_type == "semiparametric", TRUE, FALSE)
        betaTilde_boot[, l, ] <- t(
          lme4::bootMer(
            x = fit_uni, FUN = function(.) {fixef(.)},
            nsim = B,
            seed = seed,
            type = boot_type,
            use.u = use.u
          )$t
        )
      }
    }

    suppressWarnings(rm(boot_sample, fit_uni))
    # smooth across functional domain

    # 2B.1 Smooth bootstrap estimates ------------------------------------------

    if (!silent)
      message("Step 3.2: Smooth Bootstrap estimates")

    nknots <- min(round(L / 2), nknots_min)
    for (b in 1:B) {
      betaHat_boot[, , b] <- t(
        apply(
          betaTilde_boot[, , b],
          1,
          function(x)
            gam(
              x ~ s(argvals, bs = splines, k = (nknots + 1)),
              method = smooth_method
            )$fitted.values
        )
      )
    }

    rm(betaTilde_boot)
  }

  # Obtain bootstrap variance
  betaHat_var <- array(NA, dim = c(L, L, nrow(betaHat)))
  ## account for within-subject correlation
  for(r in 1:nrow(betaHat)) {
    betaHat_var[, , r] <- 1.2 * var(t(betaHat_boot[r, , ]))
  }

  # 3 Joint CIs ################################################################

  # Obtain qn to construct joint CI using the fast approach
  if (!silent)
    message("Step 3.3: Joint confidence interval construction")

  qn <- rep(0, length = nrow(betaHat))
  ## sample size in simulation-based approach
  N <- 10000
  # set seed to make sure bootstrap replicate (draws) are correlated across
  # functional domains
  set.seed(seed)

  for(i in 1:length(qn)) {
    est_bs <- t(betaHat_boot[i, , ])
    # suppress sqrt(Eigen$values) NaNs
    fit_fpca <- suppressWarnings(
      refund::fpca.face(est_bs, knots = nknots_fpca)
    )
    ## extract estimated eigenfunctions/eigenvalues
    phi <- fit_fpca$efunctions
    lambda <- fit_fpca$evalues
    K <- length(fit_fpca$evalues)

    ## simulate random coefficients
    # generate independent standard normals
    theta <- matrix(stats::rnorm(N * K), nrow = N, ncol = K)
    if (K == 1) {
      # scale to have appropriate variance
      theta <- theta * sqrt(lambda)
      # simulate new functions
      X_new <- tcrossprod(theta, phi)
    } else {
      # scale to have appropriate variance
      theta <- theta %*% diag(sqrt(lambda))
      # simulate new functions
      X_new <- tcrossprod(theta, phi)
    }
    # add back in the mean function
    x_sample <- X_new + t(fit_fpca$mu %o% rep(1,N))
    # standard deviation: apply(x_sample, 2, sd)
    Sigma_sd <- Rfast::colVars(x_sample, std = TRUE, na.rm = FALSE)
    x_mean <- colMeans(est_bs)
    un <- rep(NA, N)
    for(j in 1:N) {
      un[j] <- max(abs((x_sample[j, ] - x_mean) / Sigma_sd))
    }
    qn[i] <- stats::quantile(un, 0.95)
  }

  return(
    list(
      betaHat = betaHat,
      betaHat_var = betaHat_var,
      qn = qn,
      aic = mum$AIC_mat,
      betaTilde = mum$betaTilde,
      var_random = mum$var_random,
      designmat = mum$designmat,
      residuals = mum$residuals,
      H = NULL,
      R = NULL,
      G = NULL,
      GHat = NULL,
      Z = mum$ztlist,
      argvals = fmm$argvals,
      randeffs = mum$randeffs,
      se_mat = mum$se_mat
    )
  )
}
