args = commandArgs(trailingOnly=TRUE)

num_simulations <- as.numeric(args[1])
slurm_id <- as.numeric(args[2])

# on local computer
#source("R/bpcp2sample.R")
#source("R/delta2samp.R")
#source("R/kmciFunctions.R")

# on HPC
source("bpcp2sample.R")
source("delta2samp.R")
source("kmciFunctions.R")

library(tidyverse)
library(survival)
library(survminer)
library(scales)



# Calculate survival function
times<- seq(from=0,to=10, by=0.001)

prop_Lev<-c(.5,.2,.3) 
lambda_Lev<-c(0.5,5,0.0001) 
delay_Lev<-c(1.3,5.8,10)
S_Lev <- sapply(times,function(t) sum((1-pexp(t-delay_Lev,rate=lambda_Lev)) * prop_Lev))

prop_Lev5FU<-c(.26,1-.26)
lambda_Lev5FU<-c(4,.0001)
delay_Lev5FU<- c(1.3,0)
S_Lev5FU <- sapply(times,function(t) sum((1-pexp(t-delay_Lev5FU,rate=lambda_Lev5FU)) * prop_Lev5FU))

# find point of intersection


## Define the function to calculate the difference in survival probabilities
diff_surv_probs <- function(t) {
  S_Lev <- sum((1-pexp(t-delay_Lev,rate=lambda_Lev)) * prop_Lev)
  S_Lev5FU <- sum((1-pexp(t-delay_Lev5FU,rate=lambda_Lev5FU)) * prop_Lev5FU)
  
  # Calculate the difference
  S_Lev - S_Lev5FU
  
}

## Use uniroot to find the intersection point
intersection <- uniroot(diff_surv_probs, interval = range(2,5))
intersection_time <- intersection$root


# Simulate from mixture of exponentials (single arm)
mixexp<-function(n,prop,rate,delay=0){
  # n=number in the simulated data set
  # prop=vector of proportions in each type
  # rate= the vector of expoential rates 
  # delay = a delay for each group
  k<- length(prop)
  x<- rmultinom(1,size=n,prob=prop)
  CX<- cumsum(x)
  y<- rep(NA,n)
  for (i in 1:k){
    if (i==1 & x[i]>0){
      y[0:CX[i]]<- rexp(x[i],rate=rate[i]) + delay[i]
    } else if (x[i]>0){
      y[(CX[i-1]+1):CX[i]]<- rexp(x[i],rate=rate[i]) + delay[i] 
    }
  }
  y
}


# Wrapper Function for simulating from a mixture of exponentials (both arms)
mixexp2<-function(nC,propC,rateC,delayC,
                  nT,propT,rateT,delayT,
                  censorRange=c(0.5,1.2)){
  # censorRange is the range for uniform censoring
  #
  # simulate a two-sample mixture of exponentials
  # first simulate the Control Arm
  yC<- mixexp(nC,propC,rateC,delayC)
 # simulate Trt Arm
  yT<- mixexp(nT,propT,rateT,delayT)

  y<-c(yC,yT)
  group<-c(rep(0,nC),rep(1,nT))
  cens<- runif(nC+nT,min=censorRange[1],
               max=censorRange[2])
  time<- pmin(y,cens)
  status<-rep(0,nC+nT)
  status[y==time]<- 1
  # survival time for control at intersection time
  d<-data.frame(time=time,status=status,group=group)
  
  d
}


iteration <- function(iteration=1,nC=10,propC=prop_Lev,rateC=lambda_Lev,delayC=delay_Lev,nT=8,propT=prop_Lev5FU,rateT=lambda_Lev5FU,delayT=delay_Lev5FU,censorRange=c(4,8),time_points=c(2.5,intersection_time,5,6,7)){
  set.seed(iteration)
  
  # generate survival times
  sim_data <- mixexp2(nC,propC,rateC,delayC,nT,propT,rateT,delayT,censorRange)
  
 
  iteration_res <- sapply(time_points, function(testtime){ 

    sapply(c("difference","ratio"), function(parmtype){
      
      # bpcp for lower
      bpcp_res_lower <- bpcp2samp(time=sim_data$time, status=sim_data$status, group=sim_data$group, testtime = testtime,parmtype=parmtype, midp=F, alternative = "greater",conf.level = 0.975)
      
      # bpcp for upper
      bpcp_res_upper <- bpcp2samp(time=sim_data$time, status=sim_data$status, group=sim_data$group,  testtime = testtime,parmtype=parmtype, midp=F, alternative = "less",conf.level = 0.975)
      
      # bpcp mid-p
      bpcp_midp_res_lower <- bpcp2samp(time=sim_data$time, status=sim_data$status, group=sim_data$group,  testtime = testtime,parmtype=parmtype, midp=T, alternative = "greater",conf.level = 0.975)
      
      bpcp_midp_res_upper <- bpcp2samp(time=sim_data$time, status=sim_data$status, group=sim_data$group,  testtime = testtime,parmtype=parmtype, midp=T, alternative = "less",conf.level = 0.975)
      
      
      # delta with greenwood & zero/one adjustment
      delta_b_res_lower <- delta2samp(time=sim_data$time, status=sim_data$status, group=sim_data$group,  testtime = testtime,parmtype=parmtype, alternative = "greater",conf.level = 0.975, zero.one.adjustment=T,method="standard")
      
      delta_b_res_upper <- delta2samp(time=sim_data$time, status=sim_data$status, group=sim_data$group,  testtime = testtime,parmtype=parmtype, alternative = "less",conf.level = 0.975,zero.one.adjustment=T,method="standard")
      
      # delta with adj hybrid (borkowf), zero.one adjustment
      delta_01_res_lower <- delta2samp(time=sim_data$time, status=sim_data$status, group=sim_data$group,  testtime = testtime,parmtype=parmtype, alternative = "greater",conf.level = 0.975, zero.one.adjustment=T,method="adj_hybrid")
      
      delta_01_res_upper <- delta2samp(time=sim_data$time, status=sim_data$status, group=sim_data$group,  testtime = testtime,parmtype=parmtype, alternative = "less",conf.level = 0.975,zero.one.adjustment=T,method="adj_hybrid")
      
      return(c(bpcp_res_lower$conf.int[1],bpcp_res_upper$conf.int[2],bpcp_midp_res_lower$conf.int[1],bpcp_midp_res_upper$conf.int[2],delta_b_res_lower$conf.int[1],delta_b_res_upper$conf.int[2],delta_01_res_lower$conf.int[1],delta_01_res_upper$conf.int[2]))
      
    })
    
  })
  
  return(as.vector(iteration_res))
  
}

time_points_options=c(2,5,6,7)
if(slurm_id==(length(time_points_options)+1)){
time_points=c(intersection_time)
} else{
    time_points=time_points_options[slurm_id]	
}


sim_res <- matrix(nrow=16*length(time_points),ncol = num_simulations)

start.time <- Sys.time()
for(i in 1:num_simulations){
  sim_res[,i] <- iteration(iteration=i, time_points=time_points)
  if(i %% (num_simulations/10) == 0){
    cat(paste0(100*i/num_simulations,"% complete"))
    cat("\n")
    cat("-------------------")
    cat("\n")
  }
  
}
end.time <- Sys.time()

time.taken <- round(interval(start.time,end.time) / hours(),2)
cat(paste0("Simulation completed in ", time.taken," hours."))
cat("\n")

# save simulation results

sim_res_df <- sim_res %>% as.data.frame()
names(sim_res_df) <- paste0("Iteration",1:num_simulations)

sim_res_df$Timepoint <- rep(time_points,each=16)
sim_res_df$method <- rep(c("Melding","Melding (mid-p version)","Delta Method (Greenwood)","Delta Method (Borkowf Adjusted Hybrid)"),each=2)
sim_res_df$parmtype <- rep(c("difference","ratio"),each=8)
sim_res_df$type <- rep(c("lower","upper"))

sim_res_df_longer <- sim_res_df %>% pivot_longer(
  cols=contains("Iteration"),
  names_to="Iteration",
  names_prefix = "Iteration",
  values_to="ConfidenceLimit"
)

saveRDS(sim_res_df_longer, file=paste0(getwd(),"/results_cancer_sim_",slurm_id,".RDS")) 
#writexl::write_xlsx(sim_res_df_longer,path=paste0(getwd(),"/results_cancer_sim_",slurm_id,".xlsx"))
