# This long script generates from simple slendr models:
#  - constant N
#  - step reduction of N
#  - step increase of N
#  - exponential increase of N
#  - exponential decrease of N
#
# Those models are configured in both forward and backward time specifications,
# so each model is specified twice.
#
# It then runs those slendr model configurations through both SLiM and msprime
# backends bundled with the slendr R package. Finally, an allele frequency
# spectrum is estimated from tree sequence files saved by both backends for
# each model variant (forward and backward). The AFS are then compared for each
# of the two backends and we make sure that forward and backward configurations
# of the same model give *exactly* the same AFS (it has to be exactly the same
# because it is, in fact, generated from the same theoretical model). Then, we
# also compare the AFS between SLiM and msprime runs for the same model
# configuration. These won't be exactly the same (the tree sequences are
# generated by two completely different pieces of software after all - SLiM
# and msprime Python library), but they should be *nearly* the same.

skip_if(!check_dependencies(python = TRUE))
init_env(quiet = TRUE)

RERUN <- TRUE

seed <- 42
N <- 1000
N_factor <- 5
n_samples <- 50
seq_len <- 100e6
rec_rate <- 1e-8
mut_rate <- 1e-8

# constant population size models - forward and backward direction, SLiM and msprime

forward_const_dir <- file.path(tempdir(), "forward_const")
forward_const_pop <- population("forward_const_pop", time = 1, N = N, map = FALSE)
forward_const_model <- compile_model(forward_const_pop, path = forward_const_dir, generation_time = 1,
                               overwrite = TRUE, force = TRUE, direction = "forward", simulation_length = 5000)
forward_const_samples <- schedule_sampling(forward_const_model, times = 5001, list(forward_const_pop, n_samples))

backward_const_dir <- file.path(tempdir(), "backward_const")
backward_const_pop <- population("backward_const_pop", time = 5000, N = N, map = FALSE)
backward_const_model <- compile_model(backward_const_pop, path = backward_const_dir , generation_time = 1,
                                overwrite = TRUE, force = TRUE, direction = "backward")
backward_const_samples <- schedule_sampling(backward_const_model, times = 0, list(backward_const_pop, n_samples))

const_ts <- run_slim_msprime(
  forward_const_model, backward_const_model,
  forward_const_samples, backward_const_samples,
  seq_len, rec_rate, seed, verbose = FALSE
)

# population size contraction models - forward and backward direction, SLiM and msprime

forward_contr_dir <- file.path(tempdir(), "forward_contr")
forward_contr_pop <- population("forward_contr_pop", time = 1, N = N, map = FALSE) %>%
  resize(time = 2001, N = N / N_factor, how = "step")
forward_contr_model <- compile_model(forward_contr_pop, path = forward_contr_dir, generation_time = 1,
                               overwrite = TRUE, force = TRUE, direction = "forward", simulation_length = 5000)
forward_contr_samples <- schedule_sampling(forward_contr_model, times = 5001, list(forward_contr_pop, n_samples))

backward_contr_dir <- file.path(tempdir(), "backward_contr")
backward_contr_pop <- population("backward_contr_pop", time = 5000, N = N, map = FALSE) %>%
  resize(time = 3000, N = N / N_factor, how = "step")
backward_contr_model <- compile_model(backward_contr_pop, path = backward_contr_dir, generation_time = 1,
                                overwrite = TRUE, force = TRUE, direction = "backward")
backward_contr_samples <- schedule_sampling(backward_contr_model, times = 0, list(backward_contr_pop, n_samples))

contr_ts <- run_slim_msprime(
  forward_contr_model, backward_contr_model,
  forward_contr_samples, backward_contr_samples,
  seq_len, rec_rate, seed, verbose = FALSE
)

# population size increase models - forward and backward direction, SLiM and msprime

forward_expansion_dir <- file.path(tempdir(), "forward_expansion")
forward_expansion_pop <- population("forward_expansion_pop", time = 1, N = N, map = FALSE) %>%
  resize(time = 2001, N = N * N_factor, how = "step")
forward_expansion_model <- compile_model(forward_expansion_pop, path = forward_expansion_dir, generation_time = 1,
                                   overwrite = TRUE, force = TRUE, direction = "forward", simulation_length = 5000)
forward_expansion_samples <- schedule_sampling(forward_expansion_model, times = 5001, list(forward_expansion_pop, n_samples))

backward_expansion_dir <- file.path(tempdir(), "backward_expansion")
backward_expansion_pop <- population("backward_expansion_pop", time = 5000, N = N, map = FALSE) %>%
  resize(time = 3000, N = N * N_factor, how = "step")
backward_expansion_model <- compile_model(backward_expansion_pop, path = backward_expansion_dir, generation_time = 1,
                                    overwrite = TRUE, force = TRUE, direction = "backward")
backward_expansion_samples <- schedule_sampling(backward_expansion_model, times = 0, list(backward_expansion_pop, n_samples))

expansion_ts <- run_slim_msprime(
  forward_expansion_model, backward_expansion_model,
  forward_expansion_samples, backward_expansion_samples,
  seq_len, rec_rate, seed, verbose = FALSE
)

# exponential increase  models - forward and backward direction, SLiM and msprime

forward_exp_inc_dir <- file.path(tempdir(), "forward_exp_inc")
forward_exp_inc_pop <- population("forward_exp_inc_pop", time = 1, N = N / N_factor, map = FALSE) %>%
  resize(time = 2001, end = 3001, N = N, how = "exponential")
forward_exp_inc_model <- compile_model(forward_exp_inc_pop, path = forward_exp_inc_dir, generation_time = 1,
                                   overwrite = TRUE, force = TRUE, direction = "forward", simulation_length = 5000)
forward_exp_inc_samples <- schedule_sampling(forward_exp_inc_model, times = 5001, list(forward_exp_inc_pop, n_samples))

backward_exp_inc_dir <- file.path(tempdir(), "backward_exp_inc")
backward_exp_inc_pop <- population("backward_exp_inc_pop", time = 1, N = N / N_factor, map = FALSE) %>%
  resize(time = 2001, end = 3001, N = N, how = "exponential")
backward_exp_inc_model <- compile_model(backward_exp_inc_pop, path = backward_exp_inc_dir, generation_time = 1,
                                   overwrite = TRUE, force = TRUE, direction = "forward", simulation_length = 5000)
backward_exp_inc_samples <- schedule_sampling(backward_exp_inc_model, times = 5001, list(backward_exp_inc_pop, n_samples))

exp_inc_ts <- run_slim_msprime(
  forward_exp_inc_model, backward_exp_inc_model,
  forward_exp_inc_samples, backward_exp_inc_samples,
  seq_len, rec_rate, seed, verbose = FALSE
)

# exponential decrease  models - forward and backward direction, SLiM and msprime

forward_exp_decr_dir <- file.path(tempdir(), "forward_exp_decr")
forward_exp_decr_pop <- population("forward_exp_decr_pop", time = 1, N = N, map = FALSE) %>%
  resize(time = 2001, end = 3001, N = N / N_factor, how = "exponential")
forward_exp_decr_model <- compile_model(forward_exp_decr_pop, path = forward_exp_decr_dir, generation_time = 1,
                                 overwrite = TRUE, force = TRUE, direction = "forward", simulation_length = 5000)
forward_exp_decr_samples <- schedule_sampling(forward_exp_decr_model, times = 5001, list(forward_exp_decr_pop, n_samples))

backward_exp_decr_dir <- file.path(tempdir(), "backward_exp_decr")
backward_exp_decr_pop <- population("backward_exp_decr_pop", time = 1, N = N, map = FALSE) %>%
  resize(time = 2001, end = 3001, N = N / N_factor, how = "exponential")
backward_exp_decr_model <- compile_model(backward_exp_decr_pop, path = backward_exp_decr_dir, generation_time = 1,
                                  overwrite = TRUE, force = TRUE, direction = "forward", simulation_length = 5000)
backward_exp_decr_samples <- schedule_sampling(backward_exp_decr_model, times = 5001, list(backward_exp_decr_pop, n_samples))

exp_decr_ts <- run_slim_msprime(
  forward_exp_decr_model, backward_exp_decr_model,
  forward_exp_decr_samples, backward_exp_decr_samples,
  seq_len, rec_rate, seed, verbose = FALSE
)

# load tree sequence files from msprime
msprime_forward_const_ts <- load_tree_sequence("msprime", "forward", const_ts, forward_const_model, N, rec_rate, mut_rate, seed)
msprime_backward_const_ts <- load_tree_sequence("msprime", "backward", const_ts, backward_const_model, N, rec_rate, mut_rate, seed)

msprime_forward_contr_ts <- load_tree_sequence("msprime", "forward", contr_ts, forward_contr_model, N, rec_rate, mut_rate, seed)
msprime_backward_contr_ts <- load_tree_sequence("msprime", "backward", contr_ts, backward_contr_model, N, rec_rate, mut_rate, seed)

msprime_forward_expansion_ts <- load_tree_sequence("msprime", "forward", expansion_ts, forward_expansion_model, N, rec_rate, mut_rate, seed)
msprime_backward_expansion_ts <- load_tree_sequence("msprime", "backward", expansion_ts, backward_expansion_model, N, rec_rate, mut_rate, seed)

msprime_forward_exp_inc_ts <- load_tree_sequence("msprime", "forward", exp_inc_ts, forward_exp_inc_model, N, rec_rate, mut_rate, seed)
msprime_backward_exp_inc_ts <- load_tree_sequence("msprime", "backward", exp_inc_ts, backward_exp_inc_model, N, rec_rate, mut_rate, seed)

msprime_forward_exp_decr_ts <- load_tree_sequence("msprime", "forward", exp_decr_ts, forward_exp_decr_model, N, rec_rate, mut_rate, seed)
msprime_backward_exp_decr_ts <- load_tree_sequence("msprime", "backward", exp_decr_ts, backward_exp_decr_model, N, rec_rate, mut_rate, seed)

# load tree sequence files from SLiM
slim_forward_const_ts <- load_tree_sequence("SLiM", "forward", const_ts, forward_const_model, N, rec_rate, mut_rate, seed)
slim_backward_const_ts <- load_tree_sequence("SLiM", "backward", const_ts, backward_const_model, N, rec_rate, mut_rate, seed)

slim_forward_contr_ts <- load_tree_sequence("SLiM", "forward", contr_ts, forward_contr_model, N, rec_rate, mut_rate, seed)
slim_backward_contr_ts <- load_tree_sequence("SLiM", "backward", contr_ts, backward_contr_model, N, rec_rate, mut_rate, seed)

slim_forward_expansion_ts <- load_tree_sequence("SLiM", "forward", expansion_ts, forward_expansion_model, N, rec_rate, mut_rate, seed)
slim_backward_expansion_ts <- load_tree_sequence("SLiM", "backward", expansion_ts, backward_expansion_model, N, rec_rate, mut_rate, seed)

slim_forward_exp_inc_ts <- load_tree_sequence("SLiM", "forward", exp_inc_ts, forward_exp_inc_model, N, rec_rate, mut_rate, seed)
slim_backward_exp_inc_ts <- load_tree_sequence("SLiM", "backward", exp_inc_ts, forward_exp_inc_model, N, rec_rate, mut_rate, seed)

slim_forward_exp_decr_ts <- load_tree_sequence("SLiM", "forward", exp_decr_ts, forward_exp_decr_model, N, rec_rate, mut_rate, seed)
slim_backward_exp_decr_ts <- load_tree_sequence("SLiM", "backward", exp_decr_ts, forward_exp_decr_model, N, rec_rate, mut_rate, seed)

# compute AFS from all tree sequence files - msprime
msprime_forward_const_afs <- ts_afs(msprime_forward_const_ts, polarised = TRUE)[-1]
msprime_backward_const_afs <- ts_afs(msprime_backward_const_ts, polarised = TRUE)[-1]

msprime_forward_contr_afs <- ts_afs(msprime_forward_contr_ts, polarised = TRUE)[-1]
msprime_backward_contr_afs <- ts_afs(msprime_backward_contr_ts, polarised = TRUE)[-1]

msprime_forward_expansion_afs <- ts_afs(msprime_forward_expansion_ts, polarised = TRUE)[-1]
msprime_backward_expansion_afs <- ts_afs(msprime_backward_expansion_ts, polarised = TRUE)[-1]

msprime_forward_exp_inc_afs <- ts_afs(msprime_forward_exp_inc_ts, polarised = TRUE)[-1]
msprime_backward_exp_inc_afs <- ts_afs(msprime_backward_exp_inc_ts, polarised = TRUE)[-1]

msprime_forward_exp_decr_afs <- ts_afs(msprime_forward_exp_decr_ts, polarised = TRUE)[-1]
msprime_backward_exp_decr_afs <- ts_afs(msprime_backward_exp_decr_ts, polarised = TRUE)[-1]

# compute AFS from all tree sequence files - SLiM
slim_forward_const_afs <- ts_afs(slim_forward_const_ts, polarised = TRUE)[-1]
slim_backward_const_afs <- ts_afs(slim_backward_const_ts, polarised = TRUE)[-1]

slim_forward_contr_afs <- ts_afs(slim_forward_contr_ts, polarised = TRUE)[-1]
slim_backward_contr_afs <- ts_afs(slim_backward_contr_ts, polarised = TRUE)[-1]

slim_forward_expansion_afs <- ts_afs(slim_forward_expansion_ts, polarised = TRUE)[-1]
slim_backward_expansion_afs <- ts_afs(slim_backward_expansion_ts, polarised = TRUE)[-1]

slim_forward_exp_inc_afs <- ts_afs(slim_forward_exp_inc_ts, polarised = TRUE)[-1]
slim_backward_exp_inc_afs <- ts_afs(slim_backward_exp_inc_ts, polarised = TRUE)[-1]

slim_forward_exp_decr_afs <- ts_afs(slim_forward_exp_decr_ts, polarised = TRUE)[-1]
slim_backward_exp_decr_afs <- ts_afs(slim_backward_exp_decr_ts, polarised = TRUE)[-1]

# bind together all allele frequency spectra results
afs <- dplyr::bind_rows(
  dplyr::tibble(f = msprime_forward_const_afs, n = 1:(2 * n_samples), sim = "msprime", direction = "forward", model = sprintf("constant %d", N)),
  dplyr::tibble(f = msprime_backward_const_afs, n = 1:(2 * n_samples), sim = "msprime", direction = "backward", model = sprintf("constant %d", N)),

  dplyr::tibble(f = msprime_forward_contr_afs, n = 1:(2 * n_samples), sim = "msprime", direction = "forward", model = sprintf("step contraction %d to %d", N, N / N_factor)),
  dplyr::tibble(f = msprime_backward_contr_afs, n = 1:(2 * n_samples), sim = "msprime", direction = "backward", model = sprintf("step contraction %d to %d", N, N / N_factor)),

  dplyr::tibble(f = msprime_forward_expansion_afs, n = 1:(2 * n_samples), sim = "msprime", direction = "forward", model = sprintf("step expansion %d to %d", N, N * N_factor)),
  dplyr::tibble(f = msprime_backward_expansion_afs, n = 1:(2 * n_samples), sim = "msprime", direction = "backward", model = sprintf("step expansion %d to %d", N, N * N_factor)),

  dplyr::tibble(f = msprime_forward_exp_inc_afs, n = 1:(2 * n_samples), sim = "msprime", direction = "forward", model = sprintf("exponential increase %d to %d", N / N_factor, N)),
  dplyr::tibble(f = msprime_backward_exp_inc_afs, n = 1:(2 * n_samples), sim = "msprime", direction = "backward", model = sprintf("exponential increase %d to %d", N / N_factor, N)),

  dplyr::tibble(f = msprime_forward_exp_decr_afs, n = 1:(2 * n_samples), sim = "msprime", direction = "forward", model = sprintf("exponential decrease %d to %d", N, N / N_factor)),
  dplyr::tibble(f = msprime_backward_exp_decr_afs, n = 1:(2 * n_samples), sim = "msprime", direction = "backward", model = sprintf("exponential decrease %d to %d", N, N / N_factor)),

  dplyr::tibble(f = slim_forward_const_afs, n = 1:(2 * n_samples), sim = "slim", direction = "forward", model = sprintf("constant %d", N)),
  dplyr::tibble(f = slim_backward_const_afs, n = 1:(2 * n_samples), sim = "slim", direction = "backward", model = sprintf("constant %d", N)),

  dplyr::tibble(f = slim_forward_contr_afs, n = 1:(2 * n_samples), sim = "slim", direction = "forward", model = sprintf("step contraction %d to %d", N, N / N_factor)),
  dplyr::tibble(f = slim_backward_contr_afs, n = 1:(2 * n_samples), sim = "slim", direction = "backward", model = sprintf("step contraction %d to %d", N, N / N_factor)),

  dplyr::tibble(f = slim_forward_expansion_afs, n = 1:(2 * n_samples), sim = "slim", direction = "forward", model = sprintf("step expansion %d to %d", N, N * N_factor)),
  dplyr::tibble(f = slim_backward_expansion_afs, n = 1:(2 * n_samples), sim = "slim", direction = "backward", model = sprintf("step expansion %d to %d", N, N * N_factor)),

  dplyr::tibble(f = slim_forward_exp_inc_afs, n = 1:(2 * n_samples), sim = "slim", direction = "forward", model = sprintf("exponential increase %d to %d", N / N_factor, N)),
  dplyr::tibble(f = slim_backward_exp_inc_afs, n = 1:(2 * n_samples), sim = "slim", direction = "backward", model = sprintf("exponential increase %d to %d", N / N_factor, N)),

  dplyr::tibble(f = slim_forward_exp_decr_afs, n = 1:(2 * n_samples), sim = "slim", direction = "forward", model = sprintf("exponential decrease %d to %d", N, N / N_factor)),
  dplyr::tibble(f = slim_backward_exp_decr_afs, n = 1:(2 * n_samples), sim = "slim", direction = "backward", model = sprintf("exponential decrease %d to %d", N, N / N_factor))
) %>%
  dplyr::mutate(f = as.vector(f),
                sim = factor(sim, levels = c("msprime", "slim")),
                model = factor(
                  model,
                  levels = c(sprintf("constant %d", N),
                             sprintf("step contraction %d to %d", N, N / N_factor),
                             sprintf("step expansion %d to %d", N, N * N_factor),
                             sprintf("exponential increase %d to %d", N / N_factor, N),
                             sprintf("exponential decrease %d to %d", N, N / N_factor))))

test_that("msprime forward/backward sims are exactly the same", {
  expect_true({
    df <- afs[afs$sim == "msprime" & grepl("constant", afs$model), ]
    all(df[df$direction == "forward", "f"] == df[df$direction == "backward", "f"])
  })
  expect_true({
    df <- afs[afs$sim == "msprime" & grepl("step contraction", afs$model), ]
    all(df[df$direction == "forward", "f"] == df[df$direction == "backward", "f"])
  })
  expect_true({
    df <- afs[afs$sim == "msprime" & grepl("step expansion", afs$model), ]
    all(df[df$direction == "forward", "f"] == df[df$direction == "backward", "f"])
  })
  expect_true({
    df <- afs[afs$sim == "msprime" & grepl("exponential increase", afs$model), ]
    all(df[df$direction == "forward", "f"] == df[df$direction == "backward", "f"])
  })
  expect_true({
    df <- afs[afs$sim == "msprime" & grepl("exponential decrease", afs$model), ]
    all(df[df$direction == "forward", "f"] == df[df$direction == "backward", "f"])
  })
})

test_that("SLiM forward/backward sims are exactly the same", {
  expect_true({
    df <- afs[afs$sim == "slim" & grepl("constant", afs$model), ]
    all(df[df$direction == "forward", "f"] == df[df$direction == "backward", "f"])
  })
  expect_true({
    df <- afs[afs$sim == "slim" & grepl("step contraction", afs$model), ]
    all(df[df$direction == "forward", "f"] == df[df$direction == "backward", "f"])
  })
  expect_true({
    df <- afs[afs$sim == "slim" & grepl("step expansion", afs$model), ]
    all(df[df$direction == "forward", "f"] == df[df$direction == "backward", "f"])
  })
  expect_true({
    df <- afs[afs$sim == "slim" & grepl("exponential increase", afs$model), ]
    all(df[df$direction == "forward", "f"] == df[df$direction == "backward", "f"])
  })
  expect_true({
    df <- afs[afs$sim == "slim" & grepl("exponential decrease", afs$model), ]
    all(df[df$direction == "forward", "f"] == df[df$direction == "backward", "f"])
  })
})

if (RERUN) {
# SLiM and msprime simulations from the same model give the same result
# (tested by comparing the distribution plots)
library(ggplot2)
p <- ggplot(afs, aes(n, f, color = direction, linetype = sim)) +
  geom_line(stat = "identity", alpha = 0.5) +
  facet_wrap(~ model) +
  labs(x = "number of derived alleles", y = "count",
       title = "Site frequency spectra obtained from five demographic models",
       subtitle = "Each model was specified in forward or backward direction of time and executed by
two different backend scripts in slendr (implemented in SLiM and msprime)") +
  guides(color = guide_legend("direction of\ntime in slendr"),
         linetype = guide_legend("slendr backend\nengine used")) +
  scale_x_continuous(breaks = c(1, seq(20, 2 * n_samples, 20)),
                     limits = c(1, 2 * n_samples))
png_file <- sprintf("afs_%s.png", Sys.info()["sysname"])
ggsave(png_file, p, width = 8, height = 5)
}

# make sure that the distributions as they were originally inspected and
# verified visually match the new distributions plot -- this is obviously not
# a rigorous test but the allele frequency spectra distributions from msprime
# and SLiM simulations match almost perfectly assuming we simulate large enough
# data to eliminate most of the nose
test_that("AFS distributions from SLiM and msprime simulations match", {
  afs <- afs %>% dplyr::mutate(sim = as.character(sim), model = as.character(model))
  if (RERUN) {
  current_tsv <- paste0(tempfile(), ".tsv.gz")
  readr::write_tsv(afs, current_tsv, progress = FALSE)
  }

  original_tsv <- sprintf("afs_%s.tsv.gz", Sys.info()["sysname"])
  if (RERUN) {
  readr::write_tsv(afs, original_tsv, progress = FALSE)
  }
  orig_afs <- readr::read_tsv(original_tsv, show_col_types = FALSE, progress = FALSE)

  expect_equal(afs, orig_afs)
})
