提交 48e296ca 编写于 作者: T test

Merge branch 'branch-0.5.0' of http://192.168.1.105:6060/megasearch/milvus into branch-0.5.0


Former-commit-id: dc9986a7dd5d09aabf5bc16f4840a2cac4538848
......@@ -20,9 +20,9 @@ Before you make any contributions, make sure you follow this list.
Contributions to Milvus fall into the following categories.
1. To report a bug or a problem with documentation, please file an [issue](https://github.com/milvus-io/milvus/issues/new) providing the details of the problem. If you believe the issue needs priority attention, please comment on the issue to notify the team.
2. To propose a new feature, please file a new feature request [issue](https://github.com/milvus-io/milvus/issues/new). Describe the intended feature and discuss the design and implementation with the team and community. Once the team agrees that the plan looks good, go ahead and implement it, following the [Contributing code].
3. To implement a feature or bug-fix for an existing outstanding issue, follow the [Contributing code]. If you need more context on a particular issue, comment on the issue to let people know.
1. To report a bug or a problem with documentation, please file an [issue](https://github.com/milvus-io/milvus/issues/new/choose) providing the details of the problem. If you believe the issue needs priority attention, please comment on the issue to notify the team.
2. To propose a new feature, please file a new feature request [issue](https://github.com/milvus-io/milvus/issues/new/choose). Describe the intended feature and discuss the design and implementation with the team and community. Once the team agrees that the plan looks good, go ahead and implement it, following the [Contributing code](CONTRIBUTING.md#contributing-code).
3. To implement a feature or bug-fix for an existing outstanding issue, follow the [Contributing code](CONTRIBUTING.md#contributing-code). If you need more context on a particular issue, comment on the issue to let people know.
## How can I contribute?
......@@ -44,6 +44,7 @@ Before sending your pull requests for review, make sure your changes are consist
## Coding Style
## Run unit test
```shell
......
......@@ -99,11 +99,8 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
}
Status ExecutionEngineImpl::AddWithIds(long n, const float *xdata, const long *xids) {
auto ec = index_->Add(n, xdata, xids);
if (ec != KNOWHERE_SUCCESS) {
return Status(DB_ERROR, "Add error");
}
return Status::OK();
auto status = index_->Add(n, xdata, xids);
return status;
}
size_t ExecutionEngineImpl::Count() const {
......@@ -131,11 +128,8 @@ size_t ExecutionEngineImpl::PhysicalSize() const {
}
Status ExecutionEngineImpl::Serialize() {
auto ec = write_index(index_, location_);
if (ec != KNOWHERE_SUCCESS) {
return Status(DB_ERROR, "Serialize: write to disk error");
}
return Status::OK();
auto status = write_index(index_, location_);
return status;
}
Status ExecutionEngineImpl::Load(bool to_cache) {
......@@ -254,12 +248,11 @@ Status ExecutionEngineImpl::Merge(const std::string &location) {
}
if (auto file_index = std::dynamic_pointer_cast<BFIndex>(to_merge)) {
auto ec = index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
if (ec != KNOWHERE_SUCCESS) {
auto status = index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
if (!status.ok()) {
ENGINE_LOG_ERROR << "Merge: Add Error";
return Status(DB_ERROR, "Merge: Add Error");
}
return Status::OK();
return status;
} else {
return Status(DB_ERROR, "file index type is not idmap");
}
......@@ -287,11 +280,11 @@ ExecutionEngineImpl::BuildIndex(const std::string &location, EngineType engine_t
build_cfg["nlist"] = nlist_;
AutoGenParams(to_index->GetType(), Count(), build_cfg);
auto ec = to_index->BuildAll(Count(),
auto status = to_index->BuildAll(Count(),
from_index->GetRawVectors(),
from_index->GetRawIds(),
build_cfg);
if (ec != KNOWHERE_SUCCESS) { throw Exception(DB_ERROR, "Build index error"); }
if (!status.ok()) { throw Exception(DB_ERROR, status.message()); }
return std::make_shared<ExecutionEngineImpl>(to_index, location, engine_type, metric_type_, nlist_);
}
......@@ -309,12 +302,11 @@ Status ExecutionEngineImpl::Search(long n,
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
auto cfg = Config::object{{"k", k}, {"nprobe", nprobe}};
auto ec = index_->Search(n, data, distances, labels, cfg);
if (ec != KNOWHERE_SUCCESS) {
auto status = index_->Search(n, data, distances, labels, cfg);
if (!status.ok()) {
ENGINE_LOG_ERROR << "Search error";
return Status(DB_ERROR, "Search: Search Error");
}
return Status::OK();
return status;
}
Status ExecutionEngineImpl::Cache() {
......
......@@ -28,7 +28,8 @@ namespace engine {
constexpr int64_t M_BYTE = 1024 * 1024;
ErrorCode KnowhereResource::Initialize() {
Status
KnowhereResource::Initialize() {
struct GpuResourceSetting {
int64_t pinned_memory = 300*M_BYTE;
int64_t temp_memory = 300*M_BYTE;
......@@ -65,12 +66,13 @@ ErrorCode KnowhereResource::Initialize() {
iter->second.resource_num);
}
return KNOWHERE_SUCCESS;
return Status::OK();
}
ErrorCode KnowhereResource::Finalize() {
Status
KnowhereResource::Finalize() {
knowhere::FaissGpuResourceMgr::GetInstance().Free(); // free gpu resource.
return KNOWHERE_SUCCESS;
return Status::OK();
}
}
......
......@@ -18,7 +18,7 @@
#pragma once
#include "utils/Error.h"
#include "utils/Status.h"
namespace zilliz {
namespace milvus {
......@@ -26,8 +26,11 @@ namespace engine {
class KnowhereResource {
public:
static ErrorCode Initialize();
static ErrorCode Finalize();
static Status
Initialize();
static Status
Finalize();
};
......
......@@ -21,7 +21,6 @@
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/helpers/Cloner.h"
#include "vec_impl.h"
#include "data_transfer.h"
......@@ -32,12 +31,13 @@ namespace engine {
using namespace zilliz::knowhere;
ErrorCode VecIndexImpl::BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) {
Status
VecIndexImpl::BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) {
try {
dim = cfg["dim"].as<int>();
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
......@@ -49,36 +49,38 @@ ErrorCode VecIndexImpl::BuildAll(const long &nb,
index_->Add(dataset, cfg);
} catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR;
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT;
return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR;
return Status(KNOWHERE_ERROR, e.what());
}
return KNOWHERE_SUCCESS;
return Status::OK();
}
ErrorCode VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
Status
VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
try {
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
index_->Add(dataset, cfg);
} catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR;
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT;
return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR;
return Status(KNOWHERE_ERROR, e.what());
}
return KNOWHERE_SUCCESS;
return Status::OK();
}
ErrorCode VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) {
Status
VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) {
try {
auto k = cfg["k"].as<int>();
auto dataset = GenDataset(nq, dim, xq);
......@@ -117,41 +119,47 @@ ErrorCode VecIndexImpl::Search(const long &nq, const float *xq, float *dist, lon
} catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR;
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT;
return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR;
return Status(KNOWHERE_ERROR, e.what());
}
return KNOWHERE_SUCCESS;
return Status::OK();
}
zilliz::knowhere::BinarySet VecIndexImpl::Serialize() {
zilliz::knowhere::BinarySet
VecIndexImpl::Serialize() {
type = ConvertToCpuIndexType(type);
return index_->Serialize();
}
ErrorCode VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
Status
VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
index_->Load(index_binary);
dim = Dimension();
return KNOWHERE_SUCCESS;
return Status::OK();
}
int64_t VecIndexImpl::Dimension() {
int64_t
VecIndexImpl::Dimension() {
return index_->Dimension();
}
int64_t VecIndexImpl::Count() {
int64_t
VecIndexImpl::Count() {
return index_->Count();
}
IndexType VecIndexImpl::GetType() {
IndexType
VecIndexImpl::GetType() {
return type;
}
VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
VecIndexPtr
VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
// TODO(linxj): exception handle
auto gpu_index = zilliz::knowhere::cloner::CopyCpuToGpu(index_, device_id, cfg);
auto new_index = std::make_shared<VecIndexImpl>(gpu_index, ConvertToGpuIndexType(type));
......@@ -159,7 +167,8 @@ VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg)
return new_index;
}
VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) {
VecIndexPtr
VecIndexImpl::CopyToCpu(const Config &cfg) {
// TODO(linxj): exception handle
auto cpu_index = zilliz::knowhere::cloner::CopyGpuToCpu(index_, cfg);
auto new_index = std::make_shared<VecIndexImpl>(cpu_index, ConvertToCpuIndexType(type));
......@@ -167,32 +176,37 @@ VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) {
return new_index;
}
VecIndexPtr VecIndexImpl::Clone() {
VecIndexPtr
VecIndexImpl::Clone() {
// TODO(linxj): exception handle
auto clone_index = std::make_shared<VecIndexImpl>(index_->Clone(), type);
clone_index->dim = dim;
return clone_index;
}
int64_t VecIndexImpl::GetDeviceId() {
if (auto device_idx = std::dynamic_pointer_cast<GPUIndex>(index_)){
int64_t
VecIndexImpl::GetDeviceId() {
if (auto device_idx = std::dynamic_pointer_cast<GPUIndex>(index_)) {
return device_idx->GetGpuDevice();
}
// else
return -1; // -1 == cpu
}
float *BFIndex::GetRawVectors() {
float *
BFIndex::GetRawVectors() {
auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_);
if (raw_index) { return raw_index->GetRawVectors(); }
return nullptr;
}
int64_t *BFIndex::GetRawIds() {
int64_t *
BFIndex::GetRawIds() {
return std::static_pointer_cast<IDMAP>(index_)->GetRawIds();
}
ErrorCode BFIndex::Build(const Config &cfg) {
ErrorCode
BFIndex::Build(const Config &cfg) {
try {
dim = cfg["dim"].as<int>();
std::static_pointer_cast<IDMAP>(index_)->Train(cfg);
......@@ -209,12 +223,13 @@ ErrorCode BFIndex::Build(const Config &cfg) {
return KNOWHERE_SUCCESS;
}
ErrorCode BFIndex::BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) {
Status
BFIndex::BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) {
try {
dim = cfg["dim"].as<int>();
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
......@@ -223,24 +238,25 @@ ErrorCode BFIndex::BuildAll(const long &nb,
index_->Add(dataset, cfg);
} catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR;
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT;
return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR;
return Status(KNOWHERE_ERROR, e.what());
}
return KNOWHERE_SUCCESS;
return Status::OK();
}
// TODO(linxj): add lock here.
ErrorCode IVFMixIndex::BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) {
Status
IVFMixIndex::BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) {
try {
dim = cfg["dim"].as<int>();
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
......@@ -257,26 +273,27 @@ ErrorCode IVFMixIndex::BuildAll(const long &nb,
type = ConvertToCpuIndexType(type);
} else {
WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed";
return KNOWHERE_ERROR;
return Status(KNOWHERE_ERROR, "Build IVFMIXIndex Failed");
}
} catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR;
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT;
return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR;
return Status(KNOWHERE_ERROR, e.what());
}
return KNOWHERE_SUCCESS;
return Status::OK();
}
ErrorCode IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
Status
IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
//index_ = std::make_shared<IVF>();
index_->Load(index_binary);
dim = Dimension();
return KNOWHERE_SUCCESS;
return Status::OK();
}
}
......
......@@ -19,7 +19,6 @@
#pragma once
#include "knowhere/index/vector_index/VectorIndex.h"
#include "vec_index.h"
......@@ -31,27 +30,53 @@ class VecIndexImpl : public VecIndex {
public:
explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type)
: index_(std::move(index)), type(type) {};
ErrorCode BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
VecIndexPtr CopyToGpu(const int64_t &device_id, const Config &cfg) override;
VecIndexPtr CopyToCpu(const Config &cfg) override;
IndexType GetType() override;
int64_t Dimension() override;
int64_t Count() override;
ErrorCode Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override;
zilliz::knowhere::BinarySet Serialize() override;
ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) override;
VecIndexPtr Clone() override;
int64_t GetDeviceId() override;
ErrorCode Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) override;
Status
BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
VecIndexPtr
CopyToGpu(const int64_t &device_id, const Config &cfg) override;
VecIndexPtr
CopyToCpu(const Config &cfg) override;
IndexType
GetType() override;
int64_t
Dimension() override;
int64_t
Count() override;
Status
Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override;
zilliz::knowhere::BinarySet
Serialize() override;
Status
Load(const zilliz::knowhere::BinarySet &index_binary) override;
VecIndexPtr
Clone() override;
int64_t
GetDeviceId() override;
Status
Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) override;
protected:
int64_t dim = 0;
IndexType type = IndexType::INVALID;
std::shared_ptr<zilliz::knowhere::VectorIndex> index_ = nullptr;
};
......@@ -60,28 +85,39 @@ class IVFMixIndex : public VecIndexImpl {
explicit IVFMixIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type)
: VecIndexImpl(std::move(index), type) {};
ErrorCode BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) override;
Status
BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
Status
Load(const zilliz::knowhere::BinarySet &index_binary) override;
};
class BFIndex : public VecIndexImpl {
public:
explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index),
IndexType::FAISS_IDMAP) {};
ErrorCode Build(const Config& cfg);
float *GetRawVectors();
ErrorCode BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
int64_t *GetRawIds();
ErrorCode
Build(const Config &cfg);
float *
GetRawVectors();
Status
BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
int64_t *
GetRawIds();
};
}
......
......@@ -25,7 +25,6 @@
#include "knowhere/index/vector_index/IndexKDT.h"
#include "knowhere/index/vector_index/IndexNSG.h"
#include "knowhere/common/Exception.h"
#include "vec_index.h"
#include "vec_impl.h"
#include "utils/Log.h"
......@@ -39,23 +38,19 @@ namespace engine {
static constexpr float TYPICAL_COUNT = 1000000.0;
struct FileIOWriter {
std::fstream fs;
std::string name;
FileIOWriter(const std::string &fname);
~FileIOWriter();
size_t operator()(void *ptr, size_t size);
};
struct FileIOReader {
std::fstream fs;
std::string name;
FileIOReader(const std::string &fname);
~FileIOReader();
size_t operator()(void *ptr, size_t size);
size_t operator()(void *ptr, size_t size, size_t pos);
size_t
operator()(void *ptr, size_t size);
size_t
operator()(void *ptr, size_t size, size_t pos);
};
FileIOReader::FileIOReader(const std::string &fname) {
......@@ -67,14 +62,27 @@ FileIOReader::~FileIOReader() {
fs.close();
}
size_t FileIOReader::operator()(void *ptr, size_t size) {
size_t
FileIOReader::operator()(void *ptr, size_t size) {
fs.read(reinterpret_cast<char *>(ptr), size);
}
size_t FileIOReader::operator()(void *ptr, size_t size, size_t pos) {
size_t
FileIOReader::operator()(void *ptr, size_t size, size_t pos) {
return 0;
}
struct FileIOWriter {
std::fstream fs;
std::string name;
FileIOWriter(const std::string &fname);
~FileIOWriter();
size_t operator()(void *ptr, size_t size);
};
FileIOWriter::FileIOWriter(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::out | std::ios::binary);
......@@ -84,12 +92,14 @@ FileIOWriter::~FileIOWriter() {
fs.close();
}
size_t FileIOWriter::operator()(void *ptr, size_t size) {
size_t
FileIOWriter::operator()(void *ptr, size_t size) {
fs.write(reinterpret_cast<char *>(ptr), size);
}
VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config &cfg) {
VecIndexPtr
GetVecIndexFactory(const IndexType &type, const Config &cfg) {
std::shared_ptr<zilliz::knowhere::VectorIndex> index;
auto gpu_device = cfg.get_with_default("gpu_id", 0);
switch (type) {
......@@ -145,13 +155,15 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config &cfg) {
return std::make_shared<VecIndexImpl>(index, type);
}
VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) {
VecIndexPtr
LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) {
auto index = GetVecIndexFactory(index_type);
index->Load(index_binary);
return index;
}
VecIndexPtr read_index(const std::string &location) {
VecIndexPtr
read_index(const std::string &location) {
knowhere::BinarySet load_data_list;
FileIOReader reader(location);
reader.fs.seekg(0, reader.fs.end);
......@@ -195,7 +207,8 @@ VecIndexPtr read_index(const std::string &location) {
return LoadVecIndex(current_type, load_data_list);
}
ErrorCode write_index(VecIndexPtr index, const std::string &location) {
Status
write_index(VecIndexPtr index, const std::string &location) {
try {
auto binaryset = index->Serialize();
auto index_type = index->GetType();
......@@ -215,28 +228,29 @@ ErrorCode write_index(VecIndexPtr index, const std::string &location) {
}
} catch (knowhere::KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR;
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what();
std::string estring(e.what());
if (estring.find("No space left on device") != estring.npos) {
WRAPPER_LOG_ERROR << "No space left on the device";
return KNOWHERE_NO_SPACE;
return Status(KNOWHERE_NO_SPACE, "No space left on the device");
} else {
return KNOWHERE_ERROR;
return Status(KNOWHERE_ERROR, e.what());
}
}
return KNOWHERE_SUCCESS;
return Status::OK();
}
// TODO(linxj): redo here.
void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Config &cfg) {
void
AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Config &cfg) {
auto nlist = cfg.get_with_default("nlist", 0);
if (size <= TYPICAL_COUNT / 16384 + 1) {
//handle less row count, avoid nlist set to 0
cfg["nlist"] = 1;
} else if (int(size / TYPICAL_COUNT) *nlist == 0) {
} else if (int(size / TYPICAL_COUNT) * nlist == 0) {
//calculate a proper nlist if nlist not specified or size less than TYPICAL_COUNT
cfg["nlist"] = int(size / TYPICAL_COUNT * 16384);
}
......@@ -270,7 +284,8 @@ void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Co
#define GPU_MAX_NRPOBE 1024
#endif
void ParameterValidation(const IndexType &type, Config &cfg) {
void
ParameterValidation(const IndexType &type, Config &cfg) {
switch (type) {
case IndexType::FAISS_IVFSQ8_GPU:
case IndexType::FAISS_IVFFLAT_GPU:
......@@ -291,7 +306,8 @@ void ParameterValidation(const IndexType &type, Config &cfg) {
}
}
IndexType ConvertToCpuIndexType(const IndexType &type) {
IndexType
ConvertToCpuIndexType(const IndexType &type) {
// TODO(linxj): add IDMAP
switch (type) {
case IndexType::FAISS_IVFFLAT_GPU:
......@@ -308,7 +324,8 @@ IndexType ConvertToCpuIndexType(const IndexType &type) {
}
}
IndexType ConvertToGpuIndexType(const IndexType &type) {
IndexType
ConvertToGpuIndexType(const IndexType &type) {
switch (type) {
case IndexType::FAISS_IVFFLAT_MIX:
case IndexType::FAISS_IVFFLAT_CPU: {
......
......@@ -21,8 +21,7 @@
#include <string>
#include <memory>
#include "utils/Error.h"
#include "utils/Status.h"
#include "knowhere/common/Config.h"
#include "knowhere/common/BinarySet.h"
......@@ -50,62 +49,84 @@ enum class IndexType {
};
class VecIndex;
using VecIndexPtr = std::shared_ptr<VecIndex>;
class VecIndex {
public:
virtual ErrorCode BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt = 0,
const float *xt = nullptr) = 0;
virtual ErrorCode Add(const long &nb,
const float *xb,
const long *ids,
const Config &cfg = Config()) = 0;
virtual ErrorCode Search(const long &nq,
const float *xq,
float *dist,
long *ids,
const Config &cfg = Config()) = 0;
virtual VecIndexPtr CopyToGpu(const int64_t &device_id,
const Config &cfg = Config()) = 0;
virtual VecIndexPtr CopyToCpu(const Config &cfg = Config()) = 0;
virtual VecIndexPtr Clone() = 0;
virtual int64_t GetDeviceId() = 0;
virtual IndexType GetType() = 0;
virtual int64_t Dimension() = 0;
virtual int64_t Count() = 0;
virtual zilliz::knowhere::BinarySet Serialize() = 0;
virtual ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) = 0;
virtual Status
BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt = 0,
const float *xt = nullptr) = 0;
virtual Status
Add(const long &nb,
const float *xb,
const long *ids,
const Config &cfg = Config()) = 0;
virtual Status
Search(const long &nq,
const float *xq,
float *dist,
long *ids,
const Config &cfg = Config()) = 0;
virtual VecIndexPtr
CopyToGpu(const int64_t &device_id,
const Config &cfg = Config()) = 0;
virtual VecIndexPtr
CopyToCpu(const Config &cfg = Config()) = 0;
virtual VecIndexPtr
Clone() = 0;
virtual int64_t
GetDeviceId() = 0;
virtual IndexType
GetType() = 0;
virtual int64_t
Dimension() = 0;
virtual int64_t
Count() = 0;
virtual zilliz::knowhere::BinarySet
Serialize() = 0;
virtual Status
Load(const zilliz::knowhere::BinarySet &index_binary) = 0;
};
extern ErrorCode write_index(VecIndexPtr index, const std::string &location);
extern Status
write_index(VecIndexPtr index, const std::string &location);
extern VecIndexPtr
read_index(const std::string &location);
extern VecIndexPtr read_index(const std::string &location);
extern VecIndexPtr
GetVecIndexFactory(const IndexType &type, const Config &cfg = Config());
extern VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config &cfg = Config());
extern VecIndexPtr
LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary);
extern VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary);
extern void
AutoGenParams(const IndexType &type, const long &size, Config &cfg);
extern void AutoGenParams(const IndexType &type, const long &size, Config &cfg);
extern void
ParameterValidation(const IndexType &type, Config &cfg);
extern void ParameterValidation(const IndexType &type, Config &cfg);
extern IndexType
ConvertToCpuIndexType(const IndexType &type);
extern IndexType ConvertToCpuIndexType(const IndexType &type);
extern IndexType ConvertToGpuIndexType(const IndexType &type);
extern IndexType
ConvertToGpuIndexType(const IndexType &type);
}
}
......
......@@ -17,8 +17,7 @@
#include "scheduler/Scheduler.h"
#include <gtest/gtest.h>
#include <src/scheduler/tasklabel/DefaultLabel.h>
#include <src/server/ServerConfig.h>
#include "src/scheduler/tasklabel/DefaultLabel.h"
#include "cache/DataObj.h"
#include "cache/GpuCacheMgr.h"
#include "scheduler/task/TestTask.h"
......@@ -35,13 +34,12 @@ namespace engine {
class MockVecIndex : public engine::VecIndex {
public:
virtual ErrorCode BuildAll(const long &nb,
virtual Status BuildAll(const long &nb,
const float *xb,
const long *ids,
const engine::Config &cfg,
const long &nt = 0,
const float *xt = nullptr) {
}
engine::VecIndexPtr Clone() override {
......@@ -56,14 +54,14 @@ public:
return engine::IndexType::INVALID;
}
virtual ErrorCode Add(const long &nb,
virtual Status Add(const long &nb,
const float *xb,
const long *ids,
const engine::Config &cfg = engine::Config()) {
}
virtual ErrorCode Search(const long &nq,
virtual Status Search(const long &nq,
const float *xq,
float *dist,
long *ids,
......@@ -92,7 +90,7 @@ public:
return binset;
}
virtual ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) {
virtual Status Load(const zilliz::knowhere::BinarySet &index_binary) {
}
......
......@@ -19,7 +19,6 @@
#include "cache/CpuCacheMgr.h"
#include "cache/GpuCacheMgr.h"
#include "server/ServerConfig.h"
#include "utils/Error.h"
#include "src/wrapper/vec_index.h"
......@@ -48,13 +47,13 @@ public:
}
virtual ErrorCode BuildAll(const long &nb,
virtual Status BuildAll(const long &nb,
const float *xb,
const long *ids,
const engine::Config &cfg,
const long &nt = 0,
const float *xt = nullptr) {
return 0;
return Status();
}
engine::VecIndexPtr Clone() override {
......@@ -69,19 +68,19 @@ public:
return engine::IndexType::INVALID;
}
virtual ErrorCode Add(const long &nb,
virtual Status Add(const long &nb,
const float *xb,
const long *ids,
const engine::Config &cfg = engine::Config()) {
return 0;
return Status();
}
virtual ErrorCode Search(const long &nq,
virtual Status Search(const long &nq,
const float *xq,
float *dist,
long *ids,
const engine::Config &cfg = engine::Config()) {
return 0;
return Status();
}
engine::VecIndexPtr CopyToGpu(const int64_t &device_id,
......@@ -106,8 +105,8 @@ public:
return binset;
}
virtual ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) {
return 0;
virtual Status Load(const zilliz::knowhere::BinarySet &index_binary) {
return Status();
}
public:
......
......@@ -28,6 +28,7 @@ set(wrapper_files
set(util_files
utils.cpp
${MILVUS_ENGINE_SRC}/utils/easylogging++.cc
${MILVUS_ENGINE_SRC}/utils/Status.cpp
)
set(knowhere_libs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册