#' Plot a Join Plan as a Flowchart
#'
#' Takes a plan generated by `create_join_plan()` and creates a flowchart
#' visualizing the sequence of aggregations and merges.
#'
#' @param join_plan A `data.table` created by `create_join_plan()`.
#' @return A `DiagrammeR` graph object that can be printed to the RStudio
#'   Viewer pane.
#' @importFrom data.table as.data.table
#' @export
plot_join_plan <- function(join_plan) {
  if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
    stop("The 'DiagrammeR' package is required. Please install it: install.packages('DiagrammeR')", call. = FALSE)
  }
  if (!is.data.table(join_plan) || !all(c("step", "operation", "target", "code") %in% names(join_plan))) {
      stop("'join_plan' must be a data.table generated by create_join_plan().")
  }


  q_name <- function(x) paste0('"', x, '"')

  all_nodes <- list()
  edge_defs <- c()

  for (i in 1:nrow(join_plan)) {
    step <- join_plan[i, ]
    target_node <- step$target
    
    node_type <- switch(
      step$operation,
      AGGREGATE = "intermediate",
      MERGE     = "intermediate",
      SELECT    = "final"
    )
    all_nodes[[target_node]] <- node_type
  extract_rhs_table <- function(code) {
    tbl <- sub("^\\s*\\w+\\s*<-\\s*([A-Za-z0-9_.`]+)\\s*\\[.*", "\\1", code)
    gsub("^`|`$", "", tbl)
  }

  extract_merge_xy <- function(code) {
    x <- sub(".*merge\\([^)]*\\bx\\s*=\\s*([A-Za-z0-9_.`]+).*", "\\1", code)
    y <- sub(".*merge\\([^)]*\\by\\s*=\\s*([A-Za-z0-9_.`]+).*", "\\1", code)
    c(gsub("^`|`$", "", x), gsub("^`|`$", "", y))
  }

  inputs <- switch(
    step$operation,
    AGGREGATE = extract_rhs_table(step$code),
    MERGE     = extract_merge_xy(step$code),
    SELECT    = extract_rhs_table(step$code)
  )

    
    for (input_node in inputs) {
      if (is.null(all_nodes[[input_node]])) {
        all_nodes[[input_node]] <- "source"
      }
    }
    
    if (length(inputs) > 0) {
      edge_defs <- c(edge_defs, 
                     paste0(q_name(inputs), " -> ", q_name(target_node), 
                            ' [label="', step$operation, '"]'))
    }
  }

  source_nodes <- names(all_nodes)[sapply(all_nodes, `==`, "source")]
  intermediate_nodes <- names(all_nodes)[sapply(all_nodes, `==`, "intermediate")]
  final_node <- names(all_nodes)[sapply(all_nodes, `==`, "final")]

  node_defs <- paste(
    paste("node [shape=box, style=filled, fillcolor=lightblue]", q_name(source_nodes), collapse=";\n"),
    paste("node [shape=ellipse, style=filled, fillcolor=lightgray]", q_name(intermediate_nodes), collapse=";\n"),
    paste("node [shape=diamond, style=filled, fillcolor=yellow]", q_name(final_node), collapse=";\n"),
    sep=";\n"
  )

  dot_string <- paste(
    "digraph join_plan {",
    "  rankdir = LR; splines=ortho;",
    "  graph [compound=true];",
    "  node [fontname = Helvetica];",
    "  edge [fontname = Helvetica];",
    node_defs,
    paste(unique(edge_defs), collapse = "\n"),
    "}"
  )
  
  DiagrammeR::grViz(dot_string)
}