#' Calculate Weights
#'
#' Calculate weights using three methods: IPW, Calibration, and Calibration+IPW
#'
#' Calculates weights intended to reduce the sampling bias present in All of Us. Three versions
#' of weights are calculated from different reweighting strategies: IPW, Calibration, and 
#' Calibration+IPW.
#'
#' @param sample_a data.frame with representative sample
#' @param sample_b data.frame with All of Us sample
#' @param method string or string vector specifying weighting method to use: "ipw", "cal", and "ipw+cal"
#' @param aux_variables character vector with names of calibration variables
#' @param study_variables character vector with names of study variables
#' @param weight character vector with name of the weight variable in sample_a
#' @param strata character vector with name of the strata variable in sample_a
#' @param psu character vector with name of the primary sampling units variable in sample_a
#'
#' @return list of data.frame with added (or replaced) weight columns and survey designs
#' 
#' @examples
#' # Prepare the NHIS data
#' calVars <- c(
#'   "SEX_A_R", "AGEP_A_R", "HISPALLP_A_R", "ORIENT_A_R", "HICOV_A_R", "EDUCP_A_R", "REGION_R",
#'   "EMPLASTWK_A_R", "HOUTENURE_A_R", "MARITAL_A_R"
#' )
#' stuVars <- "DIBTYPE_A_R"
#' vars_dummies <- c("AGEP_A_R","HISPALLP_A_R","EDUCP_A_R","REGION_R")
#' nhis_keep_vars <- c("PPSU","PSTRAT","WTFA_A")
#' nhis_imputed <- impute_data(nhis_processed, c(calVars, stuVars), nhis_keep_vars)
#' nhis_dummied <- dummies(nhis_imputed, vars=paste0(vars_dummies, '_I'))
#' factor_vars <- setdiff(names(nhis_dummied), nhis_keep_vars)
#' nhis_dummied[factor_vars] <- lapply(nhis_dummied[factor_vars], as.factor)
#' 
#' # Prepare the synthetic All of Us data
#' aou_imputed <- impute_data(aou_synthetic, c(calVars, stuVars))
#' aou_dummied <- dummies(aou_imputed, vars=paste0(vars_dummies, '_I'))
#' aou_dummied[] <- lapply(aou_dummied, as.factor)
#' 
#' # Calculate IPW weights using NHIS data and applied to All of Us
#' weights_df <- calculate_weights(
#'   nhis_dummied, 
#'   aou_dummied, 
#'   'ipw',
#'   paste0(calVars, '_I'), 
#'   paste0(stuVars, '_I'), 
#'   weight='WTFA_A',
#'   strata='PSTRAT',
#'   psu='PPSU'
#' )
#' 
#' @import dplyr
#' @importFrom stringr str_detect
#' @importFrom stats reformulate setNames
#' @importFrom purrr map_dfr
#' @importFrom survey svytotal svydesign svymean
#' @importFrom glue glue
#' @importFrom nonprobsvy nonprob control_sel
#' @export
calculate_weights <- function(sample_a, sample_b, method, aux_variables, study_variables, weight, strata, psu) {
  
  # Make sure method is valid
  if(length(setdiff(method, c('ipw', 'cal', 'cal+ipw'))) > 0)
    stop('method argument can only contain "ipw", "cal", and/or "cal+ipw"')
  
  all_vars <- c(aux_variables, study_variables)
  
  # Prepare samples
  Sample_B <- sample_b
  
  # we prepare two versions of Sample A, one with the study variable and another without
  ## Function input in the weight calculation function
  Sample_A1 <- sample_a
  ## Function input for the propensity score model; it should include the auxiliary variables only
  Sample_A2 <- sample_a %>% select(-all_of(study_variables))
  
  # change variables to factors for svytotals
  Sample_B <- Sample_B %>%  
    mutate(across(all_of(all_vars), as.factor))
  Sample_A1 <- Sample_A1 %>%  
    mutate(across(all_of(all_vars), as.factor))
  
  ## define the survey design of the NHIS data (Sample_A1)
  survey_design_A1 <- svydesign(
    ids = reformulate(psu),
    strata = reformulate(strata),
    weights = reformulate(weight),
    data = Sample_A1,
    nest = TRUE
  )
  
  ## define the survey design of the NHIS data  (Sample_A2)
  survey_design_A2 <- svydesign(
    ids = reformulate(psu),
    strata = reformulate(strata),
    weights = reformulate(weight),
    data = Sample_A2,
    nest = TRUE
  )
 
  # list of weights
  weights <- list()
  
  ### IPW Method
  if('ipw' %in% method) {
    est_ipw <- nonprob(
      selection = reformulate(aux_variables),
      target = reformulate(study_variables[1]), ## Choose any study variable, does not affect weight calculation
      svydesign = survey_design_A2,
      data = Sample_B,
      method_selection = "logit" ## Other methods include "probit", or "cloglog"
    )
    
    weights[['ipw']] <- est_ipw$ipw_weights
  }
  
  ### Calibration weights
  if('cal' %in% method | 'cal+ipw' %in% method) {
    # Loop through calibration variables and compute totals
    cal_totals <- map_dfr(aux_variables, ~{
      estm <- as.data.frame(svytotal(as.formula(paste0("~", .x)), survey_design_A1))
      data.frame(
        VAR = .x,
        VARNAME = rownames(estm),
        CALTOT = estm$total,
        stringsAsFactors = FALSE
      )
    }) %>%
      left_join(tibble(VAR = aux_variables), by = "VAR")
    
    # calculate calibration intercept - group by VAR and sum CALTOT
    cal_totals_grouped <- cal_totals %>%
      group_by(VAR) %>%
      summarise(
        CALTOT = sum(CALTOT),
        .groups = "drop"
      ) #%>%
      #filter(!str_detect(VAR, "\\d+$"))
    
    # finalize cal_totals so it includes calibration totals only
    cal_totals <- cal_totals %>%
      filter(!str_detect(VARNAME, "0$")) %>%
      filter(!str_detect(VAR, "1$"))%>% ## drop first category from each auxiliary variable
      select(-VARNAME)
    
    pop_totals_A <- c(`(Intercept)` = cal_totals_grouped$CALTOT[1], setNames(cal_totals$CALTOT, cal_totals$VAR))
    model_formula_A <- reformulate(cal_totals$VAR)
    
    # Append 1 to the end of variables in pop_totals_A. Since they are factors, these are
    # added when converting from model_formula_A into the model.matrix. Exclude intercept.
    idx <- !grepl('Intercept', names(pop_totals_A))
    names(pop_totals_A)[idx] <- paste0(names(pop_totals_A[idx]), '1')
  }
  
  if('cal' %in% method) {
    ## Calibration - Sample A is not needed, but calibration totals should be defined
    est_cal <- nonprob(
      selection = model_formula_A,
      target = reformulate(study_variables[1]),
      data = Sample_B,
      pop_totals = dput(pop_totals_A),
      method_selection = "logit"
    )
    
    # read the calculated weights
    weights[['cal']] <- est_cal$ipw_weights
  }
  
  if('cal+ipw' %in% method) {
    est_calipw <- nonprob(
      selection = model_formula_A,
      target = reformulate(study_variables[1]),
      svydesign = survey_design_A2,
      data = Sample_B,
      method_selection = "logit",
      control_selection = control_sel(gee_h_fun = 1, est_method = "gee")
    )
    
    # read the calculated weights
    weights[['calipw']] <- est_calipw$ipw_weights
  }
  
  # Add weights to Sample_B
  for(.x in names(weights)) {
    Sample_B[[paste0(.x, '_weight')]] <- weights[[.x]]
  }
  
  return(Sample_B)
  
}
