Cross-Validation for Symbolic Regression

library(leaf)
if (!backend_available()) {
  message("Install backend with leaf::install_leaf()")
}  
set.seed(42)

N <- 50L  # Number of data points

# Features
x1 <- runif(N, min = 0, max = 5)
x2 <- runif(N, min = 0, max = 1)

# Target: y = 2 * x1^2 - 3 * x2 + noise
y_b <- 2 * x1^2 - 3 * x2 + rnorm(N, mean = 0, sd = 1e-2)

train_data <- data.frame(x1 = x1, x2 = x2, y = y_b)
head(train_data)
#>         x1         x2         y
#> 1 4.574030 0.33342721 40.846442
#> 2 4.685377 0.34674825 42.857433
#> 3 1.430698 0.39848541  2.914093
#> 4 4.152238 0.78469278 32.134514
#> 5 3.208728 0.03893649 20.475954
#> 6 2.595480 0.74879539 11.229410
# Initialize the symbolic regressor
regressor = leaf::SymbolicRegressor$new(
    engine = 'rsrm',
    num_iterations=4L, 
    loss = 'MSE',
    max_params = 3,
    threshold = 1e-10,
    base = list(verbose = FALSE),
    mcts = list(times = 8),
    gp = list(times = 8)
)
# Stage 1: Discover equation skeletons
search_results = regressor$search_equations(
    data = train_data,
    formula = "y ~ f(x1, x2)"
)
#> 1. Processing data for equation search based on formula...
#> 2. Running engine 'rsrm' over 1 folds using up to 1 processes...
#> -- FINAL RESULTS --
#> Episode: 4/4
#> time: 18.18s
#> loss: 7.798843366550043e-05
#> form: F
#> HOF:
#>                                   equation  complexity                                                                                                   loss
#> 0                                        0           0 999999999999999967336168804116691273849533185806555472917961779471295845921727862608739868455469056.00
#> 1                                  21.0992           1                                                                                                 273.35
#> 2                               13.2467*X1           2                                                                                                  41.72
#> 3                          4.4235*X1**2.25           3                                                                                                   1.04
#> 4                    5.8143*X1**2 - 1.3268           4                                                                                                   0.71
#> 5                 5.8201*X1**2 - 0.9692*X2           5                                                                                                   0.00
#> 6        5.8202*X1**2 - 0.9690*X2 - 0.0007           6                                                                                                   0.00
#> 7     5.8200*X1**2 - 0.9686*X2 - 0.0002/X1           8                                                                                                   0.00
#> 8  5.8200*X1**2 - 0.9686*X2 - 0.0000/X1**2           9                                                                                                   0.00
#> ---
#> 
task:dataset_28e9753b-c27d-43d3-9c6d-481d58f40509 expr:5.819963910966452*X1**2 + -0.9686370971634156*X2 + -2.06850213849703e-06/X1**2 Loss_MSE:0.00 Test 0/1.
#> final result:
#> success rate : 0%
#> average discovery time is 18.191 seconds
#> Number of equations looked at (per test) [Total, Timed out, Successful]:  [[1478, 0, 1478]]
#> 3. Found 9 raw skeletons. Deduplicating...

print("=== Search results ===")
#> [1] "=== Search results ==="
print(search_results)
#>                           Equation Complexity
#> 0                               β1          1
#> 1                            β1⋅x1          2
#> 2                       β1⋅x1^2.25          3
#> 3                  β1⋅x1^2 + -1⋅β2          4
#> 4               β1⋅x1^2 + -1⋅β2⋅x2          5
#> 5       β1⋅x1^2 + -1⋅β2⋅x2 + -1⋅β3          6
#> 6 β1⋅x1^2 + -1⋅β2⋅x2 + -1⋅β3⋅x1^-1          8
# Stage 2: Fit parameters and compute loss (with CV)
fit_results = regressor$fit(
    data = train_data,
    do_cv = TRUE,
    n_cv_folds = 5
)
#> Fitting parameters for 7 equations...
#>   - Performing CV for Equation with ID=0...
#>   - Performing CV for Equation with ID=1...
#>   - Performing CV for Equation with ID=2...
#>   - Performing CV for Equation with ID=3...
#>   - Performing CV for Equation with ID=4...
#>   - Performing CV for Equation with ID=5...
#>   - Performing CV for Equation with ID=6...
#> Parameter fitting complete.

print("\n=== Fit & CV results ===")
#> [1] "\n=== Fit & CV results ==="
print(fit_results)
#>                           Equation Complexity         Loss    Loss (cv)
#> 0                               β1          1 2.733529e+02 2.802430e+02
#> 1                            β1⋅x1          2 4.172247e+01 4.281000e+01
#> 2                       β1⋅x1^2.25          3 1.040410e+00 1.084084e+00
#> 3                  β1⋅x1^2 + -1⋅β2          4 7.098531e-01 7.611947e-01
#> 4               β1⋅x1^2 + -1⋅β2⋅x2          5 8.269706e-05 9.661740e-05
#> 5       β1⋅x1^2 + -1⋅β2⋅x2 + -1⋅β3          6 8.261101e-05 9.799049e-05
#> 6 β1⋅x1^2 + -1⋅β2⋅x2 + -1⋅β3⋅x1^-1          8 7.868977e-05 9.192478e-05
# Stage 3: Evaluate additional metrics
  eval_table = regressor$evaluate(
      metrics = c('AIC', 'Elbow'),
      metrics_cv = c('R2','Elbow')
  )
  print("\n=== Additional metrics ===")
#> [1] "\n=== Additional metrics ==="
  print(eval_table)
#>                           Equation Complexity         Loss    Loss (cv)
#> 0                               β1          1 2.733529e+02 2.802430e+02
#> 1                            β1⋅x1          2 4.172247e+01 4.281000e+01
#> 2                       β1⋅x1^2.25          3 1.040410e+00 1.084084e+00
#> 3                  β1⋅x1^2 + -1⋅β2          4 7.098531e-01 7.611947e-01
#> 4               β1⋅x1^2 + -1⋅β2⋅x2          5 8.269706e-05 9.661740e-05
#> 6 β1⋅x1^2 + -1⋅β2⋅x2 + -1⋅β3⋅x1^-1          8 7.868977e-05 9.192478e-05
#>         AIC         AIC_w        Elbow     R2 (cv)  Elbow (cv)
#> 0 424.43203 2.343930e-111 4.675572e-03 -0.02520608 0.004560906
#> 1 330.44585  6.008769e-91 2.827208e-02  0.84338925 0.027575991
#> 2 145.87460  7.209807e-51 7.920838e-01  0.99603412 0.796609238
#> 3 128.75899  3.754305e-47 7.518792e-01  0.99721534 0.850040401
#> 4 -83.93372  5.756593e-01 3.929556e-01  0.99999965 0.414162284
#> 6 -81.94105  2.125508e-01 1.248170e-06  0.99999966 0.119702364