未验证 提交 b4bddc08 编写于 作者: S shengjun.li 提交者: GitHub

add clustering configuration item (#3519)

Signed-off-by: Nshengjun.li <shengjun.li@zilliz.com>
上级 a8925ca6
...@@ -177,6 +177,9 @@ ConfigMgr::ConfigMgr() { ...@@ -177,6 +177,9 @@ ConfigMgr::ConfigMgr() {
{"engine.omp_thread_num", {"engine.omp_thread_num",
CreateIntegerConfig("engine.omp_thread_num", true, 0, std::numeric_limits<int64_t>::max(), CreateIntegerConfig("engine.omp_thread_num", true, 0, std::numeric_limits<int64_t>::max(),
&config.engine.omp_thread_num.value, 0, nullptr, nullptr)}, &config.engine.omp_thread_num.value, 0, nullptr, nullptr)},
{"engine.clustering_type",
CreateEnumConfig("engine.clustering_type", false, &ClusteringMap, &config.engine.clustering_type.value,
ClusteringType::K_MEANS, nullptr, nullptr)},
{"engine.simd_type", CreateEnumConfig("engine.simd_type", false, &SimdMap, &config.engine.simd_type.value, {"engine.simd_type", CreateEnumConfig("engine.simd_type", false, &SimdMap, &config.engine.simd_type.value,
SimdType::AUTO, nullptr, nullptr)}, SimdType::AUTO, nullptr, nullptr)},
......
...@@ -63,6 +63,16 @@ const configEnum SimdMap{ ...@@ -63,6 +63,16 @@ const configEnum SimdMap{
{"avx512", SimdType::AVX512}, {"avx512", SimdType::AVX512},
}; };
enum ClusteringType {
K_MEANS = 1,
K_MEANS_PLUS_PLUS,
};
const configEnum ClusteringMap{
{"k-means", ClusteringType::K_MEANS},
{"k-means++", ClusteringType::K_MEANS_PLUS_PLUS},
};
struct ServerConfig { struct ServerConfig {
using String = ConfigValue<std::string>; using String = ConfigValue<std::string>;
using Bool = ConfigValue<bool>; using Bool = ConfigValue<bool>;
...@@ -116,6 +126,7 @@ struct ServerConfig { ...@@ -116,6 +126,7 @@ struct ServerConfig {
Integer search_combine_nq{0}; Integer search_combine_nq{0};
Integer use_blas_threshold{0}; Integer use_blas_threshold{0};
Integer omp_thread_num{0}; Integer omp_thread_num{0};
Integer clustering_type{0};
Integer simd_type{0}; Integer simd_type{0};
} engine; } engine;
......
...@@ -260,7 +260,8 @@ int split_clusters (size_t d, size_t k, size_t n, ...@@ -260,7 +260,8 @@ int split_clusters (size_t d, size_t k, size_t n,
} }
}; };
KmeansType kmeans_type = KmeansType::KMEANS;
ClusteringType clustering_type = ClusteringType::K_MEANS;
void Clustering::kmeans_algorithm(std::vector<int>& centroids_index, int64_t random_seed, void Clustering::kmeans_algorithm(std::vector<int>& centroids_index, int64_t random_seed,
size_t n_input_centroids, size_t d, size_t k, size_t n_input_centroids, size_t d, size_t k,
...@@ -328,18 +329,17 @@ void Clustering::kmeans_plus_plus_algorithm(std::vector<int>& centroids_index, i ...@@ -328,18 +329,17 @@ void Clustering::kmeans_plus_plus_algorithm(std::vector<int>& centroids_index, i
//calculate P(x) //calculate P(x)
#pragma omp parallel for #pragma omp parallel for
for (size_t point_it = 0; point_it < thread_max_num; point_it++) { for (size_t task_i = 0; task_i < thread_max_num; task_i++) {
size_t left = point_it == 0 ? 0 : task[point_it - 1]; size_t left = (task_i == 0) ? 0 : task[task_i - 1];
size_t right = task[point_it]; size_t right = task[task_i];
// cout <<"Thread = "<< omp_get_thread_num() <<" left = "<<left<<" right = "<<right << endl;
pre_sum[left] = dx_distance[left]; pre_sum[left] = dx_distance[left];
for (size_t j = left + 1; j < right; j++) { for (size_t j = left + 1; j < right; j++) {
pre_sum[j] = pre_sum[j - 1] + dx_distance[j]; pre_sum[j] = pre_sum[j - 1] + dx_distance[j];
} }
} }
float sum = 0.0; float sum = 0.0;
for (size_t point_it = 0; point_it < thread_max_num; point_it++) { for (size_t task_i = 0; task_i < thread_max_num; task_i++) {
sum += pre_sum[task[point_it] - 1]; sum += pre_sum[task[task_i] - 1];
} }
// the random num is [0,sum] // the random num is [0,sum]
...@@ -493,12 +493,14 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in, ...@@ -493,12 +493,14 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
int64_t random_seed = seed + 1 + redo * 15486557L; int64_t random_seed = seed + 1 + redo * 15486557L;
std::vector<int> centroids_index(nx); std::vector<int> centroids_index(nx);
if (KmeansType::KMEANS == kmeans_type) { if (ClusteringType::K_MEANS == clustering_type) {
//Use classic kmeans algorithm //Use classic kmeans algorithm
kmeans_algorithm(centroids_index, random_seed, n_input_centroids, d, k, nx, x_in); kmeans_algorithm(centroids_index, random_seed, n_input_centroids, d, k, nx, x_in);
} else if (KmeansType::KMEANS_PLUSPLUS == kmeans_type) { } else if (ClusteringType::K_MEANS_PLUS_PLUS == clustering_type) {
//Use kmeans++ algorithm //Use kmeans++ algorithm
kmeans_plus_plus_algorithm(centroids_index, random_seed, n_input_centroids, d, k, nx, x_in); kmeans_plus_plus_algorithm(centroids_index, random_seed, n_input_centroids, d, k, nx, x_in);
} else {
FAISS_THROW_FMT ("Clustering Type is knonws: %d", (int)clustering_type);
} }
centroids.resize(d * k); centroids.resize(d * k);
......
...@@ -16,17 +16,17 @@ ...@@ -16,17 +16,17 @@
namespace faiss { namespace faiss {
/** /**
* The algorithm of Kmeans Type * The algorithm of clustering
*/ */
enum KmeansType enum ClusteringType
{ {
KMEANS, K_MEANS,
KMEANS_PLUSPLUS, K_MEANS_PLUS_PLUS,
KMEANS_TWO, K_MEANS_TWO,
}; };
//The default algorithm use the KMEANS_PLUSPLUS //The default algorithm use the K_MEANS
extern KmeansType kmeans_type; extern ClusteringType clustering_type;
/** Class for the clustering parameters. Can be passed to the /** Class for the clustering parameters. Can be passed to the
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <faiss/Clustering.h>
#include <faiss/utils/distances.h> #include <faiss/utils/distances.h>
#include "config/ServerConfig.h" #include "config/ServerConfig.h"
...@@ -78,6 +79,17 @@ DBWrapper::StartService() { ...@@ -78,6 +79,17 @@ DBWrapper::StartService() {
int64_t use_blas_threshold = config.engine.use_blas_threshold(); int64_t use_blas_threshold = config.engine.use_blas_threshold();
faiss::distance_compute_blas_threshold = use_blas_threshold; faiss::distance_compute_blas_threshold = use_blas_threshold;
int64_t clustering_type = config.engine.clustering_type();
switch (clustering_type) {
case ClusteringType::K_MEANS:
default:
faiss::clustering_type = faiss::ClusteringType::K_MEANS;
break;
case ClusteringType::K_MEANS_PLUS_PLUS:
faiss::clustering_type = faiss::ClusteringType::K_MEANS_PLUS_PLUS;
break;
}
// create db root folder // create db root folder
s = CommonUtil::CreateDirectory(opt.meta_.path_); s = CommonUtil::CreateDirectory(opt.meta_.path_);
if (!s.ok()) { if (!s.ok()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册