提交 72885057 编写于 作者: J jinhai

Merge branch 'Refactor_Knowhere' into 'branch-0.5.0'

MS-583 Change to Status from errorcode

See merge request megasearch/milvus!599

Former-commit-id: 1beb12a93ae6576ab0d991cfdb47847f7989e477
......@@ -99,11 +99,8 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
}
Status ExecutionEngineImpl::AddWithIds(long n, const float *xdata, const long *xids) {
auto ec = index_->Add(n, xdata, xids);
if (ec != KNOWHERE_SUCCESS) {
return Status(DB_ERROR, "Add error");
}
return Status::OK();
auto status = index_->Add(n, xdata, xids);
return status;
}
size_t ExecutionEngineImpl::Count() const {
......@@ -131,11 +128,8 @@ size_t ExecutionEngineImpl::PhysicalSize() const {
}
Status ExecutionEngineImpl::Serialize() {
auto ec = write_index(index_, location_);
if (ec != KNOWHERE_SUCCESS) {
return Status(DB_ERROR, "Serialize: write to disk error");
}
return Status::OK();
auto status = write_index(index_, location_);
return status;
}
Status ExecutionEngineImpl::Load(bool to_cache) {
......@@ -254,12 +248,11 @@ Status ExecutionEngineImpl::Merge(const std::string &location) {
}
if (auto file_index = std::dynamic_pointer_cast<BFIndex>(to_merge)) {
auto ec = index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
if (ec != KNOWHERE_SUCCESS) {
auto status = index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
if (!status.ok()) {
ENGINE_LOG_ERROR << "Merge: Add Error";
return Status(DB_ERROR, "Merge: Add Error");
}
return Status::OK();
return status;
} else {
return Status(DB_ERROR, "file index type is not idmap");
}
......@@ -287,11 +280,11 @@ ExecutionEngineImpl::BuildIndex(const std::string &location, EngineType engine_t
build_cfg["nlist"] = nlist_;
AutoGenParams(to_index->GetType(), Count(), build_cfg);
auto ec = to_index->BuildAll(Count(),
auto status = to_index->BuildAll(Count(),
from_index->GetRawVectors(),
from_index->GetRawIds(),
build_cfg);
if (ec != KNOWHERE_SUCCESS) { throw Exception(DB_ERROR, "Build index error"); }
if (!status.ok()) { throw Exception(DB_ERROR, status.message()); }
return std::make_shared<ExecutionEngineImpl>(to_index, location, engine_type, metric_type_, nlist_);
}
......@@ -309,12 +302,11 @@ Status ExecutionEngineImpl::Search(long n,
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
auto cfg = Config::object{{"k", k}, {"nprobe", nprobe}};
auto ec = index_->Search(n, data, distances, labels, cfg);
if (ec != KNOWHERE_SUCCESS) {
auto status = index_->Search(n, data, distances, labels, cfg);
if (!status.ok()) {
ENGINE_LOG_ERROR << "Search error";
return Status(DB_ERROR, "Search: Search Error");
}
return Status::OK();
return status;
}
Status ExecutionEngineImpl::Cache() {
......
......@@ -28,7 +28,8 @@ namespace engine {
constexpr int64_t M_BYTE = 1024 * 1024;
ErrorCode KnowhereResource::Initialize() {
Status
KnowhereResource::Initialize() {
struct GpuResourceSetting {
int64_t pinned_memory = 300*M_BYTE;
int64_t temp_memory = 300*M_BYTE;
......@@ -65,12 +66,13 @@ ErrorCode KnowhereResource::Initialize() {
iter->second.resource_num);
}
return KNOWHERE_SUCCESS;
return Status::OK();
}
ErrorCode KnowhereResource::Finalize() {
Status
KnowhereResource::Finalize() {
knowhere::FaissGpuResourceMgr::GetInstance().Free(); // free gpu resource.
return KNOWHERE_SUCCESS;
return Status::OK();
}
}
......
......@@ -18,7 +18,7 @@
#pragma once
#include "utils/Error.h"
#include "utils/Status.h"
namespace zilliz {
namespace milvus {
......@@ -26,8 +26,11 @@ namespace engine {
class KnowhereResource {
public:
static ErrorCode Initialize();
static ErrorCode Finalize();
static Status
Initialize();
static Status
Finalize();
};
......
......@@ -21,7 +21,6 @@
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/helpers/Cloner.h"
#include "vec_impl.h"
#include "data_transfer.h"
......@@ -32,12 +31,13 @@ namespace engine {
using namespace zilliz::knowhere;
ErrorCode VecIndexImpl::BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) {
Status
VecIndexImpl::BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) {
try {
dim = cfg["dim"].as<int>();
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
......@@ -49,36 +49,38 @@ ErrorCode VecIndexImpl::BuildAll(const long &nb,
index_->Add(dataset, cfg);
} catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR;
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT;
return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR;
return Status(KNOWHERE_ERROR, e.what());
}
return KNOWHERE_SUCCESS;
return Status::OK();
}
ErrorCode VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
Status
VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
try {
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
index_->Add(dataset, cfg);
} catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR;
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT;
return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR;
return Status(KNOWHERE_ERROR, e.what());
}
return KNOWHERE_SUCCESS;
return Status::OK();
}
ErrorCode VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) {
Status
VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) {
try {
auto k = cfg["k"].as<int>();
auto dataset = GenDataset(nq, dim, xq);
......@@ -117,41 +119,47 @@ ErrorCode VecIndexImpl::Search(const long &nq, const float *xq, float *dist, lon
} catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR;
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT;
return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR;
return Status(KNOWHERE_ERROR, e.what());
}
return KNOWHERE_SUCCESS;
return Status::OK();
}
zilliz::knowhere::BinarySet VecIndexImpl::Serialize() {
zilliz::knowhere::BinarySet
VecIndexImpl::Serialize() {
type = ConvertToCpuIndexType(type);
return index_->Serialize();
}
ErrorCode VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
Status
VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
index_->Load(index_binary);
dim = Dimension();
return KNOWHERE_SUCCESS;
return Status::OK();
}
int64_t VecIndexImpl::Dimension() {
int64_t
VecIndexImpl::Dimension() {
return index_->Dimension();
}
int64_t VecIndexImpl::Count() {
int64_t
VecIndexImpl::Count() {
return index_->Count();
}
IndexType VecIndexImpl::GetType() {
IndexType
VecIndexImpl::GetType() {
return type;
}
VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
VecIndexPtr
VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
// TODO(linxj): exception handle
auto gpu_index = zilliz::knowhere::cloner::CopyCpuToGpu(index_, device_id, cfg);
auto new_index = std::make_shared<VecIndexImpl>(gpu_index, ConvertToGpuIndexType(type));
......@@ -159,7 +167,8 @@ VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg)
return new_index;
}
VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) {
VecIndexPtr
VecIndexImpl::CopyToCpu(const Config &cfg) {
// TODO(linxj): exception handle
auto cpu_index = zilliz::knowhere::cloner::CopyGpuToCpu(index_, cfg);
auto new_index = std::make_shared<VecIndexImpl>(cpu_index, ConvertToCpuIndexType(type));
......@@ -167,32 +176,37 @@ VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) {
return new_index;
}
VecIndexPtr VecIndexImpl::Clone() {
VecIndexPtr
VecIndexImpl::Clone() {
// TODO(linxj): exception handle
auto clone_index = std::make_shared<VecIndexImpl>(index_->Clone(), type);
clone_index->dim = dim;
return clone_index;
}
int64_t VecIndexImpl::GetDeviceId() {
if (auto device_idx = std::dynamic_pointer_cast<GPUIndex>(index_)){
int64_t
VecIndexImpl::GetDeviceId() {
if (auto device_idx = std::dynamic_pointer_cast<GPUIndex>(index_)) {
return device_idx->GetGpuDevice();
}
// else
return -1; // -1 == cpu
}
float *BFIndex::GetRawVectors() {
float *
BFIndex::GetRawVectors() {
auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_);
if (raw_index) { return raw_index->GetRawVectors(); }
return nullptr;
}
int64_t *BFIndex::GetRawIds() {
int64_t *
BFIndex::GetRawIds() {
return std::static_pointer_cast<IDMAP>(index_)->GetRawIds();
}
ErrorCode BFIndex::Build(const Config &cfg) {
ErrorCode
BFIndex::Build(const Config &cfg) {
try {
dim = cfg["dim"].as<int>();
std::static_pointer_cast<IDMAP>(index_)->Train(cfg);
......@@ -209,12 +223,13 @@ ErrorCode BFIndex::Build(const Config &cfg) {
return KNOWHERE_SUCCESS;
}
ErrorCode BFIndex::BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) {
Status
BFIndex::BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) {
try {
dim = cfg["dim"].as<int>();
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
......@@ -223,24 +238,25 @@ ErrorCode BFIndex::BuildAll(const long &nb,
index_->Add(dataset, cfg);
} catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR;
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT;
return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR;
return Status(KNOWHERE_ERROR, e.what());
}
return KNOWHERE_SUCCESS;
return Status::OK();
}
// TODO(linxj): add lock here.
ErrorCode IVFMixIndex::BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) {
Status
IVFMixIndex::BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) {
try {
dim = cfg["dim"].as<int>();
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
......@@ -257,26 +273,27 @@ ErrorCode IVFMixIndex::BuildAll(const long &nb,
type = ConvertToCpuIndexType(type);
} else {
WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed";
return KNOWHERE_ERROR;
return Status(KNOWHERE_ERROR, "Build IVFMIXIndex Failed");
}
} catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR;
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT;
return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR;
return Status(KNOWHERE_ERROR, e.what());
}
return KNOWHERE_SUCCESS;
return Status::OK();
}
ErrorCode IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
Status
IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
//index_ = std::make_shared<IVF>();
index_->Load(index_binary);
dim = Dimension();
return KNOWHERE_SUCCESS;
return Status::OK();
}
}
......
......@@ -19,7 +19,6 @@
#pragma once
#include "knowhere/index/vector_index/VectorIndex.h"
#include "vec_index.h"
......@@ -31,27 +30,53 @@ class VecIndexImpl : public VecIndex {
public:
explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type)
: index_(std::move(index)), type(type) {};
ErrorCode BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
VecIndexPtr CopyToGpu(const int64_t &device_id, const Config &cfg) override;
VecIndexPtr CopyToCpu(const Config &cfg) override;
IndexType GetType() override;
int64_t Dimension() override;
int64_t Count() override;
ErrorCode Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override;
zilliz::knowhere::BinarySet Serialize() override;
ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) override;
VecIndexPtr Clone() override;
int64_t GetDeviceId() override;
ErrorCode Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) override;
Status
BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
VecIndexPtr
CopyToGpu(const int64_t &device_id, const Config &cfg) override;
VecIndexPtr
CopyToCpu(const Config &cfg) override;
IndexType
GetType() override;
int64_t
Dimension() override;
int64_t
Count() override;
Status
Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override;
zilliz::knowhere::BinarySet
Serialize() override;
Status
Load(const zilliz::knowhere::BinarySet &index_binary) override;
VecIndexPtr
Clone() override;
int64_t
GetDeviceId() override;
Status
Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) override;
protected:
int64_t dim = 0;
IndexType type = IndexType::INVALID;
std::shared_ptr<zilliz::knowhere::VectorIndex> index_ = nullptr;
};
......@@ -60,28 +85,39 @@ class IVFMixIndex : public VecIndexImpl {
explicit IVFMixIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type)
: VecIndexImpl(std::move(index), type) {};
ErrorCode BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) override;
Status
BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
Status
Load(const zilliz::knowhere::BinarySet &index_binary) override;
};
class BFIndex : public VecIndexImpl {
public:
explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index),
IndexType::FAISS_IDMAP) {};
ErrorCode Build(const Config& cfg);
float *GetRawVectors();
ErrorCode BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
int64_t *GetRawIds();
ErrorCode
Build(const Config &cfg);
float *
GetRawVectors();
Status
BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
int64_t *
GetRawIds();
};
}
......
......@@ -25,7 +25,6 @@
#include "knowhere/index/vector_index/IndexKDT.h"
#include "knowhere/index/vector_index/IndexNSG.h"
#include "knowhere/common/Exception.h"
#include "vec_index.h"
#include "vec_impl.h"
#include "utils/Log.h"
......@@ -39,23 +38,19 @@ namespace engine {
static constexpr float TYPICAL_COUNT = 1000000.0;
struct FileIOWriter {
std::fstream fs;
std::string name;
FileIOWriter(const std::string &fname);
~FileIOWriter();
size_t operator()(void *ptr, size_t size);
};
struct FileIOReader {
std::fstream fs;
std::string name;
FileIOReader(const std::string &fname);
~FileIOReader();
size_t operator()(void *ptr, size_t size);
size_t operator()(void *ptr, size_t size, size_t pos);
size_t
operator()(void *ptr, size_t size);
size_t
operator()(void *ptr, size_t size, size_t pos);
};
FileIOReader::FileIOReader(const std::string &fname) {
......@@ -67,14 +62,27 @@ FileIOReader::~FileIOReader() {
fs.close();
}
size_t FileIOReader::operator()(void *ptr, size_t size) {
size_t
FileIOReader::operator()(void *ptr, size_t size) {
fs.read(reinterpret_cast<char *>(ptr), size);
}
size_t FileIOReader::operator()(void *ptr, size_t size, size_t pos) {
size_t
FileIOReader::operator()(void *ptr, size_t size, size_t pos) {
return 0;
}
struct FileIOWriter {
std::fstream fs;
std::string name;
FileIOWriter(const std::string &fname);
~FileIOWriter();
size_t operator()(void *ptr, size_t size);
};
FileIOWriter::FileIOWriter(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::out | std::ios::binary);
......@@ -84,12 +92,14 @@ FileIOWriter::~FileIOWriter() {
fs.close();
}
size_t FileIOWriter::operator()(void *ptr, size_t size) {
size_t
FileIOWriter::operator()(void *ptr, size_t size) {
fs.write(reinterpret_cast<char *>(ptr), size);
}
VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config &cfg) {
VecIndexPtr
GetVecIndexFactory(const IndexType &type, const Config &cfg) {
std::shared_ptr<zilliz::knowhere::VectorIndex> index;
auto gpu_device = cfg.get_with_default("gpu_id", 0);
switch (type) {
......@@ -145,13 +155,15 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config &cfg) {
return std::make_shared<VecIndexImpl>(index, type);
}
VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) {
VecIndexPtr
LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) {
auto index = GetVecIndexFactory(index_type);
index->Load(index_binary);
return index;
}
VecIndexPtr read_index(const std::string &location) {
VecIndexPtr
read_index(const std::string &location) {
knowhere::BinarySet load_data_list;
FileIOReader reader(location);
reader.fs.seekg(0, reader.fs.end);
......@@ -195,7 +207,8 @@ VecIndexPtr read_index(const std::string &location) {
return LoadVecIndex(current_type, load_data_list);
}
ErrorCode write_index(VecIndexPtr index, const std::string &location) {
Status
write_index(VecIndexPtr index, const std::string &location) {
try {
auto binaryset = index->Serialize();
auto index_type = index->GetType();
......@@ -215,28 +228,29 @@ ErrorCode write_index(VecIndexPtr index, const std::string &location) {
}
} catch (knowhere::KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR;
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what();
std::string estring(e.what());
if (estring.find("No space left on device") != estring.npos) {
WRAPPER_LOG_ERROR << "No space left on the device";
return KNOWHERE_NO_SPACE;
return Status(KNOWHERE_NO_SPACE, "No space left on the device");
} else {
return KNOWHERE_ERROR;
return Status(KNOWHERE_ERROR, e.what());
}
}
return KNOWHERE_SUCCESS;
return Status::OK();
}
// TODO(linxj): redo here.
void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Config &cfg) {
void
AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Config &cfg) {
auto nlist = cfg.get_with_default("nlist", 0);
if (size <= TYPICAL_COUNT / 16384 + 1) {
//handle less row count, avoid nlist set to 0
cfg["nlist"] = 1;
} else if (int(size / TYPICAL_COUNT) *nlist == 0) {
} else if (int(size / TYPICAL_COUNT) * nlist == 0) {
//calculate a proper nlist if nlist not specified or size less than TYPICAL_COUNT
cfg["nlist"] = int(size / TYPICAL_COUNT * 16384);
}
......@@ -270,7 +284,8 @@ void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Co
#define GPU_MAX_NRPOBE 1024
#endif
void ParameterValidation(const IndexType &type, Config &cfg) {
void
ParameterValidation(const IndexType &type, Config &cfg) {
switch (type) {
case IndexType::FAISS_IVFSQ8_GPU:
case IndexType::FAISS_IVFFLAT_GPU:
......@@ -291,7 +306,8 @@ void ParameterValidation(const IndexType &type, Config &cfg) {
}
}
IndexType ConvertToCpuIndexType(const IndexType &type) {
IndexType
ConvertToCpuIndexType(const IndexType &type) {
// TODO(linxj): add IDMAP
switch (type) {
case IndexType::FAISS_IVFFLAT_GPU:
......@@ -308,7 +324,8 @@ IndexType ConvertToCpuIndexType(const IndexType &type) {
}
}
IndexType ConvertToGpuIndexType(const IndexType &type) {
IndexType
ConvertToGpuIndexType(const IndexType &type) {
switch (type) {
case IndexType::FAISS_IVFFLAT_MIX:
case IndexType::FAISS_IVFFLAT_CPU: {
......
......@@ -21,8 +21,7 @@
#include <string>
#include <memory>
#include "utils/Error.h"
#include "utils/Status.h"
#include "knowhere/common/Config.h"
#include "knowhere/common/BinarySet.h"
......@@ -50,62 +49,84 @@ enum class IndexType {
};
class VecIndex;
using VecIndexPtr = std::shared_ptr<VecIndex>;
class VecIndex {
public:
virtual ErrorCode BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt = 0,
const float *xt = nullptr) = 0;
virtual ErrorCode Add(const long &nb,
const float *xb,
const long *ids,
const Config &cfg = Config()) = 0;
virtual ErrorCode Search(const long &nq,
const float *xq,
float *dist,
long *ids,
const Config &cfg = Config()) = 0;
virtual VecIndexPtr CopyToGpu(const int64_t &device_id,
const Config &cfg = Config()) = 0;
virtual VecIndexPtr CopyToCpu(const Config &cfg = Config()) = 0;
virtual VecIndexPtr Clone() = 0;
virtual int64_t GetDeviceId() = 0;
virtual IndexType GetType() = 0;
virtual int64_t Dimension() = 0;
virtual int64_t Count() = 0;
virtual zilliz::knowhere::BinarySet Serialize() = 0;
virtual ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) = 0;
virtual Status
BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt = 0,
const float *xt = nullptr) = 0;
virtual Status
Add(const long &nb,
const float *xb,
const long *ids,
const Config &cfg = Config()) = 0;
virtual Status
Search(const long &nq,
const float *xq,
float *dist,
long *ids,
const Config &cfg = Config()) = 0;
virtual VecIndexPtr
CopyToGpu(const int64_t &device_id,
const Config &cfg = Config()) = 0;
virtual VecIndexPtr
CopyToCpu(const Config &cfg = Config()) = 0;
virtual VecIndexPtr
Clone() = 0;
virtual int64_t
GetDeviceId() = 0;
virtual IndexType
GetType() = 0;
virtual int64_t
Dimension() = 0;
virtual int64_t
Count() = 0;
virtual zilliz::knowhere::BinarySet
Serialize() = 0;
virtual Status
Load(const zilliz::knowhere::BinarySet &index_binary) = 0;
};
extern ErrorCode write_index(VecIndexPtr index, const std::string &location);
extern Status
write_index(VecIndexPtr index, const std::string &location);
extern VecIndexPtr
read_index(const std::string &location);
extern VecIndexPtr read_index(const std::string &location);
extern VecIndexPtr
GetVecIndexFactory(const IndexType &type, const Config &cfg = Config());
extern VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config &cfg = Config());
extern VecIndexPtr
LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary);
extern VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary);
extern void
AutoGenParams(const IndexType &type, const long &size, Config &cfg);
extern void AutoGenParams(const IndexType &type, const long &size, Config &cfg);
extern void
ParameterValidation(const IndexType &type, Config &cfg);
extern void ParameterValidation(const IndexType &type, Config &cfg);
extern IndexType
ConvertToCpuIndexType(const IndexType &type);
extern IndexType ConvertToCpuIndexType(const IndexType &type);
extern IndexType ConvertToGpuIndexType(const IndexType &type);
extern IndexType
ConvertToGpuIndexType(const IndexType &type);
}
}
......
......@@ -17,8 +17,7 @@
#include "scheduler/Scheduler.h"
#include <gtest/gtest.h>
#include <src/scheduler/tasklabel/DefaultLabel.h>
#include <src/server/ServerConfig.h>
#include "src/scheduler/tasklabel/DefaultLabel.h"
#include "cache/DataObj.h"
#include "cache/GpuCacheMgr.h"
#include "scheduler/task/TestTask.h"
......@@ -35,13 +34,12 @@ namespace engine {
class MockVecIndex : public engine::VecIndex {
public:
virtual ErrorCode BuildAll(const long &nb,
virtual Status BuildAll(const long &nb,
const float *xb,
const long *ids,
const engine::Config &cfg,
const long &nt = 0,
const float *xt = nullptr) {
}
engine::VecIndexPtr Clone() override {
......@@ -56,14 +54,14 @@ public:
return engine::IndexType::INVALID;
}
virtual ErrorCode Add(const long &nb,
virtual Status Add(const long &nb,
const float *xb,
const long *ids,
const engine::Config &cfg = engine::Config()) {
}
virtual ErrorCode Search(const long &nq,
virtual Status Search(const long &nq,
const float *xq,
float *dist,
long *ids,
......@@ -92,7 +90,7 @@ public:
return binset;
}
virtual ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) {
virtual Status Load(const zilliz::knowhere::BinarySet &index_binary) {
}
......
......@@ -19,7 +19,6 @@
#include "cache/CpuCacheMgr.h"
#include "cache/GpuCacheMgr.h"
#include "server/ServerConfig.h"
#include "utils/Error.h"
#include "src/wrapper/vec_index.h"
......@@ -48,13 +47,13 @@ public:
}
virtual ErrorCode BuildAll(const long &nb,
virtual Status BuildAll(const long &nb,
const float *xb,
const long *ids,
const engine::Config &cfg,
const long &nt = 0,
const float *xt = nullptr) {
return 0;
return Status();
}
engine::VecIndexPtr Clone() override {
......@@ -69,19 +68,19 @@ public:
return engine::IndexType::INVALID;
}
virtual ErrorCode Add(const long &nb,
virtual Status Add(const long &nb,
const float *xb,
const long *ids,
const engine::Config &cfg = engine::Config()) {
return 0;
return Status();
}
virtual ErrorCode Search(const long &nq,
virtual Status Search(const long &nq,
const float *xq,
float *dist,
long *ids,
const engine::Config &cfg = engine::Config()) {
return 0;
return Status();
}
engine::VecIndexPtr CopyToGpu(const int64_t &device_id,
......@@ -106,8 +105,8 @@ public:
return binset;
}
virtual ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) {
return 0;
virtual Status Load(const zilliz::knowhere::BinarySet &index_binary) {
return Status();
}
public:
......
......@@ -28,6 +28,7 @@ set(wrapper_files
set(util_files
utils.cpp
${MILVUS_ENGINE_SRC}/utils/easylogging++.cc
${MILVUS_ENGINE_SRC}/utils/Status.cpp
)
set(knowhere_libs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册