未验证 提交 77a12391 编写于 作者: M Markus Vieth 提交者: GitHub

Merge pull request #4570 from Yuki992411/ml-kmeans

[ml] Fix un-initialized centroids bug (k-means)
......@@ -97,6 +97,10 @@ pcl::Kmeans::computeCentroids()
for (Centroids::value_type& centroid : centroids_) {
PointId num_points_in_cluster = 0;
// Set zero
for (Point::value_type& c : centroid)
c = 0.0;
// For each PointId in this set
for (const auto& pid : clusters_to_points_[cid]) {
Point p = data_[pid];
......
......@@ -43,6 +43,7 @@ add_subdirectory(gpu)
add_subdirectory(io)
add_subdirectory(kdtree)
add_subdirectory(keypoints)
add_subdirectory(ml)
add_subdirectory(people)
add_subdirectory(octree)
add_subdirectory(outofcore)
......
set(SUBSYS_NAME tests_ml)
set(SUBSYS_DESC "Point cloud library ml module unit tests")
PCL_SET_TEST_DEPENDENCIES(SUBSYS_DEPS ml)
set(DEFAULT ON)
set(build TRUE)
PCL_SUBSYS_OPTION(build "${SUBSYS_NAME}" "${SUBSYS_DESC}" ${DEFAULT} "${REASON}")
PCL_SUBSYS_DEPEND(build "${SUBSYS_NAME}" DEPS ${SUBSYS_DEPS} OPT_DEPS ${OPT_DEPS})
if(NOT build)
return()
endif()
PCL_ADD_TEST(ml_kmeans test_ml_kmeans FILES test_kmeans.cpp LINK_WITH pcl_gtest pcl_common pcl_ml)
/*
* SPDX-License-Identifier: BSD-3-Clause
*
* Point Cloud Library (PCL) - www.pointclouds.org
* Copyright (c) 2020-, Open Perception
*
* All rights reserved
*/
#include <pcl/test/gtest.h>
#include <pcl/common/random.h>
#include <pcl/ml/kmeans.h>
using namespace pcl;
using namespace pcl::common;
using Point = std::vector<float>;
// Prepare random number generator in PCL
UniformGenerator<float> engine (-100000.0, 100000.0, 2021);
class SampleDataChecker
{
public:
int data_size_;
int dim_;
int cluster_size_;
std::vector<Point> data_sequence_;
std::vector<Point> answer_centroids_;
// Create sample data
void
createDataSequence ()
{
for (int data_id = 0; data_id < data_size_; ++data_id)
{
Point data;
for (int dim_i = 0; dim_i < dim_; ++dim_i)
data.push_back (engine.run ());
data_sequence_.push_back (data);
}
}
void
testKmeans (Kmeans& k_means)
{
k_means.setClusterSize (cluster_size_);
k_means.setInputData (data_sequence_);
k_means.initialClusterPoints ();
k_means.computeCentroids ();
// Input centroids that should be the correct answer
answer_centroids_ = k_means.get_centroids ();
// If centroids_ was initialized before calculating it,
// then it should not change
// no matter how many times this class method is called.
k_means.computeCentroids ();
k_means.computeCentroids ();
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////
TEST (ComputeCentroids, Case1)
{
// Create sample data sequence
SampleDataChecker sdc;
sdc.data_size_ = 20;
sdc.dim_ = 21;
sdc.cluster_size_ = 9;
sdc.createDataSequence ();
// Compute centroids with K-means
Kmeans k_means (sdc.data_size_, sdc.dim_);
sdc.testKmeans (k_means);
// Evaluate if the two centroids are the same
EXPECT_EQ (sdc.cluster_size_, k_means.get_centroids ().size ());
EXPECT_EQ (sdc.answer_centroids_, k_means.get_centroids ());
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////
TEST (ComputeCentroids, Case2)
{
// Create sample data sequence
SampleDataChecker sdc;
sdc.data_size_ = 1;
sdc.dim_ = 1;
sdc.cluster_size_ = 1;
sdc.createDataSequence ();
// Compute centroids with K-means
Kmeans k_means (sdc.data_size_, sdc.dim_);
sdc.testKmeans (k_means);
// Evaluate if the two centroids are the same
EXPECT_EQ (sdc.cluster_size_, k_means.get_centroids ().size ());
EXPECT_EQ (sdc.answer_centroids_, k_means.get_centroids ());
}
/* ---[ */
int
main (int argc, char** argv)
{
testing::InitGoogleTest (&argc, argv);
return (RUN_ALL_TESTS ());
}
/* ]--- */
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册