提交 9577da9f 编写于 作者: X XuanYang-cn 提交者: yefu.chen

Add writenode main

Signed-off-by: NXuanYang-cn <xuan.yang@zilliz.com>
上级 eb871718
......@@ -71,6 +71,8 @@ build-go: build-cpp
@mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="1" && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/querynode $(PWD)/cmd/querynode/query_node.go 1>/dev/null
@echo "Building indexbuilder ..."
@mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="1" && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/indexbuilder $(PWD)/cmd/indexbuilder/indexbuilder.go 1>/dev/null
@echo "Building write node ..."
@mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="1" && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/writenode $(PWD)/cmd/writenode/writenode.go 1>/dev/null
build-cpp:
@(env bash $(PWD)/scripts/core_build.sh -f "$(CUSTOM_THIRDPARTY_PATH)")
......@@ -104,6 +106,7 @@ install: all
@mkdir -p $(GOPATH)/bin && cp -f $(PWD)/bin/querynode $(GOPATH)/bin/querynode
@mkdir -p $(GOPATH)/bin && cp -f $(PWD)/bin/master $(GOPATH)/bin/master
@mkdir -p $(GOPATH)/bin && cp -f $(PWD)/bin/proxy $(GOPATH)/bin/proxy
@mkdir -p $(GOPATH)/bin && cp -f $(PWD)/bin/writenode $(GOPATH)/bin/writenode
@mkdir -p $(LIBRARY_PATH) && cp -f $(PWD)/internal/core/output/lib/* $(LIBRARY_PATH)
@echo "Installation successful."
......@@ -114,3 +117,4 @@ clean:
@rm -rvf querynode
@rm -rvf master
@rm -rvf proxy
@rm -rvf writenode
......@@ -22,7 +22,7 @@ namespace milvus {
using boost::algorithm::to_lower_copy;
namespace Metric = knowhere::Metric;
static const auto metric_bimap = [] {
static auto map = [] {
boost::bimap<std::string, MetricType> mapping;
using pos = boost::bimap<std::string, MetricType>::value_type;
mapping.insert(pos(to_lower_copy(std::string(Metric::L2)), MetricType::METRIC_L2));
......@@ -38,15 +38,8 @@ static const auto metric_bimap = [] {
MetricType
GetMetricType(const std::string& type_name) {
auto real_name = to_lower_copy(type_name);
AssertInfo(metric_bimap.left.count(real_name), "metric type not found: (" + type_name + ")");
return metric_bimap.left.at(real_name);
}
std::string
MetricTypeToName(MetricType metric_type) {
AssertInfo(metric_bimap.right.count(metric_type),
"metric_type enum(" + std::to_string((int)metric_type) + ") not found");
return metric_bimap.right.at(metric_type);
AssertInfo(map.left.count(real_name), "metric type not found: (" + type_name + ")");
return map.left.at(real_name);
}
} // namespace milvus
......@@ -23,10 +23,8 @@ using engine::FieldElementType;
using engine::QueryResult;
using MetricType = faiss::MetricType;
MetricType
faiss::MetricType
GetMetricType(const std::string& type);
std::string
MetricTypeToName(MetricType metric_type);
// NOTE: dependent type
// used at meta-template programming
......
......@@ -9,7 +9,6 @@ set(MILVUS_QUERY_SRCS
visitors/ExecExprVisitor.cpp
Plan.cpp
Search.cpp
SearchOnSealed.cpp
BruteForceSearch.cpp
)
add_library(milvus_query ${MILVUS_QUERY_SRCS})
......
......@@ -194,12 +194,9 @@ Parser::ParseVecNode(const Json& out_body) {
auto topK = vec_info["topk"];
AssertInfo(topK > 0, "topK must greater than 0");
AssertInfo(topK < 16384, "topK is too large");
auto field_offset_opt = schema.get_offset(field_name);
AssertInfo(field_offset_opt.has_value(), "field_name(" + field_name + ") not found");
auto field_meta = schema.operator[](field_name);
auto vec_node = [&]() -> std::unique_ptr<VectorPlanNode> {
auto field_meta = schema.operator[](field_name);
auto data_type = field_meta.get_data_type();
if (data_type == DataType::VECTOR_FLOAT) {
return std::make_unique<FloatVectorANNS>();
......@@ -211,7 +208,6 @@ Parser::ParseVecNode(const Json& out_body) {
vec_node->query_info_.metric_type_ = vec_info.at("metric_type");
vec_node->query_info_.search_params_ = vec_info.at("params");
vec_node->query_info_.field_id_ = field_name;
vec_node->query_info_.field_offset_ = field_offset_opt.value();
vec_node->placeholder_tag_ = vec_info.at("query");
auto tag = vec_node->placeholder_tag_;
AssertInfo(!tag2field_.count(tag), "duplicated placeholder tag");
......
......@@ -41,7 +41,6 @@ using PlanNodePtr = std::unique_ptr<PlanNode>;
struct QueryInfo {
int64_t topK_;
FieldId field_id_;
int64_t field_offset_;
std::string metric_type_; // TODO: use enum
nlohmann::json search_params_;
};
......
......@@ -29,6 +29,7 @@ create_bitmap_view(std::optional<const BitmapSimple*> bitmaps_opt, int64_t chunk
auto src_vec = ~bitmaps.at(chunk_id);
auto dst = std::make_shared<faiss::ConcurrentBitset>(src_vec.size());
auto iter = reinterpret_cast<BitmapChunk::block_type*>(dst->mutable_data());
boost::to_block_range(src_vec, iter);
return dst;
}
......@@ -130,6 +131,16 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
results.topK_ = topK;
results.num_queries_ = num_queries;
// TODO: deprecated code begin
final_uids = results.internal_seg_offsets_;
for (auto& id : final_uids) {
if (id == -1) {
continue;
}
id = record.uids_[id];
}
results.result_ids_ = std::move(final_uids);
// TODO: deprecated code end
return Status::OK();
}
......@@ -208,6 +219,16 @@ BinaryQueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
results.topK_ = topK;
results.num_queries_ = num_queries;
// TODO: deprecated code begin
final_uids = results.internal_seg_offsets_;
for (auto& id : final_uids) {
if (id == -1) {
continue;
}
id = record.uids_[id];
}
results.result_ids_ = std::move(final_uids);
// TODO: deprecated code end
return Status::OK();
}
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
//
// Created by mike on 12/26/20.
//
#include "query/SearchOnSealed.h"
#include <knowhere/index/vector_index/VecIndex.h>
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
namespace milvus::query {
aligned_vector<uint8_t>
AssembleBitmap(const BitmapSimple& bitmap_simple) {
int64_t N = 0;
for (auto& bitmap : bitmap_simple) {
N += bitmap.size();
}
aligned_vector<uint8_t> result(upper_align(upper_div(N, 8), sizeof(BitmapChunk::block_type)));
auto acc_byte_count = 0;
for (auto& bitmap_raw : bitmap_simple) {
auto bitmap = ~bitmap_raw;
auto size = bitmap.size();
Assert(size % 8 == 0);
auto byte_count = size / 8;
auto iter = reinterpret_cast<BitmapChunk::block_type*>(result.data() + acc_byte_count);
boost::to_block_range(bitmap, iter);
acc_byte_count += byte_count;
}
return result;
}
void
SearchOnSealed(const Schema& schema,
const segcore::SealedIndexingRecord& record,
const QueryInfo& query_info,
const float* query_data,
int64_t num_queries,
Timestamp timestamp,
std::optional<const BitmapSimple*> bitmaps_opt,
QueryResult& result) {
auto topK = query_info.topK_;
auto field_offset = query_info.field_offset_;
auto& field = schema[field_offset];
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
auto dim = field.get_dim();
Assert(record.test_readiness(field_offset));
auto indexing_entry = record.get_entry(field_offset);
Assert(indexing_entry->metric_type_ == GetMetricType(query_info.metric_type_));
auto final = [&] {
auto ds = knowhere::GenDataset(num_queries, dim, query_data);
auto conf = query_info.search_params_;
conf[milvus::knowhere::meta::TOPK] = query_info.topK_;
conf[milvus::knowhere::Metric::TYPE] = MetricTypeToName(indexing_entry->metric_type_);
if (bitmaps_opt.has_value()) {
auto bitmap = AssembleBitmap(*bitmaps_opt.value());
return indexing_entry->indexing_->Query(ds, conf, faiss::BitsetView(bitmap.data(), num_queries));
} else {
return indexing_entry->indexing_->Query(ds, conf, nullptr);
}
}();
auto ids = final->Get<idx_t*>(knowhere::meta::IDS);
auto distances = final->Get<float*>(knowhere::meta::DISTANCE);
auto total_num = num_queries * topK;
result.internal_seg_offsets_.resize(total_num);
result.result_distances_.resize(total_num);
result.num_queries_ = num_queries;
result.topK_ = topK;
std::copy_n(ids, total_num, result.internal_seg_offsets_.data());
std::copy_n(distances, total_num, result.result_distances_.data());
}
} // namespace milvus::query
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "segcore/SealedIndexingRecord.h"
#include "query/PlanNode.h"
#include "query/Search.h"
namespace milvus::query {
void
SearchOnSealed(const Schema& schema,
const segcore::SealedIndexingRecord& record,
const QueryInfo& query_info,
const float* query_data,
int64_t num_queries,
Timestamp timestamp,
std::optional<const BitmapSimple*> bitmaps_opt,
QueryResult& result);
} // namespace milvus::query
......@@ -17,7 +17,6 @@
#include "segcore/SegmentSmallIndex.h"
#include "query/generated/ExecExprVisitor.h"
#include "query/Search.h"
#include "query/SearchOnSealed.h"
namespace milvus::query {
......@@ -64,24 +63,13 @@ ExecPlanNodeVisitor::visit(FloatVectorANNS& node) {
auto& ph = placeholder_group_.at(0);
auto src_data = ph.get_blob<float>();
auto num_queries = ph.num_of_queries_;
ExecExprVisitor::RetType bitmap_holder;
std::optional<const ExecExprVisitor::RetType*> bitset_pack;
if (node.predicate_.has_value()) {
bitmap_holder = ExecExprVisitor(*segment).call_child(*node.predicate_.value());
bitset_pack = &bitmap_holder;
}
auto& sealed_indexing = segment->get_sealed_indexing_record();
if (sealed_indexing.test_readiness(node.query_info_.field_offset_)) {
SearchOnSealed(segment->get_schema(), sealed_indexing, node.query_info_, src_data, num_queries, timestamp_,
bitset_pack, ret);
auto bitmap = ExecExprVisitor(*segment).call_child(*node.predicate_.value());
auto ptr = &bitmap;
QueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, ptr, ret);
} else {
QueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, bitset_pack, ret);
QueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, std::nullopt, ret);
}
ret_ = ret;
}
......@@ -95,16 +83,13 @@ ExecPlanNodeVisitor::visit(BinaryVectorANNS& node) {
auto& ph = placeholder_group_.at(0);
auto src_data = ph.get_blob<uint8_t>();
auto num_queries = ph.num_of_queries_;
ExecExprVisitor::RetType bitmap_holder;
std::optional<const ExecExprVisitor::RetType*> bitset_pack;
if (node.predicate_.has_value()) {
bitmap_holder = ExecExprVisitor(*segment).call_child(*node.predicate_.value());
bitset_pack = &bitmap_holder;
auto bitmap = ExecExprVisitor(*segment).call_child(*node.predicate_.value());
auto ptr = &bitmap;
BinaryQueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, ptr, ret);
} else {
BinaryQueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, std::nullopt, ret);
}
BinaryQueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, bitset_pack, ret);
ret_ = ret;
}
......
......@@ -12,9 +12,7 @@ set(SEGCORE_FILES
Reduce.cpp
plan_c.cpp
reduce_c.cpp
load_index_c.cpp
SealedIndexingRecord.cpp
)
load_index_c.cpp)
add_library(milvus_segcore SHARED
${SEGCORE_FILES}
)
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
//
// Created by mike on 12/25/20.
//
#include "segcore/SealedIndexingRecord.h"
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <mutex>
#include <map>
#include <shared_mutex>
#include <utility>
#include <memory>
#include <tbb/concurrent_hash_map.h>
#include "utils/EasyAssert.h"
#include "knowhere/index/vector_index/VecIndex.h"
#include "common/Types.h"
namespace milvus::segcore {
struct SealedIndexingEntry {
MetricType metric_type_;
knowhere::VecIndexPtr indexing_;
};
using SealedIndexingEntryPtr = std::unique_ptr<SealedIndexingEntry>;
struct SealedIndexingRecord {
void
add_entry(int64_t field_offset, SealedIndexingEntryPtr&& ptr) {
std::unique_lock lck(mutex_);
entries_[field_offset] = std::move(ptr);
}
const SealedIndexingEntry*
get_entry(int64_t field_offset) const {
std::shared_lock lck(mutex_);
AssertInfo(entries_.count(field_offset), "field_offset not found");
return entries_.at(field_offset).get();
}
bool
test_readiness(int64_t field_offset) const {
std::shared_lock lck(mutex_);
return entries_.count(field_offset);
}
private:
// field_offset -> SealedIndexingEntry
std::map<int64_t, SealedIndexingEntryPtr> entries_;
mutable std::shared_mutex mutex_;
};
} // namespace milvus::segcore
......@@ -19,7 +19,6 @@
#include "query/deprecated/GeneralQuery.h"
#include "query/Plan.h"
#include "common/LoadIndex.h"
namespace milvus {
namespace segcore {
......@@ -80,9 +79,6 @@ class SegmentBase {
virtual Status
Close() = 0;
virtual Status
LoadIndexing(const LoadIndexInfo& info) = 0;
// // to make all data inserted visible
// // maybe a no-op?
// virtual Status
......
......@@ -84,11 +84,6 @@ class SegmentNaive : public SegmentBase {
Status
BuildIndex(IndexMetaPtr index_meta) override;
Status
LoadIndexing(const LoadIndexInfo& info) override {
PanicInfo("unimplemented");
}
Status
FillTargetEntry(const query::Plan* Plan, QueryResult& results) override {
PanicInfo("unimplemented");
......
......@@ -257,7 +257,6 @@ SegmentSmallIndex::BuildVecIndexImpl(const IndexMeta::Entry& entry) {
Status
SegmentSmallIndex::BuildIndex(IndexMetaPtr remote_index_meta) {
if (remote_index_meta == nullptr) {
PanicInfo("deprecated");
std::cout << "WARN: Null index ptr is detected, use default index" << std::endl;
int dim = 0;
......@@ -286,10 +285,12 @@ SegmentSmallIndex::BuildIndex(IndexMetaPtr remote_index_meta) {
knowhere::IndexMode::MODE_CPU, conf);
remote_index_meta = index_meta;
}
if (record_.ack_responder_.GetAck() < 1024 * 4) {
return Status(SERVER_BUILD_INDEX_ERROR, "too few elements");
}
PanicInfo("unimplemented");
AssertInfo(false, "unimplemented");
return Status::OK();
#if 0
index_meta_ = remote_index_meta;
for (auto& [index_name, entry] : index_meta_->get_entries()) {
......@@ -349,19 +350,11 @@ SegmentSmallIndex::FillTargetEntry(const query::Plan* plan, QueryResult& results
Assert(results.result_offsets_.size() == size);
Assert(results.row_data_.size() == 0);
// TODO: deprecate
results.result_ids_.clear();
results.result_ids_.resize(size);
if (plan->schema_.get_is_auto_id()) {
auto& uids = record_.uids_;
for (int64_t i = 0; i < size; ++i) {
auto seg_offset = results.internal_seg_offsets_[i];
auto row_id = seg_offset == -1 ? -1 : uids[seg_offset];
// TODO: deprecate
results.result_ids_[i] = row_id;
std::vector<char> blob(sizeof(row_id));
memcpy(blob.data(), &row_id, sizeof(row_id));
results.row_data_.emplace_back(std::move(blob));
......@@ -376,10 +369,6 @@ SegmentSmallIndex::FillTargetEntry(const query::Plan* plan, QueryResult& results
for (int64_t i = 0; i < size; ++i) {
auto seg_offset = results.internal_seg_offsets_[i];
auto row_id = seg_offset == -1 ? -1 : uids->operator[](seg_offset);
// TODO: deprecate
results.result_ids_[i] = row_id;
std::vector<char> blob(sizeof(row_id));
memcpy(blob.data(), &row_id, sizeof(row_id));
results.row_data_.emplace_back(std::move(blob));
......@@ -388,20 +377,4 @@ SegmentSmallIndex::FillTargetEntry(const query::Plan* plan, QueryResult& results
return Status::OK();
}
Status
SegmentSmallIndex::LoadIndexing(const LoadIndexInfo& info) {
auto field_offset_opt = schema_->get_offset(info.field_name);
AssertInfo(field_offset_opt.has_value(), "field name(" + info.field_name + ") not found");
Assert(info.index_params.count("metric_type"));
auto metric_type_str = info.index_params.at("metric_type");
auto entry = std::make_unique<SealedIndexingEntry>();
entry->metric_type_ = GetMetricType(metric_type_str);
entry->indexing_ = info.index;
sealed_indexing_record_.add_entry(field_offset_opt.value(), std::move(entry));
return Status::OK();
}
} // namespace milvus::segcore
......@@ -20,9 +20,9 @@
#include <query/PlanNode.h>
#include "AckResponder.h"
#include "SealedIndexingRecord.h"
#include "ConcurrentVector.h"
#include "segcore/SegmentBase.h"
// #include "knowhere/index/structured_index/StructuredIndex.h"
#include "query/deprecated/GeneralQuery.h"
#include "utils/Status.h"
#include "segcore/DeletedRecord.h"
......@@ -79,12 +79,14 @@ class SegmentSmallIndex : public SegmentBase {
Status
DropRawData(std::string_view field_name) override {
PanicInfo("unimplemented");
// TODO: NO-OP
return Status::OK();
}
Status
LoadRawData(std::string_view field_name, const char* blob, int64_t blob_size) override {
PanicInfo("unimplemented");
// TODO: NO-OP
return Status::OK();
}
int64_t
......@@ -106,11 +108,6 @@ class SegmentSmallIndex : public SegmentBase {
return deleted_record_;
}
const SealedIndexingRecord&
get_sealed_indexing_record() const {
return sealed_indexing_record_;
}
const Schema&
get_schema() const {
return *schema_;
......@@ -132,9 +129,6 @@ class SegmentSmallIndex : public SegmentBase {
return 0;
}
Status
LoadIndexing(const LoadIndexInfo& info) override;
public:
friend std::unique_ptr<SegmentBase>
CreateSegment(SchemaPtr schema, int64_t chunk_size);
......@@ -166,7 +160,6 @@ class SegmentSmallIndex : public SegmentBase {
InsertRecord record_;
DeletedRecord deleted_record_;
IndexingRecord indexing_record_;
SealedIndexingRecord sealed_indexing_record_;
tbb::concurrent_unordered_multimap<idx_t, int64_t> uid2offset_;
};
......
......@@ -118,17 +118,21 @@ ResetSearchResult(std::vector<std::vector<int64_t>>& search_records,
std::vector<float> result_distances;
std::vector<int64_t> internal_seg_offsets;
std::vector<int64_t> result_ids;
for (int j = 0; j < search_records[i].size(); j++) {
auto& offset = search_records[i][j];
auto distance = search_result->result_distances_[offset];
auto internal_seg_offset = search_result->internal_seg_offsets_[offset];
auto id = search_result->result_ids_[offset];
result_distances.push_back(distance);
internal_seg_offsets.push_back(internal_seg_offset);
result_ids.push_back(id);
}
search_result->result_distances_ = result_distances;
search_result->internal_seg_offsets_ = internal_seg_offsets;
search_result->result_ids_ = result_ids;
}
}
......
......@@ -13,7 +13,6 @@ set(MILVUS_TEST_FILES
test_bitmap.cpp
test_binary.cpp
test_index_wrapper.cpp
test_sealed.cpp
)
add_executable(all_tests
${MILVUS_TEST_FILES}
......
......@@ -22,7 +22,7 @@ TEST(Binary, Insert) {
int64_t topK = 5;
auto schema = std::make_shared<Schema>();
schema->AddField("vecbin", DataType::VECTOR_BINARY, 128, MetricType::METRIC_Jaccard);
schema->AddField("age", DataType::INT32);
schema->AddField("age", DataType::INT64);
auto dataset = DataGen(schema, N, 10);
auto segment = CreateSegment(schema);
segment->PreInsert(N);
......
......@@ -458,15 +458,13 @@ TEST(Query, FillSegment) {
QueryResult result;
segment->Search(plan.get(), groups.data(), timestamps.data(), 1, result);
// TODO: deprecated result_ids_
ASSERT_EQ(result.result_ids_, result.internal_seg_offsets_);
auto topk = 5;
auto num_queries = 10;
result.result_offsets_.resize(topk * num_queries);
segment->FillTargetEntry(plan.get(), result);
// TODO: deprecated result_ids_
ASSERT_EQ(result.result_ids_, result.internal_seg_offsets_);
auto ans = result.row_data_;
ASSERT_EQ(ans.size(), topk * num_queries);
int64_t std_index = 0;
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
//
// Created by mike on 12/28/20.
//
#include "test_utils/DataGen.h"
#include <gtest/gtest.h>
#include <knowhere/index/vector_index/VecIndex.h>
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
#include <knowhere/index/vector_index/VecIndexFactory.h>
#include <knowhere/index/vector_index/IndexIVF.h>
using namespace milvus;
using namespace milvus::segcore;
using namespace milvus;
TEST(Sealed, without_predicate) {
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
auto dim = 16;
auto topK = 5;
auto metric_type = MetricType::METRIC_L2;
schema->AddField("fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
schema->AddField("age", DataType::FLOAT);
std::string dsl = R"({
"bool": {
"must": [
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5
}
}
}
]
}
})";
int64_t N = 1000 * 1000;
auto dataset = DataGen(schema, N);
auto vec_col = dataset.get_col<float>(0);
for (int64_t i = 0; i < 1000 * dim; ++i) {
vec_col.push_back(0);
}
auto query_ptr = vec_col.data() + 4200 * dim;
auto segment = CreateSegment(schema);
segment->PreInsert(N);
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
auto plan = CreatePlan(*schema, dsl);
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
QueryResult qr;
Timestamp time = 1000000;
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr);
auto pre_result = QueryResultToJson(qr);
auto indexing = std::make_shared<knowhere::IVF>();
auto conf = knowhere::Config{{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, topK},
{knowhere::IndexParams::nlist, 100},
{knowhere::IndexParams::nprobe, 10},
{knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
{knowhere::meta::DEVICEID, 0}};
auto database = knowhere::GenDataset(N, dim, vec_col.data() + 1000 * dim);
indexing->Train(database, conf);
indexing->AddWithoutIds(database, conf);
EXPECT_EQ(indexing->Count(), N);
EXPECT_EQ(indexing->Dim(), dim);
auto query_dataset = knowhere::GenDataset(num_queries, dim, query_ptr);
auto result = indexing->Query(query_dataset, conf, nullptr);
auto ids = result->Get<int64_t*>(milvus::knowhere::meta::IDS); // for comparison
auto dis = result->Get<float*>(milvus::knowhere::meta::DISTANCE); // for comparison
std::vector<int64_t> vec_ids(ids, ids + topK * num_queries);
std::vector<float> vec_dis(dis, dis + topK * num_queries);
qr.internal_seg_offsets_ = vec_ids;
qr.result_distances_ = vec_dis;
auto ref_result = QueryResultToJson(qr);
LoadIndexInfo load_info;
load_info.field_name = "fakevec";
load_info.field_id = 42;
load_info.index = indexing;
load_info.index_params["metric_type"] = "L2";
segment->LoadIndexing(load_info);
qr = QueryResult();
segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr);
auto post_result = QueryResultToJson(qr);
std::cout << ref_result.dump(1);
std::cout << post_result.dump(1);
ASSERT_EQ(ref_result.dump(2), post_result.dump(2));
}
TEST(Sealed, with_predicate) {
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
auto dim = 16;
auto topK = 5;
auto metric_type = MetricType::METRIC_L2;
schema->AddField("fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
schema->AddField("counter", DataType::INT64);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"counter": {
"GE": 420000,
"LT": 420005
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5
}
}
}
]
}
})";
int64_t N = 1000 * 1000;
auto dataset = DataGen(schema, N);
auto vec_col = dataset.get_col<float>(0);
auto query_ptr = vec_col.data() + 420000 * dim;
auto segment = CreateSegment(schema);
segment->PreInsert(N);
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
auto plan = CreatePlan(*schema, dsl);
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
QueryResult qr;
Timestamp time = 10000000;
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr);
auto pre_qr = qr;
auto indexing = std::make_shared<knowhere::IVF>();
auto conf = knowhere::Config{{knowhere::meta::DIM, dim},
{knowhere::meta::TOPK, topK},
{knowhere::IndexParams::nlist, 100},
{knowhere::IndexParams::nprobe, 10},
{knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
{knowhere::meta::DEVICEID, 0}};
auto database = knowhere::GenDataset(N, dim, vec_col.data());
indexing->Train(database, conf);
indexing->AddWithoutIds(database, conf);
EXPECT_EQ(indexing->Count(), N);
EXPECT_EQ(indexing->Dim(), dim);
auto query_dataset = knowhere::GenDataset(num_queries, dim, query_ptr);
auto result = indexing->Query(query_dataset, conf, nullptr);
LoadIndexInfo load_info;
load_info.field_name = "fakevec";
load_info.field_id = 42;
load_info.index = indexing;
load_info.index_params["metric_type"] = "L2";
segment->LoadIndexing(load_info);
qr = QueryResult();
segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr);
auto post_qr = qr;
for (int i = 0; i < num_queries; ++i) {
auto offset = i * topK;
ASSERT_EQ(post_qr.internal_seg_offsets_[offset], 420000 + i);
ASSERT_EQ(post_qr.result_distances_[offset], 0.0);
}
}
\ No newline at end of file
......@@ -111,16 +111,8 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42) {
}
case engine::DataType::INT64: {
vector<int64_t> data(N);
int64_t index = 0;
// begin with counter
if (field.get_name().rfind("counter", 0) == 0) {
for (auto& x : data) {
x = index++;
}
} else {
for (auto& x : data) {
x = er() % (2 * N);
}
x = er();
}
insert_cols(data);
break;
......@@ -178,26 +170,6 @@ CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) {
return raw_group;
}
inline auto
CreatePlaceholderGroupFromBlob(int64_t num_queries, int dim, const float* src) {
namespace ser = milvus::proto::service;
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::VECTOR_FLOAT);
int64_t src_index = 0;
for (int i = 0; i < num_queries; ++i) {
std::vector<float> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(src[src_index++]);
}
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
value->add_values(vec.data(), vec.size() * sizeof(float));
}
return raw_group;
}
inline auto
CreateBinaryPlaceholderGroup(int64_t num_queries, int64_t dim, int64_t seed = 42) {
assert(dim % 8 == 0);
......
......@@ -85,7 +85,7 @@ func (ibNode *insertBufferNode) Name() string {
}
func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg {
log.Println("=========== insert buffer Node Operating")
// log.Println("=========== insert buffer Node Operating")
if len(in) != 1 {
log.Println("Error: Invalid operate message input in insertBuffertNode, input length = ", len(in))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册