---
title: "Cross-Validation and Tuning"
format: html
vignette: >
  %\VignetteIndexEntry{Cross-Validation and Tuning}
  %\VignetteEngine{quarto::html}
  %\VignetteEncoding{UTF-8}
---

```{r}
#| label: setup
#| message: false
library(tidyeof)
library(stars)
library(dplyr)
library(ggplot2)
```

## Overview

EOF and CCA analyses require choosing how many modes to retain (`k`). Too few modes discard useful signal; too many introduce noise. `tidyeof` provides two cross-validation functions for data-driven mode selection:

- `tune_eof()` -- optimize EOF truncation for field reconstruction
- `tune_cca()` -- jointly optimize predictor EOFs, response EOFs, and CCA modes for downscaling

Both use temporal cross-validation: the time series is split into folds, patterns are estimated on training folds, and reconstruction or prediction skill is evaluated on held-out folds.

## Setup

```{r}
#| label: load-data
fine <- system.file("testdata/prism_test.RDS", package = "tidyeof") |>
  readRDS()

coarse <- fine |>
  st_warp(cellsize = 0.2, method = "average", use_gdal = TRUE, no_data_value = -99999) |>
  setNames(names(fine)) |>
  st_set_dimensions("band",
    values = st_get_dimension_values(fine, "time"),
    names = "time")
```

## Tuning EOF truncation

`tune_eof()` evaluates reconstruction skill across a range of `k` values. For each held-out fold it uses a *speckled holdout*: a random scatter of grid cells is hidden, the mode amplitudes are estimated from the visible cells only, and the hidden cells are then predicted. Because the hidden cells play no part in estimating the amplitudes, the prediction is genuinely out-of-sample, and skill stops improving once `k` exceeds the number of modes the data actually support. This is what lets the RMSE curve have a true minimum.

(Simply projecting the whole held-out field onto the EOFs would not work: the projection is least-squares optimal for the very data being scored, so error would fall monotonically with `k` and the "best" `k` would always be the largest. `hidden_fraction` and `n_reps` control the size and number of the random masks.)

```{r}
#| label: tune-eof
eof_results <- tune_eof(fine, k = 1:8, kfolds = 3)
eof_results
```

Three metrics are computed by default, all on the hidden cells:

- **RMSE** -- root mean square error (lower is better)
- **cor_spatial** -- mean spatial *anomaly* correlation per time step, i.e. computed after removing each cell's temporal mean so it reflects pattern skill rather than the static climatology (higher is better)
- **cor_temporal** -- mean temporal correlation per grid cell (higher is better)

Summarize across folds to find the optimal `k`:

```{r}
#| label: summarize-eof
eof_summary <- summarize_eof_cv(eof_results, metric = "rmse")
eof_summary
```

```{r}
#| label: plot-eof-cv
#| fig-width: 6
#| fig-height: 4
eof_results |>
  group_by(k) |>
  summarize(
    rmse_mean = mean(rmse),
    rmse_se = sd(rmse) / sqrt(n()),
    .groups = "drop"
  ) |>
  ggplot(aes(k, rmse_mean)) +
  geom_ribbon(aes(ymin = rmse_mean - rmse_se, ymax = rmse_mean + rmse_se), alpha = 0.2) +
  geom_line() +
  geom_point() +
  labs(x = "Number of EOFs (k)", y = "RMSE", title = "EOF reconstruction skill") +
  theme_bw()
```

## Tuning CCA downscaling

`tune_cca()` jointly searches over three dimensions:

- `k_pred` -- number of predictor EOFs
- `k_resp` -- number of response EOFs
- `k_cca` -- number of CCA modes (optional; defaults to `min(k_pred, k_resp)`)

### Preparing folds

First, prepare cross-validation folds using `prep_cv_folds()`. This pre-computes EOF patterns at the maximum requested `k` for each fold, so the grid search over truncation is fast (it just subsets the pre-computed patterns).

```{r}
#| label: prep-folds
cv_folds <- prep_cv_folds(
  coarse, fine,
  kfolds = 3,
  max_k_pred = 5,
  max_k_resp = 5,
  weight = TRUE
)

cv_folds
```

### Grid search

```{r}
#| label: tune-cca
cca_results <- tune_cca(
  cv_folds,
  k_pred = 1:5,
  k_resp = 1:5
)

cca_results
```

Summarize to find the best parameter combination:

```{r}
#| label: summarize-cca
cca_summary <- summarize_cv(cca_results, metric = "rmse")
head(cca_summary)
```

```{r}
#| label: plot-cca-cv
#| fig-width: 7
#| fig-height: 5
cca_results |>
  group_by(k_pred, k_resp) |>
  summarize(rmse = mean(rmse), .groups = "drop") |>
  ggplot(aes(k_pred, k_resp, fill = rmse)) +
  geom_tile() +
  scale_fill_viridis_c(direction = -1) +
  labs(x = "Predictor EOFs", y = "Response EOFs", fill = "RMSE",
       title = "CCA downscaling skill") +
  theme_bw()
```

### Tuning k_cca separately

Using fewer CCA modes than `min(k_pred, k_resp)` acts as additional regularization. To search over `k_cca` explicitly:

```{r}
#| label: tune-kcca
cca_results_3d <- tune_cca(
  cv_folds,
  k_pred = 2:4,
  k_resp = 2:4,
  k_cca = 1:3
)

cca_results_3d |>
  group_by(k_pred, k_resp, k_cca) |>
  summarize(rmse = mean(rmse), .groups = "drop") |>
  arrange(rmse) |>
  head(10)
```

## Fold construction

`prep_folds()` creates balanced temporal folds for cross-validation. It assigns contiguous blocks of time steps to folds (not random sampling), which is appropriate for autocorrelated time series.

```{r}
#| label: prep-folds-demo
times <- st_get_dimension_values(fine, "time")
folds <- prep_folds(times, kfolds = 5)

# Each fold is a vector of held-out times
lengths(folds)
```

## Parallel execution

For large grids or many parameter combinations, `tune_cca()` supports parallel execution via `furrr`:

```{r}
#| label: parallel
#| eval: false
library(furrr)
plan(multisession, workers = 4)

results <- tune_cca(cv_folds, k_pred = 1:10, k_resp = 1:10, parallel = TRUE)

plan(sequential) # clean up
```

## Best practices

- **Start coarse, then refine.** Begin with a wide search (e.g., `k_pred = 1:10`) with few folds, then narrow the range with more folds.
- **Use 3-5 folds.** With short time series (< 50 time steps), 3 folds preserves enough training data. With longer series, 5 folds gives more stable estimates.
- **Rotation and CV don't mix.** Varimax rotation makes modes non-orthogonal, so subsetting `pat[1:k]` no longer gives the best rank-k approximation. The CV functions require `rotate = FALSE` and will error if rotation was applied.
- **Inspect multiple metrics.** RMSE penalizes large errors; spatial correlation rewards pattern fidelity; temporal correlation rewards consistent bias. The best `k` may differ across metrics -- choose based on your application.
