未验证 提交 48706f41 编写于 作者: J Jiquan Long 提交者: GitHub

Migrate scalar index from knowhere (#16174)

Signed-off-by: Ndragondriver <jiquan.long@zilliz.com>
上级 66146944
......@@ -27,3 +27,4 @@ add_subdirectory( query )
add_subdirectory( common )
add_subdirectory( indexbuilder )
add_subdirectory( config )
add_subdirectory( index )
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "common/type_c.h"
#include <string>
namespace milvus {
template <typename T, typename = std::enable_if_t<std::is_fundamental_v<T> || std::is_same_v<T, std::string>>>
inline CDataType
GetDType() {
return None;
}
template <>
inline CDataType
GetDType<bool>() {
return Bool;
}
template <>
inline CDataType
GetDType<int8_t>() {
return Int8;
}
template <>
inline CDataType
GetDType<int16_t>() {
return Int16;
}
template <>
inline CDataType
GetDType<int32_t>() {
return Int32;
}
template <>
inline CDataType
GetDType<int64_t>() {
return Int64;
}
template <>
inline CDataType
GetDType<float>() {
return Float;
}
template <>
inline CDataType
GetDType<double>() {
return Double;
}
template <>
inline CDataType
GetDType<std::string>() {
return VarChar;
}
} // namespace milvus
......@@ -36,6 +36,28 @@ enum ErrorCode {
IllegalArgument = 5,
};
// pure C don't support that we use schemapb.DataType directly.
// Note: the value of all enumerations must match the corresponding schemapb.DataType.
// TODO: what if there are increments in schemapb.DataType.
enum CDataType {
None = 0,
Bool = 1,
Int8 = 2,
Int16 = 3,
Int32 = 4,
Int64 = 5,
Float = 10,
Double = 11,
String = 20,
VarChar = 21,
BinaryVector = 100,
FloatVector = 101,
};
typedef enum CDataType CDataType;
typedef struct CStatus {
int error_code;
const char* error_msg;
......
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License
aux_source_directory( ${MILVUS_ENGINE_SRC}/index INDEX_FILES )
add_library( milvus_index STATIC ${INDEX_FILES} )
target_link_libraries(milvus_index
knowhere
milvus_proto
)
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "index/Index.h"
namespace milvus::scalar {
void
dummy() {
}
} // namespace milvus::scalar
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <memory>
#include <knowhere/index/Index.h>
#include <knowhere/common/Dataset.h>
#include <knowhere/index/structured_index_simple/StructuredIndex.h>
namespace milvus::scalar {
using Index = milvus::knowhere::Index;
using IndexPtr = std::unique_ptr<Index>;
using BinarySet = knowhere::BinarySet;
using Config = knowhere::Config;
using DatasetPtr = knowhere::DatasetPtr;
using OperatorType = knowhere::scalar::OperatorType;
class IndexBase : public Index {
virtual void
Build(const DatasetPtr& dataset) = 0;
};
using IndexBasePtr = std::unique_ptr<IndexBase>;
} // namespace milvus::scalar
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <string>
#include "ScalarIndexSort.h"
namespace milvus::scalar {
template <typename T>
inline ScalarIndexPtr<T>
IndexFactory::CreateIndex(std::string index_type) {
return CreateScalarIndexSort<T>();
}
} // namespace milvus::scalar
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "index/IndexFactory.h"
#include "index/ScalarIndexSort.h"
namespace milvus::scalar {
IndexBasePtr
IndexFactory::CreateIndex(CDataType dtype, std::string index_type) {
switch (dtype) {
case Bool:
return CreateIndex<bool>(index_type);
case Int8:
return CreateIndex<int8_t>(index_type);
case Int16:
return CreateIndex<int16_t>(index_type);
case Int32:
return CreateIndex<int32_t>(index_type);
case Int64:
return CreateIndex<int64_t>(index_type);
case Float:
return CreateIndex<float>(index_type);
case Double:
return CreateIndex<double>(index_type);
case String:
case VarChar:
return CreateIndex<std::string>(index_type);
case None:
case BinaryVector:
case FloatVector:
default:
throw std::invalid_argument(std::string("invalid data type: ") + std::to_string(dtype));
}
}
} // namespace milvus::scalar
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <utils/Types.h>
#include "index/Index.h"
#include "common/type_c.h"
#include "ScalarIndex.h"
#include <string>
namespace milvus::scalar {
class IndexFactory {
public:
IndexFactory() = default;
IndexFactory(const IndexFactory&) = delete;
IndexFactory
operator=(const IndexFactory&) = delete;
public:
static IndexFactory&
GetInstance() {
// thread-safe enough after c++ 11
static IndexFactory instance;
return instance;
}
IndexBasePtr
CreateIndex(CDataType dtype, std::string index_type);
template <typename T>
ScalarIndexPtr<T>
CreateIndex(std::string index_type);
};
} // namespace milvus::scalar
#include "index/IndexFactory-inl.h"
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
namespace milvus::scalar {
template <typename T>
struct IndexStructure {
IndexStructure() : a_(0), idx_(0) {
}
explicit IndexStructure(const T a) : a_(a), idx_(0) {
}
IndexStructure(const T a, const size_t idx) : a_(a), idx_(idx) {
}
bool
operator<(const IndexStructure& b) const {
return a_ < b.a_;
}
bool
operator<=(const IndexStructure& b) const {
return a_ <= b.a_;
}
bool
operator>(const IndexStructure& b) const {
return a_ > b.a_;
}
bool
operator>=(const IndexStructure& b) const {
return a_ >= b.a_;
}
bool
operator==(const IndexStructure& b) const {
return a_ == b.a_;
}
T a_;
size_t idx_;
};
} // namespace milvus::scalar
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <map>
#include <memory>
#include <string>
#include <boost/dynamic_bitset.hpp>
#include "index/Index.h"
namespace milvus::scalar {
using TargetBitmap = boost::dynamic_bitset<>;
using TargetBitmapPtr = std::unique_ptr<TargetBitmap>;
template <typename T>
class ScalarIndex : public IndexBase {
public:
virtual void
Build(size_t n, const T* values) = 0;
virtual const TargetBitmapPtr
In(size_t n, const T* values) = 0;
virtual const TargetBitmapPtr
NotIn(size_t n, const T* values) = 0;
virtual const TargetBitmapPtr
Range(T value, OperatorType op) = 0;
virtual const TargetBitmapPtr
Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) = 0;
};
template <typename T>
using ScalarIndexPtr = std::unique_ptr<ScalarIndex<T>>;
} // namespace milvus::scalar
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <algorithm>
#include <memory>
#include <utility>
#include <pb/schema.pb.h>
#include <vector>
#include <string>
#include "knowhere/common/Log.h"
namespace milvus::scalar {
template <typename T>
inline ScalarIndexSort<T>::ScalarIndexSort() : is_built_(false), data_() {
}
template <typename T>
inline ScalarIndexSort<T>::ScalarIndexSort(const size_t n, const T* values) : is_built_(false) {
ScalarIndexSort<T>::Build(n, values);
}
template <typename T>
inline void
ScalarIndexSort<T>::Build(const DatasetPtr& dataset) {
auto size = dataset->Get<int64_t>(knowhere::meta::ROWS);
auto data = dataset->Get<const void*>(knowhere::meta::TENSOR);
Build(size, reinterpret_cast<const T*>(data));
}
template <typename T>
inline void
ScalarIndexSort<T>::Build(const size_t n, const T* values) {
data_.reserve(n);
T* p = const_cast<T*>(values);
for (size_t i = 0; i < n; ++i) {
data_.emplace_back(IndexStructure(*p++, i));
}
build();
}
template <typename T>
inline void
ScalarIndexSort<T>::build() {
if (is_built_)
return;
if (data_.size() == 0) {
// todo: throw an exception
throw std::invalid_argument("ScalarIndexSort cannot build null values!");
}
std::sort(data_.begin(), data_.end());
is_built_ = true;
}
template <typename T>
inline BinarySet
ScalarIndexSort<T>::Serialize(const Config& config) {
if (!is_built_) {
build();
}
auto index_data_size = data_.size() * sizeof(IndexStructure<T>);
std::shared_ptr<uint8_t[]> index_data(new uint8_t[index_data_size]);
memcpy(index_data.get(), data_.data(), index_data_size);
std::shared_ptr<uint8_t[]> index_length(new uint8_t[sizeof(size_t)]);
auto index_size = data_.size();
memcpy(index_length.get(), &index_size, sizeof(size_t));
BinarySet res_set;
res_set.Append("index_data", index_data, index_data_size);
res_set.Append("index_length", index_length, sizeof(size_t));
return res_set;
}
template <typename T>
inline void
ScalarIndexSort<T>::Load(const BinarySet& index_binary) {
size_t index_size;
auto index_length = index_binary.GetByName("index_length");
memcpy(&index_size, index_length->data.get(), (size_t)index_length->size);
auto index_data = index_binary.GetByName("index_data");
data_.resize(index_size);
memcpy(data_.data(), index_data->data.get(), (size_t)index_data->size);
is_built_ = true;
}
template <typename T>
inline const TargetBitmapPtr
ScalarIndexSort<T>::In(const size_t n, const T* values) {
if (!is_built_) {
build();
}
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
for (size_t i = 0; i < n; ++i) {
auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
for (; lb < ub; ++lb) {
if (lb->a_ != *(values + i)) {
std::cout << "error happens in ScalarIndexSort<T>::In, experted value is: " << *(values + i)
<< ", but real value is: " << lb->a_;
}
bitset->set(lb->idx_);
}
}
return bitset;
}
template <typename T>
inline const TargetBitmapPtr
ScalarIndexSort<T>::NotIn(const size_t n, const T* values) {
if (!is_built_) {
build();
}
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
bitset->set();
for (size_t i = 0; i < n; ++i) {
auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(*(values + i)));
for (; lb < ub; ++lb) {
if (lb->a_ != *(values + i)) {
std::cout << "error happens in ScalarIndexSort<T>::NotIn, experted value is: " << *(values + i)
<< ", but real value is: " << lb->a_;
}
bitset->reset(lb->idx_);
}
}
return bitset;
}
template <typename T>
inline const TargetBitmapPtr
ScalarIndexSort<T>::Range(const T value, const OperatorType op) {
if (!is_built_) {
build();
}
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
auto lb = data_.begin();
auto ub = data_.end();
switch (op) {
case OperatorType::LT:
ub = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
break;
case OperatorType::LE:
ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
break;
case OperatorType::GT:
lb = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
break;
case OperatorType::GE:
lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(value));
break;
default:
throw std::invalid_argument(std::string("Invalid OperatorType: ") + std::to_string((int)op) + "!");
}
for (; lb < ub; ++lb) {
bitset->set(lb->idx_);
}
return bitset;
}
template <typename T>
inline const TargetBitmapPtr
ScalarIndexSort<T>::Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) {
if (!is_built_) {
build();
}
TargetBitmapPtr bitset = std::make_unique<TargetBitmap>(data_.size());
if (lower_bound_value > upper_bound_value) {
std::swap(lower_bound_value, upper_bound_value);
std::swap(lb_inclusive, ub_inclusive);
}
auto lb = data_.begin();
auto ub = data_.end();
if (lb_inclusive) {
lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(lower_bound_value));
} else {
lb = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(lower_bound_value));
}
if (ub_inclusive) {
ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure<T>(upper_bound_value));
} else {
ub = std::lower_bound(data_.begin(), data_.end(), IndexStructure<T>(upper_bound_value));
}
for (; lb < ub; ++lb) {
bitset->set(lb->idx_);
}
return bitset;
}
template <>
inline void
ScalarIndexSort<std::string>::Build(const milvus::scalar::DatasetPtr& dataset) {
auto size = dataset->Get<int64_t>(knowhere::meta::ROWS);
auto data = dataset->Get<const void*>(knowhere::meta::TENSOR);
proto::schema::StringArray arr;
arr.ParseFromArray(data, size);
// TODO: optimize here. avoid memory copy.
std::vector<std::string> vecs{arr.data().begin(), arr.data().end()};
Build(arr.data().size(), vecs.data());
}
template <>
inline BinarySet
ScalarIndexSort<std::string>::Serialize(const Config& config) {
BinarySet res_set;
auto data = this->GetData();
for (const auto& record : data) {
auto idx = record.idx_;
auto str = record.a_;
std::shared_ptr<uint8_t[]> content(new uint8_t[str.length()]);
memcpy(content.get(), str.c_str(), str.length());
res_set.Append(std::to_string(idx), content, str.length());
}
return res_set;
}
template <>
inline void
ScalarIndexSort<std::string>::Load(const BinarySet& index_binary) {
std::vector<std::string> vecs;
for (const auto& [k, v] : index_binary.binary_map_) {
std::string str(reinterpret_cast<const char*>(v->data.get()), v->size);
vecs.emplace_back(str);
}
Build(vecs.size(), vecs.data());
}
} // namespace milvus::scalar
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "knowhere/common/Exception.h"
#include "index/IndexStructure.h"
#include <string>
#include "index/ScalarIndex.h"
namespace milvus::scalar {
template <typename T>
class ScalarIndexSort : public ScalarIndex<T> {
static_assert(std::is_fundamental_v<T> || std::is_same_v<T, std::string>);
public:
ScalarIndexSort();
ScalarIndexSort(size_t n, const T* values);
BinarySet
Serialize(const Config& config) override;
void
Load(const BinarySet& index_binary) override;
void
Build(const DatasetPtr& dataset) override;
void
Build(size_t n, const T* values) override;
void
build();
const TargetBitmapPtr
In(size_t n, const T* values) override;
const TargetBitmapPtr
NotIn(size_t n, const T* values) override;
const TargetBitmapPtr
Range(T value, OperatorType op) override;
const TargetBitmapPtr
Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) override;
const std::vector<IndexStructure<T>>&
GetData() {
return data_;
}
int64_t
Size() override {
return (int64_t)data_.size();
}
bool
IsBuilt() const {
return is_built_;
}
private:
bool is_built_;
std::vector<IndexStructure<T>> data_;
};
template <typename T>
using ScalarIndexSortPtr = std::unique_ptr<ScalarIndexSort<T>>;
} // namespace milvus::scalar
#include "index/ScalarIndexSort-inl.h"
namespace milvus::scalar {
template <typename T>
inline ScalarIndexSortPtr<T>
CreateScalarIndexSort() {
return std::make_unique<ScalarIndexSort<T>>();
}
} // namespace milvus::scalar
......@@ -39,7 +39,7 @@ class IndexFactory {
}
IndexCreatorBasePtr
CreateIndex(DataType dtype, const char* type_params, const char* index_params) {
CreateIndex(CDataType dtype, const char* type_params, const char* index_params) {
auto real_dtype = proto::schema::DataType(dtype);
auto invalid_dtype_msg = std::string("invalid data type: ") + std::to_string(real_dtype);
......
......@@ -22,9 +22,10 @@
#include "indexbuilder/VecIndexCreator.h"
#include "indexbuilder/index_c.h"
#include "indexbuilder/IndexFactory.h"
#include "common/type_c.h"
CStatus
CreateIndex(DataType dtype,
CreateIndex(enum CDataType dtype,
const char* serialized_type_params,
const char* serialized_index_params,
CIndex* res_index) {
......
......@@ -21,7 +21,7 @@ extern "C" {
#include "indexbuilder/type_c.h"
CStatus
CreateIndex(enum DataType dtype,
CreateIndex(enum CDataType dtype,
const char* serialized_type_params,
const char* serialized_index_params,
CIndex* res_index);
......
......@@ -11,25 +11,7 @@
#pragma once
// pure C don't support that we use schemapb.DataType directly.
// Note: the value of all enumerations must match the corresponding schemapb.DataType.
// TODO: what if there are increments in schemapb.DataType.
enum DataType {
None = 0,
Bool = 1,
Int8 = 2,
Int16 = 3,
Int32 = 4,
Int64 = 5,
Float = 10,
Double = 11,
String = 20,
BinaryVector = 100,
FloatVector = 101,
};
#include "common/type_c.h"
typedef void* CIndex;
typedef void* CIndexQueryResult;
......@@ -41,6 +41,7 @@ if (LINUX)
test_utils.cpp
test_scalar_index_creator.cpp
test_index_c_api.cpp
test_index.cpp
)
# check if memory leak exists in index builder
set(INDEX_BUILDER_TEST_FILES
......@@ -58,6 +59,7 @@ if (LINUX)
gtest_main
milvus_segcore
milvus_indexbuilder
milvus_index
log
pthread
)
......@@ -88,6 +90,7 @@ target_link_libraries(all_tests
gtest
milvus_segcore
milvus_indexbuilder
milvus_index
log
pthread
)
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
#include <map>
#include <tuple>
#include <knowhere/index/vector_index/helpers/IndexParameter.h>
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
#include <knowhere/index/vector_index/ConfAdapterMgr.h>
#include <knowhere/archive/KnowhereConfig.h>
#include "pb/index_cgo_msg.pb.h"
#define private public
#include "index/IndexFactory.h"
#include "index/Index.h"
#include "index/ScalarIndex.h"
#include "index/ScalarIndexSort.h"
#include "common/CDataType.h"
#include "test_utils/indexbuilder_test_utils.h"
constexpr int64_t nb = 100;
namespace indexcgo = milvus::proto::indexcgo;
namespace schemapb = milvus::proto::schema;
using milvus::scalar::ScalarIndexPtr;
namespace {
template <typename T>
inline std::vector<std::string>
GetIndexTypes() {
return std::vector<std::string>{"inverted_index"};
}
template <>
inline std::vector<std::string>
GetIndexTypes<std::string>() {
return std::vector<std::string>{"marisa-trie"};
}
template <typename T>
inline void
assert_in(const ScalarIndexPtr<T>& index, const std::vector<T>& arr) {
// hard to compare floating point value.
if (std::is_floating_point_v<T>) {
return;
}
auto bitset1 = index->In(arr.size(), arr.data());
ASSERT_EQ(arr.size(), bitset1->size());
ASSERT_TRUE(bitset1->any());
auto test = std::make_unique<T>(arr[arr.size() - 1] + 1);
auto bitset2 = index->In(1, test.get());
ASSERT_EQ(arr.size(), bitset2->size());
ASSERT_TRUE(bitset2->none());
}
template <typename T>
inline void
assert_not_in(const ScalarIndexPtr<T>& index, const std::vector<T>& arr) {
auto bitset1 = index->NotIn(arr.size(), arr.data());
ASSERT_EQ(arr.size(), bitset1->size());
ASSERT_TRUE(bitset1->none());
auto test = std::make_unique<T>(arr[arr.size() - 1] + 1);
auto bitset2 = index->NotIn(1, test.get());
ASSERT_EQ(arr.size(), bitset2->size());
ASSERT_TRUE(bitset2->any());
}
template <typename T>
inline void
assert_range(const ScalarIndexPtr<T>& index, const std::vector<T>& arr) {
auto test_min = arr[0];
auto test_max = arr[arr.size() - 1];
auto bitset1 = index->Range(test_min - 1, OperatorType::GT);
ASSERT_EQ(arr.size(), bitset1->size());
ASSERT_TRUE(bitset1->any());
auto bitset2 = index->Range(test_min, OperatorType::GE);
ASSERT_EQ(arr.size(), bitset2->size());
ASSERT_TRUE(bitset2->any());
auto bitset3 = index->Range(test_max + 1, OperatorType::LT);
ASSERT_EQ(arr.size(), bitset3->size());
ASSERT_TRUE(bitset3->any());
auto bitset4 = index->Range(test_max, OperatorType::LE);
ASSERT_EQ(arr.size(), bitset4->size());
ASSERT_TRUE(bitset4->any());
auto bitset5 = index->Range(test_min, true, test_max, true);
ASSERT_EQ(arr.size(), bitset5->size());
ASSERT_TRUE(bitset5->any());
}
template <>
inline void
assert_in(const ScalarIndexPtr<std::string>& index, const std::vector<std::string>& arr) {
auto bitset1 = index->In(arr.size(), arr.data());
ASSERT_EQ(arr.size(), bitset1->size());
ASSERT_TRUE(bitset1->any());
}
template <>
inline void
assert_not_in(const ScalarIndexPtr<std::string>& index, const std::vector<std::string>& arr) {
auto bitset1 = index->NotIn(arr.size(), arr.data());
ASSERT_EQ(arr.size(), bitset1->size());
ASSERT_TRUE(bitset1->none());
}
template <>
inline void
assert_range(const ScalarIndexPtr<std::string>& index, const std::vector<std::string>& arr) {
auto test_min = arr[0];
auto test_max = arr[arr.size() - 1];
auto bitset2 = index->Range(test_min, OperatorType::GE);
ASSERT_EQ(arr.size(), bitset2->size());
ASSERT_TRUE(bitset2->any());
auto bitset4 = index->Range(test_max, OperatorType::LE);
ASSERT_EQ(arr.size(), bitset4->size());
ASSERT_TRUE(bitset4->any());
auto bitset5 = index->Range(test_min, true, test_max, true);
ASSERT_EQ(arr.size(), bitset5->size());
ASSERT_TRUE(bitset5->any());
}
} // namespace
template <typename T>
class TypedScalarIndexTest : public ::testing::Test {
protected:
// void
// SetUp() override {
// }
// void
// TearDown() override {
// }
};
TYPED_TEST_CASE_P(TypedScalarIndexTest);
TYPED_TEST_P(TypedScalarIndexTest, Dummy) {
using T = TypeParam;
std::cout << typeid(T()).name() << std::endl;
std::cout << milvus::GetDType<T>() << std::endl;
}
TYPED_TEST_P(TypedScalarIndexTest, Constructor) {
using T = TypeParam;
auto dtype = milvus::GetDType<T>();
auto index_types = GetIndexTypes<T>();
for (const auto& index_type : index_types) {
auto index = milvus::scalar::IndexFactory::GetInstance().CreateIndex(dtype, index_type);
}
}
TYPED_TEST_P(TypedScalarIndexTest, In) {
using T = TypeParam;
auto dtype = milvus::GetDType<T>();
auto index_types = GetIndexTypes<T>();
for (const auto& index_type : index_types) {
auto index = milvus::scalar::IndexFactory::GetInstance().CreateIndex<T>(index_type);
auto arr = GenArr<T>(nb);
index->Build(nb, arr.data());
assert_in<T>(index, arr);
}
}
TYPED_TEST_P(TypedScalarIndexTest, NotIn) {
using T = TypeParam;
auto dtype = milvus::GetDType<T>();
auto index_types = GetIndexTypes<T>();
for (const auto& index_type : index_types) {
auto index = milvus::scalar::IndexFactory::GetInstance().CreateIndex<T>(index_type);
auto arr = GenArr<T>(nb);
index->Build(nb, arr.data());
assert_not_in<T>(index, arr);
}
}
TYPED_TEST_P(TypedScalarIndexTest, Range) {
using T = TypeParam;
auto dtype = milvus::GetDType<T>();
auto index_types = GetIndexTypes<T>();
for (const auto& index_type : index_types) {
auto index = milvus::scalar::IndexFactory::GetInstance().CreateIndex<T>(index_type);
auto arr = GenArr<T>(nb);
index->Build(nb, arr.data());
assert_range<T>(index, arr);
}
}
TYPED_TEST_P(TypedScalarIndexTest, Codec) {
using T = TypeParam;
auto dtype = milvus::GetDType<T>();
auto index_types = GetIndexTypes<T>();
for (const auto& index_type : index_types) {
auto index = milvus::scalar::IndexFactory::GetInstance().CreateIndex<T>(index_type);
auto arr = GenArr<T>(nb);
index->Build(nb, arr.data());
auto binary_set = index->Serialize(nullptr);
auto copy_index = milvus::scalar::IndexFactory::GetInstance().CreateIndex<T>(index_type);
copy_index->Load(binary_set);
assert_in<T>(copy_index, arr);
assert_not_in<T>(copy_index, arr);
assert_range<T>(copy_index, arr);
}
}
// TODO: it's easy to overflow for int8_t. Design more reasonable ut.
using ArithmeticT = ::testing::Types<int8_t, int16_t, int32_t, int64_t, float, double, std::string>;
REGISTER_TYPED_TEST_CASE_P(TypedScalarIndexTest, Dummy, Constructor, In, NotIn, Range, Codec);
INSTANTIATE_TYPED_TEST_CASE_P(ArithmeticCheck, TypedScalarIndexTest, ArithmeticT);
// TODO: bool.
......@@ -26,6 +26,7 @@
#include "test_utils/indexbuilder_test_utils.h"
#include "indexbuilder/ScalarIndexCreator.h"
#include "indexbuilder/IndexFactory.h"
#include "common/type_c.h"
constexpr int NB = 10;
......@@ -44,7 +45,7 @@ TEST(FloatVecIndex, All) {
auto dataset = GenDataset(NB, metric_type, false);
auto xb_data = dataset.get_col<float>(0);
DataType dtype = FloatVector;
CDataType dtype = FloatVector;
CIndex index;
CStatus status;
CBinarySet binary_set;
......@@ -95,7 +96,7 @@ TEST(BinaryVecIndex, All) {
auto dataset = GenDataset(NB, metric_type, true);
auto xb_data = dataset.get_col<uint8_t>(0);
DataType dtype = BinaryVector;
CDataType dtype = BinaryVector;
CIndex index;
CStatus status;
CBinarySet binary_set;
......@@ -147,7 +148,7 @@ TEST(CBoolIndexTest, All) {
auto type_params_str = generate_type_params(type_params);
auto index_params_str = generate_index_params(index_params);
DataType dtype = Bool;
CDataType dtype = Bool;
CIndex index;
CStatus status;
CBinarySet binary_set;
......@@ -198,7 +199,7 @@ TEST(CInt64IndexTest, All) {
auto type_params_str = generate_type_params(type_params);
auto index_params_str = generate_index_params(index_params);
DataType dtype = Int64;
CDataType dtype = Int64;
CIndex index;
CStatus status;
CBinarySet binary_set;
......@@ -248,7 +249,7 @@ TEST(CStringIndexTest, All) {
auto type_params_str = generate_type_params(type_params);
auto index_params_str = generate_index_params(index_params);
DataType dtype = String;
CDataType dtype = String;
CIndex index;
CStatus status;
CBinarySet binary_set;
......
......@@ -135,7 +135,7 @@ assert_range(const std::unique_ptr<ScalarIndexCreator<std::string>>& creator, co
} // namespace
template <typename T>
class TypedScalarIndexTest : public ::testing::Test {
class TypedScalarIndexCreatorTest : public ::testing::Test {
protected:
// void
// SetUp() override {
......@@ -149,15 +149,15 @@ class TypedScalarIndexTest : public ::testing::Test {
// TODO: it's easy to overflow for int8_t. Design more reasonable ut.
using ArithmeticT = ::testing::Types<int8_t, int16_t, int32_t, int64_t, float, double>;
TYPED_TEST_CASE_P(TypedScalarIndexTest);
TYPED_TEST_CASE_P(TypedScalarIndexCreatorTest);
TYPED_TEST_P(TypedScalarIndexTest, Dummy) {
TYPED_TEST_P(TypedScalarIndexCreatorTest, Dummy) {
using T = TypeParam;
std::cout << typeid(T()).name() << std::endl;
PrintMapParams(GenParams<T>());
}
TYPED_TEST_P(TypedScalarIndexTest, Constructor) {
TYPED_TEST_P(TypedScalarIndexCreatorTest, Constructor) {
using T = TypeParam;
for (const auto& tp : GenParams<T>()) {
auto type_params = tp.first;
......@@ -169,7 +169,7 @@ TYPED_TEST_P(TypedScalarIndexTest, Constructor) {
}
}
TYPED_TEST_P(TypedScalarIndexTest, In) {
TYPED_TEST_P(TypedScalarIndexCreatorTest, In) {
using T = TypeParam;
for (const auto& tp : GenParams<T>()) {
auto type_params = tp.first;
......@@ -184,7 +184,7 @@ TYPED_TEST_P(TypedScalarIndexTest, In) {
}
}
TYPED_TEST_P(TypedScalarIndexTest, NotIn) {
TYPED_TEST_P(TypedScalarIndexCreatorTest, NotIn) {
using T = TypeParam;
for (const auto& tp : GenParams<T>()) {
auto type_params = tp.first;
......@@ -199,7 +199,7 @@ TYPED_TEST_P(TypedScalarIndexTest, NotIn) {
}
}
TYPED_TEST_P(TypedScalarIndexTest, Range) {
TYPED_TEST_P(TypedScalarIndexCreatorTest, Range) {
using T = TypeParam;
for (const auto& tp : GenParams<T>()) {
auto type_params = tp.first;
......@@ -214,7 +214,7 @@ TYPED_TEST_P(TypedScalarIndexTest, Range) {
}
}
TYPED_TEST_P(TypedScalarIndexTest, Codec) {
TYPED_TEST_P(TypedScalarIndexCreatorTest, Codec) {
using T = TypeParam;
for (const auto& tp : GenParams<T>()) {
auto type_params = tp.first;
......@@ -238,9 +238,9 @@ TYPED_TEST_P(TypedScalarIndexTest, Codec) {
}
}
REGISTER_TYPED_TEST_CASE_P(TypedScalarIndexTest, Dummy, Constructor, In, NotIn, Range, Codec);
REGISTER_TYPED_TEST_CASE_P(TypedScalarIndexCreatorTest, Dummy, Constructor, In, NotIn, Range, Codec);
INSTANTIATE_TYPED_TEST_CASE_P(ArithmeticCheck, TypedScalarIndexTest, ArithmeticT);
INSTANTIATE_TYPED_TEST_CASE_P(ArithmeticCheck, TypedScalarIndexCreatorTest, ArithmeticT);
class BoolIndexTest : public ::testing::Test {
protected:
......
......@@ -368,8 +368,8 @@ generate_index_params(const MapParams& m) {
}
// TODO: std::is_arithmetic_v, hard to compare float point value. std::is_integral_v.
template <typename T, typename = typename std::enable_if_t<std::is_arithmetic_v<T>>>
inline auto
template <typename T, typename = typename std::enable_if_t<std::is_arithmetic_v<T> || std::is_same_v<T, std::string>>>
inline std::vector<T>
GenArr(int64_t n) {
auto max_i8 = std::numeric_limits<int8_t>::max() - 1;
std::vector<T> arr;
......@@ -394,6 +394,12 @@ GenStrArr(int64_t n) {
return arr;
}
template <>
inline std::vector<std::string>
GenArr<std::string>(int64_t n) {
return GenStrArr(n);
}
template <typename T, typename = typename std::enable_if_t<std::is_arithmetic_v<T>>>
inline std::vector<ScalarTestParams>
GenParams() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册