未验证 提交 b4bddc08 编写于 作者: S shengjun.li 提交者: GitHub

add clustering configuration item (#3519)

Signed-off-by: Nshengjun.li <shengjun.li@zilliz.com>
上级 a8925ca6
......@@ -177,6 +177,9 @@ ConfigMgr::ConfigMgr() {
{"engine.omp_thread_num",
CreateIntegerConfig("engine.omp_thread_num", true, 0, std::numeric_limits<int64_t>::max(),
&config.engine.omp_thread_num.value, 0, nullptr, nullptr)},
{"engine.clustering_type",
CreateEnumConfig("engine.clustering_type", false, &ClusteringMap, &config.engine.clustering_type.value,
ClusteringType::K_MEANS, nullptr, nullptr)},
{"engine.simd_type", CreateEnumConfig("engine.simd_type", false, &SimdMap, &config.engine.simd_type.value,
SimdType::AUTO, nullptr, nullptr)},
......
......@@ -63,6 +63,16 @@ const configEnum SimdMap{
{"avx512", SimdType::AVX512},
};
enum ClusteringType {
K_MEANS = 1,
K_MEANS_PLUS_PLUS,
};
const configEnum ClusteringMap{
{"k-means", ClusteringType::K_MEANS},
{"k-means++", ClusteringType::K_MEANS_PLUS_PLUS},
};
struct ServerConfig {
using String = ConfigValue<std::string>;
using Bool = ConfigValue<bool>;
......@@ -116,6 +126,7 @@ struct ServerConfig {
Integer search_combine_nq{0};
Integer use_blas_threshold{0};
Integer omp_thread_num{0};
Integer clustering_type{0};
Integer simd_type{0};
} engine;
......
......@@ -260,7 +260,8 @@ int split_clusters (size_t d, size_t k, size_t n,
}
};
KmeansType kmeans_type = KmeansType::KMEANS;
ClusteringType clustering_type = ClusteringType::K_MEANS;
void Clustering::kmeans_algorithm(std::vector<int>& centroids_index, int64_t random_seed,
size_t n_input_centroids, size_t d, size_t k,
......@@ -328,18 +329,17 @@ void Clustering::kmeans_plus_plus_algorithm(std::vector<int>& centroids_index, i
//calculate P(x)
#pragma omp parallel for
for (size_t point_it = 0; point_it < thread_max_num; point_it++) {
size_t left = point_it == 0 ? 0 : task[point_it - 1];
size_t right = task[point_it];
// cout <<"Thread = "<< omp_get_thread_num() <<" left = "<<left<<" right = "<<right << endl;
for (size_t task_i = 0; task_i < thread_max_num; task_i++) {
size_t left = (task_i == 0) ? 0 : task[task_i - 1];
size_t right = task[task_i];
pre_sum[left] = dx_distance[left];
for (size_t j = left + 1; j < right; j++) {
pre_sum[j] = pre_sum[j - 1] + dx_distance[j];
}
}
float sum = 0.0;
for (size_t point_it = 0; point_it < thread_max_num; point_it++) {
sum += pre_sum[task[point_it] - 1];
for (size_t task_i = 0; task_i < thread_max_num; task_i++) {
sum += pre_sum[task[task_i] - 1];
}
// the random num is [0,sum]
......@@ -493,12 +493,14 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
int64_t random_seed = seed + 1 + redo * 15486557L;
std::vector<int> centroids_index(nx);
if (KmeansType::KMEANS == kmeans_type) {
if (ClusteringType::K_MEANS == clustering_type) {
//Use classic kmeans algorithm
kmeans_algorithm(centroids_index, random_seed, n_input_centroids, d, k, nx, x_in);
} else if (KmeansType::KMEANS_PLUSPLUS == kmeans_type) {
} else if (ClusteringType::K_MEANS_PLUS_PLUS == clustering_type) {
//Use kmeans++ algorithm
kmeans_plus_plus_algorithm(centroids_index, random_seed, n_input_centroids, d, k, nx, x_in);
} else {
FAISS_THROW_FMT ("Clustering Type is knonws: %d", (int)clustering_type);
}
centroids.resize(d * k);
......
......@@ -16,17 +16,17 @@
namespace faiss {
/**
* The algorithm of Kmeans Type
* The algorithm of clustering
*/
enum KmeansType
enum ClusteringType
{
KMEANS,
KMEANS_PLUSPLUS,
KMEANS_TWO,
K_MEANS,
K_MEANS_PLUS_PLUS,
K_MEANS_TWO,
};
//The default algorithm use the KMEANS_PLUSPLUS
extern KmeansType kmeans_type;
//The default algorithm use the K_MEANS
extern ClusteringType clustering_type;
/** Class for the clustering parameters. Can be passed to the
......
......@@ -16,6 +16,7 @@
#include <string>
#include <vector>
#include <faiss/Clustering.h>
#include <faiss/utils/distances.h>
#include "config/ServerConfig.h"
......@@ -78,6 +79,17 @@ DBWrapper::StartService() {
int64_t use_blas_threshold = config.engine.use_blas_threshold();
faiss::distance_compute_blas_threshold = use_blas_threshold;
int64_t clustering_type = config.engine.clustering_type();
switch (clustering_type) {
case ClusteringType::K_MEANS:
default:
faiss::clustering_type = faiss::ClusteringType::K_MEANS;
break;
case ClusteringType::K_MEANS_PLUS_PLUS:
faiss::clustering_type = faiss::ClusteringType::K_MEANS_PLUS_PLUS;
break;
}
// create db root folder
s = CommonUtil::CreateDirectory(opt.meta_.path_);
if (!s.ok()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册