提交 941e4dc8 编写于 作者: J jinhai

Merge branch 'update_unittest' into 'branch-0.4.0'

MS-538 1. update kdt unittest

See merge request megasearch/milvus!539

Former-commit-id: eca8fb73f491f73ddc2775a26d0b03db9dc92fe2
#pragma once //#pragma once
//
#include <memory> //#include <memory>
#include "preprocessor.h" //#include "preprocessor.h"
//
//
namespace zilliz { //namespace zilliz {
namespace knowhere { //namespace knowhere {
//
class NormalizePreprocessor : public Preprocessor { //class NormalizePreprocessor : public Preprocessor {
public: // public:
DatasetPtr // DatasetPtr
Preprocess(const DatasetPtr &input) override; // Preprocess(const DatasetPtr &input) override;
//
private: // private:
//
void // void
Normalize(float *arr, int64_t dimension); // Normalize(float *arr, int64_t dimension);
}; //};
//
//
using NormalizePreprocessorPtr = std::shared_ptr<NormalizePreprocessor>; //using NormalizePreprocessorPtr = std::shared_ptr<NormalizePreprocessor>;
//
//
} // namespace knowhere //} // namespace knowhere
} // namespace zilliz //} // namespace zilliz
...@@ -27,8 +27,8 @@ class CPUKDTRNG : public VectorIndex { ...@@ -27,8 +27,8 @@ class CPUKDTRNG : public VectorIndex {
Load(const BinarySet &index_array) override; Load(const BinarySet &index_array) override;
public: public:
PreprocessorPtr //PreprocessorPtr
BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override; //BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override;
int64_t Count() override; int64_t Count() override;
int64_t Dimension() override; int64_t Dimension() override;
......
//
#include "knowhere/index/vector_index/definitions.h" //#include "knowhere/index/vector_index/definitions.h"
#include "knowhere/common/config.h" //#include "knowhere/common/config.h"
#include "knowhere/index/preprocessor/normalize.h" //#include "knowhere/index/preprocessor/normalize.h"
//
//
namespace zilliz { //namespace zilliz {
namespace knowhere { //namespace knowhere {
//
DatasetPtr //DatasetPtr
NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) { //NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) {
// // TODO: wrap dataset->tensor // // TODO: wrap dataset->tensor
// auto tensor = dataset->tensor()[0]; // auto tensor = dataset->tensor()[0];
// auto p_data = (float *)tensor->raw_mutable_data(); // auto p_data = (float *)tensor->raw_mutable_data();
...@@ -19,24 +19,24 @@ NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) { ...@@ -19,24 +19,24 @@ NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) {
// for (auto i = 0; i < rows; ++i) { // for (auto i = 0; i < rows; ++i) {
// Normalize(&(p_data[i * dimension]), dimension); // Normalize(&(p_data[i * dimension]), dimension);
// } // }
} //}
//
void //void
NormalizePreprocessor::Normalize(float *arr, int64_t dimension) { //NormalizePreprocessor::Normalize(float *arr, int64_t dimension) {
//double vector_length = 0; // double vector_length = 0;
//for (auto j = 0; j < dimension; j++) { // for (auto j = 0; j < dimension; j++) {
// double val = arr[j]; // double val = arr[j];
// vector_length += val * val; // vector_length += val * val;
//} // }
//vector_length = std::sqrt(vector_length); // vector_length = std::sqrt(vector_length);
//if (vector_length < 1e-6) { // if (vector_length < 1e-6) {
// auto val = (float) (1.0 / std::sqrt((double) dimension)); // auto val = (float) (1.0 / std::sqrt((double) dimension));
// for (int j = 0; j < dimension; j++) arr[j] = val; // for (int j = 0; j < dimension; j++) arr[j] = val;
//} else { // } else {
// for (int j = 0; j < dimension; j++) arr[j] = (float) (arr[j] / vector_length); // for (int j = 0; j < dimension; j++) arr[j] = (float) (arr[j] / vector_length);
//} // }
} //}
//
} // namespace knowhere //} // namespace knowhere
} // namespace zilliz //} // namespace zilliz
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "knowhere/index/vector_index/cpu_kdt_rng.h" #include "knowhere/index/vector_index/cpu_kdt_rng.h"
#include "knowhere/index/vector_index/definitions.h" #include "knowhere/index/vector_index/definitions.h"
#include "knowhere/index/preprocessor/normalize.h" //#include "knowhere/index/preprocessor/normalize.h"
#include "knowhere/index/vector_index/kdt_parameters.h" #include "knowhere/index/vector_index/kdt_parameters.h"
#include "knowhere/adapter/sptag.h" #include "knowhere/adapter/sptag.h"
#include "knowhere/common/exception.h" #include "knowhere/common/exception.h"
...@@ -60,10 +60,10 @@ CPUKDTRNG::Load(const BinarySet &binary_set) { ...@@ -60,10 +60,10 @@ CPUKDTRNG::Load(const BinarySet &binary_set) {
index_ptr_->LoadIndexFromMemory(index_blobs); index_ptr_->LoadIndexFromMemory(index_blobs);
} }
PreprocessorPtr //PreprocessorPtr
CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) { //CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) {
return std::make_shared<NormalizePreprocessor>(); // return std::make_shared<NormalizePreprocessor>();
} //}
IndexModelPtr IndexModelPtr
CPUKDTRNG::Train(const DatasetPtr &origin, const Config &train_config) { CPUKDTRNG::Train(const DatasetPtr &origin, const Config &train_config) {
...@@ -72,7 +72,7 @@ CPUKDTRNG::Train(const DatasetPtr &origin, const Config &train_config) { ...@@ -72,7 +72,7 @@ CPUKDTRNG::Train(const DatasetPtr &origin, const Config &train_config) {
//if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine //if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// && preprocessor_) { // && preprocessor_) {
preprocessor_->Preprocess(dataset); // preprocessor_->Preprocess(dataset);
//} //}
auto vectorset = ConvertToVectorSet(dataset); auto vectorset = ConvertToVectorSet(dataset);
...@@ -90,7 +90,7 @@ CPUKDTRNG::Add(const DatasetPtr &origin, const Config &add_config) { ...@@ -90,7 +90,7 @@ CPUKDTRNG::Add(const DatasetPtr &origin, const Config &add_config) {
//if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine //if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// && preprocessor_) { // && preprocessor_) {
preprocessor_->Preprocess(dataset); // preprocessor_->Preprocess(dataset);
//} //}
auto vectorset = ConvertToVectorSet(dataset); auto vectorset = ConvertToVectorSet(dataset);
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "knowhere/common/exception.h"
#include "knowhere/index/vector_index/cpu_kdt_rng.h" #include "knowhere/index/vector_index/cpu_kdt_rng.h"
#include "knowhere/index/vector_index/definitions.h" #include "knowhere/index/vector_index/definitions.h"
...@@ -125,6 +126,10 @@ TEST_P(KDTTest, kdt_serialize) { ...@@ -125,6 +126,10 @@ TEST_P(KDTTest, kdt_serialize) {
auto result = new_index->Search(query_dataset, search_cfg); auto result = new_index->Search(query_dataset, search_cfg);
AssertAnns(result, nq, k); AssertAnns(result, nq, k);
PrintResult(result, nq, k); PrintResult(result, nq, k);
ASSERT_EQ(new_index->Count(), nb);
ASSERT_EQ(new_index->Dimension(), dim);
ASSERT_THROW({new_index->Clone();}, zilliz::knowhere::KnowhereException);
ASSERT_NO_THROW({new_index->Seal();});
{ {
int fileno = 0; int fileno = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册