Multi-View Symbolic Regression with leaf

# Install the Python backend (only needs to be done once). 
leaf::install_leaf()
# Load package 
library(leaf)
if (!backend_available()) {
  message("Install backend with leaf::install_leaf()")
}  
set.seed(42)

# Group sizes
n_a <- 50L
n_b <- 30L

# Group A
x1_a <- runif(n_a, min = 0, max = 2)
x2_a <- runif(n_a, min = 0, max = 3)
y_a  <- x1_a^2 + rnorm(n_a, mean = 0, sd = 1e-4)

# Group B
x1_b <- runif(n_b, min = 0, max = 5)
x2_b <- runif(n_b, min = 0, max = 1)
y_b  <- 2 * x1_b^2 - 3 * x2_b + rnorm(n_b, mean = 0, sd = 1e-4)

# Combine data
train_data <- data.frame(
  x1 = c(x1_a, x1_b),
  x2 = c(x2_a, x2_b),
  y = c(y_a, y_b),
  group = rep(c("A", "B"), c(n_a, n_b))
)
head(train_data)
#>          x1        x2         y group
#> 1 1.8296121 1.0002816 3.3475126     A
#> 2 1.8741508 1.0402447 3.5123629     A
#> 3 0.5722791 1.1954562 0.3276609     A
#> 4 1.6608953 2.3540783 2.7586373     A
#> 5 1.2834910 0.1168095 1.6473582     A
#> 6 1.0381919 2.2463862 1.0778701     A
# Initialize the symbolic regressor
regressor = leaf::SymbolicRegressor$new(
    engine = 'rsrm',
    loss = 'MSE',
    max_params = 3L,
    num_iterations=3L, 
    threshold = 1e-10,
    base = list(verbose = FALSE),
    mcts = list(times = 8L),
    gp = list(times = 8L)
)
# Stage 1: Discover equation skeletons
search_results = regressor$search_equations(
    data = train_data,
    formula = "y ~ f(x1, x2 | group)"
)
#> 1. Processing data for equation search based on formula...
#> 2. Running engine 'rsrm' over 1 folds using up to 1 processes...
#> -- FINAL RESULTS --
#> Episode: 3/3
#> time: 20.96s
#> loss: 7.798843366570674e-09
#> form: X2+F
#> HOF:
#>                                   equation  complexity                                                                                                   loss
#> 0                                        0           0 999999999999999967336168804116691273849533185806555472917961779471295845921727862608739868455469056.00
#> 1                                   1.7959           1                                                                                                 208.56
#> 2                                1.1097*X1           2                                                                                                  42.00
#> 3                             0.4656*X1**2           3                                                                                                   1.31
#> 4                    0.5621*X1**2 - 0.3723           4                                                                                                   0.70
#> 5                 0.4656*X1**2 + 0.0000*X2           5                                                                                                   0.00
#> 6        0.4656*X1**2 + 0.0000*X2 - 0.0000           6                                                                                                   0.00
#> 7     0.4656*X1**2 + 0.0000*X2 - 0.0000/X1           8                                                                                                   0.00
#> 8  0.4656*X1**2 + 0.0000*X2 - 0.0000/X1**2           9                                                                                                   0.00
#> 9  0.4656*X1**2 + 0.0000*X2 - 0.0000/X1**2          12                                                                                                   0.00
#> ---
#> 
task:dataset_aa19bbdd-5194-4694-a36e-e06f64ac155a expr:X1**2+X2+(-0.534387536908588*X1**2 + -0.9999837943962612*X2 + -2.0685022306810226e-08/X1**2) Loss_MSE:0.00 Test 0/1.
#> final result:
#> success rate : 0%
#> average discovery time is 20.971 seconds
#> Number of equations looked at (per test) [Total, Timed out, Successful]:  [[1270, 0, 1270]]
#> 3. Found 10 raw skeletons. Deduplicating...
print("=== Search results ===")
#> [1] "=== Search results ==="
print(search_results)
#>                          Equation Complexity
#> 0                              u1          1
#> 1                           u1⋅x1          2
#> 2                         u1⋅x1^2          3
#> 3                 u1⋅x1^2 + -1⋅u2          4
#> 4                 u1⋅x1^2 + u2⋅x2          5
#> 5         u1⋅x1^2 + u2⋅x2 + -1⋅u3          6
#> 6   u1⋅x1^2 + u2⋅x2 + -1⋅u3⋅x1^-1          8
#> 7 u1⋅x1^2 + u2⋅x2 + -1⋅u3⋅x1^2^-1          9
# Stage 2: Fit parameters and compute loss (with CV)
fit_results = regressor$fit(
    data=train_data
)
#> Fitting parameters for 8 equations...
#> Parameter fitting complete.
print("\n=== Fit results ===")
#> [1] "\n=== Fit results ==="
print(fit_results)
#>                          Equation Complexity         Loss
#> 0                              u1          1 7.930425e+01
#> 1                           u1⋅x1          2 1.588952e+01
#> 2                         u1⋅x1^2          3 4.898860e-01
#> 3                 u1⋅x1^2 + -1⋅u2          4 2.643394e-01
#> 4                 u1⋅x1^2 + u2⋅x2          5 7.573504e-09
#> 5         u1⋅x1^2 + u2⋅x2 + -1⋅u3          6 7.556009e-09
#> 6   u1⋅x1^2 + u2⋅x2 + -1⋅u3⋅x1^-1          8 7.284390e-09
#> 7 u1⋅x1^2 + u2⋅x2 + -1⋅u3⋅x1^2^-1          9 7.245014e-09

# Inspect the last equation in the results
last_id <- rownames(fit_results)[nrow(fit_results)]
print("=== Last equation ===")
#> [1] "=== Last equation ==="
regressor$show_equation(last_id)
#> Details for Equation ID: 7
#> ----------------------------------------
#> Equation: u1⋅x1^2 + u2⋅x2 + -1⋅u3⋅x1^2^-1
#> ----------------------------------------
#> Full Expression per Group:
#> 
#> Group 'A':
#>                     2                            9.63124936535987e-9
#> 0.999995292847239⋅x₁  + 1.67023510898235e-5⋅x₂ - ───────────────────
#>                                                            2        
#>                                                          x₁         
#> 
#> Group 'B':
#>                    2                         2.75633799071188e-6
#> 1.99999904083213⋅x₁  - 2.99998954617359⋅x₂ - ───────────────────
#>                                                        2        
#>                                                      x₁
# Stage 3: Evaluate additional metrics
eval_table = regressor$evaluate(
    metrics = c('RMSE', 'AIC', 'Elbow')
)
print("\n=== Additional metrics ===")
#> [1] "\n=== Additional metrics ==="
print(eval_table)
#>                          Equation Complexity         Loss       AIC
#> 0                              u1          1 7.930425e+01  578.8935
#> 1                           u1⋅x1          2 1.588952e+01  450.2830
#> 2                         u1⋅x1^2          3 4.898860e-01  171.9436
#> 3                 u1⋅x1^2 + -1⋅u2          4 2.643394e-01  124.5885
#> 4                 u1⋅x1^2 + u2⋅x2          5 7.573504e-09 -206.0486
#> 5         u1⋅x1^2 + u2⋅x2 + -1⋅u3          6 7.556009e-09 -204.0486
#> 6   u1⋅x1^2 + u2⋅x2 + -1⋅u3⋅x1^-1          8 7.284390e-09 -204.0486
#> 7 u1⋅x1^2 + u2⋅x2 + -1⋅u3⋅x1^2^-1          9 7.245014e-09 -204.0486
#>           AIC_w         RMSE      Elbow
#> 0 1.694368e-171 8.905294e+00 0.01615145
#> 1 1.433616e-143 3.986166e+00 0.07100753
#> 2  3.954274e-83 6.999186e-01 0.84262426
#> 3  7.587431e-73 5.141395e-01 0.83529735
#> 4  4.753667e-01 8.702588e-05 0.16452063
#> 5  1.748776e-01 8.692530e-05 0.08365697
#> 6  1.748778e-01 8.534864e-05 0.00000000
#> 7  1.748778e-01 8.511765e-05 0.00000000