Cross-Validation for Symbolic
Regression
library(leaf)
if (!backend_available()) {
message("Install backend with leaf::install_leaf()")
}
set.seed(42)
N <- 50L # Number of data points
# Features
x1 <- runif(N, min = 0, max = 5)
x2 <- runif(N, min = 0, max = 1)
# Target: y = 2 * x1^2 - 3 * x2 + noise
y_b <- 2 * x1^2 - 3 * x2 + rnorm(N, mean = 0, sd = 1e-2)
train_data <- data.frame(x1 = x1, x2 = x2, y = y_b)
head(train_data)
#> x1 x2 y
#> 1 4.574030 0.33342721 40.846442
#> 2 4.685377 0.34674825 42.857433
#> 3 1.430698 0.39848541 2.914093
#> 4 4.152238 0.78469278 32.134514
#> 5 3.208728 0.03893649 20.475954
#> 6 2.595480 0.74879539 11.229410
# Initialize the symbolic regressor
regressor = leaf::SymbolicRegressor$new(
engine = 'rsrm',
num_iterations=4L,
loss = 'MSE',
max_params = 3,
threshold = 1e-10,
base = list(verbose = FALSE),
mcts = list(times = 8),
gp = list(times = 8)
)
# Stage 1: Discover equation skeletons
search_results = regressor$search_equations(
data = train_data,
formula = "y ~ f(x1, x2)"
)
#> 1. Processing data for equation search based on formula...
#> 2. Running engine 'rsrm' over 1 folds using up to 1 processes...
#> -- FINAL RESULTS --
#> Episode: 4/4
#> time: 18.18s
#> loss: 7.798843366550043e-05
#> form: F
#> HOF:
#> equation complexity loss
#> 0 0 0 999999999999999967336168804116691273849533185806555472917961779471295845921727862608739868455469056.00
#> 1 21.0992 1 273.35
#> 2 13.2467*X1 2 41.72
#> 3 4.4235*X1**2.25 3 1.04
#> 4 5.8143*X1**2 - 1.3268 4 0.71
#> 5 5.8201*X1**2 - 0.9692*X2 5 0.00
#> 6 5.8202*X1**2 - 0.9690*X2 - 0.0007 6 0.00
#> 7 5.8200*X1**2 - 0.9686*X2 - 0.0002/X1 8 0.00
#> 8 5.8200*X1**2 - 0.9686*X2 - 0.0000/X1**2 9 0.00
#> ---
#>
task:dataset_28e9753b-c27d-43d3-9c6d-481d58f40509 expr:5.819963910966452*X1**2 + -0.9686370971634156*X2 + -2.06850213849703e-06/X1**2 Loss_MSE:0.00 Test 0/1.
#> final result:
#> success rate : 0%
#> average discovery time is 18.191 seconds
#> Number of equations looked at (per test) [Total, Timed out, Successful]: [[1478, 0, 1478]]
#> 3. Found 9 raw skeletons. Deduplicating...
print("=== Search results ===")
#> [1] "=== Search results ==="
print(search_results)
#> Equation Complexity
#> 0 β1 1
#> 1 β1⋅x1 2
#> 2 β1⋅x1^2.25 3
#> 3 β1⋅x1^2 + -1⋅β2 4
#> 4 β1⋅x1^2 + -1⋅β2⋅x2 5
#> 5 β1⋅x1^2 + -1⋅β2⋅x2 + -1⋅β3 6
#> 6 β1⋅x1^2 + -1⋅β2⋅x2 + -1⋅β3⋅x1^-1 8
# Stage 2: Fit parameters and compute loss (with CV)
fit_results = regressor$fit(
data = train_data,
do_cv = TRUE,
n_cv_folds = 5
)
#> Fitting parameters for 7 equations...
#> - Performing CV for Equation with ID=0...
#> - Performing CV for Equation with ID=1...
#> - Performing CV for Equation with ID=2...
#> - Performing CV for Equation with ID=3...
#> - Performing CV for Equation with ID=4...
#> - Performing CV for Equation with ID=5...
#> - Performing CV for Equation with ID=6...
#> Parameter fitting complete.
print("\n=== Fit & CV results ===")
#> [1] "\n=== Fit & CV results ==="
print(fit_results)
#> Equation Complexity Loss Loss (cv)
#> 0 β1 1 2.733529e+02 2.802430e+02
#> 1 β1⋅x1 2 4.172247e+01 4.281000e+01
#> 2 β1⋅x1^2.25 3 1.040410e+00 1.084084e+00
#> 3 β1⋅x1^2 + -1⋅β2 4 7.098531e-01 7.611947e-01
#> 4 β1⋅x1^2 + -1⋅β2⋅x2 5 8.269706e-05 9.661740e-05
#> 5 β1⋅x1^2 + -1⋅β2⋅x2 + -1⋅β3 6 8.261101e-05 9.799049e-05
#> 6 β1⋅x1^2 + -1⋅β2⋅x2 + -1⋅β3⋅x1^-1 8 7.868977e-05 9.192478e-05
# Stage 3: Evaluate additional metrics
eval_table = regressor$evaluate(
metrics = c('AIC', 'Elbow'),
metrics_cv = c('R2','Elbow')
)
print("\n=== Additional metrics ===")
#> [1] "\n=== Additional metrics ==="
print(eval_table)
#> Equation Complexity Loss Loss (cv)
#> 0 β1 1 2.733529e+02 2.802430e+02
#> 1 β1⋅x1 2 4.172247e+01 4.281000e+01
#> 2 β1⋅x1^2.25 3 1.040410e+00 1.084084e+00
#> 3 β1⋅x1^2 + -1⋅β2 4 7.098531e-01 7.611947e-01
#> 4 β1⋅x1^2 + -1⋅β2⋅x2 5 8.269706e-05 9.661740e-05
#> 6 β1⋅x1^2 + -1⋅β2⋅x2 + -1⋅β3⋅x1^-1 8 7.868977e-05 9.192478e-05
#> AIC AIC_w Elbow R2 (cv) Elbow (cv)
#> 0 424.43203 2.343930e-111 4.675572e-03 -0.02520608 0.004560906
#> 1 330.44585 6.008769e-91 2.827208e-02 0.84338925 0.027575991
#> 2 145.87460 7.209807e-51 7.920838e-01 0.99603412 0.796609238
#> 3 128.75899 3.754305e-47 7.518792e-01 0.99721534 0.850040401
#> 4 -83.93372 5.756593e-01 3.929556e-01 0.99999965 0.414162284
#> 6 -81.94105 2.125508e-01 1.248170e-06 0.99999966 0.119702364