---
title: "Introduction to fda.vi"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Introduction to fda.vi}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r setup, include = FALSE}
knitr::opts_chunk$set(
  collapse  = TRUE,
  comment   = "#>",
  cache     = FALSE,
  fig.width = 7,
  fig.height = 4,
  out.width = "100%"
)
```

## Overview

**fda.vi** implements the novel Bayesian basis function selection method of
da Cruz, de Souza, and Sousa (2024) for functional data analysis. The method
smooths one or multiple functional curves simultaneously, while accounting for
within-curve correlation, by expressing each curve as a linear combination of
basis functions (B-splines or Fourier) and using a variational
expectation-maximization (VEM) algorithm to select essential basis functions
while selectively removing unnecessary ones via sparsity-inducing priors.

Key method features:

- **Automatic selection of the optimal set of basis functions** via posterior
  inclusion probabilities (PIPs) under a sparsity-inducing prior
- **Uncertainty quantification** via 95% credible bands, constructed by
  posteriors of the basis coefficients and inclusion indicators
- **Correlated error structure** modelled as a Gaussian process with covariance
  function given by that of an Ornstein-Uhlenbeck process, with the decay
  parameter $w$ that defines the correlation estimated in the M-step of the
  variational EM algorithm.
- **Automatic $K$ selection** via generalized cross-validation (GCV) over
  a user-supplied grid of candidate basis sizes
- **Multiple basis types**: cubic B-splines and Fourier bases
- **Fast**: the VEM algorithm typically converges in tens of iterations

## Installation

```{r install, eval = FALSE}
# install.packages("devtools")
devtools::install_github("steviek16/fda.vi")
```

## The Model

Let $y_{ij}$ denote observation $j$ of curve $i$, for $i = 1, \ldots, m$
curves each measured at evaluation points $t_{ij}$, $j = 1, \ldots, n_i$.
Each curve is modelled as

$$
y_{ij} = g_i(t_{ij}) + \varepsilon_i(t_{ij}), \qquad
g_i(t_{ij}) = \sum_{k=1}^{K} Z_{ki}\,\beta_{ki}\,B_k(t_{ij})
$$

where $B_k(\cdot)$ are $K$ unknown basis functions, $\beta_{ki}$ are the
basis coefficients, and $Z_{ki} \in \{0,1\}$ are inclusion indicators
drawn from independent Bernoulli sparsity-inducing priors.
The errors $\varepsilon_i(t)$ follow a zero-mean Gaussian process with
Ornstein-Uhlenbeck covariance function:

$$
\psi(t, t') = \sigma^2 \exp\!\left(-w\,|t - t'|\right)
$$

with decay parameter $w > 0$ and variance $\sigma^2 > 0$.
The parameter $\tau^2$ controls the regularization of the basis coefficients.

The VEM algorithm iterates between:

- **E-step**: updating the variational posteriors for $\beta_{ki}$,
  $Z_{ki}$, $\sigma^2$, and $\tau^2$ via coordinate ascent variational
  inference (CAVI)
- **M-step**: maximizing the ELBO over the decay parameter $w$ using
  L-BFGS-B, while holding the variational distributions updated in the
  E-step fixed

until convergence or the maximum number of iterations is achieved.

## Quick Start

```{r quickstart}
library(fda.vi)

data(toy_curves)

# Fit at a single K
fit <- vem_fit(
  y      = toy_curves$y,
  Xt     = toy_curves$Xt,
  K      = 8,
  center = FALSE,
  scale  = FALSE
)

summary(fit)
```

The `toy_curves` dataset contains three simulated curves generated using
$K = 8$ cubic B-spline basis functions, with known basis coefficients
(basis functions 2 and 5 are not relevant, with corresponding coefficients
set to zero), and Ornstein-Uhlenbeck correlated errors with $\sigma = 0.1$
and $w = 6$.

## The `toy_curves` Dataset

```{r data}
data(toy_curves)
str(toy_curves)
```

The dataset is a named list with three elements:

- `y`: a list of 3 numeric vectors, each of length 50, containing the
  observed noisy curve values
- `Xt`: a numeric vector of 50 equally spaced evaluation points on $[0, 1]$
- `true_coef`: the true basis coefficients used to generate the data,
  `c(1.5, 0, -1, 0.8, 0, -0.5, 1.2, -0.9)` — basis functions 2 and 5 are
  not relevant, with corresponding coefficients set to zero

```{r plot-toy}
plot(toy_curves$Xt, toy_curves$y[[1]],
     type = "p", pch = 16, cex = 0.6, col = "steelblue",
     xlab = "t", ylab = "y(t)", main = "Toy Curves Dataset")
for (i in 2:3) {
  points(toy_curves$Xt, toy_curves$y[[i]],
         pch = 16, cex = 0.6,
         col = c("firebrick", "forestgreen")[i - 1])
}
legend("topright", legend = paste("Curve", 1:3),
       col = c("steelblue", "firebrick", "forestgreen"),
       pch = 16, bty = "n")
```

## Fitting a Model

### Single $K$

When a single integer is passed to `K`, `vem_fit` fits the model directly
at that basis size without GCV tuning:

```{r single-k}
fit <- vem_fit(
  y      = toy_curves$y,
  Xt     = toy_curves$Xt,
  K      = 8,
  center = FALSE,
  scale  = FALSE
)
```

### Automatic $K$ Selection via GCV

When a vector of candidate values is passed to `K`, `vem_fit` fits the
model at each candidate and selects the $K$ minimizing the mean
generalized cross-validation (GCV) score across all curves:

```{r gcv-k}
fit_gcv <- vem_fit(
  y    = toy_curves$y,
  Xt   = toy_curves$Xt,
  K    = c(6, 8, 10, 15)
)

fit_gcv$best_K
fit_gcv$tuning$gcv_matrix
```

### Per-Curve $K$ Selection

Setting `selection_metric = "per_curve"` selects the best $K$ independently
for each curve, returning a composite fit with the results obtained from the
optimal fit per curve:

```{r per-curve}
fit_pc <- vem_fit(
  y                = toy_curves$y,
  Xt               = toy_curves$Xt,
  K                = c(6, 8, 10),
  selection_metric = "per_curve"
)

fit_pc$selected_K
fit_pc$is_composite
```

### Fourier Basis

For periodic functional data, a Fourier basis can be used by setting
`basis_type = "fourier"`:

```{r fourier}
fit_f <- vem_fit(
  y          = toy_curves$y,
  Xt         = toy_curves$Xt,
  K          = 10,
  basis_type = "fourier"
)

summary(fit_f)
```

## Interpreting the Output

### Summary

```{r summary}
summary(fit)
```

The summary reports:

- **Basis Type** and **K**: the selected basis and number of basis functions
- **Active Bases (per curve)**: the number of basis functions with
  $\hat{p}_{ki} > 0.5$ for each curve
- **Point estimate for decay parameter ($w$)**: larger $w$ implies
  shorter-range correlation (errors decorrelate faster)
- **Posterior $q(\sigma^2) \sim \mathrm{IG}(\delta_1^*, \delta_2^*)$**: the
  shape ($\delta_1^*$) and scale ($\delta_2^*$) of the variational
  Inverse-Gamma posterior for the error variance
- **Posterior $q(\tau^2) \sim \mathrm{IG}(\lambda_1^*, \lambda_2^*)$**: the
  shape ($\lambda_1^*$) and scale ($\lambda_2^*$) of the variational
  Inverse-Gamma posterior for the regularization parameter
- **GCV Tuning Results**: the mean GCV score at each candidate $K$

### Coefficient Matrix

```{r coef}
coef(fit)
```

Returns a $K \times m$ matrix of estimated basis coefficients. Inactive
basis functions (PIP $\leq 0.5$) have their coefficients set to zero by
the sparsity-inducing prior. For `toy_curves` the true zeros at positions 2
and 5 should be recovered:

```{r coef-check}
coefs <- coef(fit)
coefs[c(2, 5), ]  # should be zero
```

### Posterior Inclusion Probabilities

The posterior inclusion probabilities (PIPs) for all basis functions
across all curves are stored in `fit$model$prob` and can be inspected
directly:

```{r pips}
K  <- fit$best_K
m  <- length(toy_curves$y)
pip_mat <- matrix(fit$model$prob, nrow = K, ncol = m)
rownames(pip_mat) <- paste0("B", 1:K)
colnames(pip_mat) <- paste0("Curve_", 1:m)
round(pip_mat, 3)
```

### Predictions

Predictions based on the `vem_fit` object. Returns a list of length $m$
(one numeric vector per curve), where each vector has length equal to the
number of evaluation points requested.

```{r predict}
# Predictions at original evaluation points
preds <- predict(fit)
length(preds)       # one vector per curve
length(preds[[1]])  # same length as Xt

# Predictions at a denser grid
Xt_new    <- seq(0, 1, length.out = 200)
preds_new <- predict(fit, newdata = Xt_new)
```

### Plot

Estimated curves based on the results from the fit object. The shaded region
is a 95% credible band; use `show_CI = FALSE` to suppress it.

```{r plot}
# Fitted curve with 95% credible band for curve 1
plot(fit, curve_idx = 1)
```

```{r plot-all}
# All three curves
for (i in 1:3) plot(fit, curve_idx = i)
```

## Reference

da Cruz, A. C., de Souza, C. P. E., & Sousa, P. H. T. O. (2024).
Fast Bayesian basis selection for functional data representation with
correlated errors. *arXiv:2405.20758*.
<https://arxiv.org/abs/2405.20758>

## Citation

```{r citation}
citation("fda.vi")
```

```bibtex
@misc{dacruz2024vem,
  title  = {Fast {Bayesian} basis selection for functional data
             representation with correlated errors},
  author = {da Cruz, Ana Carolina and de Souza, Camila P. E. and
             Sousa, Pedro H. T. O.},
  year   = {2024},
  note   = {arXiv:2405.20758},
  url    = {https://arxiv.org/abs/2405.20758}
}
```
