未验证 提交 1615e948 编写于 作者: S shengjun.li 提交者: GitHub

fix hamming (#3342)

Signed-off-by: Nshengjun.li <shengjun.li@zilliz.com>
上级 2a6f797c
......@@ -56,61 +56,11 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
auto pf_dist = (float*)malloc(p_dist_size);
int32_t* pi_dist = reinterpret_cast<int32_t*>(p_dist);
for (int i = 0; i < elems; i++) {
*(pf_dist + i) = (float)(*(pi_dist + i));
}
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, pf_dist);
free(p_dist);
} else {
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
}
return ret_ds;
}
#if 0
DatasetPtr
BinaryIDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
auto dim = dataset_ptr->Get<int64_t>(meta::DIM);
auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
int64_t k = config[meta::TOPK].get<int64_t>();
auto elems = rows * k;
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
auto* pdistances = (int32_t*)p_dist;
index_->search_by_id(rows, p_data, k, pdistances, p_id, bitset_);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
auto pf_dist = (float*)malloc(p_dist_size);
int32_t* pi_dist = (int32_t*)p_dist;
for (int i = 0; i < elems; i++) {
*(pf_dist + i) = (float)(*(pi_dist + i));
}
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, pf_dist);
free(p_dist);
} else {
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
}
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
}
#endif
int64_t
BinaryIDMAP::Count() {
......@@ -190,29 +140,6 @@ BinaryIDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config)
index_->add_with_ids(rows, (uint8_t*)p_data, new_ids.data());
}
#if 0
DatasetPtr
BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
// GETBINARYTENSOR(dataset_ptr)
// auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
auto elems = dataset_ptr->Get<int64_t>(meta::DIM);
size_t p_x_size = sizeof(uint8_t) * elems;
auto p_x = (uint8_t*)malloc(p_x_size);
index_->get_vector_by_id(1, p_data, p_x, bitset_);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::TENSOR, p_x);
return ret_ds;
}
#endif
void
BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
const Config& config) {
......@@ -220,8 +147,16 @@ BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distanc
auto bin_flat_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get())->index;
bin_flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
int32_t* pdistances = (int32_t*)distances;
index_->search(n, (uint8_t*)data, k, pdistances, labels, bitset_);
int32_t* i_distances = reinterpret_cast<int32_t*>(distances);
bin_flat_index->search(n, (uint8_t*)data, k, i_distances, labels, bitset_);
// if hamming, it need transform int32 to float
if (bin_flat_index->metric_type == faiss::METRIC_Hamming) {
int64_t num = n * k;
for (int64_t i = 0; i < num; i++) {
distances[i] = static_cast<float>(i_distances[i]);
}
}
}
} // namespace knowhere
......
......@@ -50,11 +50,6 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
DatasetPtr
Query(const DatasetPtr&, const Config&) override;
#if 0
DatasetPtr
QueryById(const DatasetPtr& dataset_ptr, const Config& config) override;
#endif
int64_t
Count() override;
......@@ -66,11 +61,6 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
return Count() * Dim() / 8;
}
#if 0
DatasetPtr
GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) override;
#endif
virtual const uint8_t*
GetRawVectors();
......
......@@ -62,63 +62,8 @@ BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
auto pf_dist = (float*)malloc(p_dist_size);
int32_t* pi_dist = reinterpret_cast<int32_t*>(p_dist);
for (int i = 0; i < elems; i++) {
*(pf_dist + i) = (float)(*(pi_dist + i));
}
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, pf_dist);
free(p_dist);
} else {
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
}
return ret_ds;
} catch (faiss::FaissException& e) {
KNOWHERE_THROW_MSG(e.what());
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
#if 0
DatasetPtr
BinaryIVF::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
try {
int64_t k = config[meta::TOPK].get<int64_t>();
auto elems = rows * k;
size_t p_id_size = sizeof(int64_t) * elems;
size_t p_dist_size = sizeof(float) * elems;
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
int32_t* pdistances = (int32_t*)p_dist;
index_->search_by_id(rows, p_data, k, pdistances, p_id, bitset_);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
auto pf_dist = (float*)malloc(p_dist_size);
int32_t* pi_dist = (int32_t*)p_dist;
for (int i = 0; i < elems; i++) {
*(pf_dist + i) = (float)(*(pi_dist + i));
}
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, pf_dist);
free(p_dist);
} else {
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
}
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
} catch (faiss::FaissException& e) {
......@@ -127,7 +72,6 @@ BinaryIVF::QueryById(const DatasetPtr& dataset_ptr, const Config& config) {
KNOWHERE_THROW_MSG(e.what());
}
}
#endif
int64_t
BinaryIVF::Count() {
......@@ -172,35 +116,6 @@ BinaryIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
index_ = index;
}
#if 0
DatasetPtr
BinaryIVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
// GETBINARYTENSOR(dataset_ptr)
// auto rows = dataset_ptr->Get<int64_t>(meta::ROWS);
auto p_data = dataset_ptr->Get<const int64_t*>(meta::IDS);
auto elems = dataset_ptr->Get<int64_t>(meta::DIM);
try {
size_t p_x_size = sizeof(uint8_t) * elems;
auto p_x = (uint8_t*)malloc(p_x_size);
index_->get_vector_by_id(1, p_data, p_x, bitset_);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::TENSOR, p_x);
return ret_ds;
} catch (faiss::FaissException& e) {
KNOWHERE_THROW_MSG(e.what());
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
#endif
std::shared_ptr<faiss::IVFSearchParameters>
BinaryIVF::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFSearchParameters>();
......@@ -215,11 +130,10 @@ BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances
auto params = GenParams(config);
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
ivf_index->nprobe = params->nprobe;
int32_t* pdistances = (int32_t*)distances;
stdclock::time_point before = stdclock::now();
// todo: remove static cast (zhiru)
static_cast<faiss::IndexBinary*>(index_.get())->search(n, (uint8_t*)data, k, pdistances, labels, bitset_);
stdclock::time_point before = stdclock::now();
int32_t* i_distances = reinterpret_cast<int32_t*>(distances);
index_->search(n, (uint8_t*)data, k, i_distances, labels, bitset_);
stdclock::time_point after = stdclock::now();
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
......@@ -228,6 +142,14 @@ BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances
<< ", data search cost: " << faiss::indexIVF_stats.search_time;
faiss::indexIVF_stats.quantization_time = 0;
faiss::indexIVF_stats.search_time = 0;
// if hamming, it need transform int32 to float
if (ivf_index->metric_type == faiss::METRIC_Hamming) {
int64_t num = n * k;
for (int64_t i = 0; i < num; i++) {
distances[i] = static_cast<float>(i_distances[i]);
}
}
}
} // namespace knowhere
......
......@@ -62,11 +62,6 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
DatasetPtr
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
#if 0
DatasetPtr
QueryById(const DatasetPtr& dataset_ptr, const Config& config) override;
#endif
int64_t
Count() override;
......@@ -76,11 +71,6 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
void
UpdateIndexSize() override;
#if 0
DatasetPtr
GetVectorById(const DatasetPtr& dataset_ptr, const Config& config);
#endif
protected:
virtual std::shared_ptr<faiss::IVFSearchParameters>
GenParams(const Config& config);
......
......@@ -120,7 +120,7 @@ TestProcess(std::shared_ptr<milvus::Connection> connection, const milvus::Mappin
// std::string metric_type = "TANIMOTO";
nlohmann::json dsl_json, vector_param_json;
milvus_sdk::Utils::GenDSLJson(dsl_json, vector_param_json, metric_type);
milvus_sdk::Utils::GenBinaryDSLJson(dsl_json, vector_param_json, metric_type);
std::vector<milvus::VectorData> temp_entity_array;
for (auto& pair : search_entity_array) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册