未验证 提交 5fc875a0 编写于 作者: S shengjun.li 提交者: GitHub

#1603 modify substructure/superstructure to perfect match (#1718)

* modify substructure/superstructure to perfect match
Signed-off-by: Nshengjun.li <shengjun.li@zilliz.com>

* Update cases
Signed-off-by: Nzw <zw@zilliz.com>

* Fix case bug
Signed-off-by: Nzw <zw@zilliz.com>

* set invalid distance infinite
Signed-off-by: Nshengjun.li <shengjun.li@zilliz.com>

* Add distance cases
Signed-off-by: Nzw <zw@zilliz.com>

* Fix test cases
Signed-off-by: Nzw <zw@zilliz.com>

* Re-trigger ci
Signed-off-by: Nzw <zw@zilliz.com>

* fix wrong code
Signed-off-by: Nshengjun.li <shengjun.li@zilliz.com>

* Fix test case
Signed-off-by: Nzw <zw@zilliz.com>

* Fix case
Signed-off-by: Nzw <zw@zilliz.com>

* Fix cases
Signed-off-by: Nzw <zw@zilliz.com>
Co-authored-by: Nzw <zw@zilliz.com>
上级 02cfbf31
......@@ -41,8 +41,7 @@ void IndexBinaryFlat::reset() {
void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k,
int32_t *distances, idx_t *labels, ConcurrentBitsetPtr bitset) const {
const idx_t block_size = query_batch_size;
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto ||
metric_type == METRIC_Substructure || metric_type == METRIC_Superstructure) {
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
float *D = reinterpret_cast<float*>(distances);
for (idx_t s = 0; s < n; s += block_size) {
idx_t nn = block_size;
......@@ -50,19 +49,14 @@ void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k,
nn = n - s;
}
if (use_heap) {
// We see the distances and labels as heaps.
float_maxheap_array_t res = {
size_t(nn), size_t(k), labels + s * k, D + s * k
};
// We see the distances and labels as heaps.
float_maxheap_array_t res = {
size_t(nn), size_t(k), labels + s * k, D + s * k
};
binary_distence_knn_hc(metric_type, &res, x + s * code_size, xb.data(), ntotal, code_size,
/* ordered = */ true, bitset);
binary_distence_knn_hc(metric_type, &res, x + s * code_size, xb.data(), ntotal, code_size,
/* ordered = */ true, bitset);
} else {
FAISS_THROW_MSG("tanimoto_knn_mc not implemented");
}
}
if (metric_type == METRIC_Tanimoto) {
for (int i = 0; i < k * n; i++) {
......@@ -70,6 +64,19 @@ void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k,
}
}
} else if (metric_type == METRIC_Substructure || metric_type == METRIC_Superstructure) {
float *D = reinterpret_cast<float*>(distances);
for (idx_t s = 0; s < n; s += block_size) {
idx_t nn = block_size;
if (s + block_size > n) {
nn = n - s;
}
// only match ids will be chosed, not to use heap
binary_distence_knn_mc(metric_type, x + s * code_size, xb.data(), nn, ntotal, k, code_size,
D + s * k, labels + s * k, bitset);
}
} else {
for (idx_t s = 0; s < n; s += block_size) {
idx_t nn = block_size;
......
......@@ -158,46 +158,176 @@ void binary_distence_knn_hc (
}
break;
default:
break;
}
}
template <class T>
static
void binary_distence_knn_mc(
int bytes_per_code,
const uint8_t * bs1,
const uint8_t * bs2,
size_t n1,
size_t n2,
size_t k,
float *distances,
int64_t *labels,
ConcurrentBitsetPtr bitset)
{
if ((bytes_per_code + sizeof(size_t) + k * sizeof(int64_t)) * n1 < size_1M) {
int thread_max_num = omp_get_max_threads();
size_t group_num = n1 * thread_max_num;
size_t *match_num = new size_t[group_num];
int64_t *match_data = new int64_t[group_num * k];
for (size_t i = 0; i < group_num; i++) {
match_num[i] = 0;
}
T *hc = new T[n1];
for (size_t i = 0; i < n1; i++) {
hc[i].set(bs1 + i * bytes_per_code, bytes_per_code);
}
#pragma omp parallel for
for (size_t j = 0; j < n2; j++) {
if(!bitset || !bitset->test(j)) {
int thread_no = omp_get_thread_num();
const uint8_t * bs2_ = bs2 + j * bytes_per_code;
for (size_t i = 0; i < n1; i++) {
if (hc[i].compute(bs2_)) {
size_t match_index = thread_no * n1 + i;
size_t &index = match_num[match_index];
if (index < k) {
match_data[match_index * k + index] = j;
index++;
}
}
}
}
}
for (size_t i = 0, ni = 0; i < n1; i++) {
size_t n_i = 0;
float *distances_i = distances + i * k;
int64_t *labels_i = labels + i * k;
for (size_t t = 0; t < thread_max_num && n_i < k; t++) {
size_t match_index = t * n1 + i;
size_t copy_num = std::min(k - n_i, match_num[match_index]);
memcpy(labels_i + n_i, match_data + match_index * k, copy_num * sizeof(int64_t));
memset(distances + n_i, 0, copy_num * sizeof(int32_t));
n_i += copy_num;
}
for (; n_i < k; n_i++) {
distances_i[n_i] = 1.0 / 0.0;
labels_i[n_i] = -1;
}
}
delete[] hc;
delete[] match_num;
delete[] match_data;
} else {
size_t *num = new size_t[n1];
for (size_t i = 0; i < n1; i++) {
num[i] = 0;
}
const size_t block_size = batch_size;
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
const size_t j1 = std::min(j0 + block_size, n2);
#pragma omp parallel for
for (size_t i = 0; i < n1; i++) {
size_t num_i = num[i];
if (num_i == k) continue;
float * dis = distances + i * k;
int64_t * lab = labels + i * k;
T hc (bs1 + i * bytes_per_code, bytes_per_code);
const uint8_t * bs2_ = bs2 + j0 * bytes_per_code;
for (size_t j = j0; j < j1; j++, bs2_ += bytes_per_code) {
if(!bitset || !bitset->test(j)){
if (hc.compute (bs2_)) {
dis[num_i] = 0;
lab[num_i] = j;
if (++num_i == k) break;
}
}
}
num[i] = num_i;
}
}
for (size_t i = 0; i < n1; i++) {
float * dis = distances + i * k;
int64_t * lab = labels + i * k;
for (size_t num_i = num[i]; num_i < k; num_i++) {
dis[num_i] = 1.0 / 0.0;
lab[num_i] = -1;
}
}
delete[] num;
}
}
void binary_distence_knn_mc (
MetricType metric_type,
const uint8_t * a,
const uint8_t * b,
size_t na,
size_t nb,
size_t k,
size_t ncodes,
float *distances,
int64_t *labels,
ConcurrentBitsetPtr bitset) {
switch (metric_type) {
case METRIC_Substructure:
switch (ncodes) {
#define binary_distence_knn_hc_Substructure(ncodes) \
#define binary_distence_knn_mc_Substructure(ncodes) \
case ncodes: \
binary_distence_knn_hc<faiss::SubstructureComputer ## ncodes> \
(ncodes, ha, a, b, nb, order, true, bitset); \
binary_distence_knn_mc<faiss::SubstructureComputer ## ncodes> \
(ncodes, a, b, na, nb, k, distances, labels, bitset); \
break;
binary_distence_knn_hc_Substructure(8);
binary_distence_knn_hc_Substructure(16);
binary_distence_knn_hc_Substructure(32);
binary_distence_knn_hc_Substructure(64);
binary_distence_knn_hc_Substructure(128);
binary_distence_knn_hc_Substructure(256);
binary_distence_knn_hc_Substructure(512);
#undef binary_distence_knn_hc_Substructure
binary_distence_knn_mc_Substructure(8);
binary_distence_knn_mc_Substructure(16);
binary_distence_knn_mc_Substructure(32);
binary_distence_knn_mc_Substructure(64);
binary_distence_knn_mc_Substructure(128);
binary_distence_knn_mc_Substructure(256);
binary_distence_knn_mc_Substructure(512);
#undef binary_distence_knn_mc_Substructure
default:
binary_distence_knn_hc<faiss::SubstructureComputerDefault>
(ncodes, ha, a, b, nb, order, true, bitset);
binary_distence_knn_mc<faiss::SubstructureComputerDefault>
(ncodes, a, b, na, nb, k, distances, labels, bitset);
break;
}
break;
case METRIC_Superstructure:
switch (ncodes) {
#define binary_distence_knn_hc_Superstructure(ncodes) \
#define binary_distence_knn_mc_Superstructure(ncodes) \
case ncodes: \
binary_distence_knn_hc<faiss::SuperstructureComputer ## ncodes> \
(ncodes, ha, a, b, nb, order, true, bitset); \
binary_distence_knn_mc<faiss::SuperstructureComputer ## ncodes> \
(ncodes, a, b, na, nb, k, distances, labels, bitset); \
break;
binary_distence_knn_hc_Superstructure(8);
binary_distence_knn_hc_Superstructure(16);
binary_distence_knn_hc_Superstructure(32);
binary_distence_knn_hc_Superstructure(64);
binary_distence_knn_hc_Superstructure(128);
binary_distence_knn_hc_Superstructure(256);
binary_distence_knn_hc_Superstructure(512);
#undef binary_distence_knn_hc_Superstructure
binary_distence_knn_mc_Superstructure(8);
binary_distence_knn_mc_Superstructure(16);
binary_distence_knn_mc_Superstructure(32);
binary_distence_knn_mc_Superstructure(64);
binary_distence_knn_mc_Superstructure(128);
binary_distence_knn_mc_Superstructure(256);
binary_distence_knn_mc_Superstructure(512);
#undef binary_distence_knn_mc_Superstructure
default:
binary_distence_knn_hc<faiss::SuperstructureComputerDefault>
(ncodes, ha, a, b, nb, order, true, bitset);
binary_distence_knn_mc<faiss::SuperstructureComputerDefault>
(ncodes, a, b, na, nb, k, distances, labels, bitset);
break;
}
break;
......@@ -207,4 +337,4 @@ void binary_distence_knn_hc (
}
}
}
} // namespace faiss
......@@ -32,6 +32,27 @@ namespace faiss {
int ordered,
ConcurrentBitsetPtr bitset = nullptr);
/** Return the k matched distances for a set of binary query vectors,
* using a max heap.
* @param a queries, size ha->nh * ncodes
* @param b database, size nb * ncodes
* @param na number of queries vectors
* @param nb number of database vectors
* @param k number of the matched vectors to return
* @param ncodes size of the binary codes (bytes)
*/
void binary_distence_knn_mc (
MetricType metric_type,
const uint8_t * a,
const uint8_t * b,
size_t na,
size_t nb,
size_t k,
size_t ncodes,
float *distances,
int64_t *labels,
ConcurrentBitsetPtr bitset);
} // namespace faiss
#include <faiss/utils/jaccard-inl.h>
......
......@@ -15,13 +15,9 @@ namespace faiss {
a0 = a[0];
}
inline float compute (const uint8_t *b8) const {
inline bool compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0);
int accu_den = popcount64 (b[0]);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
return (a0 & b[0]) == a0;
}
};
......@@ -41,13 +37,9 @@ namespace faiss {
a0 = a[0]; a1 = a[1];
}
inline float compute (const uint8_t *b8) const {
inline bool compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1);
int accu_den = popcount64 (b[0]) + popcount64 (b[1]);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
return (a0 & b[0]) == a0 && (a1 & b[1]) == a1;
}
};
......@@ -67,15 +59,10 @@ namespace faiss {
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
}
inline float compute (const uint8_t *b8) const {
inline bool compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3);
int accu_den = popcount64 (b[0]) + popcount64 (b[1]) +
popcount64 (b[2]) + popcount64 (b[3]);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 &&
(a2 & b[2]) == a2 && (a3 & b[3]) == a3;
}
};
......@@ -96,19 +83,12 @@ namespace faiss {
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
}
inline float compute (const uint8_t *b8) const {
inline bool compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7);
int accu_den = popcount64 (b[0]) + popcount64 (b[1]) +
popcount64 (b[2]) + popcount64 (b[3]) +
popcount64 (b[4]) + popcount64 (b[5]) +
popcount64 (b[6]) + popcount64 (b[7]);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 &&
(a2 & b[2]) == a2 && (a3 & b[3]) == a3 &&
(a4 & b[4]) == a4 && (a5 & b[5]) == a5 &&
(a6 & b[6]) == a6 && (a7 & b[7]) == a7;
}
};
......@@ -132,27 +112,16 @@ namespace faiss {
a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15];
}
inline float compute (const uint8_t *b16) const {
inline bool compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15);
int accu_den = popcount64 (b[0]) + popcount64 (b[1]) +
popcount64 (b[2]) + popcount64 (b[3]) +
popcount64 (b[4]) + popcount64 (b[5]) +
popcount64 (b[6]) + popcount64 (b[7]) +
popcount64 (b[8]) + popcount64 (b[9]) +
popcount64 (b[10]) + popcount64 (b[11]) +
popcount64 (b[12]) + popcount64 (b[13]) +
popcount64 (b[14]) + popcount64 (b[15]);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 &&
(a2 & b[2]) == a2 && (a3 & b[3]) == a3 &&
(a4 & b[4]) == a4 && (a5 & b[5]) == a5 &&
(a6 & b[6]) == a6 && (a7 & b[7]) == a7 &&
(a8 & b[8]) == a8 && (a9 & b[9]) == a9 &&
(a10 & b[10]) == a10 && (a11 & b[11]) == a11 &&
(a12 & b[12]) == a12 && (a13 & b[13]) == a13 &&
(a14 & b[14]) == a14 && (a15 & b[15]) == a15;
}
};
......@@ -182,43 +151,24 @@ namespace faiss {
a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31];
}
inline float compute (const uint8_t *b16) const {
inline bool compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15) +
popcount64 (b[16] & a16) + popcount64 (b[17] & a17) +
popcount64 (b[18] & a18) + popcount64 (b[19] & a19) +
popcount64 (b[20] & a20) + popcount64 (b[21] & a21) +
popcount64 (b[22] & a22) + popcount64 (b[23] & a23) +
popcount64 (b[24] & a24) + popcount64 (b[25] & a25) +
popcount64 (b[26] & a26) + popcount64 (b[27] & a27) +
popcount64 (b[28] & a28) + popcount64 (b[29] & a29) +
popcount64 (b[30] & a30) + popcount64 (b[31] & a31);
int accu_den = popcount64 (b[0]) + popcount64 (b[1]) +
popcount64 (b[2]) + popcount64 (b[3]) +
popcount64 (b[4]) + popcount64 (b[5]) +
popcount64 (b[6]) + popcount64 (b[7]) +
popcount64 (b[8]) + popcount64 (b[9]) +
popcount64 (b[10]) + popcount64 (b[11]) +
popcount64 (b[12]) + popcount64 (b[13]) +
popcount64 (b[14]) + popcount64 (b[15]) +
popcount64 (b[16]) + popcount64 (b[17]) +
popcount64 (b[18]) + popcount64 (b[19]) +
popcount64 (b[20]) + popcount64 (b[21]) +
popcount64 (b[22]) + popcount64 (b[23]) +
popcount64 (b[24]) + popcount64 (b[25]) +
popcount64 (b[26]) + popcount64 (b[27]) +
popcount64 (b[28]) + popcount64 (b[29]) +
popcount64 (b[30]) + popcount64 (b[31]);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 &&
(a2 & b[2]) == a2 && (a3 & b[3]) == a3 &&
(a4 & b[4]) == a4 && (a5 & b[5]) == a5 &&
(a6 & b[6]) == a6 && (a7 & b[7]) == a7 &&
(a8 & b[8]) == a8 && (a9 & b[9]) == a9 &&
(a10 & b[10]) == a10 && (a11 & b[11]) == a11 &&
(a12 & b[12]) == a12 && (a13 & b[13]) == a13 &&
(a14 & b[14]) == a14 && (a15 & b[15]) == a15 &&
(a16 & b[16]) == a16 && (a17 & b[17]) == a17 &&
(a18 & b[18]) == a18 && (a19 & b[19]) == a19 &&
(a20 & b[20]) == a20 && (a21 & b[21]) == a21 &&
(a22 & b[22]) == a22 && (a23 & b[23]) == a23 &&
(a24 & b[24]) == a24 && (a25 & b[25]) == a25 &&
(a26 & b[26]) == a26 && (a27 & b[27]) == a27 &&
(a28 & b[28]) == a28 && (a29 & b[29]) == a29 &&
(a30 & b[30]) == a30 && (a31 & b[31]) == a31;
}
};
......@@ -260,76 +210,41 @@ namespace faiss {
a60 = a[60]; a61 = a[61]; a62 = a[62]; a63 = a[63];
}
inline float compute (const uint8_t *b16) const {
inline bool compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15) +
popcount64 (b[16] & a16) + popcount64 (b[17] & a17) +
popcount64 (b[18] & a18) + popcount64 (b[19] & a19) +
popcount64 (b[20] & a20) + popcount64 (b[21] & a21) +
popcount64 (b[22] & a22) + popcount64 (b[23] & a23) +
popcount64 (b[24] & a24) + popcount64 (b[25] & a25) +
popcount64 (b[26] & a26) + popcount64 (b[27] & a27) +
popcount64 (b[28] & a28) + popcount64 (b[29] & a29) +
popcount64 (b[30] & a30) + popcount64 (b[31] & a31) +
popcount64 (b[32] & a32) + popcount64 (b[33] & a33) +
popcount64 (b[34] & a34) + popcount64 (b[35] & a35) +
popcount64 (b[36] & a36) + popcount64 (b[37] & a37) +
popcount64 (b[38] & a38) + popcount64 (b[39] & a39) +
popcount64 (b[40] & a40) + popcount64 (b[41] & a41) +
popcount64 (b[42] & a42) + popcount64 (b[43] & a43) +
popcount64 (b[44] & a44) + popcount64 (b[45] & a45) +
popcount64 (b[46] & a46) + popcount64 (b[47] & a47) +
popcount64 (b[48] & a48) + popcount64 (b[49] & a49) +
popcount64 (b[50] & a50) + popcount64 (b[51] & a51) +
popcount64 (b[52] & a52) + popcount64 (b[53] & a53) +
popcount64 (b[54] & a54) + popcount64 (b[55] & a55) +
popcount64 (b[56] & a56) + popcount64 (b[57] & a57) +
popcount64 (b[58] & a58) + popcount64 (b[59] & a59) +
popcount64 (b[60] & a60) + popcount64 (b[61] & a61) +
popcount64 (b[62] & a62) + popcount64 (b[63] & a63);
int accu_den = popcount64 (b[0]) + popcount64 (b[1]) +
popcount64 (b[2]) + popcount64 (b[3]) +
popcount64 (b[4]) + popcount64 (b[5]) +
popcount64 (b[6]) + popcount64 (b[7]) +
popcount64 (b[8]) + popcount64 (b[9]) +
popcount64 (b[10]) + popcount64 (b[11]) +
popcount64 (b[12]) + popcount64 (b[13]) +
popcount64 (b[14]) + popcount64 (b[15]) +
popcount64 (b[16]) + popcount64 (b[17]) +
popcount64 (b[18]) + popcount64 (b[19]) +
popcount64 (b[20]) + popcount64 (b[21]) +
popcount64 (b[22]) + popcount64 (b[23]) +
popcount64 (b[24]) + popcount64 (b[25]) +
popcount64 (b[26]) + popcount64 (b[27]) +
popcount64 (b[28]) + popcount64 (b[29]) +
popcount64 (b[30]) + popcount64 (b[31]) +
popcount64 (b[32]) + popcount64 (b[33]) +
popcount64 (b[34]) + popcount64 (b[35]) +
popcount64 (b[36]) + popcount64 (b[37]) +
popcount64 (b[38]) + popcount64 (b[39]) +
popcount64 (b[40]) + popcount64 (b[41]) +
popcount64 (b[42]) + popcount64 (b[43]) +
popcount64 (b[44]) + popcount64 (b[45]) +
popcount64 (b[46]) + popcount64 (b[47]) +
popcount64 (b[48]) + popcount64 (b[49]) +
popcount64 (b[50]) + popcount64 (b[51]) +
popcount64 (b[52]) + popcount64 (b[53]) +
popcount64 (b[54]) + popcount64 (b[55]) +
popcount64 (b[56]) + popcount64 (b[57]) +
popcount64 (b[58]) + popcount64 (b[59]) +
popcount64 (b[60]) + popcount64 (b[61]) +
popcount64 (b[62]) + popcount64 (b[63]);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 &&
(a2 & b[2]) == a2 && (a3 & b[3]) == a3 &&
(a4 & b[4]) == a4 && (a5 & b[5]) == a5 &&
(a6 & b[6]) == a6 && (a7 & b[7]) == a7 &&
(a8 & b[8]) == a8 && (a9 & b[9]) == a9 &&
(a10 & b[10]) == a10 && (a11 & b[11]) == a11 &&
(a12 & b[12]) == a12 && (a13 & b[13]) == a13 &&
(a14 & b[14]) == a14 && (a15 & b[15]) == a15 &&
(a16 & b[16]) == a16 && (a17 & b[17]) == a17 &&
(a18 & b[18]) == a18 && (a19 & b[19]) == a19 &&
(a20 & b[20]) == a20 && (a21 & b[21]) == a21 &&
(a22 & b[22]) == a22 && (a23 & b[23]) == a23 &&
(a24 & b[24]) == a24 && (a25 & b[25]) == a25 &&
(a26 & b[26]) == a26 && (a27 & b[27]) == a27 &&
(a28 & b[28]) == a28 && (a29 & b[29]) == a29 &&
(a30 & b[30]) == a30 && (a31 & b[31]) == a31 &&
(a32 & b[32]) == a32 && (a33 & b[33]) == a33 &&
(a34 & b[34]) == a34 && (a35 & b[35]) == a35 &&
(a36 & b[36]) == a36 && (a37 & b[37]) == a37 &&
(a38 & b[38]) == a38 && (a39 & b[39]) == a39 &&
(a40 & b[40]) == a40 && (a41 & b[41]) == a41 &&
(a42 & b[42]) == a42 && (a43 & b[43]) == a43 &&
(a44 & b[44]) == a44 && (a45 & b[45]) == a45 &&
(a46 & b[46]) == a46 && (a47 & b[47]) == a47 &&
(a48 & b[48]) == a48 && (a49 & b[49]) == a49 &&
(a50 & b[50]) == a50 && (a51 & b[51]) == a51 &&
(a52 & b[52]) == a52 && (a53 & b[53]) == a53 &&
(a54 & b[54]) == a54 && (a55 & b[55]) == a55 &&
(a56 & b[56]) == a56 && (a57 & b[57]) == a57 &&
(a58 & b[58]) == a58 && (a59 & b[59]) == a59 &&
(a60 & b[60]) == a60 && (a61 & b[61]) == a61 &&
(a62 & b[62]) == a62 && (a63 & b[63]) == a63;
}
};
......@@ -348,16 +263,14 @@ namespace faiss {
n = code_size;
}
float compute (const uint8_t *b8) const {
int accu_num = 0;
int accu_den = 0;
bool compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
for (int i = 0; i < n; i++) {
accu_num += popcount64(a[i] & b8[i]);
accu_den += popcount64(b8[i]);
if ((a[i] & b[i]) != a[i]) {
return false;
}
}
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
return true;
}
};
......
......@@ -2,7 +2,6 @@ namespace faiss {
struct SuperstructureComputer8 {
uint64_t a0;
float accu_den;
SuperstructureComputer8 () {}
......@@ -14,22 +13,17 @@ namespace faiss {
assert (code_size == 8);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0];
accu_den = (float)(popcount64 (a0));
}
inline float compute (const uint8_t *b8) const {
inline bool compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
return (a0 & b[0]) == b[0];
}
};
struct SuperstructureComputer16 {
uint64_t a0, a1;
float accu_den;
SuperstructureComputer16 () {}
......@@ -41,22 +35,17 @@ namespace faiss {
assert (code_size == 16);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0]; a1 = a[1];
accu_den = (float)(popcount64 (a0) + popcount64 (a1));
}
inline float compute (const uint8_t *b8) const {
inline bool compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1];
}
};
struct SuperstructureComputer32 {
uint64_t a0, a1, a2, a3;
float accu_den;
SuperstructureComputer32 () {}
......@@ -68,24 +57,18 @@ namespace faiss {
assert (code_size == 32);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
accu_den = (float)(popcount64 (a0) + popcount64 (a1) +
popcount64 (a2) + popcount64 (a3));
}
inline float compute (const uint8_t *b8) const {
inline bool compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] &&
(a2 & b[2]) == b[2] && (a3 & b[3]) == b[3];
}
};
struct SuperstructureComputer64 {
uint64_t a0, a1, a2, a3, a4, a5, a6, a7;
float accu_den;
SuperstructureComputer64 () {}
......@@ -98,21 +81,14 @@ namespace faiss {
const uint64_t *a = (uint64_t *)a8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
accu_den = (float)(popcount64 (a0) + popcount64 (a1) +
popcount64 (a2) + popcount64 (a3) +
popcount64 (a4) + popcount64 (a5) +
popcount64 (a6) + popcount64 (a7));
}
inline float compute (const uint8_t *b8) const {
inline bool compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] &&
(a2 & b[2]) == b[2] && (a3 & b[3]) == b[3] &&
(a4 & b[4]) == b[4] && (a5 & b[5]) == b[5] &&
(a6 & b[6]) == b[6] && (a7 & b[7]) == b[7];
}
};
......@@ -120,7 +96,6 @@ namespace faiss {
struct SuperstructureComputer128 {
uint64_t a0, a1, a2, a3, a4, a5, a6, a7,
a8, a9, a10, a11, a12, a13, a14, a15;
float accu_den;
SuperstructureComputer128 () {}
......@@ -135,29 +110,18 @@ namespace faiss {
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11];
a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15];
accu_den = (float)(popcount64 (a0) + popcount64 (a1) +
popcount64 (a2) + popcount64 (a3) +
popcount64 (a4) + popcount64 (a5) +
popcount64 (a6) + popcount64 (a7) +
popcount64 (a8) + popcount64 (a9) +
popcount64 (a10) + popcount64 (a11) +
popcount64 (a12) + popcount64 (a13) +
popcount64 (a14) + popcount64 (a15));
}
inline float compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
inline float compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] &&
(a2 & b[2]) == b[2] && (a3 & b[3]) == b[3] &&
(a4 & b[4]) == b[4] && (a5 & b[5]) == b[5] &&
(a6 & b[6]) == b[6] && (a7 & b[7]) == b[7] &&
(a8 & b[8]) == b[8] && (a9 & b[9]) == b[9] &&
(a10 & b[10]) == b[10] && (a11 & b[11]) == b[11] &&
(a12 & b[12]) == b[12] && (a13 & b[13]) == b[13] &&
(a14 & b[14]) == b[14] && (a15 & b[15]) == b[15];
}
};
......@@ -167,7 +131,6 @@ namespace faiss {
a8,a9,a10,a11,a12,a13,a14,a15,
a16,a17,a18,a19,a20,a21,a22,a23,
a24,a25,a26,a27,a28,a29,a30,a31;
float accu_den;
SuperstructureComputer256 () {}
......@@ -186,45 +149,26 @@ namespace faiss {
a20 = a[20]; a21 = a[21]; a22 = a[22]; a23 = a[23];
a24 = a[24]; a25 = a[25]; a26 = a[26]; a27 = a[27];
a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31];
accu_den = (float)(popcount64 (a0) + popcount64 (a1) +
popcount64 (a2) + popcount64 (a3) +
popcount64 (a4) + popcount64 (a5) +
popcount64 (a6) + popcount64 (a7) +
popcount64 (a8) + popcount64 (a9) +
popcount64 (a10) + popcount64 (a11) +
popcount64 (a12) + popcount64 (a13) +
popcount64 (a14) + popcount64 (a15) +
popcount64 (a16) + popcount64 (a17) +
popcount64 (a18) + popcount64 (a19) +
popcount64 (a20) + popcount64 (a21) +
popcount64 (a22) + popcount64 (a23) +
popcount64 (a24) + popcount64 (a25) +
popcount64 (a26) + popcount64 (a27) +
popcount64 (a28) + popcount64 (a29) +
popcount64 (a30) + popcount64 (a31));
}
inline float compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15) +
popcount64 (b[16] & a16) + popcount64 (b[17] & a17) +
popcount64 (b[18] & a18) + popcount64 (b[19] & a19) +
popcount64 (b[20] & a20) + popcount64 (b[21] & a21) +
popcount64 (b[22] & a22) + popcount64 (b[23] & a23) +
popcount64 (b[24] & a24) + popcount64 (b[25] & a25) +
popcount64 (b[26] & a26) + popcount64 (b[27] & a27) +
popcount64 (b[28] & a28) + popcount64 (b[29] & a29) +
popcount64 (b[30] & a30) + popcount64 (b[31] & a31);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
inline float compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] &&
(a2 & b[2]) == b[2] && (a3 & b[3]) == b[3] &&
(a4 & b[4]) == b[4] && (a5 & b[5]) == b[5] &&
(a6 & b[6]) == b[6] && (a7 & b[7]) == b[7] &&
(a8 & b[8]) == b[8] && (a9 & b[9]) == b[9] &&
(a10 & b[10]) == b[10] && (a11 & b[11]) == b[11] &&
(a12 & b[12]) == b[12] && (a13 & b[13]) == b[13] &&
(a14 & b[14]) == b[14] && (a15 & b[15]) == b[15] &&
(a16 & b[16]) == b[16] && (a17 & b[17]) == b[17] &&
(a18 & b[18]) == b[18] && (a19 & b[19]) == b[19] &&
(a20 & b[20]) == b[20] && (a21 & b[21]) == b[21] &&
(a22 & b[22]) == b[22] && (a23 & b[23]) == b[23] &&
(a24 & b[24]) == b[24] && (a25 & b[25]) == b[25] &&
(a26 & b[26]) == b[26] && (a27 & b[27]) == b[27] &&
(a28 & b[28]) == b[28] && (a29 & b[29]) == b[29] &&
(a30 & b[30]) == b[30] && (a31 & b[31]) == b[31];
}
};
......@@ -238,7 +182,6 @@ namespace faiss {
a40,a41,a42,a43,a44,a45,a46,a47,
a48,a49,a50,a51,a52,a53,a54,a55,
a56,a57,a58,a59,a60,a61,a62,a63;
float accu_den;
SuperstructureComputer512 () {}
......@@ -265,85 +208,49 @@ namespace faiss {
a52 = a[52]; a53 = a[53]; a54 = a[54]; a55 = a[55];
a56 = a[56]; a57 = a[57]; a58 = a[58]; a59 = a[59];
a60 = a[60]; a61 = a[61]; a62 = a[62]; a63 = a[63];
accu_den = (float)(popcount64 (a0) + popcount64 (a1) +
popcount64 (a2) + popcount64 (a3) +
popcount64 (a4) + popcount64 (a5) +
popcount64 (a6) + popcount64 (a7) +
popcount64 (a8) + popcount64 (a9) +
popcount64 (a10) + popcount64 (a11) +
popcount64 (a12) + popcount64 (a13) +
popcount64 (a14) + popcount64 (a15) +
popcount64 (a16) + popcount64 (a17) +
popcount64 (a18) + popcount64 (a19) +
popcount64 (a20) + popcount64 (a21) +
popcount64 (a22) + popcount64 (a23) +
popcount64 (a24) + popcount64 (a25) +
popcount64 (a26) + popcount64 (a27) +
popcount64 (a28) + popcount64 (a29) +
popcount64 (a30) + popcount64 (a31) +
popcount64 (a32) + popcount64 (a33) +
popcount64 (a34) + popcount64 (a35) +
popcount64 (a36) + popcount64 (a37) +
popcount64 (a38) + popcount64 (a39) +
popcount64 (a40) + popcount64 (a41) +
popcount64 (a42) + popcount64 (a43) +
popcount64 (a44) + popcount64 (a45) +
popcount64 (a46) + popcount64 (a47) +
popcount64 (a48) + popcount64 (a49) +
popcount64 (a50) + popcount64 (a51) +
popcount64 (a52) + popcount64 (a53) +
popcount64 (a54) + popcount64 (a55) +
popcount64 (a56) + popcount64 (a57) +
popcount64 (a58) + popcount64 (a59) +
popcount64 (a60) + popcount64 (a61) +
popcount64 (a62) + popcount64 (a63));
}
inline float compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15) +
popcount64 (b[16] & a16) + popcount64 (b[17] & a17) +
popcount64 (b[18] & a18) + popcount64 (b[19] & a19) +
popcount64 (b[20] & a20) + popcount64 (b[21] & a21) +
popcount64 (b[22] & a22) + popcount64 (b[23] & a23) +
popcount64 (b[24] & a24) + popcount64 (b[25] & a25) +
popcount64 (b[26] & a26) + popcount64 (b[27] & a27) +
popcount64 (b[28] & a28) + popcount64 (b[29] & a29) +
popcount64 (b[30] & a30) + popcount64 (b[31] & a31) +
popcount64 (b[32] & a32) + popcount64 (b[33] & a33) +
popcount64 (b[34] & a34) + popcount64 (b[35] & a35) +
popcount64 (b[36] & a36) + popcount64 (b[37] & a37) +
popcount64 (b[38] & a38) + popcount64 (b[39] & a39) +
popcount64 (b[40] & a40) + popcount64 (b[41] & a41) +
popcount64 (b[42] & a42) + popcount64 (b[43] & a43) +
popcount64 (b[44] & a44) + popcount64 (b[45] & a45) +
popcount64 (b[46] & a46) + popcount64 (b[47] & a47) +
popcount64 (b[48] & a48) + popcount64 (b[49] & a49) +
popcount64 (b[50] & a50) + popcount64 (b[51] & a51) +
popcount64 (b[52] & a52) + popcount64 (b[53] & a53) +
popcount64 (b[54] & a54) + popcount64 (b[55] & a55) +
popcount64 (b[56] & a56) + popcount64 (b[57] & a57) +
popcount64 (b[58] & a58) + popcount64 (b[59] & a59) +
popcount64 (b[60] & a60) + popcount64 (b[61] & a61) +
popcount64 (b[62] & a62) + popcount64 (b[63] & a63);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
}
inline bool compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] &&
(a2 & b[2]) == b[2] && (a3 & b[3]) == b[3] &&
(a4 & b[4]) == b[4] && (a5 & b[5]) == b[5] &&
(a6 & b[6]) == b[6] && (a7 & b[7]) == b[7] &&
(a8 & b[8]) == b[8] && (a9 & b[9]) == b[9] &&
(a10 & b[10]) == b[10] && (a11 & b[11]) == b[11] &&
(a12 & b[12]) == b[12] && (a13 & b[13]) == b[13] &&
(a14 & b[14]) == b[14] && (a15 & b[15]) == b[15] &&
(a16 & b[16]) == b[16] && (a17 & b[17]) == b[17] &&
(a18 & b[18]) == b[18] && (a19 & b[19]) == b[19] &&
(a20 & b[20]) == b[20] && (a21 & b[21]) == b[21] &&
(a22 & b[22]) == b[22] && (a23 & b[23]) == b[23] &&
(a24 & b[24]) == b[24] && (a25 & b[25]) == b[25] &&
(a26 & b[26]) == b[26] && (a27 & b[27]) == b[27] &&
(a28 & b[28]) == b[28] && (a29 & b[29]) == b[29] &&
(a30 & b[30]) == b[30] && (a31 & b[31]) == b[31] &&
(a32 & b[32]) == b[32] && (a33 & b[33]) == b[33] &&
(a34 & b[34]) == b[34] && (a35 & b[35]) == b[35] &&
(a36 & b[36]) == b[36] && (a37 & b[37]) == b[37] &&
(a38 & b[38]) == b[38] && (a39 & b[39]) == b[39] &&
(a40 & b[40]) == b[40] && (a41 & b[41]) == b[41] &&
(a42 & b[42]) == b[42] && (a43 & b[43]) == b[43] &&
(a44 & b[44]) == b[44] && (a45 & b[45]) == b[45] &&
(a46 & b[46]) == b[46] && (a47 & b[47]) == b[47] &&
(a48 & b[48]) == b[48] && (a49 & b[49]) == b[49] &&
(a50 & b[50]) == b[50] && (a51 & b[51]) == b[51] &&
(a52 & b[52]) == b[52] && (a53 & b[53]) == b[53] &&
(a54 & b[54]) == b[54] && (a55 & b[55]) == b[55] &&
(a56 & b[56]) == b[56] && (a57 & b[57]) == b[57] &&
(a58 & b[58]) == b[58] && (a59 & b[59]) == b[59] &&
(a60 & b[60]) == b[60] && (a61 & b[61]) == b[61] &&
(a62 & b[62]) == b[62] && (a63 & b[63]) == b[63];
}
};
struct SuperstructureComputerDefault {
const uint8_t *a;
int n;
float accu_den;
SuperstructureComputerDefault () {}
......@@ -354,21 +261,16 @@ namespace faiss {
void set (const uint8_t *a8, int code_size) {
a = a8;
n = code_size;
int i_accu_den = 0;
for (int i = 0; i < n; i++) {
i_accu_den += popcount64(a[i]);
}
accu_den = (float)i_accu_den;
}
float compute (const uint8_t *b8) const {
int accu_num = 0;
bool compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
for (int i = 0; i < n; i++) {
accu_num += popcount64(a[i] & b8[i]);
if ((a[i] & b[i]) != b[i]) {
return false;
}
}
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
return true;
}
};
......
import pdb
import copy
import struct
from random import sample
......@@ -675,7 +674,36 @@ class TestSearchBase:
status, result = connect.search_vectors(substructure_collection, top_k, query_vecs, params=search_param)
logging.getLogger().info(status)
logging.getLogger().info(result)
assert abs(result[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
assert result[0][0].id == -1
def test_search_distance_substructure_flat_index_B(self, connect, substructure_collection):
'''
target: search ip_collection, and check the result: distance
method: compare the return distance value with value computed with SUB
expected: the return distance equals to the computed value
'''
# from scipy.spatial import distance
top_k = 3
nprobe = 512
int_vectors, vectors, ids = self.init_binary_data(connect, substructure_collection, nb=2)
index_type = IndexType.FLAT
index_param = {
"nlist": 16384
}
connect.create_index(substructure_collection, index_type, index_param)
logging.getLogger().info(connect.describe_collection(substructure_collection))
logging.getLogger().info(connect.describe_index(substructure_collection))
query_int_vectors, query_vecs = gen_binary_sub_vectors(int_vectors, 2)
search_param = get_search_param(index_type)
status, result = connect.search_vectors(substructure_collection, top_k, query_vecs, params=search_param)
logging.getLogger().info(status)
logging.getLogger().info(result)
assert result[0][0].distance <= epsilon
assert result[0][0].id == ids[0]
assert result[1][0].distance <= epsilon
assert result[1][0].id == ids[1]
assert result[0][1].id == -1
assert result[1][1].id == -1
def test_search_distance_superstructure_flat_index(self, connect, superstructure_collection):
'''
......@@ -701,7 +729,36 @@ class TestSearchBase:
status, result = connect.search_vectors(superstructure_collection, top_k, query_vecs, params=search_param)
logging.getLogger().info(status)
logging.getLogger().info(result)
assert abs(result[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
assert result[0][0].id == -1
def test_search_distance_superstructure_flat_index_B(self, connect, superstructure_collection):
'''
target: search ip_collection, and check the result: distance
method: compare the return distance value with value computed with SUPER
expected: the return distance equals to the computed value
'''
# from scipy.spatial import distance
top_k = 3
nprobe = 512
int_vectors, vectors, ids = self.init_binary_data(connect, superstructure_collection, nb=2)
index_type = IndexType.FLAT
index_param = {
"nlist": 16384
}
connect.create_index(superstructure_collection, index_type, index_param)
logging.getLogger().info(connect.describe_collection(superstructure_collection))
logging.getLogger().info(connect.describe_index(superstructure_collection))
query_int_vectors, query_vecs = gen_binary_super_vectors(int_vectors, 2)
search_param = get_search_param(index_type)
status, result = connect.search_vectors(superstructure_collection, top_k, query_vecs, params=search_param)
logging.getLogger().info(status)
logging.getLogger().info(result)
assert result[0][0].id in ids
assert result[0][0].distance <= epsilon
assert result[1][0].id in ids
assert result[1][0].distance <= epsilon
assert result[0][2].id == -1
assert result[1][2].id == -1
def test_search_distance_tanimoto_flat_index(self, connect, tanimoto_collection):
'''
......
......@@ -67,6 +67,33 @@ def superstructure(x, y):
return 1 - np.double(np.bitwise_and(x, y).sum()) / np.count_nonzero(x)
def gen_binary_sub_vectors(vectors, length):
raw_vectors = []
binary_vectors = []
dim = len(vectors[0])
for i in range(length):
raw_vector = [0 for i in range(dim)]
vector = vectors[i]
for index, j in enumerate(vector):
if j == 1:
raw_vector[index] = 1
raw_vectors.append(raw_vector)
binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
return raw_vectors, binary_vectors
def gen_binary_super_vectors(vectors, length):
raw_vectors = []
binary_vectors = []
dim = len(vectors[0])
for i in range(length):
cnt_1 = np.count_nonzero(vectors[i])
raw_vector = [1 for i in range(dim)]
raw_vectors.append(raw_vector)
binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
return raw_vectors, binary_vectors
def gen_single_vector(dim):
return [[random.random() for _ in range(dim)]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册