From b4bddc080d79e0f14356dcf0f7d5903c0a7261b9 Mon Sep 17 00:00:00 2001 From: "shengjun.li" Date: Sat, 29 Aug 2020 19:02:15 +0800 Subject: [PATCH] add clustering configuration item (#3519) Signed-off-by: shengjun.li --- core/src/config/ConfigMgr.cpp | 3 +++ core/src/config/ServerConfig.h | 11 ++++++++++ .../src/index/thirdparty/faiss/Clustering.cpp | 20 ++++++++++--------- core/src/index/thirdparty/faiss/Clustering.h | 14 ++++++------- core/src/server/DBWrapper.cpp | 12 +++++++++++ 5 files changed, 44 insertions(+), 16 deletions(-) diff --git a/core/src/config/ConfigMgr.cpp b/core/src/config/ConfigMgr.cpp index 43175401..41a586e1 100644 --- a/core/src/config/ConfigMgr.cpp +++ b/core/src/config/ConfigMgr.cpp @@ -177,6 +177,9 @@ ConfigMgr::ConfigMgr() { {"engine.omp_thread_num", CreateIntegerConfig("engine.omp_thread_num", true, 0, std::numeric_limits::max(), &config.engine.omp_thread_num.value, 0, nullptr, nullptr)}, + {"engine.clustering_type", + CreateEnumConfig("engine.clustering_type", false, &ClusteringMap, &config.engine.clustering_type.value, + ClusteringType::K_MEANS, nullptr, nullptr)}, {"engine.simd_type", CreateEnumConfig("engine.simd_type", false, &SimdMap, &config.engine.simd_type.value, SimdType::AUTO, nullptr, nullptr)}, diff --git a/core/src/config/ServerConfig.h b/core/src/config/ServerConfig.h index 2d9cd117..78c9bf77 100644 --- a/core/src/config/ServerConfig.h +++ b/core/src/config/ServerConfig.h @@ -63,6 +63,16 @@ const configEnum SimdMap{ {"avx512", SimdType::AVX512}, }; +enum ClusteringType { + K_MEANS = 1, + K_MEANS_PLUS_PLUS, +}; + +const configEnum ClusteringMap{ + {"k-means", ClusteringType::K_MEANS}, + {"k-means++", ClusteringType::K_MEANS_PLUS_PLUS}, +}; + struct ServerConfig { using String = ConfigValue; using Bool = ConfigValue; @@ -116,6 +126,7 @@ struct ServerConfig { Integer search_combine_nq{0}; Integer use_blas_threshold{0}; Integer omp_thread_num{0}; + Integer clustering_type{0}; Integer simd_type{0}; } engine; diff --git a/core/src/index/thirdparty/faiss/Clustering.cpp b/core/src/index/thirdparty/faiss/Clustering.cpp index 017fb199..43df9b5e 100755 --- a/core/src/index/thirdparty/faiss/Clustering.cpp +++ b/core/src/index/thirdparty/faiss/Clustering.cpp @@ -260,7 +260,8 @@ int split_clusters (size_t d, size_t k, size_t n, } }; -KmeansType kmeans_type = KmeansType::KMEANS; + +ClusteringType clustering_type = ClusteringType::K_MEANS; void Clustering::kmeans_algorithm(std::vector& centroids_index, int64_t random_seed, size_t n_input_centroids, size_t d, size_t k, @@ -328,18 +329,17 @@ void Clustering::kmeans_plus_plus_algorithm(std::vector& centroids_index, i //calculate P(x) #pragma omp parallel for - for (size_t point_it = 0; point_it < thread_max_num; point_it++) { - size_t left = point_it == 0 ? 0 : task[point_it - 1]; - size_t right = task[point_it]; - // cout <<"Thread = "<< omp_get_thread_num() <<" left = "< #include +#include #include #include "config/ServerConfig.h" @@ -78,6 +79,17 @@ DBWrapper::StartService() { int64_t use_blas_threshold = config.engine.use_blas_threshold(); faiss::distance_compute_blas_threshold = use_blas_threshold; + int64_t clustering_type = config.engine.clustering_type(); + switch (clustering_type) { + case ClusteringType::K_MEANS: + default: + faiss::clustering_type = faiss::ClusteringType::K_MEANS; + break; + case ClusteringType::K_MEANS_PLUS_PLUS: + faiss::clustering_type = faiss::ClusteringType::K_MEANS_PLUS_PLUS; + break; + } + // create db root folder s = CommonUtil::CreateDirectory(opt.meta_.path_); if (!s.ok()) { -- GitLab