提交 c6a3b800 编写于 作者: P peng.xu

Merge branch 'branch-0.5.0' into 'branch-0.5.0'

format wrapper code

See merge request megasearch/milvus!640

Former-commit-id: ceceba71c60d0e5f0b43ea12bbfd43b046618eb7
...@@ -6,5 +6,4 @@ ...@@ -6,5 +6,4 @@
*easylogging++* *easylogging++*
*SqliteMetaImpl.cpp *SqliteMetaImpl.cpp
*src/grpc* *src/grpc*
*src/core* *src/core*
*src/wrapper* \ No newline at end of file
\ No newline at end of file
...@@ -15,13 +15,14 @@ ...@@ -15,13 +15,14 @@
// specific language governing permissions and limitations // specific language governing permissions and limitations
// under the License. // under the License.
#include "wrapper/ConfAdapter.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include "utils/Log.h"
#include <cmath> #include <cmath>
#include "ConfAdapter.h" #include <memory>
#include "src/utils/Log.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
// TODO: add conf checker // TODO(lxj): add conf checker
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
...@@ -42,7 +43,7 @@ ConfAdapter::MatchBase(knowhere::Config conf) { ...@@ -42,7 +43,7 @@ ConfAdapter::MatchBase(knowhere::Config conf) {
} }
knowhere::Config knowhere::Config
ConfAdapter::Match(const TempMetaConf &metaconf) { ConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::Cfg>(); auto conf = std::make_shared<knowhere::Cfg>();
conf->d = metaconf.dim; conf->d = metaconf.dim;
conf->metric_type = metaconf.metric_type; conf->metric_type = metaconf.metric_type;
...@@ -52,14 +53,14 @@ ConfAdapter::Match(const TempMetaConf &metaconf) { ...@@ -52,14 +53,14 @@ ConfAdapter::Match(const TempMetaConf &metaconf) {
} }
knowhere::Config knowhere::Config
ConfAdapter::MatchSearch(const TempMetaConf &metaconf, const IndexType &type) { ConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
auto conf = std::make_shared<knowhere::Cfg>(); auto conf = std::make_shared<knowhere::Cfg>();
conf->k = metaconf.k; conf->k = metaconf.k;
return conf; return conf;
} }
knowhere::Config knowhere::Config
IVFConfAdapter::Match(const TempMetaConf &metaconf) { IVFConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::IVFCfg>(); auto conf = std::make_shared<knowhere::IVFCfg>();
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist); conf->nlist = MatchNlist(metaconf.size, metaconf.nlist);
conf->d = metaconf.dim; conf->d = metaconf.dim;
...@@ -72,7 +73,7 @@ IVFConfAdapter::Match(const TempMetaConf &metaconf) { ...@@ -72,7 +73,7 @@ IVFConfAdapter::Match(const TempMetaConf &metaconf) {
static constexpr float TYPICAL_COUNT = 1000000.0; static constexpr float TYPICAL_COUNT = 1000000.0;
int64_t int64_t
IVFConfAdapter::MatchNlist(const int64_t &size, const int64_t &nlist) { IVFConfAdapter::MatchNlist(const int64_t& size, const int64_t& nlist) {
if (size <= TYPICAL_COUNT / 16384 + 1) { if (size <= TYPICAL_COUNT / 16384 + 1) {
// handle less row count, avoid nlist set to 0 // handle less row count, avoid nlist set to 0
return 1; return 1;
...@@ -84,7 +85,7 @@ IVFConfAdapter::MatchNlist(const int64_t &size, const int64_t &nlist) { ...@@ -84,7 +85,7 @@ IVFConfAdapter::MatchNlist(const int64_t &size, const int64_t &nlist) {
} }
knowhere::Config knowhere::Config
IVFConfAdapter::MatchSearch(const TempMetaConf &metaconf, const IndexType &type) { IVFConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
auto conf = std::make_shared<knowhere::IVFCfg>(); auto conf = std::make_shared<knowhere::IVFCfg>();
conf->k = metaconf.k; conf->k = metaconf.k;
conf->nprobe = metaconf.nprobe; conf->nprobe = metaconf.nprobe;
...@@ -95,17 +96,16 @@ IVFConfAdapter::MatchSearch(const TempMetaConf &metaconf, const IndexType &type) ...@@ -95,17 +96,16 @@ IVFConfAdapter::MatchSearch(const TempMetaConf &metaconf, const IndexType &type)
case IndexType::FAISS_IVFPQ_GPU: case IndexType::FAISS_IVFPQ_GPU:
if (conf->nprobe > GPU_MAX_NRPOBE) { if (conf->nprobe > GPU_MAX_NRPOBE) {
WRAPPER_LOG_WARNING << "When search with GPU, nprobe shoud be no more than " << GPU_MAX_NRPOBE WRAPPER_LOG_WARNING << "When search with GPU, nprobe shoud be no more than " << GPU_MAX_NRPOBE
<< ", but you passed " << conf->nprobe << ", but you passed " << conf->nprobe << ". Search with " << GPU_MAX_NRPOBE
<< ". Search with " << GPU_MAX_NRPOBE << " instead"; << " instead";
conf->nprobe = GPU_MAX_NRPOBE; conf->nprobe = GPU_MAX_NRPOBE;
} }
} }
return conf; return conf;
} }
knowhere::Config knowhere::Config
IVFSQConfAdapter::Match(const TempMetaConf &metaconf) { IVFSQConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::IVFSQCfg>(); auto conf = std::make_shared<knowhere::IVFSQCfg>();
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist); conf->nlist = MatchNlist(metaconf.size, metaconf.nlist);
conf->d = metaconf.dim; conf->d = metaconf.dim;
...@@ -117,7 +117,7 @@ IVFSQConfAdapter::Match(const TempMetaConf &metaconf) { ...@@ -117,7 +117,7 @@ IVFSQConfAdapter::Match(const TempMetaConf &metaconf) {
} }
knowhere::Config knowhere::Config
IVFPQConfAdapter::Match(const TempMetaConf &metaconf) { IVFPQConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::IVFPQCfg>(); auto conf = std::make_shared<knowhere::IVFPQCfg>();
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist); conf->nlist = MatchNlist(metaconf.size, metaconf.nlist);
conf->d = metaconf.dim; conf->d = metaconf.dim;
...@@ -130,7 +130,7 @@ IVFPQConfAdapter::Match(const TempMetaConf &metaconf) { ...@@ -130,7 +130,7 @@ IVFPQConfAdapter::Match(const TempMetaConf &metaconf) {
} }
knowhere::Config knowhere::Config
NSGConfAdapter::Match(const TempMetaConf &metaconf) { NSGConfAdapter::Match(const TempMetaConf& metaconf) {
auto conf = std::make_shared<knowhere::NSGCfg>(); auto conf = std::make_shared<knowhere::NSGCfg>();
conf->nlist = MatchNlist(metaconf.size, metaconf.nlist); conf->nlist = MatchNlist(metaconf.size, metaconf.nlist);
conf->d = metaconf.dim; conf->d = metaconf.dim;
...@@ -146,20 +146,20 @@ NSGConfAdapter::Match(const TempMetaConf &metaconf) { ...@@ -146,20 +146,20 @@ NSGConfAdapter::Match(const TempMetaConf &metaconf) {
conf->candidate_pool_size = 200 + 100 * scale_factor; conf->candidate_pool_size = 200 + 100 * scale_factor;
MatchBase(conf); MatchBase(conf);
// WRAPPER_LOG_DEBUG << "nlist: " << conf->nlist // WRAPPER_LOG_DEBUG << "nlist: " << conf->nlist
// << ", gpu_id: " << conf->gpu_id << ", d: " << conf->d // << ", gpu_id: " << conf->gpu_id << ", d: " << conf->d
// << ", nprobe: " << conf->nprobe << ", knng: " << conf->knng; // << ", nprobe: " << conf->nprobe << ", knng: " << conf->knng;
return conf; return conf;
} }
knowhere::Config knowhere::Config
NSGConfAdapter::MatchSearch(const TempMetaConf &metaconf, const IndexType &type) { NSGConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) {
auto conf = std::make_shared<knowhere::NSGCfg>(); auto conf = std::make_shared<knowhere::NSGCfg>();
conf->k = metaconf.k; conf->k = metaconf.k;
conf->search_length = metaconf.search_length; conf->search_length = metaconf.search_length;
return conf; return conf;
} }
} } // namespace engine
} } // namespace milvus
} } // namespace zilliz
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
// specific language governing permissions and limitations // specific language governing permissions and limitations
// under the License. // under the License.
#pragma once #pragma once
#include "knowhere/common/Config.h"
#include "VecIndex.h" #include "VecIndex.h"
#include "knowhere/common/Config.h"
#include <memory>
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
...@@ -42,16 +42,17 @@ struct TempMetaConf { ...@@ -42,16 +42,17 @@ struct TempMetaConf {
class ConfAdapter { class ConfAdapter {
public: public:
virtual knowhere::Config virtual knowhere::Config
Match(const TempMetaConf &metaconf); Match(const TempMetaConf& metaconf);
virtual knowhere::Config virtual knowhere::Config
MatchSearch(const TempMetaConf &metaconf, const IndexType &type); MatchSearch(const TempMetaConf& metaconf, const IndexType& type);
// virtual void // virtual void
// Dump(){} // Dump(){}
protected: protected:
static void MatchBase(knowhere::Config conf); static void
MatchBase(knowhere::Config conf);
}; };
using ConfAdapterPtr = std::shared_ptr<ConfAdapter>; using ConfAdapterPtr = std::shared_ptr<ConfAdapter>;
...@@ -59,36 +60,37 @@ using ConfAdapterPtr = std::shared_ptr<ConfAdapter>; ...@@ -59,36 +60,37 @@ using ConfAdapterPtr = std::shared_ptr<ConfAdapter>;
class IVFConfAdapter : public ConfAdapter { class IVFConfAdapter : public ConfAdapter {
public: public:
knowhere::Config knowhere::Config
Match(const TempMetaConf &metaconf) override; Match(const TempMetaConf& metaconf) override;
knowhere::Config knowhere::Config
MatchSearch(const TempMetaConf &metaconf, const IndexType &type) override; MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override;
protected: protected:
static int64_t MatchNlist(const int64_t &size, const int64_t &nlist); static int64_t
MatchNlist(const int64_t& size, const int64_t& nlist);
}; };
class IVFSQConfAdapter : public IVFConfAdapter { class IVFSQConfAdapter : public IVFConfAdapter {
public: public:
knowhere::Config knowhere::Config
Match(const TempMetaConf &metaconf) override; Match(const TempMetaConf& metaconf) override;
}; };
class IVFPQConfAdapter : public IVFConfAdapter { class IVFPQConfAdapter : public IVFConfAdapter {
public: public:
knowhere::Config knowhere::Config
Match(const TempMetaConf &metaconf) override; Match(const TempMetaConf& metaconf) override;
}; };
class NSGConfAdapter : public IVFConfAdapter { class NSGConfAdapter : public IVFConfAdapter {
public: public:
knowhere::Config knowhere::Config
Match(const TempMetaConf &metaconf) override; Match(const TempMetaConf& metaconf) override;
knowhere::Config knowhere::Config
MatchSearch(const TempMetaConf &metaconf, const IndexType &type) final; MatchSearch(const TempMetaConf& metaconf, const IndexType& type) final;
}; };
} } // namespace engine
} } // namespace milvus
} } // namespace zilliz
...@@ -15,18 +15,17 @@ ...@@ -15,18 +15,17 @@
// specific language governing permissions and limitations // specific language governing permissions and limitations
// under the License. // under the License.
#include "wrapper/ConfAdapterMgr.h"
#include "src/utils/Exception.h" #include "utils/Exception.h"
#include "ConfAdapterMgr.h"
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
ConfAdapterPtr ConfAdapterPtr
AdapterMgr::GetAdapter(const IndexType &indexType) { AdapterMgr::GetAdapter(const IndexType& indexType) {
if (!init_) RegisterAdapter(); if (!init_)
RegisterAdapter();
auto it = table_.find(indexType); auto it = table_.find(indexType);
if (it != table_.end()) { if (it != table_.end()) {
...@@ -36,8 +35,8 @@ AdapterMgr::GetAdapter(const IndexType &indexType) { ...@@ -36,8 +35,8 @@ AdapterMgr::GetAdapter(const IndexType &indexType) {
} }
} }
#define REGISTER_CONF_ADAPTER(T, KEY, NAME) static AdapterMgr::register_t<T> reg_##NAME##_(KEY)
#define REGISTER_CONF_ADAPTER(T, KEY, NAME) static AdapterMgr::register_t<T>reg_##NAME##_(KEY)
void void
AdapterMgr::RegisterAdapter() { AdapterMgr::RegisterAdapter() {
init_ = true; init_ = true;
...@@ -58,7 +57,6 @@ AdapterMgr::RegisterAdapter() { ...@@ -58,7 +57,6 @@ AdapterMgr::RegisterAdapter() {
REGISTER_CONF_ADAPTER(NSGConfAdapter, IndexType::NSG_MIX, nsg_mix); REGISTER_CONF_ADAPTER(NSGConfAdapter, IndexType::NSG_MIX, nsg_mix);
} }
} // engine } // namespace engine
} // milvus } // namespace milvus
} // zilliz } // namespace zilliz
...@@ -15,12 +15,13 @@ ...@@ -15,12 +15,13 @@
// specific language governing permissions and limitations // specific language governing permissions and limitations
// under the License. // under the License.
#pragma once #pragma once
#include "VecIndex.h"
#include "ConfAdapter.h" #include "ConfAdapter.h"
#include "VecIndex.h"
#include <map>
#include <memory>
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
...@@ -28,23 +29,21 @@ namespace engine { ...@@ -28,23 +29,21 @@ namespace engine {
class AdapterMgr { class AdapterMgr {
public: public:
template<typename T> template <typename T>
struct register_t { struct register_t {
explicit register_t(const IndexType &key) { explicit register_t(const IndexType& key) {
AdapterMgr::GetInstance().table_.emplace(key, [] { AdapterMgr::GetInstance().table_.emplace(key, [] { return std::make_shared<T>(); });
return std::make_shared<T>();
});
} }
}; };
static AdapterMgr & static AdapterMgr&
GetInstance() { GetInstance() {
static AdapterMgr instance; static AdapterMgr instance;
return instance; return instance;
} }
ConfAdapterPtr ConfAdapterPtr
GetAdapter(const IndexType &indexType); GetAdapter(const IndexType& indexType);
void void
RegisterAdapter(); RegisterAdapter();
...@@ -54,10 +53,6 @@ class AdapterMgr { ...@@ -54,10 +53,6 @@ class AdapterMgr {
std::map<IndexType, std::function<ConfAdapterPtr()> > table_; std::map<IndexType, std::function<ConfAdapterPtr()> > table_;
}; };
} // namespace engine
} // engine } // namespace milvus
} // milvus } // namespace zilliz
} // zilliz
...@@ -15,39 +15,38 @@ ...@@ -15,39 +15,38 @@
// specific language governing permissions and limitations // specific language governing permissions and limitations
// under the License. // under the License.
#include "wrapper/DataTransfer.h" #include "wrapper/DataTransfer.h"
#include <vector>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector>
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
namespace engine { namespace engine {
knowhere::DatasetPtr knowhere::DatasetPtr
GenDatasetWithIds(const int64_t &nb, const int64_t &dim, const float *xb, const int64_t *ids) { GenDatasetWithIds(const int64_t& nb, const int64_t& dim, const float* xb, const int64_t* ids) {
std::vector<int64_t> shape{nb, dim}; std::vector<int64_t> shape{nb, dim};
auto tensor = knowhere::ConstructFloatTensor((uint8_t *) xb, nb * dim * sizeof(float), shape); auto tensor = knowhere::ConstructFloatTensor((uint8_t*)xb, nb * dim * sizeof(float), shape);
std::vector<knowhere::TensorPtr> tensors{tensor}; std::vector<knowhere::TensorPtr> tensors{tensor};
std::vector<knowhere::FieldPtr> tensor_fields{knowhere::ConstructFloatField("data")}; std::vector<knowhere::FieldPtr> tensor_fields{knowhere::ConstructFloatField("data")};
auto tensor_schema = std::make_shared<knowhere::Schema>(tensor_fields); auto tensor_schema = std::make_shared<knowhere::Schema>(tensor_fields);
auto id_array = knowhere::ConstructInt64Array((uint8_t *) ids, nb * sizeof(int64_t)); auto id_array = knowhere::ConstructInt64Array((uint8_t*)ids, nb * sizeof(int64_t));
std::vector<knowhere::ArrayPtr> arrays{id_array}; std::vector<knowhere::ArrayPtr> arrays{id_array};
std::vector<knowhere::FieldPtr> array_fields{knowhere::ConstructInt64Field("id")}; std::vector<knowhere::FieldPtr> array_fields{knowhere::ConstructInt64Field("id")};
auto array_schema = std::make_shared<knowhere::Schema>(tensor_fields); auto array_schema = std::make_shared<knowhere::Schema>(tensor_fields);
auto dataset = std::make_shared<knowhere::Dataset>(std::move(arrays), array_schema, auto dataset =
std::move(tensors), tensor_schema); std::make_shared<knowhere::Dataset>(std::move(arrays), array_schema, std::move(tensors), tensor_schema);
return dataset; return dataset;
} }
knowhere::DatasetPtr knowhere::DatasetPtr
GenDataset(const int64_t &nb, const int64_t &dim, const float *xb) { GenDataset(const int64_t& nb, const int64_t& dim, const float* xb) {
std::vector<int64_t> shape{nb, dim}; std::vector<int64_t> shape{nb, dim};
auto tensor = knowhere::ConstructFloatTensor((uint8_t *) xb, nb * dim * sizeof(float), shape); auto tensor = knowhere::ConstructFloatTensor((uint8_t*)xb, nb * dim * sizeof(float), shape);
std::vector<knowhere::TensorPtr> tensors{tensor}; std::vector<knowhere::TensorPtr> tensors{tensor};
std::vector<knowhere::FieldPtr> tensor_fields{knowhere::ConstructFloatField("data")}; std::vector<knowhere::FieldPtr> tensor_fields{knowhere::ConstructFloatField("data")};
auto tensor_schema = std::make_shared<knowhere::Schema>(tensor_fields); auto tensor_schema = std::make_shared<knowhere::Schema>(tensor_fields);
...@@ -56,6 +55,6 @@ GenDataset(const int64_t &nb, const int64_t &dim, const float *xb) { ...@@ -56,6 +55,6 @@ GenDataset(const int64_t &nb, const int64_t &dim, const float *xb) {
return dataset; return dataset;
} }
} // namespace engine } // namespace engine
} // namespace milvus } // namespace milvus
} // namespace zilliz } // namespace zilliz
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
// specific language governing permissions and limitations // specific language governing permissions and limitations
// under the License. // under the License.
#pragma once #pragma once
#include "knowhere/adapter/Structure.h" #include "knowhere/adapter/Structure.h"
...@@ -25,11 +24,11 @@ namespace milvus { ...@@ -25,11 +24,11 @@ namespace milvus {
namespace engine { namespace engine {
extern zilliz::knowhere::DatasetPtr extern zilliz::knowhere::DatasetPtr
GenDatasetWithIds(const int64_t &nb, const int64_t &dim, const float *xb, const int64_t *ids); GenDatasetWithIds(const int64_t& nb, const int64_t& dim, const float* xb, const int64_t* ids);
extern zilliz::knowhere::DatasetPtr extern zilliz::knowhere::DatasetPtr
GenDataset(const int64_t &nb, const int64_t &dim, const float *xb); GenDataset(const int64_t& nb, const int64_t& dim, const float* xb);
} // namespace engine } // namespace engine
} // namespace milvus } // namespace milvus
} // namespace zilliz } // namespace zilliz
...@@ -15,16 +15,15 @@ ...@@ -15,16 +15,15 @@
// specific language governing permissions and limitations // specific language governing permissions and limitations
// under the License. // under the License.
#include "wrapper/KnowhereResource.h" #include "wrapper/KnowhereResource.h"
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" #include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
#include "server/Config.h" #include "server/Config.h"
#include <map> #include <map>
#include <set> #include <set>
#include <vector>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
...@@ -43,22 +42,24 @@ KnowhereResource::Initialize() { ...@@ -43,22 +42,24 @@ KnowhereResource::Initialize() {
GpuResourcesArray gpu_resources; GpuResourcesArray gpu_resources;
Status s; Status s;
//get build index gpu resource // get build index gpu resource
server::Config &config = server::Config::GetInstance(); server::Config& config = server::Config::GetInstance();
int32_t build_index_gpu; int32_t build_index_gpu;
s = config.GetDBConfigBuildIndexGPU(build_index_gpu); s = config.GetDBConfigBuildIndexGPU(build_index_gpu);
if (!s.ok()) return s; if (!s.ok())
return s;
gpu_resources.insert(std::make_pair(build_index_gpu, GpuResourceSetting())); gpu_resources.insert(std::make_pair(build_index_gpu, GpuResourceSetting()));
//get search gpu resource // get search gpu resource
std::vector<std::string> pool; std::vector<std::string> pool;
s = config.GetResourceConfigPool(pool); s = config.GetResourceConfigPool(pool);
if (!s.ok()) return s; if (!s.ok())
return s;
std::set<uint64_t> gpu_ids; std::set<uint64_t> gpu_ids;
for (auto &resource : pool) { for (auto& resource : pool) {
if (resource.length() < 4 || resource.substr(0, 3) != "gpu") { if (resource.length() < 4 || resource.substr(0, 3) != "gpu") {
// invalid // invalid
continue; continue;
...@@ -67,12 +68,10 @@ KnowhereResource::Initialize() { ...@@ -67,12 +68,10 @@ KnowhereResource::Initialize() {
gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting())); gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting()));
} }
//init gpu resources // init gpu resources
for (auto iter = gpu_resources.begin(); iter != gpu_resources.end(); ++iter) { for (auto iter = gpu_resources.begin(); iter != gpu_resources.end(); ++iter) {
knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(iter->first, knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(iter->first, iter->second.pinned_memory,
iter->second.pinned_memory, iter->second.temp_memory, iter->second.resource_num);
iter->second.temp_memory,
iter->second.resource_num);
} }
return Status::OK(); return Status::OK();
...@@ -80,10 +79,10 @@ KnowhereResource::Initialize() { ...@@ -80,10 +79,10 @@ KnowhereResource::Initialize() {
Status Status
KnowhereResource::Finalize() { KnowhereResource::Finalize() {
knowhere::FaissGpuResourceMgr::GetInstance().Free(); // free gpu resource. knowhere::FaissGpuResourceMgr::GetInstance().Free(); // free gpu resource.
return Status::OK(); return Status::OK();
} }
} // namespace engine } // namespace engine
} // namespace milvus } // namespace milvus
} // namespace zilliz } // namespace zilliz
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
// specific language governing permissions and limitations // specific language governing permissions and limitations
// under the License. // under the License.
#pragma once #pragma once
#include "utils/Status.h" #include "utils/Status.h"
...@@ -33,6 +32,6 @@ class KnowhereResource { ...@@ -33,6 +32,6 @@ class KnowhereResource {
Finalize(); Finalize();
}; };
} // namespace engine } // namespace engine
} // namespace milvus } // namespace milvus
} // namespace zilliz } // namespace zilliz
...@@ -15,14 +15,13 @@ ...@@ -15,14 +15,13 @@
// specific language governing permissions and limitations // specific language governing permissions and limitations
// under the License. // under the License.
#include "wrapper/VecImpl.h" #include "wrapper/VecImpl.h"
#include "utils/Log.h" #include "DataTransfer.h"
#include "knowhere/index/vector_index/IndexIDMAP.h"
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/common/Exception.h" #include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/index/vector_index/IndexIDMAP.h"
#include "knowhere/index/vector_index/helpers/Cloner.h" #include "knowhere/index/vector_index/helpers/Cloner.h"
#include "DataTransfer.h" #include "utils/Log.h"
/* /*
* no parameter check in this layer. * no parameter check in this layer.
...@@ -34,12 +33,8 @@ namespace milvus { ...@@ -34,12 +33,8 @@ namespace milvus {
namespace engine { namespace engine {
Status Status
VecIndexImpl::BuildAll(const int64_t &nb, VecIndexImpl::BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const float *xb, const float* xt) {
const int64_t *ids,
const Config &cfg,
const int64_t &nt,
const float *xt) {
try { try {
dim = cfg->d; dim = cfg->d;
auto dataset = GenDatasetWithIds(nb, dim, xb, ids); auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
...@@ -49,10 +44,10 @@ VecIndexImpl::BuildAll(const int64_t &nb, ...@@ -49,10 +44,10 @@ VecIndexImpl::BuildAll(const int64_t &nb,
auto model = index_->Train(dataset, cfg); auto model = index_->Train(dataset, cfg);
index_->set_index_model(model); index_->set_index_model(model);
index_->Add(dataset, cfg); index_->Add(dataset, cfg);
} catch (knowhere::KnowhereException &e) { } catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception &e) { } catch (std::exception& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_ERROR, e.what()); return Status(KNOWHERE_ERROR, e.what());
} }
...@@ -60,15 +55,15 @@ VecIndexImpl::BuildAll(const int64_t &nb, ...@@ -60,15 +55,15 @@ VecIndexImpl::BuildAll(const int64_t &nb,
} }
Status Status
VecIndexImpl::Add(const int64_t &nb, const float *xb, const int64_t *ids, const Config &cfg) { VecIndexImpl::Add(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg) {
try { try {
auto dataset = GenDatasetWithIds(nb, dim, xb, ids); auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
index_->Add(dataset, cfg); index_->Add(dataset, cfg);
} catch (knowhere::KnowhereException &e) { } catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception &e) { } catch (std::exception& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_ERROR, e.what()); return Status(KNOWHERE_ERROR, e.what());
} }
...@@ -76,7 +71,7 @@ VecIndexImpl::Add(const int64_t &nb, const float *xb, const int64_t *ids, const ...@@ -76,7 +71,7 @@ VecIndexImpl::Add(const int64_t &nb, const float *xb, const int64_t *ids, const
} }
Status Status
VecIndexImpl::Search(const int64_t &nq, const float *xq, float *dist, int64_t *ids, const Config &cfg) { VecIndexImpl::Search(const int64_t& nq, const float* xq, float* dist, int64_t* ids, const Config& cfg) {
try { try {
auto k = cfg->k; auto k = cfg->k;
auto dataset = GenDataset(nq, dim, xq); auto dataset = GenDataset(nq, dim, xq);
...@@ -110,10 +105,10 @@ VecIndexImpl::Search(const int64_t &nq, const float *xq, float *dist, int64_t *i ...@@ -110,10 +105,10 @@ VecIndexImpl::Search(const int64_t &nq, const float *xq, float *dist, int64_t *i
// TODO(linxj): avoid copy here. // TODO(linxj): avoid copy here.
memcpy(ids, p_ids, sizeof(int64_t) * nq * k); memcpy(ids, p_ids, sizeof(int64_t) * nq * k);
memcpy(dist, p_dist, sizeof(float) * nq * k); memcpy(dist, p_dist, sizeof(float) * nq * k);
} catch (knowhere::KnowhereException &e) { } catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception &e) { } catch (std::exception& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_ERROR, e.what()); return Status(KNOWHERE_ERROR, e.what());
} }
...@@ -127,7 +122,7 @@ VecIndexImpl::Serialize() { ...@@ -127,7 +122,7 @@ VecIndexImpl::Serialize() {
} }
Status Status
VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) { VecIndexImpl::Load(const zilliz::knowhere::BinarySet& index_binary) {
index_->Load(index_binary); index_->Load(index_binary);
dim = Dimension(); dim = Dimension();
return Status::OK(); return Status::OK();
...@@ -149,7 +144,7 @@ VecIndexImpl::GetType() { ...@@ -149,7 +144,7 @@ VecIndexImpl::GetType() {
} }
VecIndexPtr VecIndexPtr
VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) { VecIndexImpl::CopyToGpu(const int64_t& device_id, const Config& cfg) {
// TODO(linxj): exception handle // TODO(linxj): exception handle
auto gpu_index = zilliz::knowhere::cloner::CopyCpuToGpu(index_, device_id, cfg); auto gpu_index = zilliz::knowhere::cloner::CopyCpuToGpu(index_, device_id, cfg);
auto new_index = std::make_shared<VecIndexImpl>(gpu_index, ConvertToGpuIndexType(type)); auto new_index = std::make_shared<VecIndexImpl>(gpu_index, ConvertToGpuIndexType(type));
...@@ -158,7 +153,7 @@ VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) { ...@@ -158,7 +153,7 @@ VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) {
} }
VecIndexPtr VecIndexPtr
VecIndexImpl::CopyToCpu(const Config &cfg) { VecIndexImpl::CopyToCpu(const Config& cfg) {
// TODO(linxj): exception handle // TODO(linxj): exception handle
auto cpu_index = zilliz::knowhere::cloner::CopyGpuToCpu(index_, cfg); auto cpu_index = zilliz::knowhere::cloner::CopyGpuToCpu(index_, cfg);
auto new_index = std::make_shared<VecIndexImpl>(cpu_index, ConvertToCpuIndexType(type)); auto new_index = std::make_shared<VecIndexImpl>(cpu_index, ConvertToCpuIndexType(type));
...@@ -180,30 +175,32 @@ VecIndexImpl::GetDeviceId() { ...@@ -180,30 +175,32 @@ VecIndexImpl::GetDeviceId() {
return device_idx->GetGpuDevice(); return device_idx->GetGpuDevice();
} }
// else // else
return -1; // -1 == cpu return -1; // -1 == cpu
} }
float * float*
BFIndex::GetRawVectors() { BFIndex::GetRawVectors() {
auto raw_index = std::dynamic_pointer_cast<knowhere::IDMAP>(index_); auto raw_index = std::dynamic_pointer_cast<knowhere::IDMAP>(index_);
if (raw_index) { return raw_index->GetRawVectors(); } if (raw_index) {
return raw_index->GetRawVectors();
}
return nullptr; return nullptr;
} }
int64_t * int64_t*
BFIndex::GetRawIds() { BFIndex::GetRawIds() {
return std::static_pointer_cast<knowhere::IDMAP>(index_)->GetRawIds(); return std::static_pointer_cast<knowhere::IDMAP>(index_)->GetRawIds();
} }
ErrorCode ErrorCode
BFIndex::Build(const Config &cfg) { BFIndex::Build(const Config& cfg) {
try { try {
dim = cfg->d; dim = cfg->d;
std::static_pointer_cast<knowhere::IDMAP>(index_)->Train(cfg); std::static_pointer_cast<knowhere::IDMAP>(index_)->Train(cfg);
} catch (knowhere::KnowhereException &e) { } catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_UNEXPECTED_ERROR; return KNOWHERE_UNEXPECTED_ERROR;
} catch (std::exception &e) { } catch (std::exception& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return KNOWHERE_ERROR; return KNOWHERE_ERROR;
} }
...@@ -211,22 +208,18 @@ BFIndex::Build(const Config &cfg) { ...@@ -211,22 +208,18 @@ BFIndex::Build(const Config &cfg) {
} }
Status Status
BFIndex::BuildAll(const int64_t &nb, BFIndex::BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const float *xb, const float* xt) {
const int64_t *ids,
const Config &cfg,
const int64_t &nt,
const float *xt) {
try { try {
dim = cfg->d; dim = cfg->d;
auto dataset = GenDatasetWithIds(nb, dim, xb, ids); auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
std::static_pointer_cast<knowhere::IDMAP>(index_)->Train(cfg); std::static_pointer_cast<knowhere::IDMAP>(index_)->Train(cfg);
index_->Add(dataset, cfg); index_->Add(dataset, cfg);
} catch (knowhere::KnowhereException &e) { } catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception &e) { } catch (std::exception& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_ERROR, e.what()); return Status(KNOWHERE_ERROR, e.what());
} }
...@@ -235,12 +228,8 @@ BFIndex::BuildAll(const int64_t &nb, ...@@ -235,12 +228,8 @@ BFIndex::BuildAll(const int64_t &nb,
// TODO(linxj): add lock here. // TODO(linxj): add lock here.
Status Status
IVFMixIndex::BuildAll(const int64_t &nb, IVFMixIndex::BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const float *xb, const float* xt) {
const int64_t *ids,
const Config &cfg,
const int64_t &nt,
const float *xt) {
try { try {
dim = cfg->d; dim = cfg->d;
auto dataset = GenDatasetWithIds(nb, dim, xb, ids); auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
...@@ -259,10 +248,10 @@ IVFMixIndex::BuildAll(const int64_t &nb, ...@@ -259,10 +248,10 @@ IVFMixIndex::BuildAll(const int64_t &nb,
WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed"; WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed";
return Status(KNOWHERE_ERROR, "Build IVFMIXIndex Failed"); return Status(KNOWHERE_ERROR, "Build IVFMIXIndex Failed");
} }
} catch (knowhere::KnowhereException &e) { } catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception &e) { } catch (std::exception& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_ERROR, e.what()); return Status(KNOWHERE_ERROR, e.what());
} }
...@@ -270,12 +259,12 @@ IVFMixIndex::BuildAll(const int64_t &nb, ...@@ -270,12 +259,12 @@ IVFMixIndex::BuildAll(const int64_t &nb,
} }
Status Status
IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) { IVFMixIndex::Load(const zilliz::knowhere::BinarySet& index_binary) {
index_->Load(index_binary); index_->Load(index_binary);
dim = Dimension(); dim = Dimension();
return Status::OK(); return Status::OK();
} }
} // namespace engine } // namespace engine
} // namespace milvus } // namespace milvus
} // namespace zilliz } // namespace zilliz
...@@ -15,14 +15,13 @@ ...@@ -15,14 +15,13 @@
// specific language governing permissions and limitations // specific language governing permissions and limitations
// under the License. // under the License.
#pragma once #pragma once
#include "knowhere/index/vector_index/VectorIndex.h"
#include "VecIndex.h" #include "VecIndex.h"
#include "knowhere/index/vector_index/VectorIndex.h"
#include <utility>
#include <memory> #include <memory>
#include <utility>
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
...@@ -30,23 +29,19 @@ namespace engine { ...@@ -30,23 +29,19 @@ namespace engine {
class VecIndexImpl : public VecIndex { class VecIndexImpl : public VecIndex {
public: public:
explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type) explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType& type)
: index_(std::move(index)), type(type) { : index_(std::move(index)), type(type) {
} }
Status Status
BuildAll(const int64_t &nb, BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const float *xb, const float* xt) override;
const int64_t *ids,
const Config &cfg,
const int64_t &nt,
const float *xt) override;
VecIndexPtr VecIndexPtr
CopyToGpu(const int64_t &device_id, const Config &cfg) override; CopyToGpu(const int64_t& device_id, const Config& cfg) override;
VecIndexPtr VecIndexPtr
CopyToCpu(const Config &cfg) override; CopyToCpu(const Config& cfg) override;
IndexType IndexType
GetType() override; GetType() override;
...@@ -58,13 +53,13 @@ class VecIndexImpl : public VecIndex { ...@@ -58,13 +53,13 @@ class VecIndexImpl : public VecIndex {
Count() override; Count() override;
Status Status
Add(const int64_t &nb, const float *xb, const int64_t *ids, const Config &cfg) override; Add(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg) override;
zilliz::knowhere::BinarySet zilliz::knowhere::BinarySet
Serialize() override; Serialize() override;
Status Status
Load(const zilliz::knowhere::BinarySet &index_binary) override; Load(const zilliz::knowhere::BinarySet& index_binary) override;
VecIndexPtr VecIndexPtr
Clone() override; Clone() override;
...@@ -73,7 +68,7 @@ class VecIndexImpl : public VecIndex { ...@@ -73,7 +68,7 @@ class VecIndexImpl : public VecIndex {
GetDeviceId() override; GetDeviceId() override;
Status Status
Search(const int64_t &nq, const float *xq, float *dist, int64_t *ids, const Config &cfg) override; Search(const int64_t& nq, const float* xq, float* dist, int64_t* ids, const Config& cfg) override;
protected: protected:
int64_t dim = 0; int64_t dim = 0;
...@@ -85,46 +80,38 @@ class VecIndexImpl : public VecIndex { ...@@ -85,46 +80,38 @@ class VecIndexImpl : public VecIndex {
class IVFMixIndex : public VecIndexImpl { class IVFMixIndex : public VecIndexImpl {
public: public:
explicit IVFMixIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type) explicit IVFMixIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType& type)
: VecIndexImpl(std::move(index), type) { : VecIndexImpl(std::move(index), type) {
} }
Status Status
BuildAll(const int64_t &nb, BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const float *xb, const float* xt) override;
const int64_t *ids,
const Config &cfg,
const int64_t &nt,
const float *xt) override;
Status Status
Load(const zilliz::knowhere::BinarySet &index_binary) override; Load(const zilliz::knowhere::BinarySet& index_binary) override;
}; };
class BFIndex : public VecIndexImpl { class BFIndex : public VecIndexImpl {
public: public:
explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index), explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index)
IndexType::FAISS_IDMAP) { : VecIndexImpl(std::move(index), IndexType::FAISS_IDMAP) {
} }
ErrorCode ErrorCode
Build(const Config &cfg); Build(const Config& cfg);
float * float*
GetRawVectors(); GetRawVectors();
Status Status
BuildAll(const int64_t &nb, BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const float *xb, const float* xt) override;
const int64_t *ids,
const Config &cfg, int64_t*
const int64_t &nt,
const float *xt) override;
int64_t *
GetRawIds(); GetRawIds();
}; };
} // namespace engine } // namespace engine
} // namespace milvus } // namespace milvus
} // namespace zilliz } // namespace zilliz
...@@ -16,17 +16,17 @@ ...@@ -16,17 +16,17 @@
// under the License. // under the License.
#include "wrapper/VecIndex.h" #include "wrapper/VecIndex.h"
#include "knowhere/index/vector_index/IndexIVF.h" #include "VecImpl.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexGPUIVF.h" #include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
#include "knowhere/index/vector_index/IndexGPUIVFSQ.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
#include "knowhere/index/vector_index/IndexGPUIVFPQ.h" #include "knowhere/index/vector_index/IndexGPUIVFPQ.h"
#include "knowhere/index/vector_index/IndexGPUIVFSQ.h"
#include "knowhere/index/vector_index/IndexIDMAP.h" #include "knowhere/index/vector_index/IndexIDMAP.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
#include "knowhere/index/vector_index/IndexKDT.h" #include "knowhere/index/vector_index/IndexKDT.h"
#include "knowhere/index/vector_index/IndexNSG.h" #include "knowhere/index/vector_index/IndexNSG.h"
#include "knowhere/common/Exception.h"
#include "VecImpl.h"
#include "utils/Log.h" #include "utils/Log.h"
#include <cuda.h> #include <cuda.h>
...@@ -39,18 +39,18 @@ struct FileIOReader { ...@@ -39,18 +39,18 @@ struct FileIOReader {
std::fstream fs; std::fstream fs;
std::string name; std::string name;
explicit FileIOReader(const std::string &fname); explicit FileIOReader(const std::string& fname);
~FileIOReader(); ~FileIOReader();
size_t size_t
operator()(void *ptr, size_t size); operator()(void* ptr, size_t size);
size_t size_t
operator()(void *ptr, size_t size, size_t pos); operator()(void* ptr, size_t size, size_t pos);
}; };
FileIOReader::FileIOReader(const std::string &fname) { FileIOReader::FileIOReader(const std::string& fname) {
name = fname; name = fname;
fs = std::fstream(name, std::ios::in | std::ios::binary); fs = std::fstream(name, std::ios::in | std::ios::binary);
} }
...@@ -60,12 +60,12 @@ FileIOReader::~FileIOReader() { ...@@ -60,12 +60,12 @@ FileIOReader::~FileIOReader() {
} }
size_t size_t
FileIOReader::operator()(void *ptr, size_t size) { FileIOReader::operator()(void* ptr, size_t size) {
fs.read(reinterpret_cast<char *>(ptr), size); fs.read(reinterpret_cast<char*>(ptr), size);
} }
size_t size_t
FileIOReader::operator()(void *ptr, size_t size, size_t pos) { FileIOReader::operator()(void* ptr, size_t size, size_t pos) {
return 0; return 0;
} }
...@@ -73,12 +73,13 @@ struct FileIOWriter { ...@@ -73,12 +73,13 @@ struct FileIOWriter {
std::fstream fs; std::fstream fs;
std::string name; std::string name;
explicit FileIOWriter(const std::string &fname); explicit FileIOWriter(const std::string& fname);
~FileIOWriter(); ~FileIOWriter();
size_t operator()(void *ptr, size_t size); size_t
operator()(void* ptr, size_t size);
}; };
FileIOWriter::FileIOWriter(const std::string &fname) { FileIOWriter::FileIOWriter(const std::string& fname) {
name = fname; name = fname;
fs = std::fstream(name, std::ios::out | std::ios::binary); fs = std::fstream(name, std::ios::out | std::ios::binary);
} }
...@@ -88,14 +89,14 @@ FileIOWriter::~FileIOWriter() { ...@@ -88,14 +89,14 @@ FileIOWriter::~FileIOWriter() {
} }
size_t size_t
FileIOWriter::operator()(void *ptr, size_t size) { FileIOWriter::operator()(void* ptr, size_t size) {
fs.write(reinterpret_cast<char *>(ptr), size); fs.write(reinterpret_cast<char*>(ptr), size);
} }
VecIndexPtr VecIndexPtr
GetVecIndexFactory(const IndexType &type, const Config &cfg) { GetVecIndexFactory(const IndexType& type, const Config& cfg) {
std::shared_ptr<zilliz::knowhere::VectorIndex> index; std::shared_ptr<zilliz::knowhere::VectorIndex> index;
auto gpu_device = -1; // TODO(linxj): remove hardcode here auto gpu_device = -1; // TODO(linxj): remove hardcode here
switch (type) { switch (type) {
case IndexType::FAISS_IDMAP: { case IndexType::FAISS_IDMAP: {
index = std::make_shared<zilliz::knowhere::IDMAP>(); index = std::make_shared<zilliz::knowhere::IDMAP>();
...@@ -141,22 +142,20 @@ GetVecIndexFactory(const IndexType &type, const Config &cfg) { ...@@ -141,22 +142,20 @@ GetVecIndexFactory(const IndexType &type, const Config &cfg) {
index = std::make_shared<zilliz::knowhere::NSG>(gpu_device); index = std::make_shared<zilliz::knowhere::NSG>(gpu_device);
break; break;
} }
default: { default: { return nullptr; }
return nullptr;
}
} }
return std::make_shared<VecIndexImpl>(index, type); return std::make_shared<VecIndexImpl>(index, type);
} }
VecIndexPtr VecIndexPtr
LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) { LoadVecIndex(const IndexType& index_type, const zilliz::knowhere::BinarySet& index_binary) {
auto index = GetVecIndexFactory(index_type); auto index = GetVecIndexFactory(index_type);
index->Load(index_binary); index->Load(index_binary);
return index; return index;
} }
VecIndexPtr VecIndexPtr
read_index(const std::string &location) { read_index(const std::string& location) {
knowhere::BinarySet load_data_list; knowhere::BinarySet load_data_list;
FileIOReader reader(location); FileIOReader reader(location);
reader.fs.seekg(0, reader.fs.end); reader.fs.seekg(0, reader.fs.end);
...@@ -201,28 +200,28 @@ read_index(const std::string &location) { ...@@ -201,28 +200,28 @@ read_index(const std::string &location) {
} }
Status Status
write_index(VecIndexPtr index, const std::string &location) { write_index(VecIndexPtr index, const std::string& location) {
try { try {
auto binaryset = index->Serialize(); auto binaryset = index->Serialize();
auto index_type = index->GetType(); auto index_type = index->GetType();
FileIOWriter writer(location); FileIOWriter writer(location);
writer(&index_type, sizeof(IndexType)); writer(&index_type, sizeof(IndexType));
for (auto &iter : binaryset.binary_map_) { for (auto& iter : binaryset.binary_map_) {
auto meta = iter.first.c_str(); auto meta = iter.first.c_str();
size_t meta_length = iter.first.length(); size_t meta_length = iter.first.length();
writer(&meta_length, sizeof(meta_length)); writer(&meta_length, sizeof(meta_length));
writer((void *) meta, meta_length); writer((void*)meta, meta_length);
auto binary = iter.second; auto binary = iter.second;
int64_t binary_length = binary->size; int64_t binary_length = binary->size;
writer(&binary_length, sizeof(binary_length)); writer(&binary_length, sizeof(binary_length));
writer((void *) binary->data.get(), binary_length); writer((void*)binary->data.get(), binary_length);
} }
} catch (knowhere::KnowhereException &e) { } catch (knowhere::KnowhereException& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
} catch (std::exception &e) { } catch (std::exception& e) {
WRAPPER_LOG_ERROR << e.what(); WRAPPER_LOG_ERROR << e.what();
std::string estring(e.what()); std::string estring(e.what());
if (estring.find("No space left on device") != estring.npos) { if (estring.find("No space left on device") != estring.npos) {
...@@ -236,7 +235,7 @@ write_index(VecIndexPtr index, const std::string &location) { ...@@ -236,7 +235,7 @@ write_index(VecIndexPtr index, const std::string &location) {
} }
IndexType IndexType
ConvertToCpuIndexType(const IndexType &type) { ConvertToCpuIndexType(const IndexType& type) {
// TODO(linxj): add IDMAP // TODO(linxj): add IDMAP
switch (type) { switch (type) {
case IndexType::FAISS_IVFFLAT_GPU: case IndexType::FAISS_IVFFLAT_GPU:
...@@ -247,14 +246,12 @@ ConvertToCpuIndexType(const IndexType &type) { ...@@ -247,14 +246,12 @@ ConvertToCpuIndexType(const IndexType &type) {
case IndexType::FAISS_IVFSQ8_MIX: { case IndexType::FAISS_IVFSQ8_MIX: {
return IndexType::FAISS_IVFSQ8_CPU; return IndexType::FAISS_IVFSQ8_CPU;
} }
default: { default: { return type; }
return type;
}
} }
} }
IndexType IndexType
ConvertToGpuIndexType(const IndexType &type) { ConvertToGpuIndexType(const IndexType& type) {
switch (type) { switch (type) {
case IndexType::FAISS_IVFFLAT_MIX: case IndexType::FAISS_IVFFLAT_MIX:
case IndexType::FAISS_IVFFLAT_CPU: { case IndexType::FAISS_IVFFLAT_CPU: {
...@@ -264,12 +261,10 @@ ConvertToGpuIndexType(const IndexType &type) { ...@@ -264,12 +261,10 @@ ConvertToGpuIndexType(const IndexType &type) {
case IndexType::FAISS_IVFSQ8_CPU: { case IndexType::FAISS_IVFSQ8_CPU: {
return IndexType::FAISS_IVFSQ8_GPU; return IndexType::FAISS_IVFSQ8_GPU;
} }
default: { default: { return type; }
return type;
}
} }
} }
} // namespace engine } // namespace engine
} // namespace milvus } // namespace milvus
} // namespace zilliz } // namespace zilliz
...@@ -15,15 +15,14 @@ ...@@ -15,15 +15,14 @@
// specific language governing permissions and limitations // specific language governing permissions and limitations
// under the License. // under the License.
#pragma once #pragma once
#include <string>
#include <memory> #include <memory>
#include <string>
#include "utils/Status.h"
#include "knowhere/common/Config.h"
#include "knowhere/common/BinarySet.h" #include "knowhere/common/BinarySet.h"
#include "knowhere/common/Config.h"
#include "utils/Status.h"
namespace zilliz { namespace zilliz {
namespace milvus { namespace milvus {
...@@ -36,7 +35,7 @@ enum class IndexType { ...@@ -36,7 +35,7 @@ enum class IndexType {
FAISS_IDMAP = 1, FAISS_IDMAP = 1,
FAISS_IVFFLAT_CPU, FAISS_IVFFLAT_CPU,
FAISS_IVFFLAT_GPU, FAISS_IVFFLAT_GPU,
FAISS_IVFFLAT_MIX, // build on gpu and search on cpu FAISS_IVFFLAT_MIX, // build on gpu and search on cpu
FAISS_IVFPQ_CPU, FAISS_IVFPQ_CPU,
FAISS_IVFPQ_GPU, FAISS_IVFPQ_GPU,
SPTAG_KDT_RNT_CPU, SPTAG_KDT_RNT_CPU,
...@@ -53,32 +52,20 @@ using VecIndexPtr = std::shared_ptr<VecIndex>; ...@@ -53,32 +52,20 @@ using VecIndexPtr = std::shared_ptr<VecIndex>;
class VecIndex { class VecIndex {
public: public:
virtual Status virtual Status
BuildAll(const int64_t &nb, BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt = 0,
const float *xb, const float* xt = nullptr) = 0;
const int64_t *ids,
const Config &cfg,
const int64_t &nt = 0,
const float *xt = nullptr) = 0;
virtual Status virtual Status
Add(const int64_t &nb, Add(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg = Config()) = 0;
const float *xb,
const int64_t *ids,
const Config &cfg = Config()) = 0;
virtual Status virtual Status
Search(const int64_t &nq, Search(const int64_t& nq, const float* xq, float* dist, int64_t* ids, const Config& cfg = Config()) = 0;
const float *xq,
float *dist,
int64_t *ids,
const Config &cfg = Config()) = 0;
virtual VecIndexPtr virtual VecIndexPtr
CopyToGpu(const int64_t &device_id, CopyToGpu(const int64_t& device_id, const Config& cfg = Config()) = 0;
const Config &cfg = Config()) = 0;
virtual VecIndexPtr virtual VecIndexPtr
CopyToCpu(const Config &cfg = Config()) = 0; CopyToCpu(const Config& cfg = Config()) = 0;
virtual VecIndexPtr virtual VecIndexPtr
Clone() = 0; Clone() = 0;
...@@ -99,27 +86,27 @@ class VecIndex { ...@@ -99,27 +86,27 @@ class VecIndex {
Serialize() = 0; Serialize() = 0;
virtual Status virtual Status
Load(const zilliz::knowhere::BinarySet &index_binary) = 0; Load(const zilliz::knowhere::BinarySet& index_binary) = 0;
}; };
extern Status extern Status
write_index(VecIndexPtr index, const std::string &location); write_index(VecIndexPtr index, const std::string& location);
extern VecIndexPtr extern VecIndexPtr
read_index(const std::string &location); read_index(const std::string& location);
extern VecIndexPtr extern VecIndexPtr
GetVecIndexFactory(const IndexType &type, const Config &cfg = Config()); GetVecIndexFactory(const IndexType& type, const Config& cfg = Config());
extern VecIndexPtr extern VecIndexPtr
LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary); LoadVecIndex(const IndexType& index_type, const zilliz::knowhere::BinarySet& index_binary);
extern IndexType extern IndexType
ConvertToCpuIndexType(const IndexType &type); ConvertToCpuIndexType(const IndexType& type);
extern IndexType extern IndexType
ConvertToGpuIndexType(const IndexType &type); ConvertToGpuIndexType(const IndexType& type);
} // namespace engine } // namespace engine
} // namespace milvus } // namespace milvus
} // namespace zilliz } // namespace zilliz
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册