// Copyright (C) 2019-2020 Zilliz. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software distributed under the License // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License #pragma once #include #include #include #include #include #include "Constants.h" #include "common/Schema.h" #include "index/ScalarIndexSort.h" #include "index/StringIndexSort.h" #include "knowhere/index/VecIndex.h" #include "knowhere/index/VecIndexFactory.h" #include "knowhere/index/vector_index/IndexIVF.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h" #include "query/SearchOnIndex.h" #include "segcore/SegmentGrowingImpl.h" #include "segcore/SegmentSealedImpl.h" #include "segcore/Utils.h" using boost::algorithm::starts_with; namespace milvus::segcore { struct GeneratedData { std::vector row_ids_; std::vector timestamps_; InsertData* raw_; std::vector field_ids; SchemaPtr schema_; template std::vector get_col(FieldId field_id) const { std::vector ret(raw_->num_rows()); for (auto target_field_data : raw_->fields_data()) { if (field_id.get() != target_field_data.field_id()) { continue; } auto& field_meta = schema_->operator[](field_id); if (field_meta.is_vector()) { if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) { int len = raw_->num_rows() * field_meta.get_dim(); ret.resize(len); auto src_data = reinterpret_cast(target_field_data.vectors().float_vector().data().data()); std::copy_n(src_data, len, ret.data()); } else if (field_meta.get_data_type() == DataType::VECTOR_BINARY) { int len = raw_->num_rows() * (field_meta.get_dim() / 8); ret.resize(len); auto src_data = reinterpret_cast(target_field_data.vectors().binary_vector().data()); std::copy_n(src_data, len, ret.data()); } else { PanicInfo("unsupported"); } return std::move(ret); } switch (field_meta.get_data_type()) { case DataType::BOOL: { auto src_data = reinterpret_cast(target_field_data.scalars().bool_data().data().data()); std::copy_n(src_data, raw_->num_rows(), ret.data()); break; } case DataType::INT8: case DataType::INT16: case DataType::INT32: { auto src_data = reinterpret_cast(target_field_data.scalars().int_data().data().data()); std::copy_n(src_data, raw_->num_rows(), ret.data()); break; } case DataType::INT64: { auto src_data = reinterpret_cast(target_field_data.scalars().long_data().data().data()); std::copy_n(src_data, raw_->num_rows(), ret.data()); break; } case DataType::FLOAT: { auto src_data = reinterpret_cast(target_field_data.scalars().float_data().data().data()); std::copy_n(src_data, raw_->num_rows(), ret.data()); break; } case DataType::DOUBLE: { auto src_data = reinterpret_cast(target_field_data.scalars().double_data().data().data()); std::copy_n(src_data, raw_->num_rows(), ret.data()); break; } case DataType::VARCHAR: { auto ret_data = reinterpret_cast(ret.data()); auto src_data = target_field_data.scalars().string_data().data(); std::copy(src_data.begin(), src_data.end(), ret_data); break; } default: { PanicInfo("unsupported"); } } } return std::move(ret); } std::unique_ptr get_col(FieldId field_id) const { for (auto target_field_data : raw_->fields_data()) { if (field_id.get() == target_field_data.field_id()) { return std::make_unique(target_field_data); } } PanicInfo("field id not find"); } private: GeneratedData() = default; friend GeneratedData DataGen(SchemaPtr schema, int64_t N, uint64_t seed, uint64_t ts_offset, int repeat_count); }; inline GeneratedData DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42, uint64_t ts_offset = 0, int repeat_count = 1) { using std::vector; std::default_random_engine er(seed); std::normal_distribution<> distr(0, 1); int offset = 0; auto insert_data = std::make_unique(); auto insert_cols = [&insert_data](auto& data, int64_t count, auto& field_meta) { auto array = milvus::segcore::CreateDataArrayFrom(data.data(), count, field_meta); insert_data->mutable_fields_data()->AddAllocated(array.release()); }; for (auto field_id : schema->get_field_ids()) { auto field_meta = schema->operator[](field_id); switch (field_meta.get_data_type()) { case DataType::VECTOR_FLOAT: { auto dim = field_meta.get_dim(); vector final(dim * N); bool is_ip = starts_with(field_meta.get_name().get(), "normalized"); #pragma omp parallel for for (int n = 0; n < N; ++n) { vector data(dim); float sum = 0; std::default_random_engine er2(seed + n); std::normal_distribution<> distr2(0, 1); for (auto& x : data) { x = distr2(er2) + offset; sum += x * x; } if (is_ip) { sum = sqrt(sum); for (auto& x : data) { x /= sum; } } std::copy(data.begin(), data.end(), final.begin() + dim * n); } insert_cols(final, N, field_meta); break; } case DataType::VECTOR_BINARY: { auto dim = field_meta.get_dim(); Assert(dim % 8 == 0); vector data(dim / 8 * N); for (auto& x : data) { x = er(); } insert_cols(data, N, field_meta); break; } case DataType::INT64: { vector data(N); for (int i = 0; i < N; i++) { data[i] = i / repeat_count; } insert_cols(data, N, field_meta); break; } case DataType::INT32: { vector data(N); for (auto& x : data) { x = er() % (2 * N); } insert_cols(data, N, field_meta); break; } case DataType::INT16: { vector data(N); for (auto& x : data) { x = er() % (2 * N); } insert_cols(data, N, field_meta); break; } case DataType::INT8: { vector data(N); for (auto& x : data) { x = er() % (2 * N); } insert_cols(data, N, field_meta); break; } case DataType::FLOAT: { vector data(N); for (auto& x : data) { x = distr(er); } insert_cols(data, N, field_meta); break; } case DataType::DOUBLE: { vector data(N); for (auto& x : data) { x = distr(er); } insert_cols(data, N, field_meta); break; } case DataType::VARCHAR: { vector data(N); for (int i = 0; i < N / repeat_count; i++) { auto str = std::to_string(er()); for (int j = 0; j < repeat_count; j++) { data[i * repeat_count + j] = str; } } insert_cols(data, N, field_meta); break; } default: { throw std::runtime_error("unimplemented"); } } ++offset; } GeneratedData res; res.schema_ = schema; res.raw_ = insert_data.release(); res.raw_->set_num_rows(N); for (int i = 0; i < N; ++i) { res.row_ids_.push_back(i); res.timestamps_.push_back(i + ts_offset); } return res; } inline auto CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) { namespace ser = milvus::proto::common; ser::PlaceholderGroup raw_group; auto value = raw_group.add_placeholders(); value->set_tag("$0"); value->set_type(ser::PlaceholderType::FloatVector); std::normal_distribution dis(0, 1); std::default_random_engine e(seed); for (int i = 0; i < num_queries; ++i) { std::vector vec; for (int d = 0; d < dim; ++d) { vec.push_back(dis(e)); } // std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float)); value->add_values(vec.data(), vec.size() * sizeof(float)); } return raw_group; } inline auto CreatePlaceholderGroup(int64_t num_queries, int dim, const std::vector& vecs) { namespace ser = milvus::proto::common; ser::PlaceholderGroup raw_group; auto value = raw_group.add_placeholders(); value->set_tag("$0"); value->set_type(ser::PlaceholderType::FloatVector); for (int i = 0; i < num_queries; ++i) { std::vector vec; for (int d = 0; d < dim; ++d) { vec.push_back(vecs[i * dim + d]); } value->add_values(vec.data(), vec.size() * sizeof(float)); } return raw_group; } inline auto CreatePlaceholderGroupFromBlob(int64_t num_queries, int dim, const float* src) { namespace ser = milvus::proto::common; ser::PlaceholderGroup raw_group; auto value = raw_group.add_placeholders(); value->set_tag("$0"); value->set_type(ser::PlaceholderType::FloatVector); int64_t src_index = 0; for (int i = 0; i < num_queries; ++i) { std::vector vec; for (int d = 0; d < dim; ++d) { vec.push_back(src[src_index++]); } // std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float)); value->add_values(vec.data(), vec.size() * sizeof(float)); } return raw_group; } inline auto CreateBinaryPlaceholderGroup(int64_t num_queries, int64_t dim, int64_t seed = 42) { assert(dim % 8 == 0); namespace ser = milvus::proto::common; ser::PlaceholderGroup raw_group; auto value = raw_group.add_placeholders(); value->set_tag("$0"); value->set_type(ser::PlaceholderType::BinaryVector); std::default_random_engine e(seed); for (int i = 0; i < num_queries; ++i) { std::vector vec; for (int d = 0; d < dim / 8; ++d) { vec.push_back(e()); } // std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float)); value->add_values(vec.data(), vec.size()); } return raw_group; } inline auto CreateBinaryPlaceholderGroupFromBlob(int64_t num_queries, int64_t dim, const uint8_t* ptr) { assert(dim % 8 == 0); namespace ser = milvus::proto::common; ser::PlaceholderGroup raw_group; auto value = raw_group.add_placeholders(); value->set_tag("$0"); value->set_type(ser::PlaceholderType::BinaryVector); for (int i = 0; i < num_queries; ++i) { std::vector vec; for (int d = 0; d < dim / 8; ++d) { vec.push_back(*ptr); ++ptr; } // std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float)); value->add_values(vec.data(), vec.size()); } return raw_group; } inline auto SearchResultToVector(const SearchResult& sr) { int64_t num_queries = sr.total_nq_; int64_t topk = sr.unity_topK_; std::vector> result; for (int q = 0; q < num_queries; ++q) { for (int k = 0; k < topk; ++k) { int index = q * topk + k; result.emplace_back(std::make_pair(sr.seg_offsets_[index], sr.distances_[index])); } } return result; } inline json SearchResultToJson(const SearchResult& sr) { int64_t num_queries = sr.total_nq_; int64_t topk = sr.unity_topK_; std::vector> results; for (int q = 0; q < num_queries; ++q) { std::vector result; for (int k = 0; k < topk; ++k) { int index = q * topk + k; result.emplace_back(std::to_string(sr.seg_offsets_[index]) + "->" + std::to_string(sr.distances_[index])); } results.emplace_back(std::move(result)); } return json{results}; }; inline void SealedLoadFieldData(const GeneratedData& dataset, SegmentSealed& seg, const std::set& exclude_fields = {}) { auto row_count = dataset.row_ids_.size(); { LoadFieldDataInfo info; FieldMeta field_meta(FieldName("RowID"), RowFieldID, DataType::INT64); auto array = CreateScalarDataArrayFrom(dataset.row_ids_.data(), row_count, field_meta); info.field_data = array.release(); info.row_count = dataset.row_ids_.size(); info.field_id = RowFieldID.get(); // field id for RowId seg.LoadFieldData(info); } { LoadFieldDataInfo info; FieldMeta field_meta(FieldName("Timestamp"), TimestampFieldID, DataType::INT64); auto array = CreateScalarDataArrayFrom(dataset.timestamps_.data(), row_count, field_meta); info.field_data = array.release(); info.row_count = dataset.timestamps_.size(); info.field_id = TimestampFieldID.get(); seg.LoadFieldData(info); } for (auto field_data : dataset.raw_->fields_data()) { int64_t field_id = field_data.field_id(); if (exclude_fields.find(field_id) != exclude_fields.end()) { continue; } LoadFieldDataInfo info; info.field_id = field_data.field_id(); info.row_count = row_count; info.field_data = &field_data; seg.LoadFieldData(info); } } inline std::unique_ptr SealedCreator(SchemaPtr schema, const GeneratedData& dataset) { auto segment = CreateSealedSegment(schema); SealedLoadFieldData(dataset, *segment); return segment; } inline knowhere::VecIndexPtr GenVecIndexing(int64_t N, int64_t dim, const float* vec) { // {knowhere::IndexParams::nprobe, 10}, auto conf = knowhere::Config{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}, {knowhere::meta::DIM, dim}, {knowhere::indexparam::NLIST, 1024}, {knowhere::meta::DEVICE_ID, 0}}; auto database = knowhere::GenDataset(N, dim, vec); auto indexing = std::make_shared(); indexing->Train(database, conf); indexing->AddWithoutIds(database, conf); return indexing; } template inline scalar::IndexBasePtr GenScalarIndexing(int64_t N, const T* data) { if constexpr (std::is_same_v) { auto indexing = scalar::CreateStringIndexSort(); indexing->Build(N, data); return indexing; } else { auto indexing = scalar::CreateScalarIndexSort(); indexing->Build(N, data); return indexing; } } inline std::vector translate_text_plan_to_binary_plan(const char* text_plan) { proto::plan::PlanNode plan_node; auto ok = google::protobuf::TextFormat::ParseFromString(text_plan, &plan_node); AssertInfo(ok, "Failed to parse"); std::string binary_plan; plan_node.SerializeToString(&binary_plan); std::vector ret; ret.resize(binary_plan.size()); std::memcpy(ret.data(), binary_plan.c_str(), binary_plan.size()); return ret; } inline auto GenTss(int64_t num, int64_t begin_ts) { std::vector tss(num, 0); std::iota(tss.begin(), tss.end(), begin_ts); return tss; } inline auto GenPKs(int64_t num, int64_t begin_pk) { auto arr = std::make_unique(); for (int64_t i = 0; i < num; i++) { arr->add_data(begin_pk + i); } auto ids = std::make_shared(); ids->set_allocated_int_id(arr.release()); return ids; } template inline auto GenPKs(const Iter begin, const Iter end) { auto arr = std::make_unique(); for (auto it = begin; it != end; it++) { arr->add_data(*it); } auto ids = std::make_shared(); ids->set_allocated_int_id(arr.release()); return ids; } inline auto GenPKs(const std::vector& pks) { return GenPKs(pks.begin(), pks.end()); } } // namespace milvus::segcore