From 35a8e1fa3fa1bd1689119f437bd2e3a27d548332 Mon Sep 17 00:00:00 2001
From: Juha Reunanen <juha.reunanen@aiforia.com>
Date: Sun, 28 Sep 2025 17:19:52 +0300
Subject: [PATCH] Improve the somewhat flaky `test_loss_multibinary_log` by
 avoiding samples very close to class boundaries (#3112)

---
 dlib/test/dnn.cpp | 45 +++++++++++++++++++++++++++++++++++++++++----
 1 file changed, 41 insertions(+), 4 deletions(-)

diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp
index c564e277e1..cabb7d71e3 100644
--- a/dlib/test/dnn.cpp
+++ b/dlib/test/dnn.cpp
@@ -4457,13 +4457,50 @@ void test_multm_prev()
 
         for (size_t i = 0; i < labels.size(); ++i)
         {
-            matrix<float, 0, 1> x = matrix_cast<float>(randm(dims, 1)) * rnd.get_double_in_range(1, 9);
-            const auto norm = sqrt(sum(squared(x)));
-            if (norm < 3)
+            const double class_boundary_1 = 3.0;
+            const double class_boundary_2 = 6.0;
+
+            const double desired_margin = 0.1;
+
+            const auto get_random_matrix = [&rnd, dims]()
+            {
+                return matrix<float, 0, 1>(matrix_cast<float>(randm(dims, 1)) * rnd.get_double_in_range(1, 9));
+            };
+
+            const auto get_distance_from_nearest_class_boundary = [class_boundary_1, class_boundary_2](double norm)
+            {
+                return std::min(
+                    std::abs(norm - class_boundary_1),
+                    std::abs(norm - class_boundary_2)
+                );
+            };
+
+            auto x = get_random_matrix();
+            auto norm = sqrt(sum(squared(x)));
+            auto distance_from_nearest_class_boundary = get_distance_from_nearest_class_boundary(norm);
+
+            // Try again if the newly generated sample is very close to either of the class boundaries
+            int retry_counter = 0;
+            const int max_retry_counter = 10;
+            while (distance_from_nearest_class_boundary < desired_margin && ++retry_counter <= max_retry_counter)
+            {
+                const auto new_x = get_random_matrix();
+                const auto new_norm = sqrt(sum(squared(new_x)));
+                const auto new_distance_from_nearest_class_boundary = get_distance_from_nearest_class_boundary(new_norm);
+
+                if (new_distance_from_nearest_class_boundary > distance_from_nearest_class_boundary)
+                {
+                    x = new_x;
+                    norm = new_norm;
+                    distance_from_nearest_class_boundary = new_distance_from_nearest_class_boundary;
+                }
+            }
+
+            if (norm < class_boundary_1)
             {
                 labels[i][0] = 1.f;
             }
-            else if (3 <= norm && norm < 6)
+            else if (class_boundary_1 <= norm && norm < class_boundary_2)
             {
                 labels[i][0] = 1.f;
                 labels[i][1] = 1.f;
