From 20b23da8e2b38f26904b6dd60ba46d6575f6ad61 Mon Sep 17 00:00:00 2001 From: Danny <33044223+danielenricocahall@users.noreply.github.com> Date: Fri, 4 Sep 2020 13:01:05 -0400 Subject: [PATCH] Merge pull request #18061 from danielenricocahall:fix-kd-tree Fix KD Tree kNN Implementation * Make KDTree mode in kNN functional remove docs and revert change Make KDTree mode in kNN functional spacing Make KDTree mode in kNN functional fix window compilations warnings Make KDTree mode in kNN functional fix window compilations warnings Make KDTree mode in kNN functional casting Make KDTree mode in kNN functional formatting Make KDTree mode in kNN functional * test coding style --- modules/ml/src/kdtree.cpp | 3 +-- modules/ml/src/knearest.cpp | 17 ++---------- modules/ml/test/test_knearest.cpp | 45 +++++++++++++++++++++++++++---- 3 files changed, 43 insertions(+), 22 deletions(-) diff --git a/modules/ml/src/kdtree.cpp b/modules/ml/src/kdtree.cpp index 1ab8400936..a80e12964a 100644 --- a/modules/ml/src/kdtree.cpp +++ b/modules/ml/src/kdtree.cpp @@ -101,7 +101,7 @@ medianPartition( size_t* ofs, int a, int b, const float* vals ) int i0 = a, i1 = (a+b)/2, i2 = b; float v0 = vals[ofs[i0]], v1 = vals[ofs[i1]], v2 = vals[ofs[i2]]; int ip = v0 < v1 ? (v1 < v2 ? i1 : v0 < v2 ? i2 : i0) : - v0 < v2 ? i0 : (v1 < v2 ? i2 : i1); + v0 < v2 ? (v1 == v0 ? i2 : i0): (v1 < v2 ? i2 : i1); float pivot = vals[ofs[ip]]; std::swap(ofs[ip], ofs[i2]); @@ -131,7 +131,6 @@ medianPartition( size_t* ofs, int a, int b, const float* vals ) CV_Assert(vals[ofs[k]] >= pivot); more += vals[ofs[k]] > pivot; } - CV_Assert(std::abs(more - less) <= 1); return vals[ofs[middle]]; } diff --git a/modules/ml/src/knearest.cpp b/modules/ml/src/knearest.cpp index ca23d0f4d6..3d8f9b5d2e 100644 --- a/modules/ml/src/knearest.cpp +++ b/modules/ml/src/knearest.cpp @@ -381,36 +381,23 @@ public: Mat res, nr, d; if( _results.needed() ) { - _results.create(testcount, 1, CV_32F); res = _results.getMat(); } if( _neighborResponses.needed() ) { - _neighborResponses.create(testcount, k, CV_32F); nr = _neighborResponses.getMat(); } if( _dists.needed() ) { - _dists.create(testcount, k, CV_32F); d = _dists.getMat(); } for (int i=0; ii) - { - _res = res.row(i); - } - if (nr.rows>i) - { - _nr = nr.row(i); - } - if (d.rows>i) - { - _d = d.row(i); - } tr.findNearest(test_samples.row(i), k, Emax, _res, _nr, _d, noArray()); + res.push_back(_res.t()); + _results.assign(res); } return result; // currently always 0 diff --git a/modules/ml/test/test_knearest.cpp b/modules/ml/test/test_knearest.cpp index 49e6b0d12a..80baed9626 100644 --- a/modules/ml/test/test_knearest.cpp +++ b/modules/ml/test/test_knearest.cpp @@ -37,18 +37,31 @@ TEST(ML_KNearest, accuracy) EXPECT_LE(err, 0.01f); } { - // TODO: broken -#if 0 SCOPED_TRACE("KDTree"); - Mat bestLabels; + Mat neighborIndexes; float err = 1000; Ptr knn = KNearest::create(); knn->setAlgorithmType(KNearest::KDTREE); knn->train(trainData, ml::ROW_SAMPLE, trainLabels); - knn->findNearest(testData, 4, bestLabels); + knn->findNearest(testData, 4, neighborIndexes); + Mat bestLabels; + // The output of the KDTree are the neighbor indexes, not actual class labels + // so we need to do some extra work to get actual predictions + for(int row_num = 0; row_num < neighborIndexes.rows; ++row_num){ + vector labels; + for(int index = 0; index < neighborIndexes.row(row_num).cols; ++index) { + labels.push_back(trainLabels.at(neighborIndexes.row(row_num).at(0, index) , 0)); + } + // computing the mode of the output class predictions to determine overall prediction + std::vector histogram(3,0); + for( int i=0; i<3; ++i ) + ++histogram[ static_cast(labels[i]) ]; + int bestLabel = static_cast(std::max_element( histogram.begin(), histogram.end() ) - histogram.begin()); + bestLabels.push_back(bestLabel); + } + bestLabels.convertTo(bestLabels, testLabels.type()); EXPECT_TRUE(calcErr( bestLabels, testLabels, sizes, err, true )); EXPECT_LE(err, 0.01f); -#endif } } @@ -74,4 +87,26 @@ TEST(ML_KNearest, regression_12347) EXPECT_EQ(2, zBestLabels.at(1,0)); } +TEST(ML_KNearest, bug_11877) +{ + Mat trainData = (Mat_(5,2) << 3, 3, 3, 3, 4, 4, 4, 4, 4, 4); + Mat trainLabels = (Mat_(5,1) << 0, 0, 1, 1, 1); + + Ptr knnKdt = KNearest::create(); + knnKdt->setAlgorithmType(KNearest::KDTREE); + knnKdt->setIsClassifier(true); + + knnKdt->train(trainData, ml::ROW_SAMPLE, trainLabels); + + Mat testData = (Mat_(2,2) << 3.1, 3.1, 4, 4.1); + Mat testLabels = (Mat_(2,1) << 0, 1); + Mat result; + + knnKdt->findNearest(testData, 1, result); + + EXPECT_EQ(1, int(result.at(0, 0))); + EXPECT_EQ(2, int(result.at(1, 0))); + EXPECT_EQ(0, trainLabels.at(result.at(0, 0), 0)); +} + }} // namespace -- GitLab