% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/vae.r
\name{train_vae}
\alias{train_vae}
\title{Train VAE for CAR Prior}
\usage{
train_vae(
  W,
  GEOID,
  model_name,
  save_dir,
  n_samples = 10000,
  batch_size = 256,
  epoch = 10000,
  lr_init = 0.001,
  lr_min = 1e-07,
  verbose = TRUE,
  use_gpu = TRUE
)
}
\arguments{
\item{W}{Matrix. A proximity or adjacency matrix representing spatial relationships.}

\item{GEOID}{Character vector. Identifiers for spatial units (e.g., region or area codes).}

\item{model_name}{Character. The name of the trained VAE model.}

\item{save_dir}{Character. Directory to save the trained VAE model and associated metadata. Defaults to the current working directory.}

\item{n_samples}{Integer. Number of samples to draw from the prior for training. Default is \code{10000}.}

\item{batch_size}{Integer. Batch size for VAE training. Default is \code{256}.}

\item{epoch}{Integer. Number of training epochs. Default is \code{10000}.}

\item{lr_init}{Numeric. Initial learning rate. Default is \code{0.001}.}

\item{lr_min}{Numeric. Minimum learning rate at the final epoch. Default is \code{1e-7}.}

\item{verbose}{Logical; if \code{TRUE} (default), prints progress.}

\item{use_gpu}{Boolean. Use GPU if available. Default is \code{TRUE}.}
}
\value{
A named list containing:
\item{loss}{Total training loss}
\item{RCL}{Reconstruction error}
\item{KLD}{Kullback–Leibler divergence}
}
\description{
Trains a Variational Autoencoder (VAE) to learn the spatial structure implied by the
Conditional Autoregressive (CAR) prior. The trained VAE parameters are saved and can
later be used as a generator within Hamiltonian Monte Carlo (HMC) sampling.
}
\details{
The function requires a configured Python environment via the \pkg{reticulate} interface,
with VAE training implemented in Python. It uses \code{py$train_vae()} defined in the
sourced Python modules (see \code{\link{load_environment}}).
}
\examples{
\dontrun{
library(vmsae)
library(sf)
# this function is time consuming for the first run
install_environment()
load_environment()

acs_data <- read_sf(system.file("example", "mo_county.shp", package = "vmsae"))
W <- readRDS(system.file("example", "W.Rds", package = "vmsae"))

loss <- train_vae(W = W,
  GEOID = acs_data$GEOID,
  model_name = "test",
  save_dir = tempdir(),
  n_samples = 1000, # set to larger values in practice, e.g. 10000.
  batch_size = 256,
  epoch = 1000)     # set to larger values in practice, e.g. 10000.
}

}
