提交 941e4dc8 编写于 作者: J jinhai

Merge branch 'update_unittest' into 'branch-0.4.0'

MS-538 1. update kdt unittest

See merge request megasearch/milvus!539

Former-commit-id: eca8fb73f491f73ddc2775a26d0b03db9dc92fe2
#pragma once
#include <memory>
#include "preprocessor.h"
namespace zilliz {
namespace knowhere {
class NormalizePreprocessor : public Preprocessor {
public:
DatasetPtr
Preprocess(const DatasetPtr &input) override;
private:
void
Normalize(float *arr, int64_t dimension);
};
using NormalizePreprocessorPtr = std::shared_ptr<NormalizePreprocessor>;
} // namespace knowhere
} // namespace zilliz
//#pragma once
//
//#include <memory>
//#include "preprocessor.h"
//
//
//namespace zilliz {
//namespace knowhere {
//
//class NormalizePreprocessor : public Preprocessor {
// public:
// DatasetPtr
// Preprocess(const DatasetPtr &input) override;
//
// private:
//
// void
// Normalize(float *arr, int64_t dimension);
//};
//
//
//using NormalizePreprocessorPtr = std::shared_ptr<NormalizePreprocessor>;
//
//
//} // namespace knowhere
//} // namespace zilliz
......@@ -27,8 +27,8 @@ class CPUKDTRNG : public VectorIndex {
Load(const BinarySet &index_array) override;
public:
PreprocessorPtr
BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override;
//PreprocessorPtr
//BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override;
int64_t Count() override;
int64_t Dimension() override;
......
#include "knowhere/index/vector_index/definitions.h"
#include "knowhere/common/config.h"
#include "knowhere/index/preprocessor/normalize.h"
namespace zilliz {
namespace knowhere {
DatasetPtr
NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) {
//
//#include "knowhere/index/vector_index/definitions.h"
//#include "knowhere/common/config.h"
//#include "knowhere/index/preprocessor/normalize.h"
//
//
//namespace zilliz {
//namespace knowhere {
//
//DatasetPtr
//NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) {
// // TODO: wrap dataset->tensor
// auto tensor = dataset->tensor()[0];
// auto p_data = (float *)tensor->raw_mutable_data();
......@@ -19,24 +19,24 @@ NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) {
// for (auto i = 0; i < rows; ++i) {
// Normalize(&(p_data[i * dimension]), dimension);
// }
}
void
NormalizePreprocessor::Normalize(float *arr, int64_t dimension) {
//double vector_length = 0;
//for (auto j = 0; j < dimension; j++) {
// double val = arr[j];
// vector_length += val * val;
//}
//vector_length = std::sqrt(vector_length);
//if (vector_length < 1e-6) {
// auto val = (float) (1.0 / std::sqrt((double) dimension));
// for (int j = 0; j < dimension; j++) arr[j] = val;
//} else {
// for (int j = 0; j < dimension; j++) arr[j] = (float) (arr[j] / vector_length);
//}
}
} // namespace knowhere
} // namespace zilliz
//}
//
//void
//NormalizePreprocessor::Normalize(float *arr, int64_t dimension) {
// double vector_length = 0;
// for (auto j = 0; j < dimension; j++) {
// double val = arr[j];
// vector_length += val * val;
// }
// vector_length = std::sqrt(vector_length);
// if (vector_length < 1e-6) {
// auto val = (float) (1.0 / std::sqrt((double) dimension));
// for (int j = 0; j < dimension; j++) arr[j] = val;
// } else {
// for (int j = 0; j < dimension; j++) arr[j] = (float) (arr[j] / vector_length);
// }
//}
//
//} // namespace knowhere
//} // namespace zilliz
......@@ -9,7 +9,7 @@
#include "knowhere/index/vector_index/cpu_kdt_rng.h"
#include "knowhere/index/vector_index/definitions.h"
#include "knowhere/index/preprocessor/normalize.h"
//#include "knowhere/index/preprocessor/normalize.h"
#include "knowhere/index/vector_index/kdt_parameters.h"
#include "knowhere/adapter/sptag.h"
#include "knowhere/common/exception.h"
......@@ -60,10 +60,10 @@ CPUKDTRNG::Load(const BinarySet &binary_set) {
index_ptr_->LoadIndexFromMemory(index_blobs);
}
PreprocessorPtr
CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) {
return std::make_shared<NormalizePreprocessor>();
}
//PreprocessorPtr
//CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) {
// return std::make_shared<NormalizePreprocessor>();
//}
IndexModelPtr
CPUKDTRNG::Train(const DatasetPtr &origin, const Config &train_config) {
......@@ -72,7 +72,7 @@ CPUKDTRNG::Train(const DatasetPtr &origin, const Config &train_config) {
//if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// && preprocessor_) {
preprocessor_->Preprocess(dataset);
// preprocessor_->Preprocess(dataset);
//}
auto vectorset = ConvertToVectorSet(dataset);
......@@ -90,7 +90,7 @@ CPUKDTRNG::Add(const DatasetPtr &origin, const Config &add_config) {
//if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// && preprocessor_) {
preprocessor_->Preprocess(dataset);
// preprocessor_->Preprocess(dataset);
//}
auto vectorset = ConvertToVectorSet(dataset);
......
......@@ -8,6 +8,7 @@
#include <iostream>
#include <sstream>
#include "knowhere/common/exception.h"
#include "knowhere/index/vector_index/cpu_kdt_rng.h"
#include "knowhere/index/vector_index/definitions.h"
......@@ -125,6 +126,10 @@ TEST_P(KDTTest, kdt_serialize) {
auto result = new_index->Search(query_dataset, search_cfg);
AssertAnns(result, nq, k);
PrintResult(result, nq, k);
ASSERT_EQ(new_index->Count(), nb);
ASSERT_EQ(new_index->Dimension(), dim);
ASSERT_THROW({new_index->Clone();}, zilliz::knowhere::KnowhereException);
ASSERT_NO_THROW({new_index->Seal();});
{
int fileno = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册