提交 b2003e9f 编写于 作者: H Heisenberg

MS-583 Change to Status from errorcode


Former-commit-id: eaef6a33657b660e4e4c832a56b6982f74e49ec2
上级 c7ec7ced
...@@ -99,11 +99,8 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) { ...@@ -99,11 +99,8 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
} }
Status ExecutionEngineImpl::AddWithIds(long n, const float *xdata, const long *xids) { Status ExecutionEngineImpl::AddWithIds(long n, const float *xdata, const long *xids) {
auto ec = index_->Add(n, xdata, xids); auto status = index_->Add(n, xdata, xids);
if (ec != KNOWHERE_SUCCESS) { return status;
return Status(DB_ERROR, "Add error");
}
return Status::OK();
} }
size_t ExecutionEngineImpl::Count() const { size_t ExecutionEngineImpl::Count() const {
...@@ -131,11 +128,8 @@ size_t ExecutionEngineImpl::PhysicalSize() const { ...@@ -131,11 +128,8 @@ size_t ExecutionEngineImpl::PhysicalSize() const {
} }
Status ExecutionEngineImpl::Serialize() { Status ExecutionEngineImpl::Serialize() {
auto ec = write_index(index_, location_); auto status = write_index(index_, location_);
if (ec != KNOWHERE_SUCCESS) { return status;
return Status(DB_ERROR, "Serialize: write to disk error");
}
return Status::OK();
} }
Status ExecutionEngineImpl::Load(bool to_cache) { Status ExecutionEngineImpl::Load(bool to_cache) {
...@@ -254,12 +248,11 @@ Status ExecutionEngineImpl::Merge(const std::string &location) { ...@@ -254,12 +248,11 @@ Status ExecutionEngineImpl::Merge(const std::string &location) {
} }
if (auto file_index = std::dynamic_pointer_cast<BFIndex>(to_merge)) { if (auto file_index = std::dynamic_pointer_cast<BFIndex>(to_merge)) {
auto ec = index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds()); auto status = index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
if (ec != KNOWHERE_SUCCESS) { if (!status.ok()) {
ENGINE_LOG_ERROR << "Merge: Add Error"; ENGINE_LOG_ERROR << "Merge: Add Error";
return Status(DB_ERROR, "Merge: Add Error");
} }
return Status::OK(); return status;
} else { } else {
return Status(DB_ERROR, "file index type is not idmap"); return Status(DB_ERROR, "file index type is not idmap");
} }
...@@ -287,11 +280,11 @@ ExecutionEngineImpl::BuildIndex(const std::string &location, EngineType engine_t ...@@ -287,11 +280,11 @@ ExecutionEngineImpl::BuildIndex(const std::string &location, EngineType engine_t
build_cfg["nlist"] = nlist_; build_cfg["nlist"] = nlist_;
AutoGenParams(to_index->GetType(), Count(), build_cfg); AutoGenParams(to_index->GetType(), Count(), build_cfg);
auto ec = to_index->BuildAll(Count(), auto status = to_index->BuildAll(Count(),
from_index->GetRawVectors(), from_index->GetRawVectors(),
from_index->GetRawIds(), from_index->GetRawIds(),
build_cfg); build_cfg);
if (ec != KNOWHERE_SUCCESS) { throw Exception(DB_ERROR, "Build index error"); } if (!status.ok()) { throw Exception(DB_ERROR, status.message()); }
return std::make_shared<ExecutionEngineImpl>(to_index, location, engine_type, metric_type_, nlist_); return std::make_shared<ExecutionEngineImpl>(to_index, location, engine_type, metric_type_, nlist_);
} }
...@@ -309,12 +302,11 @@ Status ExecutionEngineImpl::Search(long n, ...@@ -309,12 +302,11 @@ Status ExecutionEngineImpl::Search(long n,
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe; ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
auto cfg = Config::object{{"k", k}, {"nprobe", nprobe}}; auto cfg = Config::object{{"k", k}, {"nprobe", nprobe}};
auto ec = index_->Search(n, data, distances, labels, cfg); auto status = index_->Search(n, data, distances, labels, cfg);
if (ec != KNOWHERE_SUCCESS) { if (!status.ok()) {
ENGINE_LOG_ERROR << "Search error"; ENGINE_LOG_ERROR << "Search error";
return Status(DB_ERROR, "Search: Search Error");
} }
return Status::OK(); return status;
} }
Status ExecutionEngineImpl::Cache() { Status ExecutionEngineImpl::Cache() {
......
...@@ -28,7 +28,8 @@ namespace engine { ...@@ -28,7 +28,8 @@ namespace engine {
constexpr int64_t M_BYTE = 1024 * 1024; constexpr int64_t M_BYTE = 1024 * 1024;
ErrorCode KnowhereResource::Initialize() { Status
KnowhereResource::Initialize() {
struct GpuResourceSetting { struct GpuResourceSetting {
int64_t pinned_memory = 300*M_BYTE; int64_t pinned_memory = 300*M_BYTE;
int64_t temp_memory = 300*M_BYTE; int64_t temp_memory = 300*M_BYTE;
...@@ -65,12 +66,13 @@ ErrorCode KnowhereResource::Initialize() { ...@@ -65,12 +66,13 @@ ErrorCode KnowhereResource::Initialize() {
iter->second.resource_num); iter->second.resource_num);
} }
return KNOWHERE_SUCCESS; return Status::OK();
} }
ErrorCode KnowhereResource::Finalize() { Status
KnowhereResource::Finalize() {
knowhere::FaissGpuResourceMgr::GetInstance().Free(); // free gpu resource. knowhere::FaissGpuResourceMgr::GetInstance().Free(); // free gpu resource.
return KNOWHERE_SUCCESS; return Status::OK();
} }
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#pragma once #pragma once
#include "utils/Error.h" #include "utils/Status.h"
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
...@@ -26,8 +26,11 @@ namespace engine { ...@@ -26,8 +26,11 @@ namespace engine {
class KnowhereResource { class KnowhereResource {
public: public:
static ErrorCode Initialize(); static Status
static ErrorCode Finalize(); Initialize();
static Status
Finalize();
}; };
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include "knowhere/index/vector_index/IndexGPUIVF.h" #include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/common/Exception.h" #include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/helpers/Cloner.h" #include "knowhere/index/vector_index/helpers/Cloner.h"
#include "vec_impl.h" #include "vec_impl.h"
#include "data_transfer.h" #include "data_transfer.h"
...@@ -32,12 +31,13 @@ namespace engine { ...@@ -32,12 +31,13 @@ namespace engine {
using namespace zilliz::knowhere; using namespace zilliz::knowhere;
ErrorCode VecIndexImpl::BuildAll(const long &nb, Status
const float *xb, VecIndexImpl::BuildAll(const long &nb,
const long *ids, const float *xb,
const Config &cfg, const long *ids,
const long &nt, const Config &cfg,
const float *xt) { const long &nt,
const float *xt) {
try { try {
dim = cfg["dim"].as<int>(); dim = cfg["dim"].as<int>();
auto dataset = GenDatasetWithIds(nb, dim, xb, ids); auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
...@@ -49,36 +49,38 @@ ErrorCode VecIndexImpl::BuildAll(const long &nb, ...@@ -49,36 +49,38 @@ ErrorCode VecIndexImpl::BuildAll(const long &nb,
index_->Add(dataset, cfg); index_->Add(dataset, cfg);
} catch (KnowhereException &e) { } catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR; return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) { } catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT; return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) { } catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR; return Status(KNOWHERE_ERROR, e.what());
} }
return KNOWHERE_SUCCESS; return Status::OK();
} }
ErrorCode VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) { Status
VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
try { try {
auto dataset = GenDatasetWithIds(nb, dim, xb, ids); auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
index_->Add(dataset, cfg); index_->Add(dataset, cfg);
} catch (KnowhereException &e) { } catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR; return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) { } catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT; return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) { } catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR; return Status(KNOWHERE_ERROR, e.what());
} }
return KNOWHERE_SUCCESS; return Status::OK();
} }
ErrorCode VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) { Status
VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) {
try { try {
auto k = cfg["k"].as<int>(); auto k = cfg["k"].as<int>();
auto dataset = GenDataset(nq, dim, xq); auto dataset = GenDataset(nq, dim, xq);
...@@ -117,41 +119,47 @@ ErrorCode VecIndexImpl::Search(const long &nq, const float *xq, float *dist, lon ...@@ -117,41 +119,47 @@ ErrorCode VecIndexImpl::Search(const long &nq, const float *xq, float *dist, lon
} catch (KnowhereException &e) { } catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR; return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) { } catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT; return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) { } catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR; return Status(KNOWHERE_ERROR, e.what());
} }
return KNOWHERE_SUCCESS; return Status::OK();
} }
zilliz::knowhere::BinarySet VecIndexImpl::Serialize() { zilliz::knowhere::BinarySet
VecIndexImpl::Serialize() {
type = ConvertToCpuIndexType(type); type = ConvertToCpuIndexType(type);
return index_->Serialize(); return index_->Serialize();
} }
ErrorCode VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) { Status
VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
index_->Load(index_binary); index_->Load(index_binary);
dim = Dimension(); dim = Dimension();
return KNOWHERE_SUCCESS; return Status::OK();
} }
int64_t VecIndexImpl::Dimension() { int64_t
VecIndexImpl::Dimension() {
return index_->Dimension(); return index_->Dimension();
} }
int64_t VecIndexImpl::Count() { int64_t
VecIndexImpl::Count() {
return index_->Count(); return index_->Count();
} }
IndexType VecIndexImpl::GetType() { IndexType
VecIndexImpl::GetType() {
return type; return type;
} }
VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) { VecIndexPtr
VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
// TODO(linxj): exception handle // TODO(linxj): exception handle
auto gpu_index = zilliz::knowhere::cloner::CopyCpuToGpu(index_, device_id, cfg); auto gpu_index = zilliz::knowhere::cloner::CopyCpuToGpu(index_, device_id, cfg);
auto new_index = std::make_shared<VecIndexImpl>(gpu_index, ConvertToGpuIndexType(type)); auto new_index = std::make_shared<VecIndexImpl>(gpu_index, ConvertToGpuIndexType(type));
...@@ -159,7 +167,8 @@ VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) ...@@ -159,7 +167,8 @@ VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg)
return new_index; return new_index;
} }
VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) { VecIndexPtr
VecIndexImpl::CopyToCpu(const Config &cfg) {
// TODO(linxj): exception handle // TODO(linxj): exception handle
auto cpu_index = zilliz::knowhere::cloner::CopyGpuToCpu(index_, cfg); auto cpu_index = zilliz::knowhere::cloner::CopyGpuToCpu(index_, cfg);
auto new_index = std::make_shared<VecIndexImpl>(cpu_index, ConvertToCpuIndexType(type)); auto new_index = std::make_shared<VecIndexImpl>(cpu_index, ConvertToCpuIndexType(type));
...@@ -167,32 +176,37 @@ VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) { ...@@ -167,32 +176,37 @@ VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) {
return new_index; return new_index;
} }
VecIndexPtr VecIndexImpl::Clone() { VecIndexPtr
VecIndexImpl::Clone() {
// TODO(linxj): exception handle // TODO(linxj): exception handle
auto clone_index = std::make_shared<VecIndexImpl>(index_->Clone(), type); auto clone_index = std::make_shared<VecIndexImpl>(index_->Clone(), type);
clone_index->dim = dim; clone_index->dim = dim;
return clone_index; return clone_index;
} }
int64_t VecIndexImpl::GetDeviceId() { int64_t
if (auto device_idx = std::dynamic_pointer_cast<GPUIndex>(index_)){ VecIndexImpl::GetDeviceId() {
if (auto device_idx = std::dynamic_pointer_cast<GPUIndex>(index_)) {
return device_idx->GetGpuDevice(); return device_idx->GetGpuDevice();
} }
// else // else
return -1; // -1 == cpu return -1; // -1 == cpu
} }
float *BFIndex::GetRawVectors() { float *
BFIndex::GetRawVectors() {
auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_); auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_);
if (raw_index) { return raw_index->GetRawVectors(); } if (raw_index) { return raw_index->GetRawVectors(); }
return nullptr; return nullptr;
} }
int64_t *BFIndex::GetRawIds() { int64_t *
BFIndex::GetRawIds() {
return std::static_pointer_cast<IDMAP>(index_)->GetRawIds(); return std::static_pointer_cast<IDMAP>(index_)->GetRawIds();
} }
ErrorCode BFIndex::Build(const Config &cfg) { ErrorCode
BFIndex::Build(const Config &cfg) {
try { try {
dim = cfg["dim"].as<int>(); dim = cfg["dim"].as<int>();
std::static_pointer_cast<IDMAP>(index_)->Train(cfg); std::static_pointer_cast<IDMAP>(index_)->Train(cfg);
...@@ -209,12 +223,13 @@ ErrorCode BFIndex::Build(const Config &cfg) { ...@@ -209,12 +223,13 @@ ErrorCode BFIndex::Build(const Config &cfg) {
return KNOWHERE_SUCCESS; return KNOWHERE_SUCCESS;
} }
ErrorCode BFIndex::BuildAll(const long &nb, Status
const float *xb, BFIndex::BuildAll(const long &nb,
const long *ids, const float *xb,
const Config &cfg, const long *ids,
const long &nt, const Config &cfg,
const float *xt) { const long &nt,
const float *xt) {
try { try {
dim = cfg["dim"].as<int>(); dim = cfg["dim"].as<int>();
auto dataset = GenDatasetWithIds(nb, dim, xb, ids); auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
...@@ -223,24 +238,25 @@ ErrorCode BFIndex::BuildAll(const long &nb, ...@@ -223,24 +238,25 @@ ErrorCode BFIndex::BuildAll(const long &nb,
index_->Add(dataset, cfg); index_->Add(dataset, cfg);
} catch (KnowhereException &e) { } catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR; return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) { } catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT; return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) { } catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR; return Status(KNOWHERE_ERROR, e.what());
} }
return KNOWHERE_SUCCESS; return Status::OK();
} }
// TODO(linxj): add lock here. // TODO(linxj): add lock here.
ErrorCode IVFMixIndex::BuildAll(const long &nb, Status
const float *xb, IVFMixIndex::BuildAll(const long &nb,
const long *ids, const float *xb,
const Config &cfg, const long *ids,
const long &nt, const Config &cfg,
const float *xt) { const long &nt,
const float *xt) {
try { try {
dim = cfg["dim"].as<int>(); dim = cfg["dim"].as<int>();
auto dataset = GenDatasetWithIds(nb, dim, xb, ids); auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
...@@ -257,26 +273,27 @@ ErrorCode IVFMixIndex::BuildAll(const long &nb, ...@@ -257,26 +273,27 @@ ErrorCode IVFMixIndex::BuildAll(const long &nb,
type = ConvertToCpuIndexType(type); type = ConvertToCpuIndexType(type);
} else { } else {
WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed"; WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed";
return KNOWHERE_ERROR; return Status(KNOWHERE_ERROR, "Build IVFMIXIndex Failed");
} }
} catch (KnowhereException &e) { } catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR; return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (jsoncons::json_exception &e) { } catch (jsoncons::json_exception &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_INVALID_ARGUMENT; return Status(KNOWHERE_INVALID_ARGUMENT, e.what());
} catch (std::exception &e) { } catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR; return Status(KNOWHERE_ERROR, e.what());
} }
return KNOWHERE_SUCCESS; return Status::OK();
} }
ErrorCode IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) { Status
IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
//index_ = std::make_shared<IVF>(); //index_ = std::make_shared<IVF>();
index_->Load(index_binary); index_->Load(index_binary);
dim = Dimension(); dim = Dimension();
return KNOWHERE_SUCCESS; return Status::OK();
} }
} }
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#pragma once #pragma once
#include "knowhere/index/vector_index/VectorIndex.h" #include "knowhere/index/vector_index/VectorIndex.h"
#include "vec_index.h" #include "vec_index.h"
...@@ -31,27 +30,53 @@ class VecIndexImpl : public VecIndex { ...@@ -31,27 +30,53 @@ class VecIndexImpl : public VecIndex {
public: public:
explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type) explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type)
: index_(std::move(index)), type(type) {}; : index_(std::move(index)), type(type) {};
ErrorCode BuildAll(const long &nb,
const float *xb, Status
const long *ids, BuildAll(const long &nb,
const Config &cfg, const float *xb,
const long &nt, const long *ids,
const float *xt) override; const Config &cfg,
VecIndexPtr CopyToGpu(const int64_t &device_id, const Config &cfg) override; const long &nt,
VecIndexPtr CopyToCpu(const Config &cfg) override; const float *xt) override;
IndexType GetType() override;
int64_t Dimension() override; VecIndexPtr
int64_t Count() override; CopyToGpu(const int64_t &device_id, const Config &cfg) override;
ErrorCode Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override;
zilliz::knowhere::BinarySet Serialize() override; VecIndexPtr
ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) override; CopyToCpu(const Config &cfg) override;
VecIndexPtr Clone() override;
int64_t GetDeviceId() override; IndexType
ErrorCode Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) override; GetType() override;
int64_t
Dimension() override;
int64_t
Count() override;
Status
Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override;
zilliz::knowhere::BinarySet
Serialize() override;
Status
Load(const zilliz::knowhere::BinarySet &index_binary) override;
VecIndexPtr
Clone() override;
int64_t
GetDeviceId() override;
Status
Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) override;
protected: protected:
int64_t dim = 0; int64_t dim = 0;
IndexType type = IndexType::INVALID; IndexType type = IndexType::INVALID;
std::shared_ptr<zilliz::knowhere::VectorIndex> index_ = nullptr; std::shared_ptr<zilliz::knowhere::VectorIndex> index_ = nullptr;
}; };
...@@ -60,28 +85,39 @@ class IVFMixIndex : public VecIndexImpl { ...@@ -60,28 +85,39 @@ class IVFMixIndex : public VecIndexImpl {
explicit IVFMixIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type) explicit IVFMixIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type)
: VecIndexImpl(std::move(index), type) {}; : VecIndexImpl(std::move(index), type) {};
ErrorCode BuildAll(const long &nb, Status
const float *xb, BuildAll(const long &nb,
const long *ids, const float *xb,
const Config &cfg, const long *ids,
const long &nt, const Config &cfg,
const float *xt) override; const long &nt,
ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) override; const float *xt) override;
Status
Load(const zilliz::knowhere::BinarySet &index_binary) override;
}; };
class BFIndex : public VecIndexImpl { class BFIndex : public VecIndexImpl {
public: public:
explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index), explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index),
IndexType::FAISS_IDMAP) {}; IndexType::FAISS_IDMAP) {};
ErrorCode Build(const Config& cfg);
float *GetRawVectors(); ErrorCode
ErrorCode BuildAll(const long &nb, Build(const Config &cfg);
const float *xb,
const long *ids, float *
const Config &cfg, GetRawVectors();
const long &nt,
const float *xt) override; Status
int64_t *GetRawIds(); BuildAll(const long &nb,
const float *xb,
const long *ids,
const Config &cfg,
const long &nt,
const float *xt) override;
int64_t *
GetRawIds();
}; };
} }
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include "knowhere/index/vector_index/IndexKDT.h" #include "knowhere/index/vector_index/IndexKDT.h"
#include "knowhere/index/vector_index/IndexNSG.h" #include "knowhere/index/vector_index/IndexNSG.h"
#include "knowhere/common/Exception.h" #include "knowhere/common/Exception.h"
#include "vec_index.h" #include "vec_index.h"
#include "vec_impl.h" #include "vec_impl.h"
#include "utils/Log.h" #include "utils/Log.h"
...@@ -39,23 +38,19 @@ namespace engine { ...@@ -39,23 +38,19 @@ namespace engine {
static constexpr float TYPICAL_COUNT = 1000000.0; static constexpr float TYPICAL_COUNT = 1000000.0;
struct FileIOWriter {
std::fstream fs;
std::string name;
FileIOWriter(const std::string &fname);
~FileIOWriter();
size_t operator()(void *ptr, size_t size);
};
struct FileIOReader { struct FileIOReader {
std::fstream fs; std::fstream fs;
std::string name; std::string name;
FileIOReader(const std::string &fname); FileIOReader(const std::string &fname);
~FileIOReader(); ~FileIOReader();
size_t operator()(void *ptr, size_t size);
size_t operator()(void *ptr, size_t size, size_t pos); size_t
operator()(void *ptr, size_t size);
size_t
operator()(void *ptr, size_t size, size_t pos);
}; };
FileIOReader::FileIOReader(const std::string &fname) { FileIOReader::FileIOReader(const std::string &fname) {
...@@ -67,14 +62,27 @@ FileIOReader::~FileIOReader() { ...@@ -67,14 +62,27 @@ FileIOReader::~FileIOReader() {
fs.close(); fs.close();
} }
size_t FileIOReader::operator()(void *ptr, size_t size) { size_t
FileIOReader::operator()(void *ptr, size_t size) {
fs.read(reinterpret_cast<char *>(ptr), size); fs.read(reinterpret_cast<char *>(ptr), size);
} }
size_t FileIOReader::operator()(void *ptr, size_t size, size_t pos) { size_t
FileIOReader::operator()(void *ptr, size_t size, size_t pos) {
return 0; return 0;
} }
struct FileIOWriter {
std::fstream fs;
std::string name;
FileIOWriter(const std::string &fname);
~FileIOWriter();
size_t operator()(void *ptr, size_t size);
};
FileIOWriter::FileIOWriter(const std::string &fname) { FileIOWriter::FileIOWriter(const std::string &fname) {
name = fname; name = fname;
fs = std::fstream(name, std::ios::out | std::ios::binary); fs = std::fstream(name, std::ios::out | std::ios::binary);
...@@ -84,12 +92,14 @@ FileIOWriter::~FileIOWriter() { ...@@ -84,12 +92,14 @@ FileIOWriter::~FileIOWriter() {
fs.close(); fs.close();
} }
size_t FileIOWriter::operator()(void *ptr, size_t size) { size_t
FileIOWriter::operator()(void *ptr, size_t size) {
fs.write(reinterpret_cast<char *>(ptr), size); fs.write(reinterpret_cast<char *>(ptr), size);
} }
VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config &cfg) { VecIndexPtr
GetVecIndexFactory(const IndexType &type, const Config &cfg) {
std::shared_ptr<zilliz::knowhere::VectorIndex> index; std::shared_ptr<zilliz::knowhere::VectorIndex> index;
auto gpu_device = cfg.get_with_default("gpu_id", 0); auto gpu_device = cfg.get_with_default("gpu_id", 0);
switch (type) { switch (type) {
...@@ -145,13 +155,15 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config &cfg) { ...@@ -145,13 +155,15 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config &cfg) {
return std::make_shared<VecIndexImpl>(index, type); return std::make_shared<VecIndexImpl>(index, type);
} }
VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) { VecIndexPtr
LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) {
auto index = GetVecIndexFactory(index_type); auto index = GetVecIndexFactory(index_type);
index->Load(index_binary); index->Load(index_binary);
return index; return index;
} }
VecIndexPtr read_index(const std::string &location) { VecIndexPtr
read_index(const std::string &location) {
knowhere::BinarySet load_data_list; knowhere::BinarySet load_data_list;
FileIOReader reader(location); FileIOReader reader(location);
reader.fs.seekg(0, reader.fs.end); reader.fs.seekg(0, reader.fs.end);
...@@ -195,7 +207,8 @@ VecIndexPtr read_index(const std::string &location) { ...@@ -195,7 +207,8 @@ VecIndexPtr read_index(const std::string &location) {
return LoadVecIndex(current_type, load_data_list); return LoadVecIndex(current_type, load_data_list);
} }
ErrorCode write_index(VecIndexPtr index, const std::string &location) { Status
write_index(VecIndexPtr index, const std::string &location) {
try { try {
auto binaryset = index->Serialize(); auto binaryset = index->Serialize();
auto index_type = index->GetType(); auto index_type = index->GetType();
...@@ -215,28 +228,29 @@ ErrorCode write_index(VecIndexPtr index, const std::string &location) { ...@@ -215,28 +228,29 @@ ErrorCode write_index(VecIndexPtr index, const std::string &location) {
} }
} catch (knowhere::KnowhereException &e) { } catch (knowhere::KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR; return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception &e) { } catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
std::string estring(e.what()); std::string estring(e.what());
if (estring.find("No space left on device") != estring.npos) { if (estring.find("No space left on device") != estring.npos) {
WRAPPER_LOG_ERROR << "No space left on the device"; WRAPPER_LOG_ERROR << "No space left on the device";
return KNOWHERE_NO_SPACE; return Status(KNOWHERE_NO_SPACE, "No space left on the device");
} else { } else {
return KNOWHERE_ERROR; return Status(KNOWHERE_ERROR, e.what());
} }
} }
return KNOWHERE_SUCCESS; return Status::OK();
} }
// TODO(linxj): redo here. // TODO(linxj): redo here.
void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Config &cfg) { void
AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Config &cfg) {
auto nlist = cfg.get_with_default("nlist", 0); auto nlist = cfg.get_with_default("nlist", 0);
if (size <= TYPICAL_COUNT / 16384 + 1) { if (size <= TYPICAL_COUNT / 16384 + 1) {
//handle less row count, avoid nlist set to 0 //handle less row count, avoid nlist set to 0
cfg["nlist"] = 1; cfg["nlist"] = 1;
} else if (int(size / TYPICAL_COUNT) *nlist == 0) { } else if (int(size / TYPICAL_COUNT) * nlist == 0) {
//calculate a proper nlist if nlist not specified or size less than TYPICAL_COUNT //calculate a proper nlist if nlist not specified or size less than TYPICAL_COUNT
cfg["nlist"] = int(size / TYPICAL_COUNT * 16384); cfg["nlist"] = int(size / TYPICAL_COUNT * 16384);
} }
...@@ -270,7 +284,8 @@ void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Co ...@@ -270,7 +284,8 @@ void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Co
#define GPU_MAX_NRPOBE 1024 #define GPU_MAX_NRPOBE 1024
#endif #endif
void ParameterValidation(const IndexType &type, Config &cfg) { void
ParameterValidation(const IndexType &type, Config &cfg) {
switch (type) { switch (type) {
case IndexType::FAISS_IVFSQ8_GPU: case IndexType::FAISS_IVFSQ8_GPU:
case IndexType::FAISS_IVFFLAT_GPU: case IndexType::FAISS_IVFFLAT_GPU:
...@@ -291,7 +306,8 @@ void ParameterValidation(const IndexType &type, Config &cfg) { ...@@ -291,7 +306,8 @@ void ParameterValidation(const IndexType &type, Config &cfg) {
} }
} }
IndexType ConvertToCpuIndexType(const IndexType &type) { IndexType
ConvertToCpuIndexType(const IndexType &type) {
// TODO(linxj): add IDMAP // TODO(linxj): add IDMAP
switch (type) { switch (type) {
case IndexType::FAISS_IVFFLAT_GPU: case IndexType::FAISS_IVFFLAT_GPU:
...@@ -308,7 +324,8 @@ IndexType ConvertToCpuIndexType(const IndexType &type) { ...@@ -308,7 +324,8 @@ IndexType ConvertToCpuIndexType(const IndexType &type) {
} }
} }
IndexType ConvertToGpuIndexType(const IndexType &type) { IndexType
ConvertToGpuIndexType(const IndexType &type) {
switch (type) { switch (type) {
case IndexType::FAISS_IVFFLAT_MIX: case IndexType::FAISS_IVFFLAT_MIX:
case IndexType::FAISS_IVFFLAT_CPU: { case IndexType::FAISS_IVFFLAT_CPU: {
......
...@@ -21,8 +21,7 @@ ...@@ -21,8 +21,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include "utils/Error.h" #include "utils/Status.h"
#include "knowhere/common/Config.h" #include "knowhere/common/Config.h"
#include "knowhere/common/BinarySet.h" #include "knowhere/common/BinarySet.h"
...@@ -50,62 +49,84 @@ enum class IndexType { ...@@ -50,62 +49,84 @@ enum class IndexType {
}; };
class VecIndex; class VecIndex;
using VecIndexPtr = std::shared_ptr<VecIndex>; using VecIndexPtr = std::shared_ptr<VecIndex>;
class VecIndex { class VecIndex {
public: public:
virtual ErrorCode BuildAll(const long &nb, virtual Status
const float *xb, BuildAll(const long &nb,
const long *ids, const float *xb,
const Config &cfg, const long *ids,
const long &nt = 0, const Config &cfg,
const float *xt = nullptr) = 0; const long &nt = 0,
const float *xt = nullptr) = 0;
virtual ErrorCode Add(const long &nb,
const float *xb, virtual Status
const long *ids, Add(const long &nb,
const Config &cfg = Config()) = 0; const float *xb,
const long *ids,
virtual ErrorCode Search(const long &nq, const Config &cfg = Config()) = 0;
const float *xq,
float *dist, virtual Status
long *ids, Search(const long &nq,
const Config &cfg = Config()) = 0; const float *xq,
float *dist,
virtual VecIndexPtr CopyToGpu(const int64_t &device_id, long *ids,
const Config &cfg = Config()) = 0; const Config &cfg = Config()) = 0;
virtual VecIndexPtr CopyToCpu(const Config &cfg = Config()) = 0; virtual VecIndexPtr
CopyToGpu(const int64_t &device_id,
virtual VecIndexPtr Clone() = 0; const Config &cfg = Config()) = 0;
virtual int64_t GetDeviceId() = 0; virtual VecIndexPtr
CopyToCpu(const Config &cfg = Config()) = 0;
virtual IndexType GetType() = 0;
virtual VecIndexPtr
virtual int64_t Dimension() = 0; Clone() = 0;
virtual int64_t Count() = 0; virtual int64_t
GetDeviceId() = 0;
virtual zilliz::knowhere::BinarySet Serialize() = 0;
virtual IndexType
virtual ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) = 0; GetType() = 0;
virtual int64_t
Dimension() = 0;
virtual int64_t
Count() = 0;
virtual zilliz::knowhere::BinarySet
Serialize() = 0;
virtual Status
Load(const zilliz::knowhere::BinarySet &index_binary) = 0;
}; };
extern ErrorCode write_index(VecIndexPtr index, const std::string &location); extern Status
write_index(VecIndexPtr index, const std::string &location);
extern VecIndexPtr
read_index(const std::string &location);
extern VecIndexPtr read_index(const std::string &location); extern VecIndexPtr
GetVecIndexFactory(const IndexType &type, const Config &cfg = Config());
extern VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config &cfg = Config()); extern VecIndexPtr
LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary);
extern VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary); extern void
AutoGenParams(const IndexType &type, const long &size, Config &cfg);
extern void AutoGenParams(const IndexType &type, const long &size, Config &cfg); extern void
ParameterValidation(const IndexType &type, Config &cfg);
extern void ParameterValidation(const IndexType &type, Config &cfg); extern IndexType
ConvertToCpuIndexType(const IndexType &type);
extern IndexType ConvertToCpuIndexType(const IndexType &type); extern IndexType
extern IndexType ConvertToGpuIndexType(const IndexType &type); ConvertToGpuIndexType(const IndexType &type);
} }
} }
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
#include "scheduler/Scheduler.h" #include "scheduler/Scheduler.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <src/scheduler/tasklabel/DefaultLabel.h> #include "src/scheduler/tasklabel/DefaultLabel.h"
#include <src/server/ServerConfig.h>
#include "cache/DataObj.h" #include "cache/DataObj.h"
#include "cache/GpuCacheMgr.h" #include "cache/GpuCacheMgr.h"
#include "scheduler/task/TestTask.h" #include "scheduler/task/TestTask.h"
...@@ -35,13 +34,12 @@ namespace engine { ...@@ -35,13 +34,12 @@ namespace engine {
class MockVecIndex : public engine::VecIndex { class MockVecIndex : public engine::VecIndex {
public: public:
virtual ErrorCode BuildAll(const long &nb, virtual Status BuildAll(const long &nb,
const float *xb, const float *xb,
const long *ids, const long *ids,
const engine::Config &cfg, const engine::Config &cfg,
const long &nt = 0, const long &nt = 0,
const float *xt = nullptr) { const float *xt = nullptr) {
} }
engine::VecIndexPtr Clone() override { engine::VecIndexPtr Clone() override {
...@@ -56,14 +54,14 @@ public: ...@@ -56,14 +54,14 @@ public:
return engine::IndexType::INVALID; return engine::IndexType::INVALID;
} }
virtual ErrorCode Add(const long &nb, virtual Status Add(const long &nb,
const float *xb, const float *xb,
const long *ids, const long *ids,
const engine::Config &cfg = engine::Config()) { const engine::Config &cfg = engine::Config()) {
} }
virtual ErrorCode Search(const long &nq, virtual Status Search(const long &nq,
const float *xq, const float *xq,
float *dist, float *dist,
long *ids, long *ids,
...@@ -92,7 +90,7 @@ public: ...@@ -92,7 +90,7 @@ public:
return binset; return binset;
} }
virtual ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) { virtual Status Load(const zilliz::knowhere::BinarySet &index_binary) {
} }
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "cache/CpuCacheMgr.h" #include "cache/CpuCacheMgr.h"
#include "cache/GpuCacheMgr.h" #include "cache/GpuCacheMgr.h"
#include "server/ServerConfig.h" #include "server/ServerConfig.h"
#include "utils/Error.h" #include "utils/Error.h"
#include "src/wrapper/vec_index.h" #include "src/wrapper/vec_index.h"
...@@ -48,13 +47,13 @@ public: ...@@ -48,13 +47,13 @@ public:
} }
virtual ErrorCode BuildAll(const long &nb, virtual Status BuildAll(const long &nb,
const float *xb, const float *xb,
const long *ids, const long *ids,
const engine::Config &cfg, const engine::Config &cfg,
const long &nt = 0, const long &nt = 0,
const float *xt = nullptr) { const float *xt = nullptr) {
return 0; return Status();
} }
engine::VecIndexPtr Clone() override { engine::VecIndexPtr Clone() override {
...@@ -69,19 +68,19 @@ public: ...@@ -69,19 +68,19 @@ public:
return engine::IndexType::INVALID; return engine::IndexType::INVALID;
} }
virtual ErrorCode Add(const long &nb, virtual Status Add(const long &nb,
const float *xb, const float *xb,
const long *ids, const long *ids,
const engine::Config &cfg = engine::Config()) { const engine::Config &cfg = engine::Config()) {
return 0; return Status();
} }
virtual ErrorCode Search(const long &nq, virtual Status Search(const long &nq,
const float *xq, const float *xq,
float *dist, float *dist,
long *ids, long *ids,
const engine::Config &cfg = engine::Config()) { const engine::Config &cfg = engine::Config()) {
return 0; return Status();
} }
engine::VecIndexPtr CopyToGpu(const int64_t &device_id, engine::VecIndexPtr CopyToGpu(const int64_t &device_id,
...@@ -106,8 +105,8 @@ public: ...@@ -106,8 +105,8 @@ public:
return binset; return binset;
} }
virtual ErrorCode Load(const zilliz::knowhere::BinarySet &index_binary) { virtual Status Load(const zilliz::knowhere::BinarySet &index_binary) {
return 0; return Status();
} }
public: public:
......
...@@ -28,6 +28,7 @@ set(wrapper_files ...@@ -28,6 +28,7 @@ set(wrapper_files
set(util_files set(util_files
utils.cpp utils.cpp
${MILVUS_ENGINE_SRC}/utils/easylogging++.cc ${MILVUS_ENGINE_SRC}/utils/easylogging++.cc
${MILVUS_ENGINE_SRC}/utils/Status.cpp
) )
set(knowhere_libs set(knowhere_libs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册