未验证 提交 59dab6cb 编写于 作者: X Xiaohai Xu 提交者: GitHub

#1653 IndexFlat performance improvement for NQ < thread_number (#1674)

* Optimize index flat L2/IP for SSE
Signed-off-by: Nsahuang <xiaohai.xu@zilliz.com>

* parallel optimization
Signed-off-by: Nsahuang <xiaohai.xu@zilliz.com>

* fix threshold
Signed-off-by: Nsahuang <xiaohai.xu@zilliz.com>

* add changelog
Signed-off-by: Nsahuang <xiaohai.xu@zilliz.com>

* add changelog
Signed-off-by: Nsahuang <xiaohai.xu@zilliz.com>
Co-authored-by: Nsahuang <xiaohai.xu@zilliz.com>
上级 3de34d38
......@@ -19,8 +19,9 @@ Please mark all change in change log and use the issue from GitHub
- \#1546 Move Config.cpp to config directory
- \#1547 Rename storage/file to storage/disk and rename classes
- \#1548 Move store/Directory to storage/Operation and add FSHandler
- \#1649 Fix Milvus crash on old CPU
- \#1619 Improve compact performance
- \#1649 Fix Milvus crash on old CPU
- \#1653 IndexFlat performance improvement for NQ < thread_number
## Task
......
......@@ -33,7 +33,7 @@ namespace faiss {
if (init_heap) ha->heapify ();
int thread_max_num = omp_get_max_threads();
if (ha->nh < thread_max_num) {
if (ha->nh < 4) {
// omp for n2
int all_hash_size = thread_max_num * k;
float *value = new float[all_hash_size];
......
......@@ -152,39 +152,84 @@ static void knn_inner_product_sse (const float * x,
ConcurrentBitsetPtr bitset = nullptr)
{
size_t k = res->k;
size_t check_period = InterruptCallback::get_period_hint (ny * d);
check_period *= omp_get_max_threads();
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
size_t i1 = std::min(i0 + check_period, nx);
size_t thread_max_num = omp_get_max_threads();
if (nx < 4) {
// omp for ny
size_t all_hash_size = thread_max_num * k;
float *value = new float[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
for (size_t i = 0; i < nx; i++) {
// init hash
for (size_t i = 0; i < all_hash_size; i++) {
value[i] = -1.0 / 0.0;
}
const float *x_i = x + i * d;
#pragma omp parallel for
for (size_t i = i0; i < i1; i++) {
const float * x_i = x + i * d;
const float * y_j = y;
for (size_t j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)) {
const float *y_j = y + j * d;
float ip = fvec_inner_product (x_i, y_j, d);
size_t thread_no = omp_get_thread_num();
float * __restrict val_ = value + thread_no * k;
int64_t * __restrict ids_ = labels + thread_no * k;
if (ip > val_[0]) {
minheap_pop (k, val_, ids_);
minheap_push (k, val_, ids_, ip, j);
}
}
}
// merge hash
float * __restrict simi = res->get_val(i);
int64_t * __restrict idxi = res->get_ids (i);
minheap_heapify (k, simi, idxi);
for (size_t i = 0; i < all_hash_size; i++) {
if (value[i] > simi[0]) {
minheap_pop (k, simi, idxi);
minheap_push (k, simi, idxi, value[i], labels[i]);
}
}
minheap_reorder (k, simi, idxi);
}
delete[] value;
delete[] labels;
for (size_t j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)){
float ip = fvec_inner_product (x_i, y_j, d);
} else {
size_t check_period = InterruptCallback::get_period_hint (ny * d);
check_period *= thread_max_num;
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
size_t i1 = std::min(i0 + check_period, nx);
#pragma omp parallel for
for (size_t i = i0; i < i1; i++) {
const float * x_i = x + i * d;
const float * y_j = y;
float * __restrict simi = res->get_val(i);
int64_t * __restrict idxi = res->get_ids (i);
if (ip > simi[0]) {
minheap_pop (k, simi, idxi);
minheap_push (k, simi, idxi, ip, j);
minheap_heapify (k, simi, idxi);
for (size_t j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)){
float ip = fvec_inner_product (x_i, y_j, d);
if (ip > simi[0]) {
minheap_pop (k, simi, idxi);
minheap_push (k, simi, idxi, ip, j);
}
}
y_j += d;
}
y_j += d;
minheap_reorder (k, simi, idxi);
}
minheap_reorder (k, simi, idxi);
InterruptCallback::check ();
}
InterruptCallback::check ();
}
}
static void knn_L2sqr_sse (
......@@ -196,37 +241,87 @@ static void knn_L2sqr_sse (
{
size_t k = res->k;
size_t check_period = InterruptCallback::get_period_hint (ny * d);
check_period *= omp_get_max_threads();
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
size_t i1 = std::min(i0 + check_period, nx);
size_t thread_max_num = omp_get_max_threads();
if (nx < 4) {
// omp for ny
size_t all_hash_size = thread_max_num * k;
float *value = new float[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
for (size_t i = 0; i < nx; i++) {
// init hash
for (size_t i = 0; i < all_hash_size; i++) {
value[i] = 1.0 / 0.0;
}
for (size_t i = 0; i < k; i++) {
labels[i] = -1;
}
const float *x_i = x + i * d;
#pragma omp parallel for
for (size_t i = i0; i < i1; i++) {
const float * x_i = x + i * d;
const float * y_j = y;
size_t j;
float * simi = res->get_val(i);
int64_t * idxi = res->get_ids (i);
maxheap_heapify (k, simi, idxi);
for (j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)){
for (size_t j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)) {
const float *y_j = y + j * d;
float disij = fvec_L2sqr (x_i, y_j, d);
if (disij < simi[0]) {
maxheap_pop (k, simi, idxi);
maxheap_push (k, simi, idxi, disij, j);
size_t thread_no = omp_get_thread_num();
float * __restrict val_ = value + thread_no * k;
int64_t * __restrict ids_ = labels + thread_no * k;
if (disij < val_[0]) {
maxheap_pop (k, val_, ids_);
maxheap_push (k, val_, ids_, disij, j);
}
}
y_j += d;
}
// merge hash
float * __restrict simi = res->get_val(i);
int64_t * __restrict idxi = res->get_ids (i);
memcpy(simi, value, k * sizeof(float));
memcpy(idxi, labels, k * sizeof(int64_t));
maxheap_heapify (k, simi, idxi, value, labels, k);
for (size_t i = k; i < all_hash_size; i++) {
if (value[i] < simi[0]) {
maxheap_pop (k, simi, idxi);
maxheap_push (k, simi, idxi, value[i], labels[i]);
}
}
maxheap_reorder (k, simi, idxi);
}
InterruptCallback::check ();
}
delete[] value;
delete[] labels;
} else {
size_t check_period = InterruptCallback::get_period_hint (ny * d);
check_period *= thread_max_num;
for (size_t i0 = 0; i0 < nx; i0 += check_period) {
size_t i1 = std::min(i0 + check_period, nx);
#pragma omp parallel for
for (size_t i = i0; i < i1; i++) {
const float * x_i = x + i * d;
const float * y_j = y;
float * simi = res->get_val(i);
int64_t * idxi = res->get_ids (i);
maxheap_heapify (k, simi, idxi);
for (size_t j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)){
float disij = fvec_L2sqr (x_i, y_j, d);
if (disij < simi[0]) {
maxheap_pop (k, simi, idxi);
maxheap_push (k, simi, idxi, disij, j);
}
}
y_j += d;
}
maxheap_reorder (k, simi, idxi);
}
InterruptCallback::check ();
}
}
}
......@@ -899,4 +994,4 @@ void pairwise_L2sqr (int64_t d,
}
} // namespace faiss
} // namespace faiss
\ No newline at end of file
......@@ -281,7 +281,7 @@ void hammings_knn_hc (
if (init_heap) ha->heapify ();
int thread_max_num = omp_get_max_threads();
if (ha->nh < thread_max_num) {
if (ha->nh < 4) {
// omp for n2
int all_hash_size = thread_max_num * k;
hamdis_t *value = new hamdis_t[all_hash_size];
......@@ -432,7 +432,7 @@ void hammings_knn_hc_1 (
}
int thread_max_num = omp_get_max_threads();
if (ha->nh < thread_max_num) {
if (ha->nh < 4) {
// omp for n2
int all_hash_size = thread_max_num * k;
hamdis_t *value = new hamdis_t[all_hash_size];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册