未验证 提交 f8d9bc91 编写于 作者: J Jiquan Long 提交者: GitHub

Unify interface of vector index & scalar index. (#15959)

Signed-off-by: Ndragondriver <jiquan.long@zilliz.com>
上级 c16195c7
......@@ -17,7 +17,7 @@
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
#include "pb/index_cgo_msg.pb.h"
#include "indexbuilder/IndexWrapper.h"
#include "indexbuilder/VecIndexCreator.h"
#include "indexbuilder/index_c.h"
#include "indexbuilder/utils.h"
#include "test_utils/indexbuilder_test_utils.h"
......@@ -64,7 +64,7 @@ IndexBuilder_build(benchmark::State& state) {
for (auto _ : state) {
auto index =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
std::make_unique<milvus::indexbuilder::VecIndexCreator>(type_params_str.c_str(), index_params_str.c_str());
index->BuildWithoutIds(xb_dataset);
}
}
......@@ -93,7 +93,7 @@ IndexBuilder_build_and_codec(benchmark::State& state) {
for (auto _ : state) {
auto index =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
std::make_unique<milvus::indexbuilder::VecIndexCreator>(type_params_str.c_str(), index_params_str.c_str());
index->BuildWithoutIds(xb_dataset);
index->Serialize();
......
......@@ -11,9 +11,11 @@
set(INDEXBUILDER_FILES
IndexWrapper.cpp
VecIndexCreator.cpp
index_c.cpp
init_c.cpp
utils.cpp
StringIndexImpl.cpp
)
add_library(milvus_indexbuilder SHARED
${INDEXBUILDER_FILES}
......@@ -27,8 +29,8 @@ endif ()
# link order matters
target_link_libraries(milvus_indexbuilder
knowhere
milvus_common
knowhere
${TBB}
${PLATFORM_LIBS}
pthread
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "knowhere/common/Dataset.h"
#include "knowhere/common/BinarySet.h"
#include <memory>
#include <knowhere/index/Index.h>
namespace milvus::indexbuilder {
class IndexCreatorBase {
public:
virtual ~IndexCreatorBase() = default;
virtual void
Build(const knowhere::DatasetPtr& dataset) = 0;
virtual knowhere::BinarySet
Serialize() = 0;
virtual void
Load(const knowhere::BinarySet&) = 0;
// virtual knowhere::IndexPtr
// GetIndex() = 0;
};
using IndexCreatorBasePtr = std::unique_ptr<IndexCreatorBase>;
} // namespace milvus::indexbuilder
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <pb/schema.pb.h>
#include <cmath>
#include "indexbuilder/IndexCreatorBase.h"
#include "indexbuilder/ScalarIndexCreator.h"
#include "indexbuilder/VecIndexCreator.h"
#include "indexbuilder/type_c.h"
#include <memory>
#include <string>
namespace milvus::indexbuilder {
// consider template factory if too many factories are needed.
class IndexFactory {
public:
IndexFactory() = default;
IndexFactory(const IndexFactory&) = delete;
IndexFactory
operator=(const IndexFactory&) = delete;
public:
static IndexFactory&
GetInstance() {
// thread-safe enough after c++ 11
static IndexFactory instance;
return instance;
}
IndexCreatorBasePtr
CreateIndex(DataType dtype, const char* type_params, const char* index_params) {
auto real_dtype = proto::schema::DataType(dtype);
auto invalid_dtype_msg = std::string("invalid data type: ") + std::to_string(real_dtype);
switch (real_dtype) {
case milvus::proto::schema::Bool:
return std::make_unique<ScalarIndexCreator<bool>>(type_params, index_params);
case milvus::proto::schema::Int8:
return std::make_unique<ScalarIndexCreator<int8_t>>(type_params, index_params);
case milvus::proto::schema::Int16:
return std::make_unique<ScalarIndexCreator<int16_t>>(type_params, index_params);
case milvus::proto::schema::Int32:
return std::make_unique<ScalarIndexCreator<int32_t>>(type_params, index_params);
case milvus::proto::schema::Int64:
return std::make_unique<ScalarIndexCreator<int64_t>>(type_params, index_params);
case milvus::proto::schema::Float:
return std::make_unique<ScalarIndexCreator<float_t>>(type_params, index_params);
case milvus::proto::schema::Double:
return std::make_unique<ScalarIndexCreator<double_t>>(type_params, index_params);
case proto::schema::VarChar:
case milvus::proto::schema::String:
return std::make_unique<ScalarIndexCreator<std::string>>(type_params, index_params);
case milvus::proto::schema::BinaryVector:
case milvus::proto::schema::FloatVector:
return std::make_unique<VecIndexCreator>(type_params, index_params);
case milvus::proto::schema::None:
case milvus::proto::schema::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
case milvus::proto::schema::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
default:
throw std::invalid_argument(invalid_dtype_msg);
}
}
};
} // namespace milvus::indexbuilder
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <knowhere/index/structured_index_simple/StructuredIndexSort.h>
#include <pb/schema.pb.h>
#include "indexbuilder/helper.h"
#include "indexbuilder/StringIndexImpl.h"
#include <string>
#include <memory>
#include <vector>
namespace milvus::indexbuilder {
template <typename T>
inline ScalarIndexCreator<T>::ScalarIndexCreator(const char* type_params, const char* index_params) {
// TODO: move parse-related logic to a common interface.
Helper::ParseFromString(type_params_, std::string(type_params));
Helper::ParseFromString(index_params_, std::string(index_params));
// TODO: create index according to the params.
index_ = std::make_unique<knowhere::scalar::StructuredIndexSort<T>>();
}
template <typename T>
inline void
ScalarIndexCreator<T>::Build(const knowhere::DatasetPtr& dataset) {
auto size = dataset->Get<int64_t>(knowhere::meta::ROWS);
auto data = dataset->Get<const void*>(knowhere::meta::TENSOR);
index_->Build(size, reinterpret_cast<const T*>(data));
}
template <typename T>
inline knowhere::BinarySet
ScalarIndexCreator<T>::Serialize() {
return index_->Serialize(config_);
}
template <typename T>
inline void
ScalarIndexCreator<T>::Load(const knowhere::BinarySet& binary_set) {
index_->Load(binary_set);
}
// not sure that the pointer of a golang bool array acts like other types.
template <>
inline void
ScalarIndexCreator<bool>::Build(const milvus::knowhere::DatasetPtr& dataset) {
auto size = dataset->Get<int64_t>(knowhere::meta::ROWS);
auto data = dataset->Get<const void*>(knowhere::meta::TENSOR);
proto::schema::BoolArray arr;
Helper::ParseParams(arr, data, size);
index_->Build(arr.data().size(), arr.data().data());
}
template <>
inline ScalarIndexCreator<std::string>::ScalarIndexCreator(const char* type_params, const char* index_params) {
// TODO: move parse-related logic to a common interface.
Helper::ParseFromString(type_params_, std::string(type_params));
Helper::ParseFromString(index_params_, std::string(index_params));
// TODO: create index according to the params.
index_ = std::make_unique<StringIndexImpl>();
}
template <>
inline void
ScalarIndexCreator<std::string>::Build(const milvus::knowhere::DatasetPtr& dataset) {
auto size = dataset->Get<int64_t>(knowhere::meta::ROWS);
auto data = dataset->Get<const void*>(knowhere::meta::TENSOR);
proto::schema::StringArray arr;
Helper::ParseParams(arr, data, size);
// TODO: optimize here. avoid memory copy.
std::vector<std::string> vecs{arr.data().begin(), arr.data().end()};
index_->Build(arr.data().size(), vecs.data());
}
} // namespace milvus::indexbuilder
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "indexbuilder/IndexCreatorBase.h"
#include "knowhere/index/structured_index_simple/StructuredIndex.h"
#include "pb/index_cgo_msg.pb.h"
#include <string>
#include <memory>
namespace milvus::indexbuilder {
template <typename T>
class ScalarIndexCreator : public IndexCreatorBase {
// of course, maybe we can support combination index later.
// for example, we can create index for combination of (field a, field b),
// attribute filtering on the combination can be speed up.
static_assert(std::is_fundamental_v<T> || std::is_same_v<T, std::string>);
public:
ScalarIndexCreator(const char* type_params, const char* index_params);
void
Build(const knowhere::DatasetPtr& dataset) override;
knowhere::BinarySet
Serialize() override;
void
Load(const knowhere::BinarySet&) override;
private:
std::unique_ptr<knowhere::scalar::StructuredIndex<T>> index_ = nullptr;
proto::indexcgo::TypeParams type_params_;
proto::indexcgo::IndexParams index_params_;
milvus::json config_;
};
} // namespace milvus::indexbuilder
#include "ScalarIndexCreator-inl.h"
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "StringIndexImpl.h"
namespace milvus::indexbuilder {
// TODO: optimize here.
knowhere::BinarySet
StringIndexImpl::Serialize(const knowhere::Config& config) {
knowhere::BinarySet res_set;
auto data = this->GetData();
for (const auto& record : data) {
auto idx = record.idx_;
auto str = record.a_;
std::shared_ptr<uint8_t[]> content(new uint8_t[str.length()]);
memcpy(content.get(), str.c_str(), str.length());
res_set.Append(std::to_string(idx), content, str.length());
}
return res_set;
}
void
StringIndexImpl::Load(const knowhere::BinarySet& index_binary) {
std::vector<std::string> vecs;
for (const auto& [k, v] : index_binary.binary_map_) {
std::string str(reinterpret_cast<const char*>(v->data.get()), v->size);
vecs.emplace_back(str);
}
Build(vecs.size(), vecs.data());
}
} // namespace milvus::indexbuilder
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <knowhere/index/structured_index_simple/StructuredIndexSort.h>
#include <string>
namespace milvus::indexbuilder {
class StringIndexImpl : public knowhere::scalar::StructuredIndexSort<std::string> {
public:
knowhere::BinarySet
Serialize(const knowhere::Config& config) override;
void
Load(const knowhere::BinarySet& index_binary) override;
};
} // namespace milvus::indexbuilder
......@@ -15,7 +15,7 @@
#include "exceptions/EasyAssert.h"
#include "pb/index_cgo_msg.pb.h"
#include "indexbuilder/IndexWrapper.h"
#include "indexbuilder/VecIndexCreator.h"
#include "indexbuilder/utils.h"
#include "knowhere/common/Timer.h"
#include "knowhere/common/Utils.h"
......@@ -25,7 +25,7 @@
namespace milvus::indexbuilder {
IndexWrapper::IndexWrapper(const char* serialized_type_params, const char* serialized_index_params) {
VecIndexCreator::VecIndexCreator(const char* serialized_type_params, const char* serialized_index_params) {
type_params_ = std::string(serialized_type_params);
index_params_ = std::string(serialized_index_params);
......@@ -37,18 +37,18 @@ IndexWrapper::IndexWrapper(const char* serialized_type_params, const char* seria
AssertInfo(!is_unsupported(index_type, metric_type), index_type + " doesn't support metric: " + metric_type);
index_ = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(get_index_type(), index_mode);
AssertInfo(index_ != nullptr, "[IndexWrapper]Index is null after create index");
AssertInfo(index_ != nullptr, "[VecIndexCreator]Index is null after create index");
}
template <typename ParamsT>
// ugly here, ParamsT will just be MapParams later
void
IndexWrapper::parse_impl(const std::string& serialized_params_str, knowhere::Config& conf) {
VecIndexCreator::parse_impl(const std::string& serialized_params_str, knowhere::Config& conf) {
bool deserialized_success;
ParamsT params;
deserialized_success = google::protobuf::TextFormat::ParseFromString(serialized_params_str, &params);
AssertInfo(deserialized_success, "[IndexWrapper]Deserialize params failed");
AssertInfo(deserialized_success, "[VecIndexCreator]Deserialize params failed");
for (auto i = 0; i < params.params_size(); ++i) {
const auto& param = params.params(i);
......@@ -112,7 +112,7 @@ IndexWrapper::parse_impl(const std::string& serialized_params_str, knowhere::Con
}
void
IndexWrapper::parse() {
VecIndexCreator::parse() {
namespace indexcgo = milvus::proto::indexcgo;
parse_impl<indexcgo::TypeParams>(type_params_, type_config_);
......@@ -124,10 +124,10 @@ IndexWrapper::parse() {
template <typename T>
void
IndexWrapper::check_parameter(knowhere::Config& conf,
const std::string& key,
std::function<T(std::string)> fn,
std::optional<T> default_v) {
VecIndexCreator::check_parameter(knowhere::Config& conf,
const std::string& key,
std::function<T(std::string)> fn,
std::optional<T> default_v) {
if (!conf.contains(key)) {
if (default_v.has_value()) {
conf[key] = default_v.value();
......@@ -140,7 +140,7 @@ IndexWrapper::check_parameter(knowhere::Config& conf,
template <typename T>
std::optional<T>
IndexWrapper::get_config_by_name(std::string name) {
VecIndexCreator::get_config_by_name(std::string name) {
if (config_.contains(name)) {
return {config_[name].get<T>()};
}
......@@ -148,14 +148,14 @@ IndexWrapper::get_config_by_name(std::string name) {
}
int64_t
IndexWrapper::dim() {
VecIndexCreator::dim() {
auto dimension = get_config_by_name<int64_t>(milvus::knowhere::meta::DIM);
AssertInfo(dimension.has_value(), "[IndexWrapper]Dimension doesn't have value");
AssertInfo(dimension.has_value(), "[VecIndexCreator]Dimension doesn't have value");
return (dimension.value());
}
void
IndexWrapper::BuildWithoutIds(const knowhere::DatasetPtr& dataset) {
VecIndexCreator::BuildWithoutIds(const knowhere::DatasetPtr& dataset) {
auto index_type = get_index_type();
auto index_mode = get_index_mode();
config_[knowhere::meta::ROWS] = dataset->Get<int64_t>(knowhere::meta::ROWS);
......@@ -189,9 +189,9 @@ IndexWrapper::BuildWithoutIds(const knowhere::DatasetPtr& dataset) {
}
void
IndexWrapper::BuildWithIds(const knowhere::DatasetPtr& dataset) {
VecIndexCreator::BuildWithIds(const knowhere::DatasetPtr& dataset) {
AssertInfo(dataset->data().find(milvus::knowhere::meta::IDS) != dataset->data().end(),
"[IndexWrapper]Can't find ids field in dataset");
"[VecIndexCreator]Can't find ids field in dataset");
auto index_type = get_index_type();
auto index_mode = get_index_mode();
config_[knowhere::meta::ROWS] = dataset->Get<int64_t>(knowhere::meta::ROWS);
......@@ -212,7 +212,7 @@ IndexWrapper::BuildWithIds(const knowhere::DatasetPtr& dataset) {
}
void
IndexWrapper::StoreRawData(const knowhere::DatasetPtr& dataset) {
VecIndexCreator::StoreRawData(const knowhere::DatasetPtr& dataset) {
auto index_type = get_index_type();
if (is_in_nm_list(index_type)) {
auto tensor = dataset->Get<const void*>(milvus::knowhere::meta::TENSOR);
......@@ -229,64 +229,25 @@ IndexWrapper::StoreRawData(const knowhere::DatasetPtr& dataset) {
}
}
std::unique_ptr<milvus::knowhere::BinarySet>
IndexWrapper::SerializeBinarySet() {
auto ret = std::make_unique<milvus::knowhere::BinarySet>(index_->Serialize(config_));
milvus::knowhere::BinarySet
VecIndexCreator::Serialize() {
auto ret = index_->Serialize(config_);
auto index_type = get_index_type();
if (is_in_nm_list(index_type)) {
std::shared_ptr<uint8_t[]> raw_data(new uint8_t[raw_data_.size()], std::default_delete<uint8_t[]>());
memcpy(raw_data.get(), raw_data_.data(), raw_data_.size());
ret->Append(RAW_DATA, raw_data, raw_data_.size());
ret.Append(RAW_DATA, raw_data, raw_data_.size());
auto slice_size = get_index_file_slice_size();
// https://github.com/milvus-io/milvus/issues/6421
// Disassemble will only divide the raw vectors, other keys were already divided
knowhere::Disassemble(slice_size * 1024 * 1024, *ret);
knowhere::Disassemble(slice_size * 1024 * 1024, ret);
}
return std::move(ret);
}
/*
* brief Return serialized binary set
* TODO: use a more efficient method to manage memory, consider std::vector later
*/
std::unique_ptr<IndexWrapper::Binary>
IndexWrapper::Serialize() {
auto binarySet = index_->Serialize(config_);
auto index_type = get_index_type();
if (is_in_nm_list(index_type)) {
std::shared_ptr<uint8_t[]> raw_data(new uint8_t[raw_data_.size()], std::default_delete<uint8_t[]>());
memcpy(raw_data.get(), raw_data_.data(), raw_data_.size());
binarySet.Append(RAW_DATA, raw_data, raw_data_.size());
auto slice_size = get_index_file_slice_size();
// https://github.com/milvus-io/milvus/issues/6421
// Disassemble will only divide the raw vectors, other keys were already divided
knowhere::Disassemble(slice_size * 1024 * 1024, binarySet);
}
namespace indexcgo = milvus::proto::indexcgo;
indexcgo::BinarySet ret;
for (auto [key, value] : binarySet.binary_map_) {
auto binary = ret.add_datas();
binary->set_key(key);
binary->set_value(value->data.get(), value->size);
}
std::string serialized_data;
auto ok = ret.SerializeToString(&serialized_data);
AssertInfo(ok, "[IndexWrapper]Can't serialize data to string");
auto binary = std::make_unique<IndexWrapper::Binary>();
binary->data.resize(serialized_data.length());
memcpy(binary->data.data(), serialized_data.c_str(), serialized_data.length());
return binary;
return ret;
}
void
IndexWrapper::LoadFromBinarySet(milvus::knowhere::BinarySet& binary_set) {
VecIndexCreator::Load(const milvus::knowhere::BinarySet& binary_set) {
auto& map_ = binary_set.binary_map_;
for (auto it = map_.begin(); it != map_.end(); ++it) {
if (it->first == RAW_DATA) {
......@@ -300,30 +261,8 @@ IndexWrapper::LoadFromBinarySet(milvus::knowhere::BinarySet& binary_set) {
index_->Load(binary_set);
}
void
IndexWrapper::Load(const char* serialized_sliced_blob_buffer, int32_t size) {
namespace indexcgo = milvus::proto::indexcgo;
auto data = std::string(serialized_sliced_blob_buffer, size);
indexcgo::BinarySet blob_buffer;
auto ok = blob_buffer.ParseFromString(data);
AssertInfo(ok, "[IndexWrapper]Can't parse data from string to blob_buffer");
milvus::knowhere::BinarySet binarySet;
for (auto i = 0; i < blob_buffer.datas_size(); i++) {
const auto& binary = blob_buffer.datas(i);
auto deleter = [&](uint8_t*) {}; // avoid repeated destruction
auto bptr = std::make_shared<milvus::knowhere::Binary>();
bptr->data = std::shared_ptr<uint8_t[]>((uint8_t*)binary.value().c_str(), deleter);
bptr->size = binary.value().length();
binarySet.Append(binary.key(), bptr);
}
index_->Load(binarySet);
}
std::string
IndexWrapper::get_index_type() {
VecIndexCreator::get_index_type() {
// return index_->index_type();
// knowhere bug here
// the index_type of all ivf-based index will change to ivf flat after loaded
......@@ -332,7 +271,7 @@ IndexWrapper::get_index_type() {
}
std::string
IndexWrapper::get_metric_type() {
VecIndexCreator::get_metric_type() {
auto type = get_config_by_name<std::string>(knowhere::Metric::TYPE);
if (type.has_value()) {
return type.value();
......@@ -347,7 +286,7 @@ IndexWrapper::get_metric_type() {
}
knowhere::IndexMode
IndexWrapper::get_index_mode() {
VecIndexCreator::get_index_mode() {
static std::map<std::string, knowhere::IndexMode> mode_map = {
{"CPU", knowhere::IndexMode::MODE_CPU},
{"GPU", knowhere::IndexMode::MODE_GPU},
......@@ -357,20 +296,20 @@ IndexWrapper::get_index_mode() {
}
int64_t
IndexWrapper::get_index_file_slice_size() {
VecIndexCreator::get_index_file_slice_size() {
if (config_.contains(knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
return config_[knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>();
}
return 4; // by default
}
std::unique_ptr<IndexWrapper::QueryResult>
IndexWrapper::Query(const knowhere::DatasetPtr& dataset) {
std::unique_ptr<VecIndexCreator::QueryResult>
VecIndexCreator::Query(const knowhere::DatasetPtr& dataset) {
return std::move(QueryImpl(dataset, config_));
}
std::unique_ptr<IndexWrapper::QueryResult>
IndexWrapper::QueryWithParam(const knowhere::DatasetPtr& dataset, const char* serialized_search_params) {
std::unique_ptr<VecIndexCreator::QueryResult>
VecIndexCreator::QueryWithParam(const knowhere::DatasetPtr& dataset, const char* serialized_search_params) {
namespace indexcgo = milvus::proto::indexcgo;
milvus::knowhere::Config search_conf;
parse_impl<indexcgo::MapParams>(std::string(serialized_search_params), search_conf);
......@@ -378,8 +317,8 @@ IndexWrapper::QueryWithParam(const knowhere::DatasetPtr& dataset, const char* se
return std::move(QueryImpl(dataset, search_conf));
}
std::unique_ptr<IndexWrapper::QueryResult>
IndexWrapper::QueryImpl(const knowhere::DatasetPtr& dataset, const knowhere::Config& conf) {
std::unique_ptr<VecIndexCreator::QueryResult>
VecIndexCreator::QueryImpl(const knowhere::DatasetPtr& dataset, const knowhere::Config& conf) {
auto load_raw_data_closure = [&]() { LoadRawData(); }; // hide this pointer
auto index_type = get_index_type();
if (is_in_nm_list(index_type)) {
......@@ -392,7 +331,7 @@ IndexWrapper::QueryImpl(const knowhere::DatasetPtr& dataset, const knowhere::Con
auto nq = dataset->Get<int64_t>(milvus::knowhere::meta::ROWS);
auto k = config_[milvus::knowhere::meta::TOPK].get<int64_t>();
auto query_res = std::make_unique<IndexWrapper::QueryResult>();
auto query_res = std::make_unique<VecIndexCreator::QueryResult>();
query_res->nq = nq;
query_res->topk = k;
query_res->ids.resize(nq * k);
......@@ -404,7 +343,7 @@ IndexWrapper::QueryImpl(const knowhere::DatasetPtr& dataset, const knowhere::Con
}
void
IndexWrapper::LoadRawData() {
VecIndexCreator::LoadRawData() {
auto index_type = get_index_type();
if (is_in_nm_list(index_type)) {
auto bs = index_->Serialize(config_);
......
......@@ -18,35 +18,31 @@
#include "knowhere/index/vector_index/VecIndex.h"
#include "knowhere/common/BinarySet.h"
#include "indexbuilder/IndexCreatorBase.h"
namespace milvus::indexbuilder {
class IndexWrapper {
// TODO: better to distinguish binary vec & float vec.
class VecIndexCreator : public IndexCreatorBase {
public:
explicit IndexWrapper(const char* serialized_type_params, const char* serialized_index_params);
int64_t
dim();
explicit VecIndexCreator(const char* serialized_type_params, const char* serialized_index_params);
void
BuildWithoutIds(const knowhere::DatasetPtr& dataset);
struct Binary {
std::vector<char> data;
};
std::unique_ptr<Binary>
Serialize();
Build(const knowhere::DatasetPtr& dataset) override {
BuildWithoutIds(dataset);
}
std::unique_ptr<milvus::knowhere::BinarySet>
SerializeBinarySet();
knowhere::BinarySet
Serialize() override;
void
LoadFromBinarySet(milvus::knowhere::BinarySet&);
Load(const knowhere::BinarySet& binary_set) override;
void
Load(const char* serialized_sliced_blob_buffer, int32_t size);
int64_t
dim();
public:
// used for tests
struct QueryResult {
std::vector<milvus::knowhere::IDType> ids;
std::vector<float> distances;
......@@ -104,6 +100,9 @@ class IndexWrapper {
void
BuildWithIds(const knowhere::DatasetPtr& dataset);
void
BuildWithoutIds(const knowhere::DatasetPtr& dataset);
private:
knowhere::VecIndexPtr index_ = nullptr;
std::string type_params_;
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "pb/index_cgo_msg.pb.h"
#include "exceptions/EasyAssert.h"
#include <google/protobuf/text_format.h>
#include <string>
#include <map>
namespace milvus::indexbuilder {
using MapParams = std::map<std::string, std::string>;
struct Helper {
static void
ParseFromString(google::protobuf::Message& params, const std::string& str) {
auto ok = google::protobuf::TextFormat::ParseFromString(str, &params);
AssertInfo(ok, "failed to parse params from string");
}
static void
ParseParams(google::protobuf::Message& params, const void* data, const size_t size) {
auto ok = params.ParseFromArray(data, size);
AssertInfo(ok, "failed to parse params from array");
}
};
} // namespace milvus::indexbuilder
......@@ -17,28 +17,23 @@
#endif
#include "exceptions/EasyAssert.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "indexbuilder/IndexWrapper.h"
#include "indexbuilder/VecIndexCreator.h"
#include "indexbuilder/index_c.h"
class CGODebugUtils {
public:
static int64_t
Strlen(const char* str, int64_t size) {
if (size == 0) {
return size;
} else {
return strlen(str);
}
}
};
#include "indexbuilder/IndexFactory.h"
CStatus
CreateIndex(const char* serialized_type_params, const char* serialized_index_params, CIndex* res_index) {
CreateIndex(DataType dtype,
const char* serialized_type_params,
const char* serialized_index_params,
CIndex* res_index) {
auto status = CStatus();
try {
auto index =
std::make_unique<milvus::indexbuilder::IndexWrapper>(serialized_type_params, serialized_index_params);
AssertInfo(res_index, "failed to create index, passed index was null");
auto index = milvus::indexbuilder::IndexFactory::GetInstance().CreateIndex(dtype, serialized_type_params,
serialized_index_params);
*res_index = index.release();
status.error_code = Success;
status.error_msg = "";
......@@ -49,24 +44,16 @@ CreateIndex(const char* serialized_type_params, const char* serialized_index_par
return status;
}
void
DeleteIndex(CIndex index) {
auto cIndex = (milvus::indexbuilder::IndexWrapper*)index;
delete cIndex;
#ifdef __linux__
malloc_trim(0);
#endif
}
CStatus
BuildFloatVecIndexWithoutIds(CIndex index, int64_t float_value_num, const float* vectors) {
DeleteIndex(CIndex index) {
auto status = CStatus();
try {
auto cIndex = (milvus::indexbuilder::IndexWrapper*)index;
auto dim = cIndex->dim();
auto row_nums = float_value_num / dim;
auto ds = milvus::knowhere::GenDataset(row_nums, dim, vectors);
cIndex->BuildWithoutIds(ds);
AssertInfo(index, "failed to delete index, passed index was null");
auto cIndex = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
delete cIndex;
#ifdef __linux__
malloc_trim(0);
#endif
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
......@@ -77,14 +64,16 @@ BuildFloatVecIndexWithoutIds(CIndex index, int64_t float_value_num, const float*
}
CStatus
BuildBinaryVecIndexWithoutIds(CIndex index, int64_t data_size, const uint8_t* vectors) {
BuildFloatVecIndex(CIndex index, int64_t float_value_num, const float* vectors) {
auto status = CStatus();
try {
auto cIndex = (milvus::indexbuilder::IndexWrapper*)index;
AssertInfo(index, "failed to build float vector index, passed index was null");
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex = dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
auto dim = cIndex->dim();
auto row_nums = (data_size * 8) / dim;
auto row_nums = float_value_num / dim;
auto ds = milvus::knowhere::GenDataset(row_nums, dim, vectors);
cIndex->BuildWithoutIds(ds);
cIndex->Build(ds);
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
......@@ -95,12 +84,16 @@ BuildBinaryVecIndexWithoutIds(CIndex index, int64_t data_size, const uint8_t* ve
}
CStatus
SerializeToBinarySet(CIndex index, CBinarySet* c_binary_set) {
BuildBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors) {
auto status = CStatus();
try {
auto cIndex = (milvus::indexbuilder::IndexWrapper*)index;
auto binary = cIndex->SerializeBinarySet();
*c_binary_set = binary.release();
AssertInfo(index, "failed to build binary vector index, passed index was null");
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex = dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
auto dim = cIndex->dim();
auto row_nums = (data_size * 8) / dim;
auto ds = milvus::knowhere::GenDataset(row_nums, dim, vectors);
cIndex->Build(ds);
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
......@@ -110,13 +103,22 @@ SerializeToBinarySet(CIndex index, CBinarySet* c_binary_set) {
return status;
}
// field_data:
// 1, serialized proto::schema::BoolArray, if type is bool;
// 2, serialized proto::schema::StringArray, if type is string;
// 3, raw pointer, if type is of fundamental except bool type;
// TODO: optimize here if necessary.
CStatus
SerializeToSlicedBuffer(CIndex index, CBinary* c_binary) {
BuildScalarIndex(CIndex c_index, int64_t size, const void* field_data) {
auto status = CStatus();
try {
auto cIndex = (milvus::indexbuilder::IndexWrapper*)index;
auto binary = cIndex->Serialize();
*c_binary = binary.release();
AssertInfo(c_index, "failed to build scalar index, passed index was null");
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(c_index);
const int64_t dim = 8; // not important here
auto dataset = milvus::knowhere::GenDataset(size, dim, field_data);
real_index->Build(dataset);
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
......@@ -126,31 +128,14 @@ SerializeToSlicedBuffer(CIndex index, CBinary* c_binary) {
return status;
}
int64_t
GetCBinarySize(CBinary c_binary) {
auto cBinary = (milvus::indexbuilder::IndexWrapper::Binary*)c_binary;
return cBinary->data.size();
}
// Note: the memory of data has been allocated outside
void
GetCBinaryData(CBinary c_binary, void* data) {
auto cBinary = (milvus::indexbuilder::IndexWrapper::Binary*)c_binary;
memcpy(data, cBinary->data.data(), cBinary->data.size());
}
void
DeleteCBinary(CBinary c_binary) {
auto cBinary = (milvus::indexbuilder::IndexWrapper::Binary*)c_binary;
delete cBinary;
}
CStatus
LoadFromSlicedBuffer(CIndex index, const char* serialized_sliced_blob_buffer, int32_t size) {
SerializeIndexToBinarySet(CIndex index, CBinarySet* c_binary_set) {
auto status = CStatus();
try {
auto cIndex = (milvus::indexbuilder::IndexWrapper*)index;
cIndex->Load(serialized_sliced_blob_buffer, size);
AssertInfo(index, "failed to serialize index to binary set, passed index was null");
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto binary = std::make_unique<milvus::knowhere::BinarySet>(real_index->Serialize());
*c_binary_set = binary.release();
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
......@@ -161,12 +146,13 @@ LoadFromSlicedBuffer(CIndex index, const char* serialized_sliced_blob_buffer, in
}
CStatus
LoadFromBinarySet(CIndex index, CBinarySet c_binary_set) {
LoadIndexFromBinarySet(CIndex index, CBinarySet c_binary_set) {
auto status = CStatus();
try {
auto cIndex = (milvus::indexbuilder::IndexWrapper*)index;
auto binary_set = (milvus::knowhere::BinarySet*)c_binary_set;
cIndex->LoadFromBinarySet(*binary_set);
AssertInfo(index, "failed to load index from binary set, passed index was null");
auto real_index = reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto binary_set = reinterpret_cast<milvus::knowhere::BinarySet*>(c_binary_set);
real_index->Load(*binary_set);
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
......@@ -180,7 +166,7 @@ CStatus
QueryOnFloatVecIndex(CIndex index, int64_t float_value_num, const float* vectors, CIndexQueryResult* res) {
auto status = CStatus();
try {
auto cIndex = (milvus::indexbuilder::IndexWrapper*)index;
auto cIndex = (milvus::indexbuilder::VecIndexCreator*)index;
auto dim = cIndex->dim();
auto row_nums = float_value_num / dim;
auto query_ds = milvus::knowhere::GenDataset(row_nums, dim, vectors);
......@@ -204,7 +190,7 @@ QueryOnFloatVecIndexWithParam(CIndex index,
CIndexQueryResult* res) {
auto status = CStatus();
try {
auto cIndex = (milvus::indexbuilder::IndexWrapper*)index;
auto cIndex = (milvus::indexbuilder::VecIndexCreator*)index;
auto dim = cIndex->dim();
auto row_nums = float_value_num / dim;
auto query_ds = milvus::knowhere::GenDataset(row_nums, dim, vectors);
......@@ -224,7 +210,7 @@ CStatus
QueryOnBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors, CIndexQueryResult* res) {
auto status = CStatus();
try {
auto cIndex = (milvus::indexbuilder::IndexWrapper*)index;
auto cIndex = (milvus::indexbuilder::VecIndexCreator*)index;
auto dim = cIndex->dim();
auto row_nums = (data_size * 8) / dim;
auto query_ds = milvus::knowhere::GenDataset(row_nums, dim, vectors);
......@@ -248,7 +234,7 @@ QueryOnBinaryVecIndexWithParam(CIndex index,
CIndexQueryResult* res) {
auto status = CStatus();
try {
auto cIndex = (milvus::indexbuilder::IndexWrapper*)index;
auto cIndex = (milvus::indexbuilder::VecIndexCreator*)index;
auto dim = cIndex->dim();
auto row_nums = (data_size * 8) / dim;
auto query_ds = milvus::knowhere::GenDataset(row_nums, dim, vectors);
......@@ -268,7 +254,7 @@ CStatus
CreateQueryResult(CIndexQueryResult* res) {
auto status = CStatus();
try {
auto query_result = std::make_unique<milvus::indexbuilder::IndexWrapper::QueryResult>();
auto query_result = std::make_unique<milvus::indexbuilder::VecIndexCreator::QueryResult>();
*res = query_result.release();
status.error_code = Success;
......@@ -282,19 +268,19 @@ CreateQueryResult(CIndexQueryResult* res) {
int64_t
NqOfQueryResult(CIndexQueryResult res) {
auto c_res = (milvus::indexbuilder::IndexWrapper::QueryResult*)res;
auto c_res = (milvus::indexbuilder::VecIndexCreator::QueryResult*)res;
return c_res->nq;
}
int64_t
TopkOfQueryResult(CIndexQueryResult res) {
auto c_res = (milvus::indexbuilder::IndexWrapper::QueryResult*)res;
auto c_res = (milvus::indexbuilder::VecIndexCreator::QueryResult*)res;
return c_res->topk;
}
void
GetIdsOfQueryResult(CIndexQueryResult res, int64_t* ids) {
auto c_res = (milvus::indexbuilder::IndexWrapper::QueryResult*)res;
auto c_res = (milvus::indexbuilder::VecIndexCreator::QueryResult*)res;
auto nq = c_res->nq;
auto k = c_res->topk;
// TODO: how could we avoid memory copy whenever this called
......@@ -303,7 +289,7 @@ GetIdsOfQueryResult(CIndexQueryResult res, int64_t* ids) {
void
GetDistancesOfQueryResult(CIndexQueryResult res, float* distances) {
auto c_res = (milvus::indexbuilder::IndexWrapper::QueryResult*)res;
auto c_res = (milvus::indexbuilder::VecIndexCreator::QueryResult*)res;
auto nq = c_res->nq;
auto k = c_res->topk;
// TODO: how could we avoid memory copy whenever this called
......@@ -314,7 +300,7 @@ CStatus
DeleteIndexQueryResult(CIndexQueryResult res) {
auto status = CStatus();
try {
auto c_res = (milvus::indexbuilder::IndexWrapper::QueryResult*)res;
auto c_res = (milvus::indexbuilder::VecIndexCreator::QueryResult*)res;
delete c_res;
status.error_code = Success;
......@@ -325,8 +311,3 @@ DeleteIndexQueryResult(CIndexQueryResult res) {
}
return status;
}
void
DeleteByteArray(const char* array) {
delete[] array;
}
......@@ -16,50 +16,38 @@ extern "C" {
#endif
#include <stdint.h>
#include "segcore/collection_c.h"
#include "common/type_c.h"
#include "common/vector_index_c.h"
typedef void* CIndex;
typedef void* CIndexQueryResult;
typedef void* CBinary;
// TODO: how could we pass map between go and c++ more efficiently?
// Solution: using Protobuf instead of JSON, this way significantly increase programming efficiency
#include "indexbuilder/type_c.h"
CStatus
CreateIndex(const char* serialized_type_params, const char* serialized_index_params, CIndex* res_index);
void
DeleteIndex(CIndex index);
CreateIndex(enum DataType dtype,
const char* serialized_type_params,
const char* serialized_index_params,
CIndex* res_index);
CStatus
BuildFloatVecIndexWithoutIds(CIndex index, int64_t float_value_num, const float* vectors);
DeleteIndex(CIndex index);
CStatus
BuildBinaryVecIndexWithoutIds(CIndex index, int64_t data_size, const uint8_t* vectors);
BuildFloatVecIndex(CIndex index, int64_t float_value_num, const float* vectors);
CStatus
SerializeToSlicedBuffer(CIndex index, CBinary* c_binary);
BuildBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors);
// field_data:
// 1, serialized proto::schema::BoolArray, if type is bool;
// 2, serialized proto::schema::StringArray, if type is string;
// 3, raw pointer, if type is of fundamental except bool type;
// TODO: optimize here if necessary.
CStatus
SerializeToBinarySet(CIndex index, CBinarySet* c_binary_set);
int64_t
GetCBinarySize(CBinary c_binary);
// Note: the memory of data is allocated outside
void
GetCBinaryData(CBinary c_binary, void* data);
void
DeleteCBinary(CBinary c_binary);
BuildScalarIndex(CIndex c_index, int64_t size, const void* field_data);
CStatus
LoadFromSlicedBuffer(CIndex index, const char* serialized_sliced_blob_buffer, int32_t size);
SerializeIndexToBinarySet(CIndex index, CBinarySet* c_binary_set);
CStatus
LoadFromBinarySet(CIndex index, CBinarySet c_binary_set);
LoadIndexFromBinarySet(CIndex index, CBinarySet c_binary_set);
CStatus
QueryOnFloatVecIndex(CIndex index, int64_t float_value_num, const float* vectors, CIndexQueryResult* res);
......@@ -99,9 +87,6 @@ GetDistancesOfQueryResult(CIndexQueryResult res, float* distances);
CStatus
DeleteIndexQueryResult(CIndexQueryResult res);
void
DeleteByteArray(const char* array);
#ifdef __cplusplus
};
#endif
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
// pure C don't support that we use schemapb.DataType directly.
// Note: the value of all enumerations must match the corresponding schemapb.DataType.
// TODO: what if there are increments in schemapb.DataType.
enum DataType {
None = 0,
Bool = 1,
Int8 = 2,
Int16 = 3,
Int32 = 4,
Int64 = 5,
Float = 10,
Double = 11,
String = 20,
BinaryVector = 100,
FloatVector = 101,
};
typedef void* CIndex;
typedef void* CIndexQueryResult;
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "indexbuilder/utils.h"
#include <algorithm>
#include <string>
#include <tuple>
#include <vector>
#include <knowhere/index/vector_index/helpers/IndexParameter.h>
#include "knowhere/index/IndexType.h"
namespace milvus::indexbuilder {
std::vector<std::string>
NM_List() {
static std::vector<std::string> ret{
milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT,
milvus::knowhere::IndexEnum::INDEX_NSG,
milvus::knowhere::IndexEnum::INDEX_RHNSWFlat,
};
return ret;
}
std::vector<std::string>
BIN_List() {
static std::vector<std::string> ret{
milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP,
milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT,
};
return ret;
}
std::vector<std::string>
Need_ID_List() {
static std::vector<std::string> ret{
// milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT,
// milvus::knowhere::IndexEnum::INDEX_NSG,
};
return ret;
}
std::vector<std::string>
Need_BuildAll_list() {
static std::vector<std::string> ret{
milvus::knowhere::IndexEnum::INDEX_NSG,
};
return ret;
}
std::vector<std::tuple<std::string, std::string>>
unsupported_index_combinations() {
static std::vector<std::tuple<std::string, std::string>> ret{
std::make_tuple(std::string(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT), std::string(knowhere::Metric::L2)),
};
return ret;
}
template <typename T>
bool
is_in_list(const T& t, std::function<std::vector<T>()> list_func) {
auto l = list_func();
return std::find(l.begin(), l.end(), t) != l.end();
}
bool
is_in_bin_list(const milvus::knowhere::IndexType& index_type) {
return is_in_list<std::string>(index_type, BIN_List);
}
bool
is_in_nm_list(const milvus::knowhere::IndexType& index_type) {
return is_in_list<std::string>(index_type, NM_List);
}
bool
is_in_need_build_all_list(const milvus::knowhere::IndexType& index_type) {
return is_in_list<std::string>(index_type, Need_BuildAll_list);
}
bool
is_in_need_id_list(const milvus::knowhere::IndexType& index_type) {
return is_in_list<std::string>(index_type, Need_ID_List);
}
bool
is_unsupported(const milvus::knowhere::IndexType& index_type, const milvus::knowhere::MetricType& metric_type) {
return is_in_list<std::tuple<std::string, std::string>>(std::make_tuple(index_type, metric_type),
unsupported_index_combinations);
}
} // namespace milvus::indexbuilder
......@@ -15,87 +15,45 @@
#include <string>
#include <tuple>
#include <vector>
#include <functional>
#include <knowhere/common/Typedef.h>
#include "knowhere/index/IndexType.h"
namespace milvus::indexbuilder {
std::vector<std::string>
NM_List() {
static std::vector<std::string> ret{
milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT,
milvus::knowhere::IndexEnum::INDEX_NSG,
milvus::knowhere::IndexEnum::INDEX_RHNSWFlat,
};
return ret;
}
NM_List();
std::vector<std::string>
BIN_List() {
static std::vector<std::string> ret{
milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP,
milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT,
};
return ret;
}
BIN_List();
std::vector<std::string>
Need_ID_List() {
static std::vector<std::string> ret{
// milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT,
// milvus::knowhere::IndexEnum::INDEX_NSG,
};
return ret;
}
Need_ID_List();
std::vector<std::string>
Need_BuildAll_list() {
static std::vector<std::string> ret{
milvus::knowhere::IndexEnum::INDEX_NSG,
};
return ret;
}
Need_BuildAll_list();
std::vector<std::tuple<std::string, std::string>>
unsupported_index_combinations() {
static std::vector<std::tuple<std::string, std::string>> ret{
std::make_tuple(std::string(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT), std::string(knowhere::Metric::L2)),
};
return ret;
}
unsupported_index_combinations();
template <typename T>
bool
is_in_list(const T& t, std::function<std::vector<T>()> list_func) {
auto l = list_func();
return std::find(l.begin(), l.end(), t) != l.end();
}
is_in_list(const T& t, std::function<std::vector<T>()> list_func);
bool
is_in_bin_list(const milvus::knowhere::IndexType& index_type) {
return is_in_list<std::string>(index_type, BIN_List);
}
is_in_bin_list(const milvus::knowhere::IndexType& index_type);
bool
is_in_nm_list(const milvus::knowhere::IndexType& index_type) {
return is_in_list<std::string>(index_type, NM_List);
}
is_in_nm_list(const milvus::knowhere::IndexType& index_type);
bool
is_in_need_build_all_list(const milvus::knowhere::IndexType& index_type) {
return is_in_list<std::string>(index_type, Need_BuildAll_list);
}
is_in_need_build_all_list(const milvus::knowhere::IndexType& index_type);
bool
is_in_need_id_list(const milvus::knowhere::IndexType& index_type) {
return is_in_list<std::string>(index_type, Need_ID_List);
}
is_in_need_id_list(const milvus::knowhere::IndexType& index_type);
bool
is_unsupported(const milvus::knowhere::IndexType& index_type, const milvus::knowhere::MetricType& metric_type) {
return is_in_list<std::tuple<std::string, std::string>>(std::make_tuple(index_type, metric_type),
unsupported_index_combinations);
}
is_unsupported(const milvus::knowhere::IndexType& index_type, const milvus::knowhere::MetricType& metric_type);
} // namespace milvus::indexbuilder
......@@ -31,6 +31,9 @@
#include <google/protobuf/message.h>
#include <google/protobuf/repeated_field.h> // IWYU pragma: export
#include <google/protobuf/extension_set.h> // IWYU pragma: export
#include <google/protobuf/map.h> // IWYU pragma: export
#include <google/protobuf/map_entry.h>
#include <google/protobuf/map_field_inl.h>
#include <google/protobuf/unknown_field_set.h>
#include "common.pb.h"
// @@protoc_insertion_point(includes)
......@@ -48,7 +51,7 @@ struct TableStruct_index_5fcgo_5fmsg_2eproto {
PROTOBUF_SECTION_VARIABLE(protodesc_cold);
static const ::PROTOBUF_NAMESPACE_ID::internal::AuxillaryParseTableField aux[]
PROTOBUF_SECTION_VARIABLE(protodesc_cold);
static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[5]
static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[7]
PROTOBUF_SECTION_VARIABLE(protodesc_cold);
static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[];
static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[];
......@@ -70,6 +73,12 @@ extern IndexParamsDefaultTypeInternal _IndexParams_default_instance_;
class MapParams;
class MapParamsDefaultTypeInternal;
extern MapParamsDefaultTypeInternal _MapParams_default_instance_;
class MapParamsV2;
class MapParamsV2DefaultTypeInternal;
extern MapParamsV2DefaultTypeInternal _MapParamsV2_default_instance_;
class MapParamsV2_ParamsEntry_DoNotUse;
class MapParamsV2_ParamsEntry_DoNotUseDefaultTypeInternal;
extern MapParamsV2_ParamsEntry_DoNotUseDefaultTypeInternal _MapParamsV2_ParamsEntry_DoNotUse_default_instance_;
class TypeParams;
class TypeParamsDefaultTypeInternal;
extern TypeParamsDefaultTypeInternal _TypeParams_default_instance_;
......@@ -81,6 +90,8 @@ template<> ::milvus::proto::indexcgo::Binary* Arena::CreateMaybeMessage<::milvus
template<> ::milvus::proto::indexcgo::BinarySet* Arena::CreateMaybeMessage<::milvus::proto::indexcgo::BinarySet>(Arena*);
template<> ::milvus::proto::indexcgo::IndexParams* Arena::CreateMaybeMessage<::milvus::proto::indexcgo::IndexParams>(Arena*);
template<> ::milvus::proto::indexcgo::MapParams* Arena::CreateMaybeMessage<::milvus::proto::indexcgo::MapParams>(Arena*);
template<> ::milvus::proto::indexcgo::MapParamsV2* Arena::CreateMaybeMessage<::milvus::proto::indexcgo::MapParamsV2>(Arena*);
template<> ::milvus::proto::indexcgo::MapParamsV2_ParamsEntry_DoNotUse* Arena::CreateMaybeMessage<::milvus::proto::indexcgo::MapParamsV2_ParamsEntry_DoNotUse>(Arena*);
template<> ::milvus::proto::indexcgo::TypeParams* Arena::CreateMaybeMessage<::milvus::proto::indexcgo::TypeParams>(Arena*);
PROTOBUF_NAMESPACE_CLOSE
namespace milvus {
......@@ -500,6 +511,180 @@ class MapParams :
};
// -------------------------------------------------------------------
class MapParamsV2_ParamsEntry_DoNotUse : public ::PROTOBUF_NAMESPACE_ID::internal::MapEntry<MapParamsV2_ParamsEntry_DoNotUse,
std::string, std::string,
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_STRING,
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_STRING,
0 > {
public:
typedef ::PROTOBUF_NAMESPACE_ID::internal::MapEntry<MapParamsV2_ParamsEntry_DoNotUse,
std::string, std::string,
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_STRING,
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_STRING,
0 > SuperType;
MapParamsV2_ParamsEntry_DoNotUse();
MapParamsV2_ParamsEntry_DoNotUse(::PROTOBUF_NAMESPACE_ID::Arena* arena);
void MergeFrom(const MapParamsV2_ParamsEntry_DoNotUse& other);
static const MapParamsV2_ParamsEntry_DoNotUse* internal_default_instance() { return reinterpret_cast<const MapParamsV2_ParamsEntry_DoNotUse*>(&_MapParamsV2_ParamsEntry_DoNotUse_default_instance_); }
static bool ValidateKey(std::string* s) {
return ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String(s->data(), s->size(), ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::PARSE, "milvus.proto.indexcgo.MapParamsV2.ParamsEntry.key");
}
static bool ValidateValue(std::string* s) {
return ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String(s->data(), s->size(), ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::PARSE, "milvus.proto.indexcgo.MapParamsV2.ParamsEntry.value");
}
void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& other) final;
::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final;
private:
static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() {
::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_index_5fcgo_5fmsg_2eproto);
return ::descriptor_table_index_5fcgo_5fmsg_2eproto.file_level_metadata[3];
}
public:
};
// -------------------------------------------------------------------
class MapParamsV2 :
public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:milvus.proto.indexcgo.MapParamsV2) */ {
public:
MapParamsV2();
virtual ~MapParamsV2();
MapParamsV2(const MapParamsV2& from);
MapParamsV2(MapParamsV2&& from) noexcept
: MapParamsV2() {
*this = ::std::move(from);
}
inline MapParamsV2& operator=(const MapParamsV2& from) {
CopyFrom(from);
return *this;
}
inline MapParamsV2& operator=(MapParamsV2&& from) noexcept {
if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) {
if (this != &from) InternalSwap(&from);
} else {
CopyFrom(from);
}
return *this;
}
static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() {
return GetDescriptor();
}
static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() {
return GetMetadataStatic().descriptor;
}
static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() {
return GetMetadataStatic().reflection;
}
static const MapParamsV2& default_instance();
static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY
static inline const MapParamsV2* internal_default_instance() {
return reinterpret_cast<const MapParamsV2*>(
&_MapParamsV2_default_instance_);
}
static constexpr int kIndexInFileMessages =
4;
friend void swap(MapParamsV2& a, MapParamsV2& b) {
a.Swap(&b);
}
inline void Swap(MapParamsV2* other) {
if (other == this) return;
InternalSwap(other);
}
// implements Message ----------------------------------------------
inline MapParamsV2* New() const final {
return CreateMaybeMessage<MapParamsV2>(nullptr);
}
MapParamsV2* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final {
return CreateMaybeMessage<MapParamsV2>(arena);
}
void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final;
void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final;
void CopyFrom(const MapParamsV2& from);
void MergeFrom(const MapParamsV2& from);
PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final;
bool IsInitialized() const final;
size_t ByteSizeLong() const final;
#if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER
const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final;
#else
bool MergePartialFromCodedStream(
::PROTOBUF_NAMESPACE_ID::io::CodedInputStream* input) final;
#endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER
void SerializeWithCachedSizes(
::PROTOBUF_NAMESPACE_ID::io::CodedOutputStream* output) const final;
::PROTOBUF_NAMESPACE_ID::uint8* InternalSerializeWithCachedSizesToArray(
::PROTOBUF_NAMESPACE_ID::uint8* target) const final;
int GetCachedSize() const final { return _cached_size_.Get(); }
private:
inline void SharedCtor();
inline void SharedDtor();
void SetCachedSize(int size) const final;
void InternalSwap(MapParamsV2* other);
friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata;
static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() {
return "milvus.proto.indexcgo.MapParamsV2";
}
private:
inline ::PROTOBUF_NAMESPACE_ID::Arena* GetArenaNoVirtual() const {
return nullptr;
}
inline void* MaybeArenaPtr() const {
return nullptr;
}
public:
::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final;
private:
static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() {
::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_index_5fcgo_5fmsg_2eproto);
return ::descriptor_table_index_5fcgo_5fmsg_2eproto.file_level_metadata[kIndexInFileMessages];
}
public:
// nested types ----------------------------------------------------
// accessors -------------------------------------------------------
enum : int {
kParamsFieldNumber = 1,
};
// map<string, string> params = 1;
int params_size() const;
void clear_params();
const ::PROTOBUF_NAMESPACE_ID::Map< std::string, std::string >&
params() const;
::PROTOBUF_NAMESPACE_ID::Map< std::string, std::string >*
mutable_params();
// @@protoc_insertion_point(class_scope:milvus.proto.indexcgo.MapParamsV2)
private:
class _Internal;
::PROTOBUF_NAMESPACE_ID::internal::InternalMetadataWithArena _internal_metadata_;
::PROTOBUF_NAMESPACE_ID::internal::MapField<
MapParamsV2_ParamsEntry_DoNotUse,
std::string, std::string,
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_STRING,
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_STRING,
0 > params_;
mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_;
friend struct ::TableStruct_index_5fcgo_5fmsg_2eproto;
};
// -------------------------------------------------------------------
class Binary :
public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:milvus.proto.indexcgo.Binary) */ {
public:
......@@ -542,7 +727,7 @@ class Binary :
&_Binary_default_instance_);
}
static constexpr int kIndexInFileMessages =
3;
5;
friend void swap(Binary& a, Binary& b) {
a.Swap(&b);
......@@ -692,7 +877,7 @@ class BinarySet :
&_BinarySet_default_instance_);
}
static constexpr int kIndexInFileMessages =
4;
6;
friend void swap(BinarySet& a, BinarySet& b) {
a.Swap(&b);
......@@ -887,6 +1072,30 @@ MapParams::params() const {
// -------------------------------------------------------------------
// -------------------------------------------------------------------
// MapParamsV2
// map<string, string> params = 1;
inline int MapParamsV2::params_size() const {
return params_.size();
}
inline void MapParamsV2::clear_params() {
params_.Clear();
}
inline const ::PROTOBUF_NAMESPACE_ID::Map< std::string, std::string >&
MapParamsV2::params() const {
// @@protoc_insertion_point(field_map:milvus.proto.indexcgo.MapParamsV2.params)
return params_.GetMap();
}
inline ::PROTOBUF_NAMESPACE_ID::Map< std::string, std::string >*
MapParamsV2::mutable_params() {
// @@protoc_insertion_point(field_mutable_map:milvus.proto.indexcgo.MapParamsV2.params)
return params_.MutableMap();
}
// -------------------------------------------------------------------
// Binary
// string key = 1;
......@@ -1036,6 +1245,10 @@ BinarySet::datas() const {
// -------------------------------------------------------------------
// -------------------------------------------------------------------
// -------------------------------------------------------------------
// @@protoc_insertion_point(namespace_scope)
......
此差异已折叠。
......@@ -5531,6 +5531,7 @@ class CreateIndexRequest :
kDbNameFieldNumber = 2,
kCollectionNameFieldNumber = 3,
kFieldNameFieldNumber = 4,
kIndexNameFieldNumber = 6,
kBaseFieldNumber = 1,
};
// repeated .milvus.proto.common.KeyValuePair extra_params = 5;
......@@ -5577,6 +5578,17 @@ class CreateIndexRequest :
std::string* release_field_name();
void set_allocated_field_name(std::string* field_name);
// string index_name = 6;
void clear_index_name();
const std::string& index_name() const;
void set_index_name(const std::string& value);
void set_index_name(std::string&& value);
void set_index_name(const char* value);
void set_index_name(const char* value, size_t size);
std::string* mutable_index_name();
std::string* release_index_name();
void set_allocated_index_name(std::string* index_name);
// .milvus.proto.common.MsgBase base = 1;
bool has_base() const;
void clear_base();
......@@ -5594,6 +5606,7 @@ class CreateIndexRequest :
::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr db_name_;
::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr collection_name_;
::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr field_name_;
::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr index_name_;
::milvus::proto::common::MsgBase* base_;
mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_;
friend struct ::TableStruct_milvus_2eproto;
......@@ -18997,6 +19010,57 @@ CreateIndexRequest::extra_params() const {
return extra_params_;
}
// string index_name = 6;
inline void CreateIndexRequest::clear_index_name() {
index_name_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
}
inline const std::string& CreateIndexRequest::index_name() const {
// @@protoc_insertion_point(field_get:milvus.proto.milvus.CreateIndexRequest.index_name)
return index_name_.GetNoArena();
}
inline void CreateIndexRequest::set_index_name(const std::string& value) {
index_name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value);
// @@protoc_insertion_point(field_set:milvus.proto.milvus.CreateIndexRequest.index_name)
}
inline void CreateIndexRequest::set_index_name(std::string&& value) {
index_name_.SetNoArena(
&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value));
// @@protoc_insertion_point(field_set_rvalue:milvus.proto.milvus.CreateIndexRequest.index_name)
}
inline void CreateIndexRequest::set_index_name(const char* value) {
GOOGLE_DCHECK(value != nullptr);
index_name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value));
// @@protoc_insertion_point(field_set_char:milvus.proto.milvus.CreateIndexRequest.index_name)
}
inline void CreateIndexRequest::set_index_name(const char* value, size_t size) {
index_name_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(),
::std::string(reinterpret_cast<const char*>(value), size));
// @@protoc_insertion_point(field_set_pointer:milvus.proto.milvus.CreateIndexRequest.index_name)
}
inline std::string* CreateIndexRequest::mutable_index_name() {
// @@protoc_insertion_point(field_mutable:milvus.proto.milvus.CreateIndexRequest.index_name)
return index_name_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
}
inline std::string* CreateIndexRequest::release_index_name() {
// @@protoc_insertion_point(field_release:milvus.proto.milvus.CreateIndexRequest.index_name)
return index_name_.ReleaseNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
}
inline void CreateIndexRequest::set_allocated_index_name(std::string* index_name) {
if (index_name != nullptr) {
} else {
}
index_name_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), index_name);
// @@protoc_insertion_point(field_set_allocated:milvus.proto.milvus.CreateIndexRequest.index_name)
}
// -------------------------------------------------------------------
// DescribeIndexRequest
......@@ -15,6 +15,7 @@ include_directories(${CMAKE_HOME_DIRECTORY}/src/thirdparty)
add_definitions(-DMILVUS_TEST_SEGCORE_YAML_PATH="${CMAKE_SOURCE_DIR}/unittest/test_utils/test_segcore.yaml")
if (LINUX)
# TODO: better to use ls/find pattern
set(MILVUS_TEST_FILES
init_gtest.cpp
test_binary.cpp
......@@ -38,10 +39,14 @@ if (LINUX)
test_conf_adapter_mgr.cpp
test_similarity_corelation.cpp
test_utils.cpp
test_scalar_index_creator.cpp
test_index_c_api.cpp
)
# check if memory leak exists in index builder
set(INDEX_BUILDER_TEST_FILES
test_index_wrapper.cpp
test_scalar_index_creator.cpp
test_index_c_api.cpp
)
add_executable(index_builder_test
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
#include <map>
#include <tuple>
#include <knowhere/index/vector_index/helpers/IndexParameter.h>
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
#include <knowhere/index/vector_index/ConfAdapterMgr.h>
#include <knowhere/archive/KnowhereConfig.h>
#include "pb/index_cgo_msg.pb.h"
#include "indexbuilder/VecIndexCreator.h"
#include "indexbuilder/index_c.h"
#include "indexbuilder/utils.h"
#include "test_utils/DataGen.h"
#include "test_utils/indexbuilder_test_utils.h"
#include "indexbuilder/ScalarIndexCreator.h"
#include "indexbuilder/IndexFactory.h"
constexpr int NB = 10;
TEST(FloatVecIndex, All) {
auto index_type = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ;
auto metric_type = milvus::knowhere::Metric::L2;
indexcgo::TypeParams type_params;
indexcgo::IndexParams index_params;
std::tie(type_params, index_params) = generate_params(index_type, metric_type);
std::string type_params_str, index_params_str;
bool ok;
ok = google::protobuf::TextFormat::PrintToString(type_params, &type_params_str);
assert(ok);
ok = google::protobuf::TextFormat::PrintToString(index_params, &index_params_str);
assert(ok);
auto dataset = GenDataset(NB, metric_type, false);
auto xb_data = dataset.get_col<float>(0);
DataType dtype = FloatVector;
CIndex index;
CStatus status;
CBinarySet binary_set;
CIndex copy_index;
{
status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), &index);
ASSERT_EQ(Success, status.error_code);
}
{
status = BuildFloatVecIndex(index, NB * DIM, xb_data.data());
ASSERT_EQ(Success, status.error_code);
}
{
status = SerializeIndexToBinarySet(index, &binary_set);
ASSERT_EQ(Success, status.error_code);
}
{
status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), &copy_index);
ASSERT_EQ(Success, status.error_code);
}
{
status = LoadIndexFromBinarySet(copy_index, binary_set);
ASSERT_EQ(Success, status.error_code);
}
{
status = DeleteIndex(index);
ASSERT_EQ(Success, status.error_code);
}
{
status = DeleteIndex(copy_index);
ASSERT_EQ(Success, status.error_code);
}
}
TEST(BinaryVecIndex, All) {
auto index_type = milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT;
auto metric_type = milvus::knowhere::Metric::JACCARD;
indexcgo::TypeParams type_params;
indexcgo::IndexParams index_params;
std::tie(type_params, index_params) = generate_params(index_type, metric_type);
std::string type_params_str, index_params_str;
bool ok;
ok = google::protobuf::TextFormat::PrintToString(type_params, &type_params_str);
assert(ok);
ok = google::protobuf::TextFormat::PrintToString(index_params, &index_params_str);
assert(ok);
auto dataset = GenDataset(NB, metric_type, true);
auto xb_data = dataset.get_col<uint8_t>(0);
DataType dtype = BinaryVector;
CIndex index;
CStatus status;
CBinarySet binary_set;
CIndex copy_index;
{
status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), &index);
ASSERT_EQ(Success, status.error_code);
}
{
status = BuildBinaryVecIndex(index, NB * DIM / 8, xb_data.data());
ASSERT_EQ(Success, status.error_code);
}
{
status = SerializeIndexToBinarySet(index, &binary_set);
ASSERT_EQ(Success, status.error_code);
}
{
status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), &copy_index);
ASSERT_EQ(Success, status.error_code);
}
{
status = LoadIndexFromBinarySet(copy_index, binary_set);
ASSERT_EQ(Success, status.error_code);
}
{
status = DeleteIndex(index);
ASSERT_EQ(Success, status.error_code);
}
{
status = DeleteIndex(copy_index);
ASSERT_EQ(Success, status.error_code);
}
}
TEST(CBoolIndexTest, All) {
schemapb::BoolArray half;
milvus::knowhere::DatasetPtr half_ds;
for (size_t i = 0; i < NB; i++) {
*(half.mutable_data()->Add()) = (i % 2) == 0;
}
half_ds = GenDsFromPB(half);
auto params = GenBoolParams();
for (const auto& tp : params) {
auto type_params = tp.first;
auto index_params = tp.second;
auto type_params_str = generate_type_params(type_params);
auto index_params_str = generate_index_params(index_params);
DataType dtype = Bool;
CIndex index;
CStatus status;
CBinarySet binary_set;
CIndex copy_index;
{
status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), &index);
ASSERT_EQ(Success, status.error_code);
}
{
status = BuildScalarIndex(index, half_ds->Get<int64_t>(milvus::knowhere::meta::ROWS),
half_ds->Get<const void*>(milvus::knowhere::meta::TENSOR));
ASSERT_EQ(Success, status.error_code);
}
{
status = SerializeIndexToBinarySet(index, &binary_set);
ASSERT_EQ(Success, status.error_code);
}
{
status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), &copy_index);
ASSERT_EQ(Success, status.error_code);
}
{
status = LoadIndexFromBinarySet(copy_index, binary_set);
ASSERT_EQ(Success, status.error_code);
}
{
status = DeleteIndex(index);
ASSERT_EQ(Success, status.error_code);
}
{
status = DeleteIndex(copy_index);
ASSERT_EQ(Success, status.error_code);
}
}
delete[](char*) half_ds->Get<const void*>(milvus::knowhere::meta::TENSOR);
}
// TODO: more scalar type.
TEST(CInt64IndexTest, All) {
auto arr = GenArr<int64_t>(NB);
auto params = GenParams<int64_t>();
for (const auto& tp : params) {
auto type_params = tp.first;
auto index_params = tp.second;
auto type_params_str = generate_type_params(type_params);
auto index_params_str = generate_index_params(index_params);
DataType dtype = Int64;
CIndex index;
CStatus status;
CBinarySet binary_set;
CIndex copy_index;
{
status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), &index);
ASSERT_EQ(Success, status.error_code);
}
{
status = BuildScalarIndex(index, arr.size(), arr.data());
ASSERT_EQ(Success, status.error_code);
}
{
status = SerializeIndexToBinarySet(index, &binary_set);
ASSERT_EQ(Success, status.error_code);
}
{
status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), &copy_index);
ASSERT_EQ(Success, status.error_code);
}
{
status = LoadIndexFromBinarySet(copy_index, binary_set);
ASSERT_EQ(Success, status.error_code);
}
{
status = DeleteIndex(index);
ASSERT_EQ(Success, status.error_code);
}
{
status = DeleteIndex(copy_index);
ASSERT_EQ(Success, status.error_code);
}
}
}
TEST(CStringIndexTest, All) {
auto strs = GenStrArr(NB);
schemapb::StringArray str_arr;
*str_arr.mutable_data() = {strs.begin(), strs.end()};
auto str_ds = GenDsFromPB(str_arr);
auto params = GenStringParams();
for (const auto& tp : params) {
auto type_params = tp.first;
auto index_params = tp.second;
auto type_params_str = generate_type_params(type_params);
auto index_params_str = generate_index_params(index_params);
DataType dtype = String;
CIndex index;
CStatus status;
CBinarySet binary_set;
CIndex copy_index;
{
status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), &index);
ASSERT_EQ(Success, status.error_code);
}
{
status = BuildScalarIndex(index, str_ds->Get<int64_t>(milvus::knowhere::meta::ROWS),
str_ds->Get<const void*>(milvus::knowhere::meta::TENSOR));
ASSERT_EQ(Success, status.error_code);
}
{
status = SerializeIndexToBinarySet(index, &binary_set);
ASSERT_EQ(Success, status.error_code);
}
{
status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), &copy_index);
ASSERT_EQ(Success, status.error_code);
}
{
status = LoadIndexFromBinarySet(copy_index, binary_set);
ASSERT_EQ(Success, status.error_code);
}
{
status = DeleteIndex(index);
ASSERT_EQ(Success, status.error_code);
}
{
status = DeleteIndex(copy_index);
ASSERT_EQ(Success, status.error_code);
}
}
delete[](char*) str_ds->Get<const void*>(milvus::knowhere::meta::TENSOR);
}
......@@ -18,7 +18,7 @@
#include <knowhere/index/vector_index/ConfAdapterMgr.h>
#include <knowhere/archive/KnowhereConfig.h>
#include "indexbuilder/IndexWrapper.h"
#include "indexbuilder/VecIndexCreator.h"
#include "indexbuilder/index_c.h"
#include "indexbuilder/utils.h"
#include "pb/index_cgo_msg.pb.h"
......@@ -191,7 +191,7 @@ TEST(BINFLAT, Build) {
}
void
print_query_result(const std::unique_ptr<milvus::indexbuilder::IndexWrapper::QueryResult>& result) {
print_query_result(const std::unique_ptr<milvus::indexbuilder::VecIndexCreator::QueryResult>& result) {
for (auto i = 0; i < result->nq; i++) {
printf("result of %dth query:\n", i);
for (auto j = 0; j < result->topk; j++) {
......@@ -228,7 +228,7 @@ TEST(BinIVFFlat, Build_and_Query) {
auto hit_ids = result->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto distances = result->Get<float*>(milvus::knowhere::meta::DISTANCE);
auto query_res = std::make_unique<milvus::indexbuilder::IndexWrapper::QueryResult>();
auto query_res = std::make_unique<milvus::indexbuilder::VecIndexCreator::QueryResult>();
query_res->nq = nq;
query_res->topk = topk;
query_res->ids.resize(nq * topk);
......@@ -275,7 +275,7 @@ TEST(PQWrapper, Build) {
auto xb_data = dataset.get_col<float>(0);
auto xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_data.data());
auto index =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
std::make_unique<milvus::indexbuilder::VecIndexCreator>(type_params_str.c_str(), index_params_str.c_str());
ASSERT_NO_THROW(index->BuildWithoutIds(xb_dataset));
}
......@@ -295,7 +295,7 @@ TEST(IVFFLATNMWrapper, Build) {
auto xb_data = dataset.get_col<float>(0);
auto xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_data.data());
auto index =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
std::make_unique<milvus::indexbuilder::VecIndexCreator>(type_params_str.c_str(), index_params_str.c_str());
ASSERT_NO_THROW(index->BuildWithoutIds(xb_dataset));
}
......@@ -316,15 +316,22 @@ TEST(IVFFLATNMWrapper, Codec) {
auto xb_data = dataset.get_col<float>(0);
auto xb_dataset = milvus::knowhere::GenDataset(flat_nb, DIM, xb_data.data());
auto index_wrapper =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
std::make_unique<milvus::indexbuilder::VecIndexCreator>(type_params_str.c_str(), index_params_str.c_str());
ASSERT_NO_THROW(index_wrapper->BuildWithoutIds(xb_dataset));
auto binary = index_wrapper->Serialize();
auto binary_set = index_wrapper->Serialize();
auto copy_index_wrapper =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
ASSERT_NO_THROW(copy_index_wrapper->Load(binary->data.data(), binary->data.size()));
std::make_unique<milvus::indexbuilder::VecIndexCreator>(type_params_str.c_str(), index_params_str.c_str());
ASSERT_NO_THROW(copy_index_wrapper->Load(binary_set));
ASSERT_EQ(copy_index_wrapper->dim(), copy_index_wrapper->dim());
auto copy_binary = copy_index_wrapper->Serialize();
auto copy_binary_set = copy_index_wrapper->Serialize();
ASSERT_EQ(binary_set.binary_map_.size(), copy_binary_set.binary_map_.size());
for (const auto& [k, v] : binary_set.binary_map_) {
ASSERT_TRUE(copy_binary_set.Contains(k));
}
}
TEST(BinFlatWrapper, Build) {
......@@ -345,7 +352,7 @@ TEST(BinFlatWrapper, Build) {
std::iota(ids.begin(), ids.end(), 0);
auto xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_data.data());
auto index =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
std::make_unique<milvus::indexbuilder::VecIndexCreator>(type_params_str.c_str(), index_params_str.c_str());
ASSERT_NO_THROW(index->BuildWithoutIds(xb_dataset));
// ASSERT_NO_THROW(index->BuildWithIds(xb_dataset));
}
......@@ -368,7 +375,7 @@ TEST(BinIdMapWrapper, Build) {
std::iota(ids.begin(), ids.end(), 0);
auto xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_data.data());
auto index =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
std::make_unique<milvus::indexbuilder::VecIndexCreator>(type_params_str.c_str(), index_params_str.c_str());
ASSERT_NO_THROW(index->BuildWithoutIds(xb_dataset));
// ASSERT_NO_THROW(index->BuildWithIds(xb_dataset));
}
......@@ -399,48 +406,50 @@ INSTANTIATE_TEST_CASE_P(
TEST_P(IndexWrapperTest, Constructor) {
auto index =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
std::make_unique<milvus::indexbuilder::VecIndexCreator>(type_params_str.c_str(), index_params_str.c_str());
}
TEST_P(IndexWrapperTest, Dim) {
auto index =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
std::make_unique<milvus::indexbuilder::VecIndexCreator>(type_params_str.c_str(), index_params_str.c_str());
ASSERT_EQ(index->dim(), DIM);
}
TEST_P(IndexWrapperTest, BuildWithoutIds) {
auto index =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
std::make_unique<milvus::indexbuilder::VecIndexCreator>(type_params_str.c_str(), index_params_str.c_str());
ASSERT_NO_THROW(index->BuildWithoutIds(xb_dataset));
}
TEST_P(IndexWrapperTest, Codec) {
auto index_wrapper =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
std::make_unique<milvus::indexbuilder::VecIndexCreator>(type_params_str.c_str(), index_params_str.c_str());
ASSERT_NO_THROW(index_wrapper->BuildWithoutIds(xb_dataset));
auto binary = index_wrapper->Serialize();
auto binary_set = index_wrapper->Serialize();
auto copy_index_wrapper =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
ASSERT_NO_THROW(copy_index_wrapper->Load(binary->data.data(), binary->data.size()));
std::make_unique<milvus::indexbuilder::VecIndexCreator>(type_params_str.c_str(), index_params_str.c_str());
ASSERT_NO_THROW(copy_index_wrapper->Load(binary_set));
ASSERT_EQ(copy_index_wrapper->dim(), copy_index_wrapper->dim());
auto copy_binary = copy_index_wrapper->Serialize();
if (!milvus::indexbuilder::is_in_nm_list(index_type)) {
// binary may be not same due to uncertain internal map order
ASSERT_EQ(binary->data.size(), copy_binary->data.size());
ASSERT_EQ(binary->data, copy_binary->data);
auto copy_binary_set = copy_index_wrapper->Serialize();
ASSERT_EQ(binary_set.binary_map_.size(), copy_binary_set.binary_map_.size());
for (const auto& [k, v] : binary_set.binary_map_) {
ASSERT_TRUE(copy_binary_set.Contains(k));
}
}
TEST_P(IndexWrapperTest, Query) {
auto index_wrapper =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
std::make_unique<milvus::indexbuilder::VecIndexCreator>(type_params_str.c_str(), index_params_str.c_str());
index_wrapper->BuildWithoutIds(xb_dataset);
std::unique_ptr<milvus::indexbuilder::IndexWrapper::QueryResult> query_result = index_wrapper->Query(xq_dataset);
std::unique_ptr<milvus::indexbuilder::VecIndexCreator::QueryResult> query_result = index_wrapper->Query(xq_dataset);
ASSERT_EQ(query_result->topk, K);
ASSERT_EQ(query_result->nq, NQ);
ASSERT_EQ(query_result->distances.size(), query_result->topk * query_result->nq);
......
此差异已折叠。
......@@ -98,76 +98,76 @@ TEST(SegmentCoreTest, MockTest) {
// Test insert column-based data
TEST(SegmentCoreTest, MockTest2) {
using namespace milvus::segcore;
using namespace milvus::engine;
// schema
auto schema = std::make_shared<Schema>();
schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddDebugField("age", DataType::INT32);
// generate random row-based data
std::vector<char> row_data;
std::vector<Timestamp> timestamps;
std::vector<int64_t> uids;
int N = 10000; // number of records
std::default_random_engine e(67);
for (int i = 0; i < N; ++i) {
uids.push_back(100000 + i);
timestamps.push_back(0);
// append vec
float vec[16];
for (auto& x : vec) {
x = e() % 2000 * 0.001 - 1.0;
}
row_data.insert(row_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec));
int age = e() % 100;
row_data.insert(row_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age));
}
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
assert(row_data.size() == line_sizeof * N);
int64_t size = N;
const int64_t* uids_raw = uids.data();
const Timestamp* timestamps_raw = timestamps.data();
std::vector<std::tuple<Timestamp, idx_t, int64_t>> ordering(size); // timestamp, pk, order_index
for (int i = 0; i < size; ++i) {
ordering[i] = std::make_tuple(timestamps_raw[i], uids_raw[i], i);
}
std::sort(ordering.begin(), ordering.end()); // sort according to timestamp
// convert row-based data to column-based data accordingly
auto sizeof_infos = schema->get_sizeof_infos();
std::vector<int> offset_infos(schema->size() + 1, 0);
std::partial_sum(sizeof_infos.begin(), sizeof_infos.end(), offset_infos.begin() + 1);
std::vector<aligned_vector<uint8_t>> entities(schema->size());
for (int fid = 0; fid < schema->size(); ++fid) {
auto len = sizeof_infos[fid];
entities[fid].resize(len * size);
}
auto raw_data = row_data.data();
std::vector<idx_t> sorted_uids(size);
std::vector<Timestamp> sorted_timestamps(size);
for (int index = 0; index < size; ++index) {
auto [t, uid, order_index] = ordering[index];
sorted_timestamps[index] = t;
sorted_uids[index] = uid;
for (int fid = 0; fid < schema->size(); ++fid) {
auto len = sizeof_infos[fid];
auto offset = offset_infos[fid];
auto src = raw_data + order_index * line_sizeof + offset;
auto dst = entities[fid].data() + index * len;
memcpy(dst, src, len);
}
}
// insert column-based data
ColumnBasedRawData data_chunk{entities, N};
auto segment = CreateGrowingSegment(schema);
auto reserved_begin = segment->PreInsert(N);
segment->Insert(reserved_begin, size, sorted_uids.data(), sorted_timestamps.data(), data_chunk);
using namespace milvus::segcore;
using namespace milvus::engine;
// schema
auto schema = std::make_shared<Schema>();
schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddDebugField("age", DataType::INT32);
// generate random row-based data
std::vector<char> row_data;
std::vector<Timestamp> timestamps;
std::vector<int64_t> uids;
int N = 10000; // number of records
std::default_random_engine e(67);
for (int i = 0; i < N; ++i) {
uids.push_back(100000 + i);
timestamps.push_back(0);
// append vec
float vec[16];
for (auto& x : vec) {
x = e() % 2000 * 0.001 - 1.0;
}
row_data.insert(row_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec));
int age = e() % 100;
row_data.insert(row_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age));
}
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
assert(row_data.size() == line_sizeof * N);
int64_t size = N;
const int64_t* uids_raw = uids.data();
const Timestamp* timestamps_raw = timestamps.data();
std::vector<std::tuple<Timestamp, idx_t, int64_t>> ordering(size); // timestamp, pk, order_index
for (int i = 0; i < size; ++i) {
ordering[i] = std::make_tuple(timestamps_raw[i], uids_raw[i], i);
}
std::sort(ordering.begin(), ordering.end()); // sort according to timestamp
// convert row-based data to column-based data accordingly
auto sizeof_infos = schema->get_sizeof_infos();
std::vector<int> offset_infos(schema->size() + 1, 0);
std::partial_sum(sizeof_infos.begin(), sizeof_infos.end(), offset_infos.begin() + 1);
std::vector<aligned_vector<uint8_t>> entities(schema->size());
for (int fid = 0; fid < schema->size(); ++fid) {
auto len = sizeof_infos[fid];
entities[fid].resize(len * size);
}
auto raw_data = row_data.data();
std::vector<idx_t> sorted_uids(size);
std::vector<Timestamp> sorted_timestamps(size);
for (int index = 0; index < size; ++index) {
auto [t, uid, order_index] = ordering[index];
sorted_timestamps[index] = t;
sorted_uids[index] = uid;
for (int fid = 0; fid < schema->size(); ++fid) {
auto len = sizeof_infos[fid];
auto offset = offset_infos[fid];
auto src = raw_data + order_index * line_sizeof + offset;
auto dst = entities[fid].data() + index * len;
memcpy(dst, src, len);
}
}
// insert column-based data
ColumnBasedRawData data_chunk{entities, N};
auto segment = CreateGrowingSegment(schema);
auto reserved_begin = segment->PreInsert(N);
segment->Insert(reserved_begin, size, sorted_uids.data(), sorted_timestamps.data(), data_chunk);
}
TEST(SegmentCoreTest, SmallIndex) {
......
......@@ -22,18 +22,29 @@
#include <knowhere/index/vector_index/VecIndexFactory.h>
#include "pb/index_cgo_msg.pb.h"
#include "indexbuilder/IndexWrapper.h"
#include "indexbuilder/VecIndexCreator.h"
#include "indexbuilder/index_c.h"
#include "DataGen.h"
#include "indexbuilder/utils.h"
#include "indexbuilder/helper.h"
#define private public
#include "indexbuilder/ScalarIndexCreator.h"
constexpr int64_t DIM = 128;
constexpr int64_t DIM = 8;
constexpr int64_t NQ = 10;
constexpr int64_t K = 4;
#ifdef MILVUS_GPU_VERSION
int DEVICEID = 0;
#endif
namespace indexcgo = milvus::proto::indexcgo;
namespace schemapb = milvus::proto::schema;
using milvus::indexbuilder::MapParams;
using milvus::indexbuilder::ScalarIndexCreator;
using milvus::knowhere::scalar::OperatorType;
using ScalarTestParams = std::pair<MapParams, MapParams>;
namespace {
auto
generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowhere::MetricType& metric_type) {
......@@ -234,7 +245,7 @@ GenDataset(int64_t N, const milvus::knowhere::MetricType& metric_type, bool is_b
}
}
using QueryResultPtr = std::unique_ptr<milvus::indexbuilder::IndexWrapper::QueryResult>;
using QueryResultPtr = std::unique_ptr<milvus::indexbuilder::VecIndexCreator::QueryResult>;
void
PrintQueryResult(const QueryResultPtr& result) {
auto nq = result->nq;
......@@ -327,4 +338,116 @@ CheckDistances(const QueryResultPtr& result,
}
}
}
auto
generate_type_params(const MapParams& m) {
indexcgo::TypeParams p;
for (const auto& [k, v] : m) {
auto kv = p.add_params();
kv->set_key(k);
kv->set_value(v);
}
std::string str;
auto ok = google::protobuf::TextFormat::PrintToString(p, &str);
Assert(ok);
return str;
}
auto
generate_index_params(const MapParams& m) {
indexcgo::IndexParams p;
for (const auto& [k, v] : m) {
auto kv = p.add_params();
kv->set_key(k);
kv->set_value(v);
}
std::string str;
auto ok = google::protobuf::TextFormat::PrintToString(p, &str);
Assert(ok);
return str;
}
// TODO: std::is_arithmetic_v, hard to compare float point value. std::is_integral_v.
template <typename T, typename = typename std::enable_if_t<std::is_arithmetic_v<T>>>
inline auto
GenArr(int64_t n) {
auto max_i8 = std::numeric_limits<int8_t>::max() - 1;
std::vector<T> arr;
arr.resize(n);
for (int64_t i = 0; i < n; i++) {
arr[i] = static_cast<T>(rand() % max_i8);
}
std::sort(arr.begin(), arr.end());
return arr;
}
inline auto
GenStrArr(int64_t n) {
using T = std::string;
std::vector<T> arr;
arr.resize(n);
for (int64_t i = 0; i < n; i++) {
auto gen = std::to_string(std::rand());
arr[i] = gen;
}
std::sort(arr.begin(), arr.end());
return arr;
}
template <typename T, typename = typename std::enable_if_t<std::is_arithmetic_v<T>>>
inline std::vector<ScalarTestParams>
GenParams() {
std::vector<ScalarTestParams> ret;
ret.emplace_back(ScalarTestParams(MapParams(), {{"index_type", "inverted_index"}}));
ret.emplace_back(ScalarTestParams(MapParams(), {{"index_type", "flat"}}));
return ret;
}
std::vector<ScalarTestParams>
GenBoolParams() {
std::vector<ScalarTestParams> ret;
ret.emplace_back(ScalarTestParams(MapParams(), {{"index_type", "inverted_index"}}));
ret.emplace_back(ScalarTestParams(MapParams(), {{"index_type", "flat"}}));
return ret;
}
std::vector<ScalarTestParams>
GenStringParams() {
std::vector<ScalarTestParams> ret;
ret.emplace_back(ScalarTestParams(MapParams(), {{"index_type", "marisa-trie"}}));
return ret;
}
void
PrintMapParam(const ScalarTestParams& tp) {
for (const auto& [k, v] : tp.first) {
std::cout << "k: " << k << ", v: " << v << std::endl;
}
for (const auto& [k, v] : tp.second) {
std::cout << "k: " << k << ", v: " << v << std::endl;
}
}
void
PrintMapParams(const std::vector<ScalarTestParams>& tps) {
for (const auto& tp : tps) {
PrintMapParam(tp);
}
}
template <typename T>
inline void
build_index(const std::unique_ptr<ScalarIndexCreator<T>>& creator, const std::vector<T>& arr) {
const int64_t dim = 8; // not important here
auto dataset = milvus::knowhere::GenDataset(arr.size(), dim, arr.data());
creator->Build(dataset);
}
// memory generated by this function should be freed by the caller.
auto
GenDsFromPB(const google::protobuf::Message& msg) {
auto data = new char[msg.ByteSize()];
msg.SerializeToArray(data, msg.ByteSize());
return milvus::knowhere::GenDataset(msg.ByteSize(), 8, data);
}
} // namespace
......@@ -54,7 +54,10 @@ func estimateIndexSize(dim int64, numRows int64, dataType schemapb.DataType) (ui
return uint64(dim) / 8 * uint64(numRows), nil
}
errMsg := "the field to build index must be a vector field"
log.Error(errMsg)
return 0, errors.New(errMsg)
// TODO: optimize here.
return 0, nil
// errMsg := "the field to build index must be a vector field"
// log.Error(errMsg)
// return 0, errors.New(errMsg)
}
......@@ -79,6 +79,8 @@ func Test_estimateIndexSize(t *testing.T) {
assert.Equal(t, uint64(200), memorySize)
memorySize, err = estimateIndexSize(10, 100, schemapb.DataType_Float)
assert.Error(t, err)
assert.Nil(t, err)
assert.Equal(t, uint64(0), memorySize)
// assert.Error(t, err)
// assert.Equal(t, uint64(0), memorySize)
}
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package indexnode
/*
#cgo CFLAGS: -I${SRCDIR}/../core/output/include
#cgo darwin LDFLAGS: -L${SRCDIR}/../core/output/lib -lmilvus_common -lmilvus_indexbuilder -Wl,-rpath,"${SRCDIR}/../core/output/lib"
#cgo linux LDFLAGS: -L${SRCDIR}/../core/output/lib -lmilvus_common -lmilvus_indexbuilder -Wl,-rpath=${SRCDIR}/../core/output/lib
#cgo windows LDFLAGS: -L${SRCDIR}/../core/output/lib -lmilvus_common -lmilvus_indexbuilder -Wl,-rpath=${SRCDIR}/../core/output/lib
#include <stdlib.h> // free
#include "segcore/collection_c.h"
#include "indexbuilder/index_c.h"
#include "common/vector_index_c.h"
*/
import "C"
import (
"errors"
"fmt"
"path/filepath"
"runtime"
"unsafe"
"go.uber.org/zap"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/indexcgopb"
"github.com/milvus-io/milvus/internal/storage"
)
// Blob is an alias for the storage.Blob type
type Blob = storage.Blob
// Index is an interface used to call the interface to build the index task in 'C'.
type Index interface {
Serialize() ([]*Blob, error)
Load([]*Blob) error
BuildFloatVecIndexWithoutIds(vectors []float32) error
BuildBinaryVecIndexWithoutIds(vectors []byte) error
Delete() error
}
// CIndex is a pointer used to access 'CGO'.
type CIndex struct {
indexPtr C.CIndex
close bool
}
func GetBinarySetKeys(cBinarySet C.CBinarySet) ([]string, error) {
size := int(C.GetBinarySetSize(cBinarySet))
if size == 0 {
return nil, fmt.Errorf("BinarySet size is zero!")
}
datas := make([]unsafe.Pointer, size)
C.GetBinarySetKeys(cBinarySet, unsafe.Pointer(&datas[0]))
ret := make([]string, size)
for i := 0; i < size; i++ {
ret[i] = C.GoString((*C.char)(datas[i]))
}
return ret, nil
}
func GetBinarySetValue(cBinarySet C.CBinarySet, key string) ([]byte, error) {
cIndexKey := C.CString(key)
defer C.free(unsafe.Pointer(cIndexKey))
ret := C.GetBinarySetValueSize(cBinarySet, cIndexKey)
size := int(ret)
if size == 0 {
return nil, fmt.Errorf("GetBinarySetValueSize size is zero!")
}
value := make([]byte, size)
status := C.CopyBinarySetValue(unsafe.Pointer(&value[0]), cIndexKey, cBinarySet)
if err := HandleCStatus(&status, "CopyBinarySetValue failed"); err != nil {
return nil, err
}
return value, nil
}
func (index *CIndex) Serialize() ([]*Blob, error) {
var cBinarySet C.CBinarySet
status := C.SerializeToBinarySet(index.indexPtr, &cBinarySet)
defer func() {
if cBinarySet != nil {
C.DeleteBinarySet(cBinarySet)
}
}()
if err := HandleCStatus(&status, "SerializeToBinarySet failed"); err != nil {
return nil, err
}
keys, err := GetBinarySetKeys(cBinarySet)
if err != nil {
return nil, err
}
ret := make([]*Blob, 0)
for _, key := range keys {
value, err := GetBinarySetValue(cBinarySet, key)
if err != nil {
return nil, err
}
blob := &Blob{
Key: key,
Value: value,
}
ret = append(ret, blob)
}
return ret, nil
}
// Serialize serializes vector data into bytes data so that it can be accessed in 'C'.
/*
func (index *CIndex) Serialize() ([]*Blob, error) {
var cBinary C.CBinary
status := C.SerializeToSlicedBuffer(index.indexPtr, &cBinary)
defer func() {
if cBinary != nil {
C.DeleteCBinary(cBinary)
}
}()
if err := HandleCStatus(&status, "SerializeToSlicedBuffer failed"); err != nil {
return nil, err
}
binarySize := C.GetCBinarySize(cBinary)
binaryData := make([]byte, binarySize)
C.GetCBinaryData(cBinary, unsafe.Pointer(&binaryData[0]))
var blobs indexcgopb.BinarySet
err := proto.Unmarshal(binaryData, &blobs)
if err != nil {
return nil, err
}
ret := make([]*Blob, 0)
for _, data := range blobs.Datas {
ret = append(ret, &Blob{Key: data.Key, Value: data.Value})
}
return ret, nil
}
*/
func (index *CIndex) Load(blobs []*Blob) error {
var cBinarySet C.CBinarySet
status := C.NewBinarySet(&cBinarySet)
defer C.DeleteBinarySet(cBinarySet)
if err := HandleCStatus(&status, "CIndex Load2 NewBinarySet failed"); err != nil {
return err
}
for _, blob := range blobs {
key := blob.Key
byteIndex := blob.Value
indexPtr := unsafe.Pointer(&byteIndex[0])
indexLen := C.int64_t(len(byteIndex))
binarySetKey := filepath.Base(key)
log.Debug("", zap.String("index key", binarySetKey))
indexKey := C.CString(binarySetKey)
status = C.AppendIndexBinary(cBinarySet, indexPtr, indexLen, indexKey)
C.free(unsafe.Pointer(indexKey))
if err := HandleCStatus(&status, "CIndex Load AppendIndexBinary failed"); err != nil {
return err
}
}
status = C.LoadFromBinarySet(index.indexPtr, cBinarySet)
return HandleCStatus(&status, "AppendIndex failed")
}
// Load loads data from 'C'.
/*
func (index *CIndex) Load(blobs []*Blob) error {
binarySet := &indexcgopb.BinarySet{Datas: make([]*indexcgopb.Binary, 0)}
for _, blob := range blobs {
binarySet.Datas = append(binarySet.Datas, &indexcgopb.Binary{Key: blob.Key, Value: blob.Value})
}
datas, err2 := proto.Marshal(binarySet)
if err2 != nil {
return err2
}
status := C.LoadFromSlicedBuffer(index.indexPtr, (*C.char)(unsafe.Pointer(&datas[0])), (C.int32_t)(len(datas)))
return HandleCStatus(&status, "LoadFromSlicedBuffer failed")
}
*/
// BuildFloatVecIndexWithoutIds builds indexes for float vector.
func (index *CIndex) BuildFloatVecIndexWithoutIds(vectors []float32) error {
/*
CStatus
BuildFloatVecIndexWithoutIds(CIndex index, int64_t float_value_num, const float* vectors);
*/
log.Debug("before BuildFloatVecIndexWithoutIds")
status := C.BuildFloatVecIndexWithoutIds(index.indexPtr, (C.int64_t)(len(vectors)), (*C.float)(&vectors[0]))
return HandleCStatus(&status, "BuildFloatVecIndexWithoutIds failed")
}
// BuildBinaryVecIndexWithoutIds builds indexes for binary vector.
func (index *CIndex) BuildBinaryVecIndexWithoutIds(vectors []byte) error {
/*
CStatus
BuildBinaryVecIndexWithoutIds(CIndex index, int64_t data_size, const uint8_t* vectors);
*/
status := C.BuildBinaryVecIndexWithoutIds(index.indexPtr, (C.int64_t)(len(vectors)), (*C.uint8_t)(&vectors[0]))
return HandleCStatus(&status, "BuildBinaryVecIndexWithoutIds failed")
}
// Delete removes the pointer to build the index in 'C'. we can ensure that it is idempotent.
func (index *CIndex) Delete() error {
/*
void
DeleteIndex(CIndex index);
*/
if index.close {
return nil
}
C.DeleteIndex(index.indexPtr)
index.close = true
return nil
}
// NewCIndex creates a new pointer to build the index in 'C'.
func NewCIndex(typeParams, indexParams map[string]string) (Index, error) {
protoTypeParams := &indexcgopb.TypeParams{
Params: make([]*commonpb.KeyValuePair, 0),
}
for key, value := range typeParams {
protoTypeParams.Params = append(protoTypeParams.Params, &commonpb.KeyValuePair{Key: key, Value: value})
}
typeParamsStr := proto.MarshalTextString(protoTypeParams)
protoIndexParams := &indexcgopb.IndexParams{
Params: make([]*commonpb.KeyValuePair, 0),
}
for key, value := range indexParams {
protoIndexParams.Params = append(protoIndexParams.Params, &commonpb.KeyValuePair{Key: key, Value: value})
}
indexParamsStr := proto.MarshalTextString(protoIndexParams)
typeParamsPointer := C.CString(typeParamsStr)
indexParamsPointer := C.CString(indexParamsStr)
defer C.free(unsafe.Pointer(typeParamsPointer))
defer C.free(unsafe.Pointer(indexParamsPointer))
/*
CStatus
CreateIndex(const char* serialized_type_params,
const char* serialized_index_params,
CIndex* res_index);
*/
var indexPtr C.CIndex
log.Debug("Start to create index ...", zap.String("params", indexParamsStr))
status := C.CreateIndex(typeParamsPointer, indexParamsPointer, &indexPtr)
if err := HandleCStatus(&status, "CreateIndex failed"); err != nil {
return nil, err
}
log.Debug("Successfully create index ...")
index := &CIndex{
indexPtr: indexPtr,
close: false,
}
runtime.SetFinalizer(index, func(index *CIndex) {
if index != nil && !index.close {
log.Error("there is leakage in index object, please check.")
}
})
return index, nil
}
// HandleCStatus deal with the error returned from CGO
func HandleCStatus(status *C.CStatus, extraInfo string) error {
if status.error_code == 0 {
return nil
}
errorCode := status.error_code
errorName, ok := commonpb.ErrorCode_name[int32(errorCode)]
if !ok {
errorName = "UnknownError"
}
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
finalMsg := fmt.Sprintf("[%s] %s", errorName, errorMsg)
logMsg := fmt.Sprintf("%s, C Runtime Exception: %s\n", extraInfo, finalMsg)
log.Warn(logMsg)
return errors.New(finalMsg)
}
此差异已折叠。
......@@ -25,6 +25,10 @@ import (
"runtime/debug"
"strconv"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/indexcgowrapper"
"github.com/milvus-io/milvus/internal/metrics"
"go.uber.org/zap"
......@@ -49,6 +53,8 @@ const (
IndexBuildTaskName = "IndexBuildTask"
)
type Blob = storage.Blob
type task interface {
Ctx() context.Context
ID() UniqueID // return ReqID
......@@ -118,8 +124,8 @@ func (bt *BaseTask) Notify(err error) {
// IndexBuildTask is used to record the information of the index tasks.
type IndexBuildTask struct {
BaseTask
index Index
cm storage.ChunkManager
index indexcgowrapper.CodecIndex
etcdKV *etcdkv.EtcdKV
savePaths []string
req *indexpb.CreateIndexRequest
......@@ -326,7 +332,7 @@ func (it *IndexBuildTask) prepareParams(ctx context.Context) error {
return nil
}
func (it *IndexBuildTask) loadVector(ctx context.Context) (storage.FieldID, storage.FieldData, error) {
func (it *IndexBuildTask) loadFieldData(ctx context.Context) (storage.FieldID, storage.FieldData, error) {
getValueByPath := func(path string) ([]byte, error) {
data, err := it.cm.Read(path)
if err != nil {
......@@ -370,7 +376,7 @@ func (it *IndexBuildTask) loadVector(ctx context.Context) (storage.FieldID, stor
}
loadVectorDuration := it.tr.RecordSpan()
log.Debug("IndexNode load data success", zap.Int64("buildId", it.req.IndexBuildID))
it.tr.Record("load vector data done")
it.tr.Record("load field data done")
var insertCodec storage.InsertCodec
collectionID, partitionID, segmentID, insertData, err2 := insertCodec.DeserializeAll(blobs)
......@@ -415,33 +421,29 @@ func (it *IndexBuildTask) buildIndex(ctx context.Context) ([]*storage.Blob, erro
{
var err error
var fieldData storage.FieldData
fieldID, fieldData, err = it.loadVector(ctx)
fieldID, fieldData, err = it.loadFieldData(ctx)
if err != nil {
return nil, err
}
floatVectorFieldData, fOk := fieldData.(*storage.FloatVectorFieldData)
if fOk {
err := it.index.BuildFloatVecIndexWithoutIds(floatVectorFieldData.Data)
dataset := indexcgowrapper.GenDataset(fieldData)
dType := dataset.DType
if dType != schemapb.DataType_None {
it.index, err = indexcgowrapper.NewCgoIndex(dType, it.newTypeParams, it.newIndexParams)
if err != nil {
log.Error("IndexNode BuildFloatVecIndexWithoutIds failed", zap.Error(err))
log.Error("failed to create index", zap.Error(err))
return nil, err
}
}
binaryVectorFieldData, bOk := fieldData.(*storage.BinaryVectorFieldData)
if bOk {
err := it.index.BuildBinaryVecIndexWithoutIds(binaryVectorFieldData.Data)
err = it.index.Build(dataset)
if err != nil {
log.Error("IndexNode BuildBinaryVecIndexWithoutIds failed", zap.Error(err))
log.Error("failed to build index", zap.Error(err))
return nil, err
}
}
metrics.IndexNodeKnowhereBuildIndexLatency.WithLabelValues(strconv.FormatInt(Params.IndexNodeCfg.NodeID, 10)).Observe(float64(it.tr.RecordSpan()))
if !fOk && !bOk {
return nil, errors.New("we expect FloatVectorFieldData or BinaryVectorFieldData")
}
it.tr.Record("build index done")
}
......@@ -556,25 +558,7 @@ func (it *IndexBuildTask) Execute(ctx context.Context) error {
defer it.releaseMemory()
var err error
it.index, err = NewCIndex(it.newTypeParams, it.newIndexParams)
if err != nil {
it.SetState(TaskStateFailed)
log.Error("IndexNode IndexBuildTask Execute NewCIndex failed",
zap.Int64("buildId", it.req.IndexBuildID),
zap.Error(err))
return err
}
defer func() {
err := it.index.Delete()
if err != nil {
log.Error("IndexNode IndexBuildTask Execute CIndexDelete failed",
zap.Int64("buildId", it.req.IndexBuildID),
zap.Error(err))
}
}()
it.tr.Record("new CIndex")
var blobs []*storage.Blob
blobs, err = it.buildIndex(ctx)
if err != nil {
......
......@@ -18,6 +18,10 @@ message MapParams {
repeated common.KeyValuePair params = 1;
}
message MapParamsV2 {
map<string, string> params = 1;
}
message Binary {
string key = 1;
bytes value = 2;
......
......@@ -139,6 +139,45 @@ func (m *MapParams) GetParams() []*commonpb.KeyValuePair {
return nil
}
type MapParamsV2 struct {
Params map[string]string `protobuf:"bytes,1,rep,name=params,proto3" json:"params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *MapParamsV2) Reset() { *m = MapParamsV2{} }
func (m *MapParamsV2) String() string { return proto.CompactTextString(m) }
func (*MapParamsV2) ProtoMessage() {}
func (*MapParamsV2) Descriptor() ([]byte, []int) {
return fileDescriptor_c009bd9544a7343c, []int{3}
}
func (m *MapParamsV2) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_MapParamsV2.Unmarshal(m, b)
}
func (m *MapParamsV2) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_MapParamsV2.Marshal(b, m, deterministic)
}
func (m *MapParamsV2) XXX_Merge(src proto.Message) {
xxx_messageInfo_MapParamsV2.Merge(m, src)
}
func (m *MapParamsV2) XXX_Size() int {
return xxx_messageInfo_MapParamsV2.Size(m)
}
func (m *MapParamsV2) XXX_DiscardUnknown() {
xxx_messageInfo_MapParamsV2.DiscardUnknown(m)
}
var xxx_messageInfo_MapParamsV2 proto.InternalMessageInfo
func (m *MapParamsV2) GetParams() map[string]string {
if m != nil {
return m.Params
}
return nil
}
type Binary struct {
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
......@@ -151,7 +190,7 @@ func (m *Binary) Reset() { *m = Binary{} }
func (m *Binary) String() string { return proto.CompactTextString(m) }
func (*Binary) ProtoMessage() {}
func (*Binary) Descriptor() ([]byte, []int) {
return fileDescriptor_c009bd9544a7343c, []int{3}
return fileDescriptor_c009bd9544a7343c, []int{4}
}
func (m *Binary) XXX_Unmarshal(b []byte) error {
......@@ -197,7 +236,7 @@ func (m *BinarySet) Reset() { *m = BinarySet{} }
func (m *BinarySet) String() string { return proto.CompactTextString(m) }
func (*BinarySet) ProtoMessage() {}
func (*BinarySet) Descriptor() ([]byte, []int) {
return fileDescriptor_c009bd9544a7343c, []int{4}
return fileDescriptor_c009bd9544a7343c, []int{5}
}
func (m *BinarySet) XXX_Unmarshal(b []byte) error {
......@@ -229,6 +268,8 @@ func init() {
proto.RegisterType((*TypeParams)(nil), "milvus.proto.indexcgo.TypeParams")
proto.RegisterType((*IndexParams)(nil), "milvus.proto.indexcgo.IndexParams")
proto.RegisterType((*MapParams)(nil), "milvus.proto.indexcgo.MapParams")
proto.RegisterType((*MapParamsV2)(nil), "milvus.proto.indexcgo.MapParamsV2")
proto.RegisterMapType((map[string]string)(nil), "milvus.proto.indexcgo.MapParamsV2.ParamsEntry")
proto.RegisterType((*Binary)(nil), "milvus.proto.indexcgo.Binary")
proto.RegisterType((*BinarySet)(nil), "milvus.proto.indexcgo.BinarySet")
}
......@@ -236,21 +277,24 @@ func init() {
func init() { proto.RegisterFile("index_cgo_msg.proto", fileDescriptor_c009bd9544a7343c) }
var fileDescriptor_c009bd9544a7343c = []byte{
// 250 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0xce, 0xcc, 0x4b, 0x49,
0xad, 0x88, 0x4f, 0x4e, 0xcf, 0x8f, 0xcf, 0x2d, 0x4e, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17,
0x12, 0xcd, 0xcd, 0xcc, 0x29, 0x2b, 0x2d, 0x86, 0xf0, 0xf4, 0xc0, 0x2a, 0x92, 0xd3, 0xf3, 0xa5,
0x78, 0x92, 0xf3, 0x73, 0x73, 0xf3, 0xf3, 0x20, 0xc2, 0x4a, 0xee, 0x5c, 0x5c, 0x21, 0x95, 0x05,
0xa9, 0x01, 0x89, 0x45, 0x89, 0xb9, 0xc5, 0x42, 0x96, 0x5c, 0x6c, 0x05, 0x60, 0x96, 0x04, 0xa3,
0x02, 0xb3, 0x06, 0xb7, 0x91, 0xa2, 0x1e, 0x8a, 0x19, 0x50, 0x9d, 0xde, 0xa9, 0x95, 0x61, 0x89,
0x39, 0xa5, 0xa9, 0x01, 0x89, 0x99, 0x45, 0x41, 0x50, 0x0d, 0x4a, 0x1e, 0x5c, 0xdc, 0x9e, 0x20,
0x2b, 0x28, 0x37, 0xc9, 0x8d, 0x8b, 0xd3, 0x37, 0xb1, 0x80, 0x72, 0x73, 0x0c, 0xb8, 0xd8, 0x9c,
0x32, 0xf3, 0x12, 0x8b, 0x2a, 0x85, 0x04, 0xb8, 0x98, 0xb3, 0x53, 0x2b, 0x25, 0x18, 0x15, 0x18,
0x35, 0x38, 0x83, 0x40, 0x4c, 0x21, 0x11, 0x2e, 0xd6, 0x32, 0x90, 0x06, 0x09, 0x26, 0x05, 0x46,
0x0d, 0x9e, 0x20, 0x08, 0x47, 0xc9, 0x81, 0x8b, 0x13, 0xa2, 0x23, 0x38, 0xb5, 0x44, 0xc8, 0x98,
0x8b, 0x35, 0x25, 0xb1, 0x24, 0x11, 0x66, 0xb1, 0xac, 0x1e, 0xd6, 0xe0, 0xd4, 0x83, 0x68, 0x08,
0x82, 0xa8, 0x75, 0x32, 0x8f, 0x32, 0x4d, 0xcf, 0x2c, 0xc9, 0x28, 0x4d, 0x02, 0xb9, 0x4c, 0x1f,
0xa2, 0x43, 0x37, 0x33, 0x1f, 0xca, 0xd2, 0xcf, 0xcc, 0x2b, 0x49, 0x2d, 0xca, 0x4b, 0xcc, 0xd1,
0x07, 0x1b, 0xa2, 0x0f, 0x33, 0xa4, 0x20, 0x29, 0x89, 0x0d, 0x2c, 0x62, 0x0c, 0x08, 0x00, 0x00,
0xff, 0xff, 0xaf, 0x1f, 0x55, 0x04, 0xca, 0x01, 0x00, 0x00,
// 289 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x90, 0x41, 0x4b, 0xc3, 0x30,
0x14, 0xc7, 0xe9, 0xc6, 0x0a, 0x7d, 0xdd, 0x41, 0xa2, 0x42, 0x19, 0x08, 0xb3, 0xa7, 0x5d, 0x4c,
0x65, 0x43, 0x74, 0x9e, 0x64, 0xe0, 0x54, 0x44, 0x18, 0x55, 0x76, 0xf0, 0x32, 0xd2, 0x2e, 0xd4,
0x60, 0x9b, 0x94, 0x34, 0x1d, 0xf6, 0x5b, 0xf8, 0x91, 0xa5, 0x49, 0x2b, 0x53, 0x94, 0x1d, 0x76,
0xfb, 0xe7, 0xcf, 0xfb, 0xfd, 0xda, 0xf7, 0xe0, 0x90, 0xf1, 0x35, 0xfd, 0x58, 0xc5, 0x89, 0x58,
0x65, 0x45, 0x82, 0x73, 0x29, 0x94, 0x40, 0xc7, 0x19, 0x4b, 0x37, 0x65, 0x61, 0x5e, 0x58, 0x4f,
0xc4, 0x89, 0x18, 0xf4, 0x63, 0x91, 0x65, 0x82, 0x9b, 0xda, 0xbf, 0x03, 0x78, 0xa9, 0x72, 0xba,
0x20, 0x92, 0x64, 0x05, 0x9a, 0x82, 0x9d, 0xeb, 0xe4, 0x59, 0xc3, 0xee, 0xc8, 0x1d, 0x9f, 0xe2,
0x1f, 0x8e, 0x86, 0x7c, 0xa4, 0xd5, 0x92, 0xa4, 0x25, 0x5d, 0x10, 0x26, 0xc3, 0x06, 0xf0, 0xef,
0xc1, 0x7d, 0xa8, 0x3f, 0xb1, 0xbf, 0x69, 0x0e, 0xce, 0x13, 0xc9, 0xf7, 0xf7, 0x7c, 0x5a, 0xe0,
0x7e, 0x8b, 0x96, 0x63, 0x34, 0xff, 0xa5, 0xc2, 0xf8, 0xcf, 0x03, 0xe1, 0x2d, 0x06, 0x9b, 0x70,
0xcb, 0x95, 0xac, 0x5a, 0xef, 0x60, 0x0a, 0xee, 0x56, 0x8d, 0x0e, 0xa0, 0xfb, 0x4e, 0x2b, 0xcf,
0x1a, 0x5a, 0x23, 0x27, 0xac, 0x23, 0x3a, 0x82, 0xde, 0xa6, 0xfe, 0x1b, 0xaf, 0xa3, 0x3b, 0xf3,
0xb8, 0xee, 0x5c, 0x59, 0xfe, 0x39, 0xd8, 0x33, 0xc6, 0xc9, 0x6e, 0xaa, 0xdf, 0x50, 0xfe, 0x0d,
0x38, 0x86, 0x78, 0xa6, 0x0a, 0x4d, 0xa0, 0xb7, 0x26, 0x8a, 0xb4, 0x0b, 0x9c, 0xfc, 0xb3, 0x80,
0x01, 0x42, 0x33, 0x3b, 0xbb, 0x7c, 0xbd, 0x48, 0x98, 0x7a, 0x2b, 0xa3, 0xfa, 0x58, 0x81, 0x21,
0xce, 0x98, 0x68, 0x52, 0xc0, 0xb8, 0xa2, 0x92, 0x93, 0x34, 0xd0, 0x92, 0xa0, 0x95, 0xe4, 0x51,
0x64, 0xeb, 0x66, 0xf2, 0x15, 0x00, 0x00, 0xff, 0xff, 0x83, 0x12, 0x13, 0xfb, 0x5d, 0x02, 0x00,
0x00,
}
......@@ -440,7 +440,9 @@ message CreateIndexRequest {
// The vector field name in this particular collection
string field_name = 4;
// Support keys: index_type,metric_type, params. Different index_type may has different params.
repeated common.KeyValuePair extra_params = 5;
repeated common.KeyValuePair extra_params = 5;
// Version before 2.0.2 doesn't contain index_name, we use default index name.
string index_name = 6;
}
/*
......
......@@ -3213,6 +3213,12 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error {
collName, fieldName := cit.CollectionName, cit.FieldName
collID, err := globalMetaCache.GetCollectionID(ctx, collName)
if err != nil {
return err
}
cit.collectionID = collID
if err := validateCollectionName(collName); err != nil {
return err
}
......@@ -3245,6 +3251,18 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error {
indexType = indexparamcheck.IndexFaissIvfPQ // IVF_PQ is the default index type
}
// skip params check of non-vector field.
vecDataTypes := []schemapb.DataType{
schemapb.DataType_FloatVector,
schemapb.DataType_BinaryVector,
}
schema, _ := globalMetaCache.GetCollectionSchema(ctx, collName)
for _, f := range schema.GetFields() {
if f.GetName() == fieldName && !funcutil.SliceContain(vecDataTypes, f.GetDataType()) {
return indexparamcheck.CheckIndexValid(f.GetDataType(), indexType, indexParams)
}
}
adapter, err := indexparamcheck.GetConfAdapterMgrInstance().GetAdapter(indexType)
if err != nil {
log.Warn("Failed to get conf adapter", zap.String("index_type", indexType))
......@@ -3257,8 +3275,6 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error {
return fmt.Errorf("invalid index params: %v", cit.CreateIndexRequest.ExtraParams)
}
collID, _ := globalMetaCache.GetCollectionID(ctx, collName)
cit.collectionID = collID
return nil
}
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
//go:build linux
// +build linux
package indexcgowrapper
// TODO: add a benchmark to check if any leakage in cgo.
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册