未验证 提交 4088f5e9 编写于 作者: S shengjun.li 提交者: GitHub

#1603 BinaryFlat add 2 Metrics: Substructure and Superstructure (#1647)

* add substructure & superstructure
Signed-off-by: Nshengjun.li <shengjun.li@zilliz.com>

* support c++sdk by lin.xiaojun
Signed-off-by: Nshengjun.li <shengjun.li@zilliz.com>

* optimizer omp for BinaryFlat; add tutorial
Signed-off-by: Nshengjun.li <shengjun.li@zilliz.com>
上级 b1f0e51d
......@@ -10,6 +10,7 @@ Please mark all change in change log and use the issue from GitHub
- \#1635 Vectors can be returned by searching after vectors deleted if `cache_insert_data` set true
## Feature
- \#1603 BinaryFlat add 2 Metric: Substructure and Superstructure
## Improvement
- \#1537 Optimize raw vector and uids read/write
......
......@@ -12,6 +12,7 @@
#include "db/Utils.h"
#include <fiu-local.h>
#include <boost/filesystem.hpp>
#include <chrono>
#include <mutex>
......@@ -230,6 +231,8 @@ bool
IsBinaryMetricType(int32_t metric_type) {
return (metric_type == (int32_t)engine::MetricType::HAMMING) ||
(metric_type == (int32_t)engine::MetricType::JACCARD) ||
(metric_type == (int32_t)engine::MetricType::SUBSTRUCTURE) ||
(metric_type == (int32_t)engine::MetricType::SUPERSTRUCTURE) ||
(metric_type == (int32_t)engine::MetricType::TANIMOTO);
}
......
......@@ -39,12 +39,14 @@ enum class EngineType {
};
enum class MetricType {
L2 = 1, // Euclidean Distance
IP = 2, // Cosine Similarity
HAMMING = 3, // Hamming Distance
JACCARD = 4, // Jaccard Distance
TANIMOTO = 5, // Tanimoto Distance
MAX_VALUE = TANIMOTO,
L2 = 1, // Euclidean Distance
IP = 2, // Cosine Similarity
HAMMING = 3, // Hamming Distance
JACCARD = 4, // Jaccard Distance
TANIMOTO = 5, // Tanimoto Distance
SUBSTRUCTURE = 6, // Substructure Distance
SUPERSTRUCTURE = 7, // Superstructure Distance
MAX_VALUE = SUPERSTRUCTURE
};
class ExecutionEngine {
......
......@@ -60,6 +60,12 @@ MappingMetricType(MetricType metric_type, milvus::json& conf) {
case MetricType::TANIMOTO:
conf[knowhere::Metric::TYPE] = knowhere::Metric::TANIMOTO;
break;
case MetricType::SUBSTRUCTURE:
conf[knowhere::Metric::TYPE] = knowhere::Metric::SUBSTRUCTURE;
break;
case MetricType::SUPERSTRUCTURE:
conf[knowhere::Metric::TYPE] = knowhere::Metric::SUPERSTRUCTURE;
break;
default:
return Status(DB_ERROR, "Unsupported metric type");
}
......
......@@ -33,6 +33,13 @@ GetMetricType(const std::string& type) {
if (type == Metric::HAMMING) {
return faiss::METRIC_Hamming;
}
if (type == Metric::SUBSTRUCTURE) {
return faiss::METRIC_Substructure;
}
if (type == Metric::SUPERSTRUCTURE) {
return faiss::METRIC_Superstructure;
}
KNOWHERE_THROW_MSG("Metric type is invalid");
}
......
......@@ -52,6 +52,8 @@ constexpr const char* L2 = "L2";
constexpr const char* HAMMING = "HAMMING";
constexpr const char* JACCARD = "JACCARD";
constexpr const char* TANIMOTO = "TANIMOTO";
constexpr const char* SUBSTRUCTURE = "SUBSTRUCTURE";
constexpr const char* SUPERSTRUCTURE = "SUPERSTRUCTURE";
} // namespace Metric
extern faiss::MetricType
......
......@@ -52,6 +52,8 @@ enum MetricType {
METRIC_Jaccard,
METRIC_Tanimoto,
METRIC_Hamming,
METRIC_Substructure, ///< Tversky case alpha = 0, beta = 1
METRIC_Superstructure, ///< Tversky case alpha = 1, beta = 0
/// some additional metrics defined in scipy.spatial.distance
METRIC_Canberra = 20,
......
......@@ -13,8 +13,8 @@
#include <cmath>
#include <cstring>
#include <faiss/utils/BinaryDistance.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/jaccard.h>
#include <faiss/utils/utils.h>
#include <faiss/utils/Heap.h>
#include <faiss/impl/FaissAssert.h>
......@@ -41,8 +41,9 @@ void IndexBinaryFlat::reset() {
void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k,
int32_t *distances, idx_t *labels, ConcurrentBitsetPtr bitset) const {
const idx_t block_size = query_batch_size;
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
float *D = new float[k * n];
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto ||
metric_type == METRIC_Substructure || metric_type == METRIC_Superstructure) {
float *D = reinterpret_cast<float*>(distances);
for (idx_t s = 0; s < n; s += block_size) {
idx_t nn = block_size;
if (s + block_size > n) {
......@@ -56,7 +57,7 @@ void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k,
size_t(nn), size_t(k), labels + s * k, D + s * k
};
jaccard_knn_hc(&res, x + s * code_size, xb.data(), ntotal, code_size,
binary_distence_knn_hc(metric_type, &res, x + s * code_size, xb.data(), ntotal, code_size,
/* ordered = */ true, bitset);
} else {
......@@ -68,8 +69,7 @@ void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k,
D[i] = -log2(1-D[i]);
}
}
memcpy(distances, D, sizeof(float) * n * k);
delete [] D;
} else {
for (idx_t s = 0; s < n; s += block_size) {
idx_t nn = block_size;
......
......@@ -16,8 +16,8 @@
#include <memory>
#include <cmath>
#include <faiss/utils/BinaryDistance.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/jaccard.h>
#include <faiss/utils/utils.h>
#include <faiss/utils/Heap.h>
......@@ -326,6 +326,9 @@ void IndexBinaryIVF::train(idx_t n, const uint8_t *x) {
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
index_tmp = IndexFlat(d, METRIC_Jaccard);
} else if (metric_type == METRIC_Substructure || metric_type == METRIC_Superstructure) {
// unsupported
FAISS_THROW_MSG("IVF not to support Substructure and Superstructure.");
} else {
index_tmp = IndexFlat(d, METRIC_L2);
}
......@@ -431,10 +434,10 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
};
template<class JaccardComputer, bool store_pairs>
template<class DistanceComputer, bool store_pairs>
struct IVFBinaryScannerJaccard: BinaryInvertedListScanner {
JaccardComputer hc;
DistanceComputer hc;
size_t code_size;
IVFBinaryScannerJaccard (size_t code_size): code_size (code_size)
......@@ -464,7 +467,7 @@ struct IVFBinaryScannerJaccard: BinaryInvertedListScanner {
size_t nup = 0;
for (size_t j = 0; j < n; j++) {
if(!bitset || !bitset->test(ids[j])){
float dis = hc.jaccard (codes);
float dis = hc.compute (codes);
if (dis < psimi[0]) {
heap_pop<C> (k, psimi, idxi);
......@@ -518,6 +521,8 @@ BinaryInvertedListScanner *select_IVFBinaryScannerJaccard (size_t code_size) {
HANDLE_CS(32)
HANDLE_CS(64)
HANDLE_CS(128)
HANDLE_CS(256)
HANDLE_CS(512)
#undef HANDLE_CS
default:
return new IVFBinaryScannerJaccard<JaccardComputerDefault,
......@@ -525,7 +530,6 @@ BinaryInvertedListScanner *select_IVFBinaryScannerJaccard (size_t code_size) {
}
}
void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
size_t n,
const uint8_t *x,
......@@ -619,16 +623,17 @@ void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
}
void search_knn_jaccard_heap(const IndexBinaryIVF& ivf,
size_t n,
const uint8_t *x,
idx_t k,
const idx_t *keys,
const float * coarse_dis,
float *distances, idx_t *labels,
bool store_pairs,
const IVFSearchParameters *params,
ConcurrentBitsetPtr bitset = nullptr)
void search_knn_binary_dis_heap(const IndexBinaryIVF& ivf,
size_t n,
const uint8_t *x,
idx_t k,
const idx_t *keys,
const float * coarse_dis,
float *distances,
idx_t *labels,
bool store_pairs,
const IVFSearchParameters *params,
ConcurrentBitsetPtr bitset = nullptr)
{
long nprobe = params ? params->nprobe : ivf.nprobe;
long max_codes = params ? params->max_codes : ivf.max_codes;
......@@ -642,7 +647,7 @@ void search_knn_jaccard_heap(const IndexBinaryIVF& ivf,
#pragma omp parallel if(n > 1) reduction(+: nlistv, ndis, nheap)
{
std::unique_ptr<BinaryInvertedListScanner> scanner
(ivf.get_InvertedListScannerJaccard (store_pairs));
(ivf.get_InvertedListScanner(store_pairs));
#pragma omp for
for (size_t i = 0; i < n; i++) {
......@@ -843,20 +848,24 @@ void search_knn_hamming_count_1 (
BinaryInvertedListScanner *IndexBinaryIVF::get_InvertedListScanner
(bool store_pairs) const
{
if (store_pairs) {
return select_IVFBinaryScannerL2<true> (code_size);
} else {
return select_IVFBinaryScannerL2<false> (code_size);
}
}
BinaryInvertedListScanner *IndexBinaryIVF::get_InvertedListScannerJaccard
(bool store_pairs) const
{
if (store_pairs) {
return select_IVFBinaryScannerJaccard<true> (code_size);
} else {
return select_IVFBinaryScannerJaccard<false> (code_size);
switch (metric_type) {
case METRIC_Jaccard:
case METRIC_Tanimoto:
if (store_pairs) {
return select_IVFBinaryScannerJaccard<true> (code_size);
} else {
return select_IVFBinaryScannerJaccard<false> (code_size);
}
case METRIC_Substructure:
case METRIC_Superstructure:
// unsupported
return nullptr;
default:
if (store_pairs) {
return select_IVFBinaryScannerL2<true>(code_size);
} else {
return select_IVFBinaryScannerL2<false>(code_size);
}
}
}
......@@ -874,9 +883,9 @@ void IndexBinaryIVF::search_preassigned(idx_t n, const uint8_t *x, idx_t k,
float *D = new float[k * n];
float *c_dis = new float [n * nprobe];
memcpy(c_dis, coarse_dis, sizeof(float) * n * nprobe);
search_knn_jaccard_heap (*this, n, x, k, idx, c_dis ,
D, labels, store_pairs,
params, bitset);
search_knn_binary_dis_heap(*this, n, x, k, idx, c_dis ,
D, labels, store_pairs,
params, bitset);
if (metric_type == METRIC_Tanimoto) {
for (int i = 0; i < k * n; i++) {
D[i] = -log2(1-D[i]);
......@@ -888,6 +897,8 @@ void IndexBinaryIVF::search_preassigned(idx_t n, const uint8_t *x, idx_t k,
} else {
//not implemented
}
} else if (metric_type == METRIC_Substructure || metric_type == METRIC_Superstructure) {
// unsupported
} else {
if (use_heap) {
search_knn_hamming_heap (*this, n, x, k, idx, coarse_dis,
......
......@@ -112,9 +112,6 @@ struct IndexBinaryIVF : IndexBinary {
virtual BinaryInvertedListScanner *get_InvertedListScanner (
bool store_pairs=false) const;
virtual BinaryInvertedListScanner *get_InvertedListScannerJaccard (
bool store_pairs=false) const;
/** assign the vectors, then call search_preassign */
void search(idx_t n, const uint8_t *x, idx_t k, int32_t *distances, idx_t *labels,
ConcurrentBitsetPtr bitset = nullptr) const override;
......
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cstdio>
#include <cstdlib>
#include <faiss/IndexBinaryFlat.h>
#include <sys/time.h>
#include <unistd.h>
// #define TEST_HAMMING
long int getTime(timeval end, timeval start) {
return 1000*(end.tv_sec - start.tv_sec) + (end.tv_usec - start.tv_usec)/1000;
}
int main() {
// freopen("0.txt", "w", stdout);
size_t d = 128; // dimension
size_t nb = 40000000; // database size
size_t nq = 10; // nb of queries
uint8_t *xb = new uint8_t[d * nb / sizeof(uint8_t)];
uint8_t *xq = new uint8_t[d * nq / sizeof(uint8_t)];
// skip 0
lrand48();
size_t size_to_long = d * nb / sizeof(int32_t);
for(size_t i = 0; i < size_to_long; i++) {
((int32_t*)xb)[i] = lrand48();
}
size_to_long = d * nq / sizeof(long int);
for(size_t i = 0; i < size_to_long; i++) {
((int32_t*)xq)[i] = lrand48();
}
#ifdef TEST_HAMMING
printf("test haming\n");
faiss::IndexBinaryFlat index(d, faiss::MetricType::METRIC_Hamming);
#else
faiss::IndexBinaryFlat index(d, faiss::MetricType::METRIC_Jaccard);
#endif
index.add(nb, xb);
printf("ntotal = %ld d = %d\n", index.ntotal, index.d);
int k = 10;
#if 0
{ // sanity check: search 5 first vectors of xb
int64_t *I = new int64_t[k * 5];
int32_t *D = new int32_t[k * 5];
float *d_float = reinterpret_cast<float*>(D);
index.search(5, xb, k, D, I);
// print results
for(int i = 0; i < 5; i++) {
for(int j = 0; j < k; j++)
#ifdef TEST_HAMMING
printf("%8ld %d\n", I[i * k + j], D[i * k + j]);
#else
printf("%8ld %.08f\n", I[i * k + j], d_float[i * k + j]);
#endif
printf("\n");
}
delete [] I;
delete [] D;
}
#endif
{ // search xq
int64_t *I = new int64_t[k * nq];
int32_t *D = new int32_t[k * nq];
float *d_float = reinterpret_cast<float*>(D);
for (int loop = 1; loop <= nq; loop ++) {
timeval t0;
gettimeofday(&t0, 0);
index.search(loop, xq, k, D, I);
timeval t1;
gettimeofday(&t1, 0);
printf("search nq %d time %ldms\n", loop, getTime(t1,t0));
#if 0
for (int i = 0; i < loop; i++) {
for(int j = 0; j < k; j++)
#ifdef TEST_HAMMING
printf("%8ld %d\n", I[i * k + j], D[i * k + j]);
#else
printf("%8ld %.08f\n", I[j + i * k], d_float[j + i * k]);
#endif
printf("\n");
}
#endif
}
delete [] I;
delete [] D;
}
delete [] xb;
delete [] xq;
return 0;
}
......@@ -5,7 +5,7 @@
-include ../../makefile.inc
CPU_TARGETS = 1-Flat 2-IVFFlat 3-IVFPQ
CPU_TARGETS = 1-Flat 2-IVFFlat 3-IVFPQ 9-BinaryFlat
GPU_TARGETS = 6-RUN 7-GPU 8-GPU
default: cpu
......@@ -17,7 +17,7 @@ cpu: $(CPU_TARGETS)
gpu: $(GPU_TARGETS)
%: %.cpp ../../libfaiss.a
$(CXX) $(CXXFLAGS) $(CPPFLAGS) -o $@ $^ $(LDFLAGS) -I../../include $(LIBS)
$(CXX) $(CXXFLAGS) $(CPPFLAGS) -o $@ $^ $(LDFLAGS) -I../../ $(LIBS)
clean:
rm -f $(CPU_TARGETS) $(GPU_TARGETS)
......
#include <faiss/utils/BinaryDistance.h>
#include <vector>
#include <memory>
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <assert.h>
#include <limits.h>
#include <omp.h>
#include <faiss/utils/Heap.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/utils.h>
namespace faiss {
size_t batch_size = 65536;
template <class T>
static
void binary_distence_knn_hc(
int bytes_per_code,
float_maxheap_array_t * ha,
const uint8_t * bs1,
const uint8_t * bs2,
size_t n2,
bool order = true,
bool init_heap = true,
ConcurrentBitsetPtr bitset = nullptr)
{
size_t k = ha->k;
if (init_heap) ha->heapify ();
int thread_max_num = omp_get_max_threads();
if (ha->nh < thread_max_num) {
// omp for n2
int all_hash_size = thread_max_num * k;
float *value = new float[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
for (int i = 0; i < ha->nh; i++) {
T hc (bs1 + i * bytes_per_code, bytes_per_code);
// init hash
for (int i = 0; i < all_hash_size; i++) {
value[i] = 1.0 / 0.0;
}
#pragma omp parallel for
for (size_t j = 0; j < n2; j++) {
if(!bitset || !bitset->test(j)) {
const uint8_t * bs2_ = bs2 + j * bytes_per_code;
tadis_t dis = hc.compute (bs2_);
int thread_no = omp_get_thread_num();
float * __restrict val_ = value + thread_no * k;
int64_t * __restrict ids_ = labels + thread_no * k;
if (dis < val_[0]) {
faiss::maxheap_pop<tadis_t> (k, val_, ids_);
faiss::maxheap_push<tadis_t> (k, val_, ids_, dis, j);
}
}
}
// merge hash
tadis_t * __restrict bh_val_ = ha->val + i * k;
int64_t * __restrict bh_ids_ = ha->ids + i * k;
for (int i = 0; i < all_hash_size; i++) {
if (value[i] < bh_val_[0]) {
faiss::maxheap_pop<tadis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<tadis_t> (k, bh_val_, bh_ids_, value[i], labels[i]);
}
}
}
delete[] value;
delete[] labels;
} else {
const size_t block_size = batch_size;
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
const size_t j1 = std::min(j0 + block_size, n2);
#pragma omp parallel for
for (size_t i = 0; i < ha->nh; i++) {
T hc (bs1 + i * bytes_per_code, bytes_per_code);
const uint8_t * bs2_ = bs2 + j0 * bytes_per_code;
tadis_t dis;
tadis_t * __restrict bh_val_ = ha->val + i * k;
int64_t * __restrict bh_ids_ = ha->ids + i * k;
size_t j;
for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
if(!bitset || !bitset->test(j)){
dis = hc.compute (bs2_);
if (dis < bh_val_[0]) {
faiss::maxheap_pop<tadis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<tadis_t> (k, bh_val_, bh_ids_, dis, j);
}
}
}
}
}
}
if (order) ha->reorder ();
}
void binary_distence_knn_hc (
MetricType metric_type,
float_maxheap_array_t * ha,
const uint8_t * a,
const uint8_t * b,
size_t nb,
size_t ncodes,
int order,
ConcurrentBitsetPtr bitset)
{
switch (metric_type) {
case METRIC_Jaccard:
case METRIC_Tanimoto:
switch (ncodes) {
#define binary_distence_knn_hc_jaccard(ncodes) \
case ncodes: \
binary_distence_knn_hc<faiss::JaccardComputer ## ncodes> \
(ncodes, ha, a, b, nb, order, true, bitset); \
break;
binary_distence_knn_hc_jaccard(8);
binary_distence_knn_hc_jaccard(16);
binary_distence_knn_hc_jaccard(32);
binary_distence_knn_hc_jaccard(64);
binary_distence_knn_hc_jaccard(128);
binary_distence_knn_hc_jaccard(256);
binary_distence_knn_hc_jaccard(512);
#undef binary_distence_knn_hc_jaccard
default:
binary_distence_knn_hc<faiss::JaccardComputerDefault>
(ncodes, ha, a, b, nb, order, true, bitset);
break;
}
break;
case METRIC_Substructure:
switch (ncodes) {
#define binary_distence_knn_hc_Substructure(ncodes) \
case ncodes: \
binary_distence_knn_hc<faiss::SubstructureComputer ## ncodes> \
(ncodes, ha, a, b, nb, order, true, bitset); \
break;
binary_distence_knn_hc_Substructure(8);
binary_distence_knn_hc_Substructure(16);
binary_distence_knn_hc_Substructure(32);
binary_distence_knn_hc_Substructure(64);
binary_distence_knn_hc_Substructure(128);
binary_distence_knn_hc_Substructure(256);
binary_distence_knn_hc_Substructure(512);
#undef binary_distence_knn_hc_Substructure
default:
binary_distence_knn_hc<faiss::SubstructureComputerDefault>
(ncodes, ha, a, b, nb, order, true, bitset);
break;
}
break;
case METRIC_Superstructure:
switch (ncodes) {
#define binary_distence_knn_hc_Superstructure(ncodes) \
case ncodes: \
binary_distence_knn_hc<faiss::SuperstructureComputer ## ncodes> \
(ncodes, ha, a, b, nb, order, true, bitset); \
break;
binary_distence_knn_hc_Superstructure(8);
binary_distence_knn_hc_Superstructure(16);
binary_distence_knn_hc_Superstructure(32);
binary_distence_knn_hc_Superstructure(64);
binary_distence_knn_hc_Superstructure(128);
binary_distence_knn_hc_Superstructure(256);
binary_distence_knn_hc_Superstructure(512);
#undef binary_distence_knn_hc_Superstructure
default:
binary_distence_knn_hc<faiss::SuperstructureComputerDefault>
(ncodes, ha, a, b, nb, order, true, bitset);
break;
}
break;
default:
break;
}
}
}
#ifndef FAISS_JACCARD_H
#define FAISS_JACCARD_H
#ifndef FAISS_BINARY_DISTANCE_H
#define FAISS_BINARY_DISTANCE_H
#include "faiss/Index.h"
#include <faiss/utils/hamming.h>
......@@ -7,14 +9,12 @@
#include <faiss/utils/Heap.h>
/* The Jaccard distance type */
/* The binary distance type */
typedef float tadis_t;
namespace faiss {
extern size_t jaccard_batch_size;
/** Return the k smallest Jaccard distances for a set of binary query vectors,
/** Return the k smallest distances for a set of binary query vectors,
* using a max heap.
* @param a queries, size ha->nh * ncodes
* @param b database, size nb * ncodes
......@@ -22,7 +22,8 @@ namespace faiss {
* @param ncodes size of the binary codes (bytes)
* @param ordered if != 0: order the results by decreasing distance
* (may be bottleneck for k/n > 0.01) */
void jaccard_knn_hc (
void binary_distence_knn_hc (
MetricType metric_type,
float_maxheap_array_t * ha,
const uint8_t * a,
const uint8_t * b,
......@@ -31,8 +32,10 @@ namespace faiss {
int ordered,
ConcurrentBitsetPtr bitset = nullptr);
} //namespace faiss
} // namespace faiss
#include <faiss/utils/jaccard-inl.h>
#include <faiss/utils/substructure-inl.h>
#include <faiss/utils/superstructure-inl.h>
#endif //FAISS_JACCARD_H
#endif // FAISS_BINARY_DISTANCE_H
......@@ -33,6 +33,7 @@
#include <math.h>
#include <assert.h>
#include <limits.h>
#include <omp.h>
#include <faiss/utils/Heap.h>
#include <faiss/impl/FaissAssert.h>
......@@ -279,28 +280,71 @@ void hammings_knn_hc (
size_t k = ha->k;
if (init_heap) ha->heapify ();
const size_t block_size = hamming_batch_size;
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
const size_t j1 = std::min(j0 + block_size, n2);
int thread_max_num = omp_get_max_threads();
if (ha->nh < thread_max_num) {
// omp for n2
int all_hash_size = thread_max_num * k;
hamdis_t *value = new hamdis_t[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
for (int i = 0; i < ha->nh; i++) {
HammingComputer hc (bs1 + i * bytes_per_code, bytes_per_code);
// init hash
for (int i = 0; i < all_hash_size; i++) {
value[i] = 0x7fffffff;
}
#pragma omp parallel for
for (size_t i = 0; i < ha->nh; i++) {
HammingComputer hc (bs1 + i * bytes_per_code, bytes_per_code);
const uint8_t * bs2_ = bs2 + j0 * bytes_per_code;
hamdis_t dis;
hamdis_t * __restrict bh_val_ = ha->val + i * k;
int64_t * __restrict bh_ids_ = ha->ids + i * k;
size_t j;
for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
if(!bitset || !bitset->test(j)){
dis = hc.hamming (bs2_);
if (dis < bh_val_[0]) {
for (size_t j = 0; j < n2; j++) {
if(!bitset || !bitset->test(j)) {
const uint8_t * bs2_ = bs2 + j * bytes_per_code;
hamdis_t dis = hc.hamming (bs2_);
int thread_no = omp_get_thread_num();
hamdis_t * __restrict val_ = value + thread_no * k;
int64_t * __restrict ids_ = labels + thread_no * k;
if (dis < val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, val_, ids_);
faiss::maxheap_push<hamdis_t> (k, val_, ids_, dis, j);
}
}
}
// merge hash
hamdis_t * __restrict bh_val_ = ha->val + i * k;
int64_t * __restrict bh_ids_ = ha->ids + i * k;
for (int i = 0; i < all_hash_size; i++) {
if (value[i] < bh_val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, value[i], labels[i]);
}
}
}
delete[] value;
delete[] labels;
} else {
const size_t block_size = hamming_batch_size;
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
const size_t j1 = std::min(j0 + block_size, n2);
#pragma omp parallel for
for (size_t i = 0; i < ha->nh; i++) {
HammingComputer hc (bs1 + i * bytes_per_code, bytes_per_code);
const uint8_t * bs2_ = bs2 + j0 * bytes_per_code;
hamdis_t dis;
hamdis_t * __restrict bh_val_ = ha->val + i * k;
int64_t * __restrict bh_ids_ = ha->ids + i * k;
size_t j;
for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
if(!bitset || !bitset->test(j)){
dis = hc.hamming (bs2_);
if (dis < bh_val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
}
}
}
}
}
}
}
if (order) ha->reorder ();
}
......@@ -387,22 +431,64 @@ void hammings_knn_hc_1 (
ha->heapify ();
}
int thread_max_num = omp_get_max_threads();
if (ha->nh < thread_max_num) {
// omp for n2
int all_hash_size = thread_max_num * k;
hamdis_t *value = new hamdis_t[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
for (int i = 0; i < ha->nh; i++) {
// init hash
for (int i = 0; i < all_hash_size; i++) {
value[i] = 0x7fffffff;
}
const uint64_t bs1_ = bs1 [i];
#pragma omp parallel for
for (size_t i = 0; i < ha->nh; i++) {
const uint64_t bs1_ = bs1 [i];
const uint64_t * bs2_ = bs2;
hamdis_t dis;
hamdis_t * bh_val_ = ha->val + i * k;
hamdis_t bh_val_0 = bh_val_[0];
int64_t * bh_ids_ = ha->ids + i * k;
size_t j;
for (j = 0; j < n2; j++, bs2_+= nwords) {
if(!bitset || !bitset->test(j)){
dis = popcount64 (bs1_ ^ *bs2_);
if (dis < bh_val_0) {
for (size_t j = 0; j < n2; j++) {
if(!bitset || !bitset->test(j)) {
hamdis_t dis = popcount64 (bs1_ ^ bs2[j]);
int thread_no = omp_get_thread_num();
hamdis_t * __restrict val_ = value + thread_no * k;
int64_t * __restrict ids_ = labels + thread_no * k;
if (dis < val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, val_, ids_);
faiss::maxheap_push<hamdis_t> (k, val_, ids_, dis, j);
}
}
}
// merge hash
hamdis_t * __restrict bh_val_ = ha->val + i * k;
int64_t * __restrict bh_ids_ = ha->ids + i * k;
for (int i = 0; i < all_hash_size; i++) {
if (value[i] < bh_val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
bh_val_0 = bh_val_[0];
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, value[i], labels[i]);
}
}
}
delete[] value;
delete[] labels;
} else {
#pragma omp parallel for
for (size_t i = 0; i < ha->nh; i++) {
const uint64_t bs1_ = bs1 [i];
const uint64_t * bs2_ = bs2;
hamdis_t dis;
hamdis_t * bh_val_ = ha->val + i * k;
hamdis_t bh_val_0 = bh_val_[0];
int64_t * bh_ids_ = ha->ids + i * k;
size_t j;
for (j = 0; j < n2; j++, bs2_+= nwords) {
if(!bitset || !bitset->test(j)){
dis = popcount64 (bs1_ ^ *bs2_);
if (dis < bh_val_0) {
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
bh_val_0 = bh_val_[0];
}
}
}
}
......
......@@ -4,6 +4,32 @@
namespace faiss {
struct JaccardComputer8 {
uint64_t a0;
JaccardComputer8 () {}
JaccardComputer8 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
assert (code_size == 8);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0];
}
inline float compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0);
int accu_den = popcount64 (b[0] | a0);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
struct JaccardComputer16 {
uint64_t a0, a1;
......@@ -19,12 +45,10 @@ namespace faiss {
a0 = a[0]; a1 = a[1];
}
inline float jaccard (const uint8_t *b8) const {
inline float compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = 0;
int accu_den = 0;
accu_num += popcount64 (b[0] & a0) + popcount64 (b[1] & a1);
accu_den += popcount64 (b[0] | a0) + popcount64 (b[1] | a1);
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1);
int accu_den = popcount64 (b[0] | a0) + popcount64 (b[1] | a1);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
......@@ -47,14 +71,12 @@ namespace faiss {
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
}
inline float jaccard (const uint8_t *b8) const {
inline float compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = 0;
int accu_den = 0;
accu_num += popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3);
accu_den += popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
popcount64 (b[2] | a2) + popcount64 (b[3] | a3);
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3);
int accu_den = popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
popcount64 (b[2] | a2) + popcount64 (b[3] | a3);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
......@@ -78,18 +100,16 @@ namespace faiss {
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
}
inline float jaccard (const uint8_t *b8) const {
inline float compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = 0;
int accu_den = 0;
accu_num += popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7);
accu_den += popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
popcount64 (b[2] | a2) + popcount64 (b[3] | a3) +
popcount64 (b[4] | a4) + popcount64 (b[5] | a5) +
popcount64 (b[6] | a6) + popcount64 (b[7] | a7);
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7);
int accu_den = popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
popcount64 (b[2] | a2) + popcount64 (b[3] | a3) +
popcount64 (b[4] | a4) + popcount64 (b[5] | a5) +
popcount64 (b[6] | a6) + popcount64 (b[7] | a7);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
......@@ -107,35 +127,209 @@ namespace faiss {
set (a8, code_size);
}
void set (const uint8_t *a16, int code_size) {
assert (code_size == 128 );
const uint64_t *a = (uint64_t *)a16;
void set (const uint8_t *au8, int code_size) {
assert (code_size == 128);
const uint64_t *a = (uint64_t *)au8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11];
a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15];
}
inline float jaccard (const uint8_t *b16) const {
inline float compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = 0;
int accu_den = 0;
accu_num += popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15);
accu_den += popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
popcount64 (b[2] | a2) + popcount64 (b[3] | a3) +
popcount64 (b[4] | a4) + popcount64 (b[5] | a5) +
popcount64 (b[6] | a6) + popcount64 (b[7] | a7) +
popcount64 (b[8] | a8) + popcount64 (b[9] | a9) +
popcount64 (b[10] | a10) + popcount64 (b[11] | a11) +
popcount64 (b[12] | a12) + popcount64 (b[13] | a13) +
popcount64 (b[14] | a14) + popcount64 (b[15] | a15);
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15);
int accu_den = popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
popcount64 (b[2] | a2) + popcount64 (b[3] | a3) +
popcount64 (b[4] | a4) + popcount64 (b[5] | a5) +
popcount64 (b[6] | a6) + popcount64 (b[7] | a7) +
popcount64 (b[8] | a8) + popcount64 (b[9] | a9) +
popcount64 (b[10] | a10) + popcount64 (b[11] | a11) +
popcount64 (b[12] | a12) + popcount64 (b[13] | a13) +
popcount64 (b[14] | a14) + popcount64 (b[15] | a15);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
struct JaccardComputer256 {
uint64_t a0,a1,a2,a3,a4,a5,a6,a7,
a8,a9,a10,a11,a12,a13,a14,a15,
a16,a17,a18,a19,a20,a21,a22,a23,
a24,a25,a26,a27,a28,a29,a30,a31;
JaccardComputer256 () {}
JaccardComputer256 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *au8, int code_size) {
assert (code_size == 256);
const uint64_t *a = (uint64_t *)au8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11];
a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15];
a16 = a[16]; a17 = a[17]; a18 = a[18]; a19 = a[19];
a20 = a[20]; a21 = a[21]; a22 = a[22]; a23 = a[23];
a24 = a[24]; a25 = a[25]; a26 = a[26]; a27 = a[27];
a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31];
}
inline float compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15) +
popcount64 (b[16] & a16) + popcount64 (b[17] & a17) +
popcount64 (b[18] & a18) + popcount64 (b[19] & a19) +
popcount64 (b[20] & a20) + popcount64 (b[21] & a21) +
popcount64 (b[22] & a22) + popcount64 (b[23] & a23) +
popcount64 (b[24] & a24) + popcount64 (b[25] & a25) +
popcount64 (b[26] & a26) + popcount64 (b[27] & a27) +
popcount64 (b[28] & a28) + popcount64 (b[29] & a29) +
popcount64 (b[30] & a30) + popcount64 (b[31] & a31);
int accu_den = popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
popcount64 (b[2] | a2) + popcount64 (b[3] | a3) +
popcount64 (b[4] | a4) + popcount64 (b[5] | a5) +
popcount64 (b[6] | a6) + popcount64 (b[7] | a7) +
popcount64 (b[8] | a8) + popcount64 (b[9] | a9) +
popcount64 (b[10] | a10) + popcount64 (b[11] | a11) +
popcount64 (b[12] | a12) + popcount64 (b[13] | a13) +
popcount64 (b[14] | a14) + popcount64 (b[15] | a15) +
popcount64 (b[16] | a16) + popcount64 (b[17] | a17) +
popcount64 (b[18] | a18) + popcount64 (b[19] | a19) +
popcount64 (b[20] | a20) + popcount64 (b[21] | a21) +
popcount64 (b[22] | a22) + popcount64 (b[23] | a23) +
popcount64 (b[24] | a24) + popcount64 (b[25] | a25) +
popcount64 (b[26] | a26) + popcount64 (b[27] | a27) +
popcount64 (b[28] | a28) + popcount64 (b[29] | a29) +
popcount64 (b[30] | a30) + popcount64 (b[31] | a31);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
struct JaccardComputer512 {
uint64_t a0,a1,a2,a3,a4,a5,a6,a7,
a8,a9,a10,a11,a12,a13,a14,a15,
a16,a17,a18,a19,a20,a21,a22,a23,
a24,a25,a26,a27,a28,a29,a30,a31,
a32,a33,a34,a35,a36,a37,a38,a39,
a40,a41,a42,a43,a44,a45,a46,a47,
a48,a49,a50,a51,a52,a53,a54,a55,
a56,a57,a58,a59,a60,a61,a62,a63;
JaccardComputer512 () {}
JaccardComputer512 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *au8, int code_size) {
assert (code_size == 512);
const uint64_t *a = (uint64_t *)au8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11];
a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15];
a16 = a[16]; a17 = a[17]; a18 = a[18]; a19 = a[19];
a20 = a[20]; a21 = a[21]; a22 = a[22]; a23 = a[23];
a24 = a[24]; a25 = a[25]; a26 = a[26]; a27 = a[27];
a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31];
a32 = a[32]; a33 = a[33]; a34 = a[34]; a35 = a[35];
a36 = a[36]; a37 = a[37]; a38 = a[38]; a39 = a[39];
a40 = a[40]; a41 = a[41]; a42 = a[42]; a43 = a[43];
a44 = a[44]; a45 = a[45]; a46 = a[46]; a47 = a[47];
a48 = a[48]; a49 = a[49]; a50 = a[50]; a51 = a[51];
a52 = a[52]; a53 = a[53]; a54 = a[54]; a55 = a[55];
a56 = a[56]; a57 = a[57]; a58 = a[58]; a59 = a[59];
a60 = a[60]; a61 = a[61]; a62 = a[62]; a63 = a[63];
}
inline float compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15) +
popcount64 (b[16] & a16) + popcount64 (b[17] & a17) +
popcount64 (b[18] & a18) + popcount64 (b[19] & a19) +
popcount64 (b[20] & a20) + popcount64 (b[21] & a21) +
popcount64 (b[22] & a22) + popcount64 (b[23] & a23) +
popcount64 (b[24] & a24) + popcount64 (b[25] & a25) +
popcount64 (b[26] & a26) + popcount64 (b[27] & a27) +
popcount64 (b[28] & a28) + popcount64 (b[29] & a29) +
popcount64 (b[30] & a30) + popcount64 (b[31] & a31) +
popcount64 (b[32] & a32) + popcount64 (b[33] & a33) +
popcount64 (b[34] & a34) + popcount64 (b[35] & a35) +
popcount64 (b[36] & a36) + popcount64 (b[37] & a37) +
popcount64 (b[38] & a38) + popcount64 (b[39] & a39) +
popcount64 (b[40] & a40) + popcount64 (b[41] & a41) +
popcount64 (b[42] & a42) + popcount64 (b[43] & a43) +
popcount64 (b[44] & a44) + popcount64 (b[45] & a45) +
popcount64 (b[46] & a46) + popcount64 (b[47] & a47) +
popcount64 (b[48] & a48) + popcount64 (b[49] & a49) +
popcount64 (b[50] & a50) + popcount64 (b[51] & a51) +
popcount64 (b[52] & a52) + popcount64 (b[53] & a53) +
popcount64 (b[54] & a54) + popcount64 (b[55] & a55) +
popcount64 (b[56] & a56) + popcount64 (b[57] & a57) +
popcount64 (b[58] & a58) + popcount64 (b[59] & a59) +
popcount64 (b[60] & a60) + popcount64 (b[61] & a61) +
popcount64 (b[62] & a62) + popcount64 (b[63] & a63);
int accu_den = popcount64 (b[0] | a0) + popcount64 (b[1] | a1) +
popcount64 (b[2] | a2) + popcount64 (b[3] | a3) +
popcount64 (b[4] | a4) + popcount64 (b[5] | a5) +
popcount64 (b[6] | a6) + popcount64 (b[7] | a7) +
popcount64 (b[8] | a8) + popcount64 (b[9] | a9) +
popcount64 (b[10] | a10) + popcount64 (b[11] | a11) +
popcount64 (b[12] | a12) + popcount64 (b[13] | a13) +
popcount64 (b[14] | a14) + popcount64 (b[15] | a15) +
popcount64 (b[16] | a16) + popcount64 (b[17] | a17) +
popcount64 (b[18] | a18) + popcount64 (b[19] | a19) +
popcount64 (b[20] | a20) + popcount64 (b[21] | a21) +
popcount64 (b[22] | a22) + popcount64 (b[23] | a23) +
popcount64 (b[24] | a24) + popcount64 (b[25] | a25) +
popcount64 (b[26] | a26) + popcount64 (b[27] | a27) +
popcount64 (b[28] | a28) + popcount64 (b[29] | a29) +
popcount64 (b[30] | a30) + popcount64 (b[31] | a31) +
popcount64 (b[32] | a32) + popcount64 (b[33] | a33) +
popcount64 (b[34] | a34) + popcount64 (b[35] | a35) +
popcount64 (b[36] | a36) + popcount64 (b[37] | a37) +
popcount64 (b[38] | a38) + popcount64 (b[39] | a39) +
popcount64 (b[40] | a40) + popcount64 (b[41] | a41) +
popcount64 (b[42] | a42) + popcount64 (b[43] | a43) +
popcount64 (b[44] | a44) + popcount64 (b[45] | a45) +
popcount64 (b[46] | a46) + popcount64 (b[47] | a47) +
popcount64 (b[48] | a48) + popcount64 (b[49] | a49) +
popcount64 (b[50] | a50) + popcount64 (b[51] | a51) +
popcount64 (b[52] | a52) + popcount64 (b[53] | a53) +
popcount64 (b[54] | a54) + popcount64 (b[55] | a55) +
popcount64 (b[56] | a56) + popcount64 (b[57] | a57) +
popcount64 (b[58] | a58) + popcount64 (b[59] | a59) +
popcount64 (b[60] | a60) + popcount64 (b[61] | a61) +
popcount64 (b[62] | a62) + popcount64 (b[63] | a63);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
......@@ -158,7 +352,7 @@ namespace faiss {
n = code_size;
}
float jaccard (const uint8_t *b8) const {
float compute (const uint8_t *b8) const {
int accu_num = 0;
int accu_den = 0;
for (int i = 0; i < n; i++) {
......@@ -186,10 +380,13 @@ namespace faiss {
JaccardComputer ## CODE_SIZE(a, CODE_SIZE) {} \
}
SPECIALIZED_HC(8);
SPECIALIZED_HC(16);
SPECIALIZED_HC(32);
SPECIALIZED_HC(64);
SPECIALIZED_HC(128);
SPECIALIZED_HC(256);
SPECIALIZED_HC(512);
#undef SPECIALIZED_HC
......
#include <faiss/utils/jaccard.h>
#include <vector>
#include <memory>
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <assert.h>
#include <limits.h>
#include <faiss/utils/Heap.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/utils.h>
namespace faiss {
size_t jaccard_batch_size = 65536;
template <class JaccardComputer>
static
void jaccard_knn_hc(
int bytes_per_code,
float_maxheap_array_t * ha,
const uint8_t * bs1,
const uint8_t * bs2,
size_t n2,
bool order = true,
bool init_heap = true,
ConcurrentBitsetPtr bitset = nullptr)
{
size_t k = ha->k;
if (init_heap) ha->heapify ();
const size_t block_size = jaccard_batch_size;
for (size_t j0 = 0; j0 < n2; j0 += block_size) {
const size_t j1 = std::min(j0 + block_size, n2);
#pragma omp parallel for
for (size_t i = 0; i < ha->nh; i++) {
JaccardComputer hc (bs1 + i * bytes_per_code, bytes_per_code);
const uint8_t * bs2_ = bs2 + j0 * bytes_per_code;
tadis_t dis;
tadis_t * __restrict bh_val_ = ha->val + i * k;
int64_t * __restrict bh_ids_ = ha->ids + i * k;
size_t j;
for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
if(!bitset || !bitset->test(j)){
dis = hc.jaccard (bs2_);
if (dis < bh_val_[0]) {
faiss::maxheap_pop<tadis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<tadis_t> (k, bh_val_, bh_ids_, dis, j);
}
}
}
}
}
if (order) ha->reorder ();
}
void jaccard_knn_hc (
float_maxheap_array_t * ha,
const uint8_t * a,
const uint8_t * b,
size_t nb,
size_t ncodes,
int order,
ConcurrentBitsetPtr bitset)
{
switch (ncodes) {
case 16:
jaccard_knn_hc<faiss::JaccardComputer16>
(16, ha, a, b, nb, order, true, bitset);
break;
case 32:
jaccard_knn_hc<faiss::JaccardComputer32>
(32, ha, a, b, nb, order, true, bitset);
break;
case 64:
jaccard_knn_hc<faiss::JaccardComputer64>
(64, ha, a, b, nb, order, true, bitset);
break;
case 128:
jaccard_knn_hc<faiss::JaccardComputer128>
(128, ha, a, b, nb, order, true, bitset);
break;
default:
jaccard_knn_hc<faiss::JaccardComputerDefault>
(ncodes, ha, a, b, nb, order, true, bitset);
}
}
}
namespace faiss {
struct SubstructureComputer8 {
uint64_t a0;
SubstructureComputer8 () {}
SubstructureComputer8 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
assert (code_size == 8);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0];
}
inline float compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0);
int accu_den = popcount64 (b[0]);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
struct SubstructureComputer16 {
uint64_t a0, a1;
SubstructureComputer16 () {}
SubstructureComputer16 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
assert (code_size == 16);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0]; a1 = a[1];
}
inline float compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1);
int accu_den = popcount64 (b[0]) + popcount64 (b[1]);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
struct SubstructureComputer32 {
uint64_t a0, a1, a2, a3;
SubstructureComputer32 () {}
SubstructureComputer32 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
assert (code_size == 32);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
}
inline float compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3);
int accu_den = popcount64 (b[0]) + popcount64 (b[1]) +
popcount64 (b[2]) + popcount64 (b[3]);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
struct SubstructureComputer64 {
uint64_t a0, a1, a2, a3, a4, a5, a6, a7;
SubstructureComputer64 () {}
SubstructureComputer64 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
assert (code_size == 64);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
}
inline float compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7);
int accu_den = popcount64 (b[0]) + popcount64 (b[1]) +
popcount64 (b[2]) + popcount64 (b[3]) +
popcount64 (b[4]) + popcount64 (b[5]) +
popcount64 (b[6]) + popcount64 (b[7]);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
struct SubstructureComputer128 {
uint64_t a0, a1, a2, a3, a4, a5, a6, a7,
a8, a9, a10, a11, a12, a13, a14, a15;
SubstructureComputer128 () {}
SubstructureComputer128 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *au8, int code_size) {
assert (code_size == 128);
const uint64_t *a = (uint64_t *)au8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11];
a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15];
}
inline float compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15);
int accu_den = popcount64 (b[0]) + popcount64 (b[1]) +
popcount64 (b[2]) + popcount64 (b[3]) +
popcount64 (b[4]) + popcount64 (b[5]) +
popcount64 (b[6]) + popcount64 (b[7]) +
popcount64 (b[8]) + popcount64 (b[9]) +
popcount64 (b[10]) + popcount64 (b[11]) +
popcount64 (b[12]) + popcount64 (b[13]) +
popcount64 (b[14]) + popcount64 (b[15]);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
struct SubstructureComputer256 {
uint64_t a0,a1,a2,a3,a4,a5,a6,a7,
a8,a9,a10,a11,a12,a13,a14,a15,
a16,a17,a18,a19,a20,a21,a22,a23,
a24,a25,a26,a27,a28,a29,a30,a31;
SubstructureComputer256 () {}
SubstructureComputer256 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *au8, int code_size) {
assert (code_size == 256);
const uint64_t *a = (uint64_t *)au8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11];
a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15];
a16 = a[16]; a17 = a[17]; a18 = a[18]; a19 = a[19];
a20 = a[20]; a21 = a[21]; a22 = a[22]; a23 = a[23];
a24 = a[24]; a25 = a[25]; a26 = a[26]; a27 = a[27];
a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31];
}
inline float compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15) +
popcount64 (b[16] & a16) + popcount64 (b[17] & a17) +
popcount64 (b[18] & a18) + popcount64 (b[19] & a19) +
popcount64 (b[20] & a20) + popcount64 (b[21] & a21) +
popcount64 (b[22] & a22) + popcount64 (b[23] & a23) +
popcount64 (b[24] & a24) + popcount64 (b[25] & a25) +
popcount64 (b[26] & a26) + popcount64 (b[27] & a27) +
popcount64 (b[28] & a28) + popcount64 (b[29] & a29) +
popcount64 (b[30] & a30) + popcount64 (b[31] & a31);
int accu_den = popcount64 (b[0]) + popcount64 (b[1]) +
popcount64 (b[2]) + popcount64 (b[3]) +
popcount64 (b[4]) + popcount64 (b[5]) +
popcount64 (b[6]) + popcount64 (b[7]) +
popcount64 (b[8]) + popcount64 (b[9]) +
popcount64 (b[10]) + popcount64 (b[11]) +
popcount64 (b[12]) + popcount64 (b[13]) +
popcount64 (b[14]) + popcount64 (b[15]) +
popcount64 (b[16]) + popcount64 (b[17]) +
popcount64 (b[18]) + popcount64 (b[19]) +
popcount64 (b[20]) + popcount64 (b[21]) +
popcount64 (b[22]) + popcount64 (b[23]) +
popcount64 (b[24]) + popcount64 (b[25]) +
popcount64 (b[26]) + popcount64 (b[27]) +
popcount64 (b[28]) + popcount64 (b[29]) +
popcount64 (b[30]) + popcount64 (b[31]);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
struct SubstructureComputer512 {
uint64_t a0,a1,a2,a3,a4,a5,a6,a7,
a8,a9,a10,a11,a12,a13,a14,a15,
a16,a17,a18,a19,a20,a21,a22,a23,
a24,a25,a26,a27,a28,a29,a30,a31,
a32,a33,a34,a35,a36,a37,a38,a39,
a40,a41,a42,a43,a44,a45,a46,a47,
a48,a49,a50,a51,a52,a53,a54,a55,
a56,a57,a58,a59,a60,a61,a62,a63;
SubstructureComputer512 () {}
SubstructureComputer512 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *au8, int code_size) {
assert (code_size == 512);
const uint64_t *a = (uint64_t *)au8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11];
a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15];
a16 = a[16]; a17 = a[17]; a18 = a[18]; a19 = a[19];
a20 = a[20]; a21 = a[21]; a22 = a[22]; a23 = a[23];
a24 = a[24]; a25 = a[25]; a26 = a[26]; a27 = a[27];
a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31];
a32 = a[32]; a33 = a[33]; a34 = a[34]; a35 = a[35];
a36 = a[36]; a37 = a[37]; a38 = a[38]; a39 = a[39];
a40 = a[40]; a41 = a[41]; a42 = a[42]; a43 = a[43];
a44 = a[44]; a45 = a[45]; a46 = a[46]; a47 = a[47];
a48 = a[48]; a49 = a[49]; a50 = a[50]; a51 = a[51];
a52 = a[52]; a53 = a[53]; a54 = a[54]; a55 = a[55];
a56 = a[56]; a57 = a[57]; a58 = a[58]; a59 = a[59];
a60 = a[60]; a61 = a[61]; a62 = a[62]; a63 = a[63];
}
inline float compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15) +
popcount64 (b[16] & a16) + popcount64 (b[17] & a17) +
popcount64 (b[18] & a18) + popcount64 (b[19] & a19) +
popcount64 (b[20] & a20) + popcount64 (b[21] & a21) +
popcount64 (b[22] & a22) + popcount64 (b[23] & a23) +
popcount64 (b[24] & a24) + popcount64 (b[25] & a25) +
popcount64 (b[26] & a26) + popcount64 (b[27] & a27) +
popcount64 (b[28] & a28) + popcount64 (b[29] & a29) +
popcount64 (b[30] & a30) + popcount64 (b[31] & a31) +
popcount64 (b[32] & a32) + popcount64 (b[33] & a33) +
popcount64 (b[34] & a34) + popcount64 (b[35] & a35) +
popcount64 (b[36] & a36) + popcount64 (b[37] & a37) +
popcount64 (b[38] & a38) + popcount64 (b[39] & a39) +
popcount64 (b[40] & a40) + popcount64 (b[41] & a41) +
popcount64 (b[42] & a42) + popcount64 (b[43] & a43) +
popcount64 (b[44] & a44) + popcount64 (b[45] & a45) +
popcount64 (b[46] & a46) + popcount64 (b[47] & a47) +
popcount64 (b[48] & a48) + popcount64 (b[49] & a49) +
popcount64 (b[50] & a50) + popcount64 (b[51] & a51) +
popcount64 (b[52] & a52) + popcount64 (b[53] & a53) +
popcount64 (b[54] & a54) + popcount64 (b[55] & a55) +
popcount64 (b[56] & a56) + popcount64 (b[57] & a57) +
popcount64 (b[58] & a58) + popcount64 (b[59] & a59) +
popcount64 (b[60] & a60) + popcount64 (b[61] & a61) +
popcount64 (b[62] & a62) + popcount64 (b[63] & a63);
int accu_den = popcount64 (b[0]) + popcount64 (b[1]) +
popcount64 (b[2]) + popcount64 (b[3]) +
popcount64 (b[4]) + popcount64 (b[5]) +
popcount64 (b[6]) + popcount64 (b[7]) +
popcount64 (b[8]) + popcount64 (b[9]) +
popcount64 (b[10]) + popcount64 (b[11]) +
popcount64 (b[12]) + popcount64 (b[13]) +
popcount64 (b[14]) + popcount64 (b[15]) +
popcount64 (b[16]) + popcount64 (b[17]) +
popcount64 (b[18]) + popcount64 (b[19]) +
popcount64 (b[20]) + popcount64 (b[21]) +
popcount64 (b[22]) + popcount64 (b[23]) +
popcount64 (b[24]) + popcount64 (b[25]) +
popcount64 (b[26]) + popcount64 (b[27]) +
popcount64 (b[28]) + popcount64 (b[29]) +
popcount64 (b[30]) + popcount64 (b[31]) +
popcount64 (b[32]) + popcount64 (b[33]) +
popcount64 (b[34]) + popcount64 (b[35]) +
popcount64 (b[36]) + popcount64 (b[37]) +
popcount64 (b[38]) + popcount64 (b[39]) +
popcount64 (b[40]) + popcount64 (b[41]) +
popcount64 (b[42]) + popcount64 (b[43]) +
popcount64 (b[44]) + popcount64 (b[45]) +
popcount64 (b[46]) + popcount64 (b[47]) +
popcount64 (b[48]) + popcount64 (b[49]) +
popcount64 (b[50]) + popcount64 (b[51]) +
popcount64 (b[52]) + popcount64 (b[53]) +
popcount64 (b[54]) + popcount64 (b[55]) +
popcount64 (b[56]) + popcount64 (b[57]) +
popcount64 (b[58]) + popcount64 (b[59]) +
popcount64 (b[60]) + popcount64 (b[61]) +
popcount64 (b[62]) + popcount64 (b[63]);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
struct SubstructureComputerDefault {
const uint8_t *a;
int n;
SubstructureComputerDefault () {}
SubstructureComputerDefault (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
a = a8;
n = code_size;
}
float compute (const uint8_t *b8) const {
int accu_num = 0;
int accu_den = 0;
for (int i = 0; i < n; i++) {
accu_num += popcount64(a[i] & b8[i]);
accu_den += popcount64(b8[i]);
}
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / (float)(accu_den);
}
};
// default template
template<int CODE_SIZE>
struct SubstructureComputer: SubstructureComputerDefault {
SubstructureComputer (const uint8_t *a, int code_size):
SubstructureComputerDefault(a, code_size) {}
};
#define SPECIALIZED_HC(CODE_SIZE) \
template<> struct SubstructureComputer<CODE_SIZE>: \
SubstructureComputer ## CODE_SIZE { \
SubstructureComputer (const uint8_t *a): \
SubstructureComputer ## CODE_SIZE(a, CODE_SIZE) {} \
}
SPECIALIZED_HC(8);
SPECIALIZED_HC(16);
SPECIALIZED_HC(32);
SPECIALIZED_HC(64);
SPECIALIZED_HC(128);
SPECIALIZED_HC(256);
SPECIALIZED_HC(512);
#undef SPECIALIZED_HC
}
namespace faiss {
struct SuperstructureComputer8 {
uint64_t a0;
float accu_den;
SuperstructureComputer8 () {}
SuperstructureComputer8 (const uint8_t *a8, int code_size) {
set (a8, code_size);
accu_den = (float)(popcount64 (a0));
}
void set (const uint8_t *a8, int code_size) {
assert (code_size == 8);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0];
}
inline float compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
}
};
struct SuperstructureComputer16 {
uint64_t a0, a1;
float accu_den;
SuperstructureComputer16 () {}
SuperstructureComputer16 (const uint8_t *a8, int code_size) {
set (a8, code_size);
accu_den = (float)(popcount64 (a0) + popcount64 (a1));
}
void set (const uint8_t *a8, int code_size) {
assert (code_size == 16);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0]; a1 = a[1];
}
inline float compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
}
};
struct SuperstructureComputer32 {
uint64_t a0, a1, a2, a3;
float accu_den;
SuperstructureComputer32 () {}
SuperstructureComputer32 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
assert (code_size == 32);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
accu_den = (float)(popcount64 (a0) + popcount64 (a1) +
popcount64 (a2) + popcount64 (a3));
}
inline float compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
}
};
struct SuperstructureComputer64 {
uint64_t a0, a1, a2, a3, a4, a5, a6, a7;
float accu_den;
SuperstructureComputer64 () {}
SuperstructureComputer64 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
assert (code_size == 64);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
accu_den = (float)(popcount64 (a0) + popcount64 (a1) +
popcount64 (a2) + popcount64 (a3) +
popcount64 (a4) + popcount64 (a5) +
popcount64 (a6) + popcount64 (a7));
}
inline float compute (const uint8_t *b8) const {
const uint64_t *b = (uint64_t *)b8;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
}
};
struct SuperstructureComputer128 {
uint64_t a0, a1, a2, a3, a4, a5, a6, a7,
a8, a9, a10, a11, a12, a13, a14, a15;
float accu_den;
SuperstructureComputer128 () {}
SuperstructureComputer128 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *au8, int code_size) {
assert (code_size == 128);
const uint64_t *a = (uint64_t *)au8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11];
a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15];
accu_den = (float)(popcount64 (a0) + popcount64 (a1) +
popcount64 (a2) + popcount64 (a3) +
popcount64 (a4) + popcount64 (a5) +
popcount64 (a6) + popcount64 (a7) +
popcount64 (a8) + popcount64 (a9) +
popcount64 (a10) + popcount64 (a11) +
popcount64 (a12) + popcount64 (a13) +
popcount64 (a14) + popcount64 (a15));
}
inline float compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
}
};
struct SuperstructureComputer256 {
uint64_t a0,a1,a2,a3,a4,a5,a6,a7,
a8,a9,a10,a11,a12,a13,a14,a15,
a16,a17,a18,a19,a20,a21,a22,a23,
a24,a25,a26,a27,a28,a29,a30,a31;
float accu_den;
SuperstructureComputer256 () {}
SuperstructureComputer256 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *au8, int code_size) {
assert (code_size == 256);
const uint64_t *a = (uint64_t *)au8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11];
a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15];
a16 = a[16]; a17 = a[17]; a18 = a[18]; a19 = a[19];
a20 = a[20]; a21 = a[21]; a22 = a[22]; a23 = a[23];
a24 = a[24]; a25 = a[25]; a26 = a[26]; a27 = a[27];
a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31];
accu_den = (float)(popcount64 (a0) + popcount64 (a1) +
popcount64 (a2) + popcount64 (a3) +
popcount64 (a4) + popcount64 (a5) +
popcount64 (a6) + popcount64 (a7) +
popcount64 (a8) + popcount64 (a9) +
popcount64 (a10) + popcount64 (a11) +
popcount64 (a12) + popcount64 (a13) +
popcount64 (a14) + popcount64 (a15) +
popcount64 (a16) + popcount64 (a17) +
popcount64 (a18) + popcount64 (a19) +
popcount64 (a20) + popcount64 (a21) +
popcount64 (a22) + popcount64 (a23) +
popcount64 (a24) + popcount64 (a25) +
popcount64 (a26) + popcount64 (a27) +
popcount64 (a28) + popcount64 (a29) +
popcount64 (a30) + popcount64 (a31));
}
inline float compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15) +
popcount64 (b[16] & a16) + popcount64 (b[17] & a17) +
popcount64 (b[18] & a18) + popcount64 (b[19] & a19) +
popcount64 (b[20] & a20) + popcount64 (b[21] & a21) +
popcount64 (b[22] & a22) + popcount64 (b[23] & a23) +
popcount64 (b[24] & a24) + popcount64 (b[25] & a25) +
popcount64 (b[26] & a26) + popcount64 (b[27] & a27) +
popcount64 (b[28] & a28) + popcount64 (b[29] & a29) +
popcount64 (b[30] & a30) + popcount64 (b[31] & a31);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
}
};
struct SuperstructureComputer512 {
uint64_t a0,a1,a2,a3,a4,a5,a6,a7,
a8,a9,a10,a11,a12,a13,a14,a15,
a16,a17,a18,a19,a20,a21,a22,a23,
a24,a25,a26,a27,a28,a29,a30,a31,
a32,a33,a34,a35,a36,a37,a38,a39,
a40,a41,a42,a43,a44,a45,a46,a47,
a48,a49,a50,a51,a52,a53,a54,a55,
a56,a57,a58,a59,a60,a61,a62,a63;
float accu_den;
SuperstructureComputer512 () {}
SuperstructureComputer512 (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *au8, int code_size) {
assert (code_size == 512);
const uint64_t *a = (uint64_t *)au8;
a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3];
a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7];
a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11];
a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15];
a16 = a[16]; a17 = a[17]; a18 = a[18]; a19 = a[19];
a20 = a[20]; a21 = a[21]; a22 = a[22]; a23 = a[23];
a24 = a[24]; a25 = a[25]; a26 = a[26]; a27 = a[27];
a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31];
a32 = a[32]; a33 = a[33]; a34 = a[34]; a35 = a[35];
a36 = a[36]; a37 = a[37]; a38 = a[38]; a39 = a[39];
a40 = a[40]; a41 = a[41]; a42 = a[42]; a43 = a[43];
a44 = a[44]; a45 = a[45]; a46 = a[46]; a47 = a[47];
a48 = a[48]; a49 = a[49]; a50 = a[50]; a51 = a[51];
a52 = a[52]; a53 = a[53]; a54 = a[54]; a55 = a[55];
a56 = a[56]; a57 = a[57]; a58 = a[58]; a59 = a[59];
a60 = a[60]; a61 = a[61]; a62 = a[62]; a63 = a[63];
accu_den = (float)(popcount64 (a0) + popcount64 (a1) +
popcount64 (a2) + popcount64 (a3) +
popcount64 (a4) + popcount64 (a5) +
popcount64 (a6) + popcount64 (a7) +
popcount64 (a8) + popcount64 (a9) +
popcount64 (a10) + popcount64 (a11) +
popcount64 (a12) + popcount64 (a13) +
popcount64 (a14) + popcount64 (a15) +
popcount64 (a16) + popcount64 (a17) +
popcount64 (a18) + popcount64 (a19) +
popcount64 (a20) + popcount64 (a21) +
popcount64 (a22) + popcount64 (a23) +
popcount64 (a24) + popcount64 (a25) +
popcount64 (a26) + popcount64 (a27) +
popcount64 (a28) + popcount64 (a29) +
popcount64 (a30) + popcount64 (a31) +
popcount64 (a32) + popcount64 (a33) +
popcount64 (a34) + popcount64 (a35) +
popcount64 (a36) + popcount64 (a37) +
popcount64 (a38) + popcount64 (a39) +
popcount64 (a40) + popcount64 (a41) +
popcount64 (a42) + popcount64 (a43) +
popcount64 (a44) + popcount64 (a45) +
popcount64 (a46) + popcount64 (a47) +
popcount64 (a48) + popcount64 (a49) +
popcount64 (a50) + popcount64 (a51) +
popcount64 (a52) + popcount64 (a53) +
popcount64 (a54) + popcount64 (a55) +
popcount64 (a56) + popcount64 (a57) +
popcount64 (a58) + popcount64 (a59) +
popcount64 (a60) + popcount64 (a61) +
popcount64 (a62) + popcount64 (a63));
}
inline float compute (const uint8_t *b16) const {
const uint64_t *b = (uint64_t *)b16;
int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) +
popcount64 (b[2] & a2) + popcount64 (b[3] & a3) +
popcount64 (b[4] & a4) + popcount64 (b[5] & a5) +
popcount64 (b[6] & a6) + popcount64 (b[7] & a7) +
popcount64 (b[8] & a8) + popcount64 (b[9] & a9) +
popcount64 (b[10] & a10) + popcount64 (b[11] & a11) +
popcount64 (b[12] & a12) + popcount64 (b[13] & a13) +
popcount64 (b[14] & a14) + popcount64 (b[15] & a15) +
popcount64 (b[16] & a16) + popcount64 (b[17] & a17) +
popcount64 (b[18] & a18) + popcount64 (b[19] & a19) +
popcount64 (b[20] & a20) + popcount64 (b[21] & a21) +
popcount64 (b[22] & a22) + popcount64 (b[23] & a23) +
popcount64 (b[24] & a24) + popcount64 (b[25] & a25) +
popcount64 (b[26] & a26) + popcount64 (b[27] & a27) +
popcount64 (b[28] & a28) + popcount64 (b[29] & a29) +
popcount64 (b[30] & a30) + popcount64 (b[31] & a31) +
popcount64 (b[32] & a32) + popcount64 (b[33] & a33) +
popcount64 (b[34] & a34) + popcount64 (b[35] & a35) +
popcount64 (b[36] & a36) + popcount64 (b[37] & a37) +
popcount64 (b[38] & a38) + popcount64 (b[39] & a39) +
popcount64 (b[40] & a40) + popcount64 (b[41] & a41) +
popcount64 (b[42] & a42) + popcount64 (b[43] & a43) +
popcount64 (b[44] & a44) + popcount64 (b[45] & a45) +
popcount64 (b[46] & a46) + popcount64 (b[47] & a47) +
popcount64 (b[48] & a48) + popcount64 (b[49] & a49) +
popcount64 (b[50] & a50) + popcount64 (b[51] & a51) +
popcount64 (b[52] & a52) + popcount64 (b[53] & a53) +
popcount64 (b[54] & a54) + popcount64 (b[55] & a55) +
popcount64 (b[56] & a56) + popcount64 (b[57] & a57) +
popcount64 (b[58] & a58) + popcount64 (b[59] & a59) +
popcount64 (b[60] & a60) + popcount64 (b[61] & a61) +
popcount64 (b[62] & a62) + popcount64 (b[63] & a63);
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
}
};
struct SuperstructureComputerDefault {
const uint8_t *a;
int n;
float accu_den;
SuperstructureComputerDefault () {}
SuperstructureComputerDefault (const uint8_t *a8, int code_size) {
set (a8, code_size);
}
void set (const uint8_t *a8, int code_size) {
a = a8;
n = code_size;
int i_accu_den = 0;
for (int i = 0; i < n; i++) {
i_accu_den += popcount64(a[i]);
}
accu_den = (float)i_accu_den;
}
float compute (const uint8_t *b8) const {
int accu_num = 0;
for (int i = 0; i < n; i++) {
accu_num += popcount64(a[i] & b8[i]);
}
if (accu_num == 0)
return 1.0;
return 1.0 - (float)(accu_num) / accu_den;
}
};
// default template
template<int CODE_SIZE>
struct SuperstructureComputer: SuperstructureComputerDefault {
SuperstructureComputer (const uint8_t *a, int code_size):
SuperstructureComputerDefault(a, code_size) {}
};
#define SPECIALIZED_HC(CODE_SIZE) \
template<> struct SuperstructureComputer<CODE_SIZE>: \
SuperstructureComputer ## CODE_SIZE { \
SuperstructureComputer (const uint8_t *a): \
SuperstructureComputer ## CODE_SIZE(a, CODE_SIZE) {} \
}
SPECIALIZED_HC(8);
SPECIALIZED_HC(16);
SPECIALIZED_HC(32);
SPECIALIZED_HC(64);
SPECIALIZED_HC(128);
SPECIALIZED_HC(256);
SPECIALIZED_HC(512);
#undef SPECIALIZED_HC
}
......@@ -268,7 +268,8 @@ HNSWConfAdapter::CheckSearch(milvus::json& oricfg, const IndexType& type) {
bool
BinIDMAPConfAdapter::CheckTrain(milvus::json& oricfg) {
static std::vector<std::string> METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD,
knowhere::Metric::TANIMOTO};
knowhere::Metric::TANIMOTO, knowhere::Metric::SUBSTRUCTURE,
knowhere::Metric::SUPERSTRUCTURE};
CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM);
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
......
......@@ -77,6 +77,8 @@ Utils::MetricTypeName(const milvus::MetricType& metric_type) {
case milvus::MetricType::HAMMING:return "Hamming distance";
case milvus::MetricType::JACCARD:return "Jaccard distance";
case milvus::MetricType::TANIMOTO:return "Tanimoto distance";
case milvus::MetricType::SUBSTRUCTURE:return "Substructure distance";
case milvus::MetricType::SUPERSTRUCTURE:return "Superstructure distance";
default:return "Unknown metric type";
}
}
......
......@@ -43,6 +43,8 @@ enum class MetricType {
HAMMING = 3, // Hamming Distance
JACCARD = 4, // Jaccard Distance
TANIMOTO = 5, // Tanimoto Distance
SUBSTRUCTURE = 6, // Substructure Distance
SUPERSTRUCTURE = 7, // Superstructure Distance
};
/**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册