#' @importFrom stats approx quantile spline
#' @importFrom graphics lines
#' @importFrom utils combn
#'
#' @title Density Equality Testing
#'
#' @description This is the main function of the `denstest` package. It performs statistical tests
#' for the equality between groups of estimated density functions using FDET, DET, or MDET.
#'
#' @param L A list of estimated density objects, where each element is a list with numeric vectors \code{x} and \code{y}.
#' \code{x} contains the evaluation points, and \code{y} the corresponding estimated density values for a single observation.
#' @param group_sizes A vector indicating the number of densities in each group.
#' @param N.max Maximum number of permutations for the test (default: 10000).
#' @param a,b Evaluation range endpoints; if NULL, determined from data.
#' @param m Number of evaluation points (default: 100).
#' @param seed Random seed for reproducibility.
#' @param density.weights Optional weights for densities.
#' @param test The test to use. One of "FDET.regular", "FDET.residual",
#' "FDET.regular.real.imag", "FDET.residual.real.imag", "DET.regular", "DET.residual",
#' "MDET.regular", or "MDET.residual".
#' @param distance The distance measure to use for FDET or DET. One of "LP", "Hellinger", or "TF".
#' @param moment Moment type used for MDET. One of "expectation", "variance", "skewness", "kurtosis", or "combined".
#' @param interpolation Method for interpolating densities. One of "linear" or "spline".
#' @param p Parameter for \eqn{L^p} distances (default: 2).
#' @param eps Cut-off parameter for the Fourier transforms.
#' @param tau Step size parameter of the Fourier Transforms.
#' @param Lmax Maximum size of the vectors containing the values of the individual Fourier transforms.
#' @param ft.lp.weight Weights for the Fourier transforms. One of "none" or "AbsRoot".
#' @param real.imag.weights Weights for the real and imaginary parts of the Fourier transforms.
#' @param moment.weights Internal parameters for specific methods.
#' @param plot Logical. If TRUE, plots the density functions in \code{L}.
#' @param legend Position of the legend in the plot.
#'
#' @return A \eqn{p}-value indicating the significance of group differences.
#'
#' @author Akin Anarat \email{akin.anarat@hhu.de}
#'
#' @references
#' Anarat A., Krutmann, J., and Schwender, H. (2025). Testing for Differences in Extrinsic Skin Aging Based on Density
#' Functions. Submitted.
#'
#' Delicado, P. (2007). Functional k-sample problem when data are density functions. Computational Statistics, 22, 391–410. \doi{10.1007/s00180-007-0047-y}
#'
#' @export
#'
#' @examples
#' n1 <- 5; n2 <- 5; n3 <- 5
#' group_sizes <- c(n1, n2, n3)
#' sample_size <- 500
#'
#' densities_group1 <- lapply(1:n1, function(i) {
#'   data <- rnorm(sample_size, 0, 0.3)
#'   d <- density(data)
#'   list(x = d$x, y = d$y)
#' })
#'
#' densities_group2 <- lapply(1:n2, function(i) {
#'   data <- rnorm(sample_size, 0, 0.32)
#'   d <- density(data)
#'   list(x = d$x, y = d$y)
#' })
#'
#' densities_group3 <- lapply(1:n3, function(i) {
#'   data <- rnorm(sample_size, 0.02, 0.28)
#'   d <- density(data)
#'   list(x = d$x, y = d$y)
#' })
#'
#' L <- c(densities_group1, densities_group2, densities_group3)
#'
#' denscomp(L, group_sizes, ft.lp.weight = "AbsRoot")

denscomp <- function(L, group_sizes, N.max = 10000, a = NULL, b = NULL, m = 100,
                     seed = NULL, density.weights = NULL,
                     test = c("FDET.regular", "FDET.residual",
                              "FDET.regular.real.imag", "FDET.residual.real.imag",
                              "DET.regular", "DET.residual",
                              "MDET.regular", "MDET.residual"),
                     distance = c("LP", "Hellinger", "TF"),
                     moment = c("expectation", "variance", "skewness", "kurtosis", "combined"),
                     interpolation = c("linear", "spline"), p = 2,
                     eps = 0.01, tau = 0.01, Lmax = 5000, ft.lp.weight = c("none", "AbsRoot"),
                     real.imag.weights = c(0.5, 0.5), moment.weights = rep(0.25, 4), plot = FALSE,
                     legend = c("topright", "topleft", "bottomright", "bottomleft", "top", "bottom", "left", "right", "center")){

  valid_distances <- c("LP", "Hellinger", "TF")
  valid_interpolation <- c("linear", "spline")
  valid_ft.lp.weights <- c("none", "AbsRoot")
  valid_tests <- c("FDET.regular", "FDET.residual",
                   "FDET.regular.real.imag", "FDET.residual.real.imag",
                   "DET.regular", "DET.residual",
                   "MDET.regular", "MDET.residual")

  test <- match.arg(test)
  distance <- match.arg(distance)
  interpolation <- match.arg(interpolation)
  ft.lp.weight <- match.arg(ft.lp.weight)
  moment <- match.arg(moment)
  legend <- match.arg(legend)

  if (!(test %in% valid_tests))
    stop("Invalid test: ", paste(test[!test %in% valid_tests], collapse = ", "),
         ". Valid options are: ", paste(valid_tests, collapse = ", "))

  if (!distance %in% valid_distances)
    stop("Invalid distance: ", .distance,
         ". Valid options are: ", paste(valid_distances, collapse = ", "))

  if (!interpolation %in% valid_interpolation)
    stop("Invalid interpolation: ", interpolation,
         ". Valid options are: ", paste(valid_interpolation, collapse = ", "))

  if (!ft.lp.weight %in% valid_ft.lp.weights)
    stop("Invalid ft.lp.weight: ", ft.lp.weight,
         ". Valid options are: ", paste(valid_ft.lp.weights, collapse = ", "))

  if (!all(sapply(p, function(x) is.numeric(x) && x > 0))) {
    stop("The p value(s) must be numeric and positive.")
  }

  if(is.null(density.weights)) density.weights <- rep(1, length(L)) / length(L)
  else density.weights <- validate_density.weights(density.weights)

  if(test == "FDET.regular.real.imag" || test == "FDET.residual.real.imag"){
    real.imag.weights <- validate_real.imag.weights(real.imag.weights)
  }

  if(test == "MDET.regular" || test == "MDET.residual"){
    moment.weights <- validate_moment.weights(moment, moment.weights)
  }

  if (grepl("residual", test)) {
    warning("Residual-based tests may lead to inflated type I error rates in certain settings. Use with caution.")
  }

  for (i in seq_along(group_sizes)) {
    assign(paste0("n", i), group_sizes[i])
  }

  if(is.null(a)) a <- min(sapply(L, function(d) quantile(d$x, 0)))
  if(is.null(b)) b <- max(sapply(L, function(d) quantile(d$x, 1)))

  x <- seq(a, b, length.out = m)

  for (i in seq_along(group_sizes)) {
    assign(paste0("densities_group", i), matrix(0, nrow = group_sizes[i], ncol = m))
  }

  start_idx <- 0

  for (g in seq_along(group_sizes)) {
    n_g <- get(paste0("n", g))
    densities_group <- get(paste0("densities_group", g))

    for (i in 1:n_g) {
      gx <- L[[i + start_idx]]$x
      gy <- L[[i + start_idx]]$y

      if(interpolation == "linear") {
        densities_group[i, ] <- approx(gx, gy, xout = x)$y
      }

      if(interpolation == "spline") {
        densities_group[i, ] <- spline(gx, gy, xout = x)$y
      }
    }

    start_idx <- start_idx + n_g

    assign(paste0("densities_group", g), densities_group)
  }

  for (g in seq_along(group_sizes)) {
    densities_group <- get(paste0("densities_group", g))
    densities_group[is.na(densities_group)] <- 0
    densities_group[densities_group < 0] <- 0
    assign(paste0("densities_group", g), densities_group)
  }

  if (plot) {
    k <- length(group_sizes)
    max_density <- 0
    for (g in seq_along(group_sizes)) {
      densities_group <- get(paste0("densities_group", g))
      max_density <- max(max_density, max(densities_group, na.rm = TRUE))
    }

    plot(NULL, xlim = c(a, b), ylim = c(0, max_density), xlab = "x", ylab = "Density")

    select_distinct_colors <- function(n) {
      hues <- seq(0, 1, length.out = n + 1)[-1]
      colors <- grDevices::hcl(h = hues * 360, c = 100, l = 65)
      return(colors)
    }

    colors <- select_distinct_colors(k)

    for (g in seq_along(group_sizes)) {
      densities_group <- get(paste0("densities_group", g))
      for (i in 1:nrow(densities_group)) {
        lines(x, densities_group[i, ], col = colors[g])
      }
    }
    legend(legend, legend = paste("Group", seq_along(group_sizes)),
           col = colors[seq_along(group_sizes)], lty = 1, cex = 0.8)
  }

  if(test == "FDET.regular" || test == "FDET.residual" ||
     test == "FDET.regular.real.imag" || test == "FDET.residual.real.imag"){

    for (g in seq_along(group_sizes)) {
      assign(paste0("fourier_real_group", g), matrix(1e-10, nrow = get(paste0("n", g)), ncol = Lmax + 1))
      assign(paste0("fourier_imag_group", g), matrix(1e-10, nrow = get(paste0("n", g)), ncol = Lmax + 1))
    }

    for (g in seq_along(group_sizes)) {
      densities_group <- get(paste0("densities_group", g))
      fourier_real_group <- get(paste0("fourier_real_group", g))
      fourier_imag_group <- get(paste0("fourier_imag_group", g))

      for (i in 1:nrow(densities_group)) {
        result <- compute_fourier(densities_group[i, ], x, a, b, eps, tau, Lmax)
        fourier_real_group[i, 1:length(result$real)] <- result$real
        fourier_imag_group[i, 1:length(result$imag)] <- result$imag
      }

      assign(paste0("fourier_real_group", g), fourier_real_group)
      assign(paste0("fourier_imag_group", g), fourier_imag_group)
    }

    find_max_columns <- function(matrix_real, eps) {
      condition <- abs(matrix_real) > 1e-10

      column_sums <- colSums(condition)

      max_column <- which.min(column_sums == nrow(matrix_real)) - 1

      return(max_column)
    }

    max_cols_per_group <- sapply(seq_along(group_sizes), function(g) {
      fourier_real_group <- get(paste0("fourier_real_group", g))
      find_max_columns(fourier_real_group, eps)
    })

    max_cols <- min(max_cols_per_group)

    for (g in seq_along(group_sizes)) {
      fourier_real_group <- get(paste0("fourier_real_group", g))
      fourier_imag_group <- get(paste0("fourier_imag_group", g))

      fourier_real_group <- fourier_real_group[, 1:max_cols, drop = FALSE]
      fourier_imag_group <- fourier_imag_group[, 1:max_cols, drop = FALSE]

      fourier_real_group <- cbind(fourier_real_group[, ncol(fourier_real_group):2, drop = FALSE], fourier_real_group)
      fourier_imag_group <- cbind(-fourier_imag_group[, ncol(fourier_imag_group):2, drop = FALSE], fourier_imag_group)

      assign(paste0("fourier_real_group", g), fourier_real_group)
      assign(paste0("fourier_imag_group", g), fourier_imag_group)
      assign(paste0("fourier_group", g), fourier_real_group + 1i*fourier_imag_group)
    }

    l <- ncol(fourier_group1)
    t.max <- (l-1)/2 * tau
    if(ft.lp.weight == "none") weights <- rep(1, l)
  }

  if(test == "DET.regular" || test == "DET.residual"){
    weights <- rep(1, m)
    ft.lp.weights <- "none"
  }

  if (test == "DET.regular" || test == "DET.residual" ||
      test == "FDET.regular" || test == "FDET.residual"){
    if (test == "DET.regular" || test == "DET.residual"){
      groups <- lapply(seq_along(group_sizes), function(g) get(paste0("densities_group", g)))
    }
    if (test == "FDET.regular" || test == "FDET.residual"){
      groups <- lapply(seq_along(group_sizes), function(g) get(paste0("fourier_group", g)))
    }

    if (distance == "TF"){
      combined_group <- do.call(rbind, groups)
      n <- nrow(combined_group)
      n.x <- ncol(combined_group)

      group_means <- lapply(groups, function(group) colMeans(group))

      sigma2_hat <- numeric(n.x)

      for (k in seq_along(groups)) {
        group <- groups[[k]]
        group_mean <- group_means[[k]]
        diffs <- abs(group - matrix(group_mean, nrow = nrow(group), ncol = n.x, byrow = TRUE))^2
        sigma2_hat <- sigma2_hat + colSums(diffs)
      }

      sigma2_hat <- sigma2_hat / (n - length(groups))
      sigma2_hat[sigma2_hat == 0] <- 10^-16

      D <- do.call(paste0("calculate_", tolower(distance), "_distances"), list(groups, sigma2_hat))
    }

    else if (distance == "Hellinger") D <- do.call(paste0("calculate_", tolower(distance), "_distances"), list(groups))

    else D <- do.call(
      paste0("calculate_", tolower(distance), "_distances"),
      list(groups, p.p = p, w = weights, ft.lp.weight = ft.lp.weight)
    )
  }

  if (test == "FDET.regular.real.imag" || test == "FDET.residual.real.imag"){
    groups.real <- lapply(seq_along(group_sizes), function(g) get(paste0("fourier_real_group", g)))
    groups.imag <- lapply(seq_along(group_sizes), function(g) get(paste0("fourier_imag_group", g)))

    if(distance == "TF"){
      combined_group_real <- do.call(rbind, groups.real)
      combined_group_imag <- do.call(rbind, groups.imag)
      n <- nrow(combined_group_real)
      n.x <- ncol(combined_group_real)

      group_means_real <- lapply(groups.real, function(group) colMeans(group))
      group_means_imag <- lapply(groups.imag, function(group) colMeans(group))

      sigma2_hat_real <- sigma2_hat_imag <- numeric(n.x)

      for (k in seq_along(groups.real)) {
        group_real <- groups.real[[k]]
        group_imag <- groups.imag[[k]]
        group_mean_real <- group_means_real[[k]]
        group_mean_imag <- group_means_imag[[k]]
        diffs_real <- abs(group_real - matrix(group_mean_real, nrow = nrow(group_real), ncol = n.x, byrow = TRUE))^2
        diffs_imag <- abs(group_imag - matrix(group_mean_imag, nrow = nrow(group_imag), ncol = n.x, byrow = TRUE))^2
        sigma2_hat_real <- sigma2_hat_real + colSums(diffs_real)
        sigma2_hat_imag <- sigma2_hat_imag + colSums(diffs_imag)
      }

      sigma2_hat_real <- sigma2_hat_real / (n - length(groups.real))
      sigma2_hat_imag <- sigma2_hat_imag / (n - length(groups.imag))
      sigma2_hat_real[sigma2_hat_real == 0] <- 10^-16
      sigma2_hat_imag[sigma2_hat_imag == 0] <- 10^-16

      D_real <- do.call(paste0("calculate_", tolower(distance), "_distances"), list(groups.real, sigma2_hat_real))
      D_imag <- do.call(paste0("calculate_", tolower(distance), "_distances"), list(groups.imag, sigma2_hat_imag))
    }

    else if (distance == "Hellinger"){
      D_real <- do.call(paste0("calculate_", tolower(distance), "_distances"), list(groups.real))
      D_imag <- do.call(paste0("calculate_", tolower(distance), "_distances"), list(groups.imag))
    }

    else{
      D_real <- do.call(
        paste0("calculate_", tolower(distance), "_distances"),
        list(groups.real, p.p = p, w = weights, ft.lp.weight = ft.lp.weight)
      )
      D_imag <- do.call(
        paste0("calculate_", tolower(distance), "_distances"),
        list(groups.imag, p.p = p, w = weights, ft.lp.weight = ft.lp.weight)
      )
    }

    D <- combine_distances(list(D_real, D_imag), real.imag.weights)
  }

  if (test == "MDET.regular" || test == "MDET.residual") {
    D_expectation <- do.call(
      "calculate_distances_expectation",
      list(
        lapply(seq_along(group_sizes), function(g) get(paste0("densities_group", g))),
        x = x
      )
    )

    D_variance <- do.call(
      "calculate_distances_variance",
      list(
        lapply(seq_along(group_sizes), function(g) get(paste0("densities_group", g))),
        x = x
      )
    )

    D_skewness <- do.call(
      "calculate_distances_skewness",
      list(
        lapply(seq_along(group_sizes), function(g) get(paste0("densities_group", g))),
        x = x
      )
    )

    D_kurtosis <- do.call(
      "calculate_distances_kurtosis",
      list(
        lapply(seq_along(group_sizes), function(g) get(paste0("densities_group", g))),
        x = x
      )
    )

    if(moment == "combined"){
      D <- (combine_distances(list(
        normalize_moments(D_expectation),
        normalize_moments(D_variance),
        normalize_moments(D_skewness),
        normalize_moments(D_kurtosis)),
        moment.weights))
    } else{
      D <- combine_distances(list(D_expectation,
                                  D_variance,
                                  D_skewness,
                                  D_kurtosis),
                                  moment.weights)
    }
  }

  B_W <- compute_b_w_ratio(D, group_sizes, weights = density.weights)

  set.seed(seed)

  if (test == "DET.regular" || test == "FDET.regular" ||
      test == "FDET.regular.real.imag" || test == "MDET.regular") {
    p_value <- permutation_test(D, group_sizes, permute_groups(group_sizes, N.max), B_W, density.weights)
  }

  if (test == "DET.residual" || test == "FDET.residual" ||
      test == "FDET.residual.real.imag" || test == "MDET.residual") {
    p_value <- permutation_test(compute_D_residual(D, group_sizes, weights = density.weights, sigma2 = sigma2_hat),
                                group_sizes, permute_groups(group_sizes, N.max), B_W, density.weights)
  }

  return(p_value)
}

