提交 813f5151 编写于 作者: X xj.lin

MS-267 Support Inner Product

update..


Former-commit-id: 8fdbb39fbdd05853f25c374f437f3ed78a46345d
上级 ba7e3c3d
......@@ -30,7 +30,10 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension,
index_ = CreatetVecIndex(EngineType::FAISS_IDMAP);
if (!index_) throw Exception("Create Empty VecIndex");
auto ec = std::static_pointer_cast<BFIndex>(index_)->Build(dimension);
Config build_cfg;
build_cfg["dim"] = dimension;
AutoGenParams(index_->GetType(), 0, build_cfg);
auto ec = std::static_pointer_cast<BFIndex>(index_)->Build(build_cfg);
if (ec != server::KNOWHERE_SUCCESS) { throw Exception("Build index error"); }
}
......
......@@ -144,10 +144,10 @@ int64_t *BFIndex::GetRawIds() {
return std::static_pointer_cast<IDMAP>(index_)->GetRawIds();
}
server::KnowhereError BFIndex::Build(const int64_t &d) {
server::KnowhereError BFIndex::Build(const Config &cfg) {
try {
dim = d;
std::static_pointer_cast<IDMAP>(index_)->Train(dim);
dim = cfg["dim"].as<int>();
std::static_pointer_cast<IDMAP>(index_)->Train(cfg);
} catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return server::KNOWHERE_UNEXPECTED_ERROR;
......@@ -171,7 +171,7 @@ server::KnowhereError BFIndex::BuildAll(const long &nb,
dim = cfg["dim"].as<int>();
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
std::static_pointer_cast<IDMAP>(index_)->Train(dim);
std::static_pointer_cast<IDMAP>(index_)->Train(cfg);
index_->Add(dataset, cfg);
} catch (KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
......
......@@ -57,7 +57,7 @@ class BFIndex : public VecIndexImpl {
public:
explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index),
IndexType::FAISS_IDMAP) {};
server::KnowhereError Build(const int64_t &d);
server::KnowhereError Build(const Config& cfg);
float *GetRawVectors();
server::KnowhereError BuildAll(const long &nb,
const float *xb,
......
......@@ -180,7 +180,7 @@ server::KnowhereError write_index(VecIndexPtr index, const std::string &location
} catch (knowhere::KnowhereException &e) {
WRAPPER_LOG_ERROR << e.what();
return server::KNOWHERE_UNEXPECTED_ERROR;
} catch (std::exception& e) {
} catch (std::exception &e) {
WRAPPER_LOG_ERROR << e.what();
return server::KNOWHERE_ERROR;
}
......@@ -192,6 +192,7 @@ server::KnowhereError write_index(VecIndexPtr index, const std::string &location
void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Config &cfg) {
if (!cfg.contains("nlist")) { cfg["nlist"] = int(size / 1000000.0 * 16384); }
if (!cfg.contains("gpu_id")) { cfg["gpu_id"] = int(0); }
if (!cfg.contains("metric_type")) { cfg["metric_type"] = "L2"; }
switch (type) {
case IndexType::FAISS_IVFSQ8_MIX: {
......
knowhere @ f866ac4e
Subproject commit b0b9dd18fadbf9dc0fccaad815e14e578a92993e
Subproject commit f866ac4e297dea477ec591a62679cf5cdd219cc8
......@@ -96,17 +96,17 @@ INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest,
//),
std::make_tuple(IndexType::FAISS_IVFFLAT_MIX, "Default",
64, 100000, 10, 10,
Config::object{{"nlist", 1000}, {"dim", 64}},
Config::object{{"nlist", 1000}, {"dim", 64}, {"metric_type", "L2"}},
Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 5}}
),
std::make_tuple(IndexType::FAISS_IDMAP, "Default",
64, 100000, 10, 10,
Config::object{{"dim", 64}},
Config::object{{"dim", 64}, {"metric_type", "L2"}},
Config::object{{"dim", 64}, {"k", 10}}
),
std::make_tuple(IndexType::FAISS_IVFSQ8_MIX, "Default",
64, 100000, 10, 10,
Config::object{{"dim", 64}, {"nlist", 1000}, {"nbits", 8}},
Config::object{{"dim", 64}, {"nlist", 1000}, {"nbits", 8}, {"metric_type", "L2"}},
Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 5}}
)
//std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册