未验证 提交 27b4cbc0 编写于 作者: X xige-16 提交者: GitHub

Cherry pick remove translateHits commit to mater (#16436)

Signed-off-by: Nxige-16 <xi.ge@zilliz.com>
Co-authored-by: Nbigsheeper <yihao.dai@zilliz.com>
上级 a37479d7
......@@ -117,7 +117,7 @@ datatype_is_floating(DataType datatype) {
class FieldMeta {
public:
static const FieldMeta RowIdMeta;
FieldMeta(const FieldMeta&) = delete;
FieldMeta(const FieldMeta&) = default;
FieldMeta(FieldMeta&&) = default;
FieldMeta&
operator=(const FieldMeta&) = delete;
......
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <limits>
#include <string>
#include <utility>
#include <vector>
#include <boost/align/aligned_allocator.hpp>
#include <boost/dynamic_bitset.hpp>
#include <NamedType/named_type.hpp>
#include "pb/schema.pb.h"
#include "utils/Types.h"
#include "FieldMeta.h"
namespace milvus {
struct SearchResult {
SearchResult() = default;
SearchResult(int64_t num_queries, int64_t topk) : topk_(topk), num_queries_(num_queries) {
auto count = get_row_count();
distances_.resize(count);
ids_.resize(count);
}
int64_t
get_row_count() const {
return topk_ * num_queries_;
}
// vector type
void
AddField(const FieldName& name,
const FieldId id,
DataType data_type,
int64_t dim,
std::optional<MetricType> metric_type) {
this->AddField(FieldMeta(name, id, data_type, dim, metric_type));
}
// scalar type
void
AddField(const FieldName& name, const FieldId id, DataType data_type) {
this->AddField(FieldMeta(name, id, data_type));
}
void
AddField(FieldMeta&& field_meta) {
output_fields_meta_.emplace_back(std::move(field_meta));
}
public:
int64_t num_queries_;
int64_t topk_;
std::vector<float> distances_;
std::vector<int64_t> ids_; // primary keys
public:
// TODO(gexi): utilize these fields
void* segment_;
std::vector<int64_t> result_offsets_;
std::vector<int64_t> primary_keys_;
aligned_vector<char> ids_data_;
std::vector<aligned_vector<char>> output_fields_data_;
std::vector<FieldMeta> output_fields_meta_;
};
using SearchResultPtr = std::shared_ptr<SearchResult>;
using SearchResultOpt = std::optional<SearchResult>;
struct RetrieveResult {
RetrieveResult() = default;
public:
void* segment_;
std::vector<int64_t> result_offsets_;
std::vector<DataArray> field_data_;
};
using RetrieveResultPtr = std::shared_ptr<RetrieveResult>;
using RetrieveResultOpt = std::optional<RetrieveResult>;
} // namespace milvus
......@@ -61,49 +61,6 @@ constexpr std::false_type always_false{};
template <typename T>
using aligned_vector = std::vector<T, boost::alignment::aligned_allocator<T, 64>>;
///////////////////////////////////////////////////////////////////////////////////////////////////
struct SearchResult {
SearchResult() = default;
SearchResult(int64_t num_queries, int64_t topk) : topk_(topk), num_queries_(num_queries) {
auto count = get_row_count();
distances_.resize(count);
ids_.resize(count);
}
int64_t
get_row_count() const {
return topk_ * num_queries_;
}
public:
int64_t num_queries_;
int64_t topk_;
std::vector<float> distances_;
std::vector<int64_t> ids_;
public:
// TODO(gexi): utilize these fields
void* segment_;
std::vector<int64_t> result_offsets_;
std::vector<int64_t> primary_keys_;
std::vector<std::vector<char>> row_data_;
};
using SearchResultPtr = std::shared_ptr<SearchResult>;
using SearchResultOpt = std::optional<SearchResult>;
struct RetrieveResult {
RetrieveResult() = default;
public:
void* segment_;
std::vector<int64_t> result_offsets_;
std::vector<DataArray> field_data_;
};
using RetrieveResultPtr = std::shared_ptr<RetrieveResult>;
using RetrieveResultOpt = std::optional<RetrieveResult>;
namespace impl {
// hide identifier name to make auto-completion happy
struct FieldIdTag;
......
......@@ -15,8 +15,15 @@
#include <algorithm>
#include "utils/Status.h"
#include "common/type_c.h"
namespace milvus::segcore {
// SearchResultDataBlobs contains the marshal blobs of many `milvus::proto::schema::SearchResultData`
struct SearchResultDataBlobs {
std::vector<std::vector<char>> blobs;
};
Status
merge_into(int64_t num_queries,
int64_t topk,
......
......@@ -13,6 +13,7 @@
#include "common/Consts.h"
#include "common/Types.h"
#include "common/QueryResult.h"
#include "segcore/Reduce.h"
using milvus::SearchResult;
......
......@@ -44,53 +44,38 @@ SegmentInternalInterface::FillTargetEntry(const query::Plan* plan, SearchResult&
AssertInfo(plan, "empty plan");
auto size = results.distances_.size();
AssertInfo(results.ids_.size() == size, "Size of result distances is not equal to size of ids");
Assert(results.row_data_.size() == 0);
std::vector<int64_t> element_sizeofs;
std::vector<aligned_vector<char>> blobs;
// fill row_ids
{
aligned_vector<char> blob(size * sizeof(int64_t));
results.ids_data_.resize(size * sizeof(int64_t));
if (plan->schema_.get_is_auto_id()) {
bulk_subscript(SystemFieldType::RowId, results.ids_.data(), size, blob.data());
bulk_subscript(SystemFieldType::RowId, results.ids_.data(), size, results.ids_data_.data());
} else {
auto key_offset_opt = get_schema().get_primary_key_offset();
AssertInfo(key_offset_opt.has_value(), "Cannot get primary key offset from schema");
auto key_offset = key_offset_opt.value();
AssertInfo(get_schema()[key_offset].get_data_type() == DataType::INT64,
"Primary key field is not INT64 type");
bulk_subscript(key_offset, results.ids_.data(), size, blob.data());
bulk_subscript(key_offset, results.ids_.data(), size, results.ids_data_.data());
}
blobs.emplace_back(std::move(blob));
element_sizeofs.push_back(sizeof(int64_t));
}
// fill other entries except primary key
// fill other entries except primary key by result_offset
for (auto field_offset : plan->target_entries_) {
auto& field_meta = get_schema()[field_offset];
auto element_sizeof = field_meta.get_sizeof();
aligned_vector<char> blob(size * element_sizeof);
bulk_subscript(field_offset, results.ids_.data(), size, blob.data());
blobs.emplace_back(std::move(blob));
element_sizeofs.push_back(element_sizeof);
}
auto target_sizeof = std::accumulate(element_sizeofs.begin(), element_sizeofs.end(), 0);
for (int64_t i = 0; i < size; ++i) {
int64_t element_offset = 0;
std::vector<char> target(target_sizeof);
for (int loc = 0; loc < blobs.size(); ++loc) {
auto element_sizeof = element_sizeofs[loc];
auto blob_ptr = blobs[loc].data();
auto src = blob_ptr + element_sizeof * i;
auto dst = target.data() + element_offset;
memcpy(dst, src, element_sizeof);
element_offset += element_sizeof;
results.output_fields_data_.emplace_back(std::move(blob));
if (field_meta.is_vector()) {
results.AddField(field_meta.get_name(), field_meta.get_id(), field_meta.get_data_type(),
field_meta.get_dim(), field_meta.get_metric_type());
} else {
results.AddField(field_meta.get_name(), field_meta.get_id(), field_meta.get_data_type());
}
assert(element_offset == target_sizeof);
results.row_data_.emplace_back(std::move(target));
}
}
......@@ -162,7 +147,7 @@ CreateScalarArrayFrom(const void* data_raw, int64_t count, DataType data_type) {
return scalar_array;
}
static std::unique_ptr<DataArray>
std::unique_ptr<DataArray>
CreateDataArrayFrom(const void* data_raw, int64_t count, const FieldMeta& field_meta) {
auto data_type = field_meta.get_data_type();
auto data_array = std::make_unique<DataArray>();
......
......@@ -23,6 +23,7 @@
#include "common/Span.h"
#include "common/SystemProperty.h"
#include "common/Types.h"
#include "common/QueryResult.h"
#include "knowhere/index/vector_index/VecIndex.h"
#include "query/Plan.h"
#include "query/PlanNode.h"
......@@ -172,4 +173,10 @@ class SegmentInternalInterface : public SegmentInterface {
mutable std::shared_mutex mutex_;
};
static std::unique_ptr<ScalarArray>
CreateScalarArrayFrom(const void* data_raw, int64_t count, DataType data_type);
std::unique_ptr<DataArray>
CreateDataArrayFrom(const void* data_raw, int64_t count, const FieldMeta& field_meta);
} // namespace milvus::segcore
......@@ -13,49 +13,21 @@
#include <unordered_set>
#include <vector>
#include "Reduce.h"
#include "common/CGoHelper.h"
#include "common/Consts.h"
#include "common/Types.h"
#include "common/QueryResult.h"
#include "exceptions/EasyAssert.h"
#include "log/Log.h"
#include "pb/milvus.pb.h"
#include "query/Plan.h"
#include "segcore/Reduce.h"
#include "segcore/ReduceStructure.h"
#include "segcore/SegmentInterface.h"
#include "segcore/reduce_c.h"
using SearchResult = milvus::SearchResult;
int
MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids) {
auto status = milvus::segcore::merge_into(num_queries, topk, distances, uids, new_distances, new_uids);
return status.code();
}
struct MarshaledHitsPerGroup {
std::vector<std::string> hits_;
std::vector<int64_t> blob_length_;
};
struct MarshaledHits {
explicit MarshaledHits(int64_t num_group) {
marshaled_hits_.resize(num_group);
}
int
get_num_group() {
return marshaled_hits_.size();
}
std::vector<MarshaledHitsPerGroup> marshaled_hits_;
};
void
DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
auto hits = (MarshaledHits*)c_marshaled_hits;
delete hits;
}
// void
// PrintSearchResult(char* buf, const milvus::SearchResult* result, int64_t seg_idx, int64_t from, int64_t to) {
// const int64_t MAXLEN = 32;
......@@ -154,6 +126,208 @@ ReduceResultData(std::vector<SearchResult*>& search_results, int64_t nq, int64_t
}
}
void
ReorganizeSearchResults(std::vector<SearchResult*>& search_results,
int32_t nq,
int32_t topK,
milvus::aligned_vector<int64_t>& result_ids,
std::vector<float>& result_distances,
std::vector<milvus::aligned_vector<char>>& result_output_fields_data) {
auto num_segments = search_results.size();
auto results_count = 0;
for (int i = 0; i < num_segments; i++) {
auto search_result = search_results[i];
AssertInfo(search_result != nullptr, "null search result when reorganize");
AssertInfo(search_result->output_fields_meta_.size() == result_output_fields_data.size(),
"illegal fields meta size"
", fields_meta_size = " +
std::to_string(search_result->output_fields_meta_.size()) +
", expected_size = " + std::to_string(result_output_fields_data.size()));
auto num_results = search_result->result_offsets_.size();
if (num_results == 0) {
continue;
}
#pragma omp parallel for
for (int j = 0; j < num_results; j++) {
auto loc = search_result->result_offsets_[j];
// AssertInfo(loc < nq * topK, "result location of out range, location = " +
// std::to_string(loc));
// set result ids
memcpy(&result_ids[loc], &search_result->ids_data_[j * sizeof(int64_t)], sizeof(int64_t));
// set result distances
result_distances[loc] = search_result->distances_[j];
// set result output fields data
for (int k = 0; k < search_result->output_fields_meta_.size(); k++) {
auto ele_size = search_result->output_fields_meta_[k].get_sizeof();
memcpy(&result_output_fields_data[k][loc * ele_size],
&search_result->output_fields_data_[k][j * ele_size], ele_size);
}
}
results_count += num_results;
}
AssertInfo(results_count == nq * topK,
"size of reduce result is less than nq * topK"
", result_count = " +
std::to_string(results_count) + ", nq * topK = " + std::to_string(nq * topK));
}
std::vector<char>
GetSearchResultDataSlice(milvus::aligned_vector<int64_t>& result_ids,
std::vector<float>& result_distances,
std::vector<milvus::aligned_vector<char>>& result_output_fields_data,
int32_t nq,
int32_t topK,
int32_t nq_begin,
int32_t nq_end,
std::vector<milvus::FieldMeta>& output_fields_meta) {
auto search_result_data = std::make_unique<milvus::proto::schema::SearchResultData>();
// set topK and nq
search_result_data->set_top_k(topK);
search_result_data->set_num_queries(nq);
auto offset_begin = nq_begin * topK;
auto offset_end = nq_end * topK;
AssertInfo(offset_begin <= offset_end,
"illegal offsets when GetSearchResultDataSlice"
", offset_begin = " +
std::to_string(offset_begin) + ", offset_end = " + std::to_string(offset_end));
AssertInfo(offset_end <= topK * nq,
"illegal offset_end when GetSearchResultDataSlice"
", offset_end = " +
std::to_string(offset_end) + ", nq = " + std::to_string(nq) + ", topK = " + std::to_string(topK));
// set ids
auto proto_ids = std::make_unique<milvus::proto::schema::IDs>();
auto ids = std::make_unique<milvus::proto::schema::LongArray>();
*ids->mutable_data() = {result_ids.begin() + offset_begin, result_ids.begin() + offset_end};
proto_ids->set_allocated_int_id(ids.release());
search_result_data->set_allocated_ids(proto_ids.release());
AssertInfo(search_result_data->ids().int_id().data_size() == offset_end - offset_begin,
"wrong ids size"
", size = " +
std::to_string(search_result_data->ids().int_id().data_size()) +
", expected size = " + std::to_string(offset_end - offset_begin));
// set scores
*search_result_data->mutable_scores() = {result_distances.begin() + offset_begin,
result_distances.begin() + offset_end};
AssertInfo(search_result_data->scores_size() == offset_end - offset_begin,
"wrong scores size"
", size = " +
std::to_string(search_result_data->scores_size()) +
", expected size = " + std::to_string(offset_end - offset_begin));
// set output fields
for (int i = 0; i < result_output_fields_data.size(); i++) {
auto& field_meta = output_fields_meta[i];
auto field_size = field_meta.get_sizeof();
auto array = milvus::segcore::CreateDataArrayFrom(
result_output_fields_data[i].data() + offset_begin * field_size, offset_end - offset_begin, field_meta);
search_result_data->mutable_fields_data()->AddAllocated(array.release());
}
// SearchResultData to blob
auto size = search_result_data->ByteSize();
auto buffer = std::vector<char>(size);
search_result_data->SerializePartialToArray(buffer.data(), size);
return buffer;
}
CStatus
Marshal(CSearchResultDataBlobs* cSearchResultDataBlobs,
CSearchResult* c_search_results,
int32_t num_segments,
int32_t* nq_slice_sizes,
int32_t num_slices) {
try {
// parse search results and get topK, nq
std::vector<SearchResult*> search_results(num_segments);
for (int i = 0; i < num_segments; ++i) {
search_results[i] = static_cast<SearchResult*>(c_search_results[i]);
}
AssertInfo(search_results.size() > 0, "empty search result when Marshal");
auto topK = search_results[0]->topk_;
auto nq = search_results[0]->num_queries_;
// init result ids, distances
auto result_ids = milvus::aligned_vector<int64_t>(nq * topK);
auto result_distances = std::vector<float>(nq * topK);
// init result output fields data
auto& output_fields_meta = search_results[0]->output_fields_meta_;
auto num_output_fields = output_fields_meta.size();
auto result_output_fields_data = std::vector<milvus::aligned_vector<char>>(num_output_fields);
for (int i = 0; i < num_output_fields; i++) {
auto size = output_fields_meta[i].get_sizeof();
result_output_fields_data[i].resize(size * nq * topK);
}
// Reorganize search results, get result ids, distances and output fields data
ReorganizeSearchResults(search_results, nq, topK, result_ids, result_distances, result_output_fields_data);
// prefix sum, get slices offsets
AssertInfo(num_slices > 0, "empty nq_slice_sizes is not allowed");
auto slice_offsets_size = num_slices + 1;
auto slice_offsets = std::vector<int32_t>(slice_offsets_size);
slice_offsets[0] = 0;
slice_offsets[1] = nq_slice_sizes[0];
for (int i = 2; i < slice_offsets_size; i++) {
slice_offsets[i] = slice_offsets[i - 1] + nq_slice_sizes[i - 1];
}
AssertInfo(slice_offsets[num_slices] == nq,
"illegal req sizes"
", slice_offsets[last] = " +
std::to_string(slice_offsets[num_slices]) + ", nq = " + std::to_string(nq));
// get search result data blobs by slices
auto search_result_data_blobs = std::make_unique<milvus::segcore::SearchResultDataBlobs>();
search_result_data_blobs->blobs.resize(num_slices);
#pragma omp parallel for
for (int i = 0; i < num_slices; i++) {
auto proto = GetSearchResultDataSlice(result_ids, result_distances, result_output_fields_data, nq, topK,
slice_offsets[i], slice_offsets[i + 1], output_fields_meta);
search_result_data_blobs->blobs[i] = proto;
}
// set final result ptr
*cSearchResultDataBlobs = search_result_data_blobs.release();
return milvus::SuccessCStatus();
} catch (std::exception& e) {
DeleteSearchResultDataBlobs(cSearchResultDataBlobs);
return milvus::FailureCStatus(UnexpectedError, e.what());
}
}
CStatus
GetSearchResultDataBlob(CProto* searchResultDataBlob,
CSearchResultDataBlobs cSearchResultDataBlobs,
int32_t blob_index) {
try {
auto search_result_data_blobs =
reinterpret_cast<milvus::segcore::SearchResultDataBlobs*>(cSearchResultDataBlobs);
AssertInfo(blob_index < search_result_data_blobs->blobs.size(), "blob_index out of range");
searchResultDataBlob->proto_blob = search_result_data_blobs->blobs[blob_index].data();
searchResultDataBlob->proto_size = search_result_data_blobs->blobs[blob_index].size();
return milvus::SuccessCStatus();
} catch (std::exception& e) {
searchResultDataBlob->proto_blob = nullptr;
searchResultDataBlob->proto_size = 0;
return milvus::FailureCStatus(UnexpectedError, e.what());
}
}
void
DeleteSearchResultDataBlobs(CSearchResultDataBlobs cSearchResultDataBlobs) {
if (cSearchResultDataBlobs == nullptr) {
return;
}
auto search_result_data_blobs = reinterpret_cast<milvus::segcore::SearchResultDataBlobs*>(cSearchResultDataBlobs);
delete search_result_data_blobs;
}
CStatus
ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_results, int64_t num_segments) {
try {
......@@ -190,121 +364,3 @@ ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_resul
return status;
}
}
CStatus
ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits, CSearchResult* c_search_results, int64_t num_segments) {
try {
auto marshaledHits = std::make_unique<MarshaledHits>(1);
auto sr = (SearchResult*)c_search_results[0];
auto topk = sr->topk_;
auto num_queries = sr->num_queries_;
std::vector<float> result_distances(num_queries * topk);
std::vector<std::vector<char>> row_datas(num_queries * topk);
std::vector<int64_t> counts(num_segments);
for (int i = 0; i < num_segments; i++) {
auto search_result = (SearchResult*)c_search_results[i];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
auto size = search_result->result_offsets_.size();
if (size == 0) {
continue;
}
#pragma omp parallel for
for (int j = 0; j < size; j++) {
auto loc = search_result->result_offsets_[j];
result_distances[loc] = search_result->distances_[j];
row_datas[loc] = search_result->row_data_[j];
}
counts[i] = size;
}
int64_t total_count = 0;
for (int i = 0; i < num_segments; i++) {
total_count += counts[i];
}
AssertInfo(total_count == num_queries * topk, "the reduces result's size less than total_num_queries*topk");
MarshaledHitsPerGroup& hits_per_group = (*marshaledHits).marshaled_hits_[0];
hits_per_group.hits_.resize(num_queries);
hits_per_group.blob_length_.resize(num_queries);
std::vector<milvus::proto::milvus::Hits> hits(num_queries);
#pragma omp parallel for
for (int m = 0; m < num_queries; m++) {
for (int n = 0; n < topk; n++) {
int64_t result_offset = m * topk + n;
hits[m].add_scores(result_distances[result_offset]);
auto& row_data = row_datas[result_offset];
hits[m].add_row_data(row_data.data(), row_data.size());
hits[m].add_ids(*(int64_t*)row_data.data());
}
}
#pragma omp parallel for
for (int j = 0; j < num_queries; j++) {
auto blob = hits[j].SerializeAsString();
hits_per_group.hits_[j] = blob;
hits_per_group.blob_length_[j] = blob.size();
}
auto status = CStatus();
status.error_code = Success;
status.error_msg = "";
auto marshaled_res = (CMarshaledHits)marshaledHits.release();
*c_marshaled_hits = marshaled_res;
return status;
} catch (std::exception& e) {
auto status = CStatus();
status.error_code = UnexpectedError;
status.error_msg = strdup(e.what());
*c_marshaled_hits = nullptr;
return status;
}
}
int64_t
GetHitsBlobSize(CMarshaledHits c_marshaled_hits) {
int64_t total_size = 0;
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
auto num_group = marshaled_hits->get_num_group();
for (int i = 0; i < num_group; i++) {
auto& length_vector = marshaled_hits->marshaled_hits_[i].blob_length_;
for (int j = 0; j < length_vector.size(); j++) {
total_size += length_vector[j];
}
}
return total_size;
}
void
GetHitsBlob(CMarshaledHits c_marshaled_hits, const void* hits) {
auto byte_hits = (char*)hits;
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
auto num_group = marshaled_hits->get_num_group();
int offset = 0;
for (int i = 0; i < num_group; i++) {
auto& hits = marshaled_hits->marshaled_hits_[i];
auto num_queries = hits.hits_.size();
for (int j = 0; j < num_queries; j++) {
auto blob_size = hits.blob_length_[j];
memcpy(byte_hits + offset, hits.hits_[j].data(), blob_size);
offset += blob_size;
}
}
}
int64_t
GetNumQueriesPerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index) {
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
auto& hits = marshaled_hits->marshaled_hits_[group_index].hits_;
return hits.size();
}
void
GetHitSizePerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query) {
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
auto& blob_lens = marshaled_hits->marshaled_hits_[group_index].blob_length_;
for (int i = 0; i < blob_lens.size(); i++) {
hit_size_peer_query[i] = blob_lens[i];
}
}
......@@ -13,38 +13,29 @@
extern "C" {
#endif
#include <stdbool.h>
#include <stdint.h>
#include "common/type_c.h"
#include "segcore/plan_c.h"
#include "segcore/segment_c.h"
typedef void* CMarshaledHits;
void
DeleteMarshaledHits(CMarshaledHits c_marshaled_hits);
int
MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids);
typedef void* CSearchResultDataBlobs;
CStatus
ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* search_results, int64_t num_segments);
CStatus
ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits, CSearchResult* c_search_results, int64_t num_segments);
Marshal(CSearchResultDataBlobs* cSearchResultDataBlobs,
CSearchResult* c_search_results,
int32_t num_segments,
int32_t* nq_slice_sizes,
int32_t num_slices);
int64_t
GetHitsBlobSize(CMarshaledHits c_marshaled_hits);
void
GetHitsBlob(CMarshaledHits c_marshaled_hits, const void* hits);
int64_t
GetNumQueriesPerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index);
CStatus
GetSearchResultDataBlob(CProto* searchResultDataBlob,
CSearchResultDataBlobs cSearchResultDataBlobs,
int32_t blob_index);
void
GetHitSizePerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query);
DeleteSearchResultDataBlobs(CSearchResultDataBlobs cSearchResultDataBlobs);
#ifdef __cplusplus
}
......
......@@ -27,6 +27,7 @@
#include "query/ExprImpl.h"
#include "segcore/Collection.h"
#include "segcore/reduce_c.h"
#include "segcore/Reduce.h"
#include "test_utils/DataGen.h"
#include "utils/Types.h"
......@@ -684,35 +685,6 @@ TEST(CApiTest, GetRowCountTest) {
// DeleteSegment(segment);
//}
TEST(CApiTest, MergeInto) {
std::vector<int64_t> uids;
std::vector<float> distance;
std::vector<int64_t> new_uids;
std::vector<float> new_distance;
int64_t num_queries = 1;
int64_t topk = 2;
uids.push_back(1);
uids.push_back(2);
distance.push_back(5);
distance.push_back(1000);
new_uids.push_back(3);
new_uids.push_back(4);
new_distance.push_back(2);
new_distance.push_back(6);
auto res = MergeInto(num_queries, topk, distance.data(), uids.data(), new_distance.data(), new_uids.data());
ASSERT_EQ(res, 0);
ASSERT_EQ(uids[0], 3);
ASSERT_EQ(distance[0], 2);
ASSERT_EQ(uids[1], 1);
ASSERT_EQ(distance[1], 5);
}
void
CheckSearchResultDuplicate(const std::vector<CSearchResult>& results) {
auto sr = (SearchResult*)results[0];
......@@ -838,89 +810,6 @@ TEST(CApiTest, ReduceRemoveDuplicates) {
DeleteSegment(segment);
}
TEST(CApiTest, Reduce) {
auto collection = NewCollection(get_default_schema_config());
auto segment = NewSegment(collection, Growing, -1);
int N = 10000;
auto [raw_data, timestamps, uids] = generate_data(N);
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
int64_t offset;
PreInsert(segment, N, &offset);
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
assert(ins_res.error_code == Success);
const char* dsl_string = R"(
{
"bool": {
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10,
"round_decimal": 3
}
}
}
})";
int num_queries = 10;
auto blob = generate_query_data(num_queries);
void* plan = nullptr;
auto status = CreateSearchPlan(collection, dsl_string, &plan);
assert(status.error_code == Success);
void* placeholderGroup = nullptr;
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
assert(status.error_code == Success);
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
timestamps.clear();
timestamps.push_back(1);
std::vector<CSearchResult> results;
CSearchResult res1;
CSearchResult res2;
auto res = Search(segment, plan, placeholderGroup, timestamps[0], &res1, -1);
assert(res.error_code == Success);
res = Search(segment, plan, placeholderGroup, timestamps[0], &res2, -1);
assert(res.error_code == Success);
results.push_back(res1);
results.push_back(res2);
status = ReduceSearchResultsAndFillData(plan, results.data(), results.size());
assert(status.error_code == Success);
void* reorganize_search_result = nullptr;
status = ReorganizeSearchResults(&reorganize_search_result, results.data(), results.size());
assert(status.error_code == Success);
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
assert(hits_blob_size > 0);
std::vector<char> hits_blob;
hits_blob.resize(hits_blob_size);
GetHitsBlob(reorganize_search_result, hits_blob.data());
assert(hits_blob.data() != nullptr);
auto num_queries_group = GetNumQueriesPerGroup(reorganize_search_result, 0);
assert(num_queries_group == num_queries);
std::vector<int64_t> hit_size_per_query;
hit_size_per_query.resize(num_queries_group);
GetHitSizePerQueries(reorganize_search_result, 0, hit_size_per_query.data());
assert(hit_size_per_query[0] > 0);
DeleteSearchPlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteSearchResult(res1);
DeleteSearchResult(res2);
DeleteMarshaledHits(reorganize_search_result);
DeleteCollection(collection);
DeleteSegment(segment);
}
TEST(CApiTest, ReduceSearchWithExpr) {
auto collection = NewCollection(get_default_schema_config());
auto segment = NewSegment(collection, Growing, -1);
......@@ -941,9 +830,10 @@ TEST(CApiTest, ReduceSearchWithExpr) {
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
placeholder_tag: "$0">
output_field_ids: 100)";
int topK = 10;
int num_queries = 10;
auto blob = generate_query_data(num_queries);
......@@ -971,29 +861,34 @@ TEST(CApiTest, ReduceSearchWithExpr) {
results.push_back(res1);
results.push_back(res2);
// 1. reduce
status = ReduceSearchResultsAndFillData(plan, results.data(), results.size());
assert(status.error_code == Success);
void* reorganize_search_result = nullptr;
status = ReorganizeSearchResults(&reorganize_search_result, results.data(), results.size());
// 2. marshal
CSearchResultDataBlobs cSearchResultData;
auto req_sizes = std::vector<int32_t>{5, 5};
status = Marshal(&cSearchResultData, results.data(), results.size(), req_sizes.data(), req_sizes.size());
assert(status.error_code == Success);
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
assert(hits_blob_size > 0);
std::vector<char> hits_blob;
hits_blob.resize(hits_blob_size);
GetHitsBlob(reorganize_search_result, hits_blob.data());
assert(hits_blob.data() != nullptr);
auto num_queries_group = GetNumQueriesPerGroup(reorganize_search_result, 0);
assert(num_queries_group == num_queries);
std::vector<int64_t> hit_size_per_query;
hit_size_per_query.resize(num_queries_group);
GetHitSizePerQueries(reorganize_search_result, 0, hit_size_per_query.data());
assert(hit_size_per_query[0] > 0);
auto search_result_data_blobs = reinterpret_cast<milvus::segcore::SearchResultDataBlobs*>(cSearchResultData);
// check result
for (int i = 0; i < req_sizes.size(); i++) {
milvus::proto::schema::SearchResultData search_result_data;
auto suc = search_result_data.ParseFromArray(search_result_data_blobs->blobs[i].data(),
search_result_data_blobs->blobs[i].size());
assert(suc);
assert(search_result_data.top_k() == topK);
assert(search_result_data.num_queries() == num_queries);
assert(search_result_data.scores().size() == topK * req_sizes[i]);
assert(search_result_data.ids().int_id().data_size() == topK * req_sizes[i]);
}
DeleteSearchResultDataBlobs(cSearchResultData);
DeleteSearchPlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteSearchResult(res1);
DeleteSearchResult(res2);
DeleteMarshaledHits(reorganize_search_result);
DeleteCollection(collection);
DeleteSegment(segment);
}
......
......@@ -637,9 +637,9 @@ TEST(Query, FillSegment) {
// dispatch here
int N = 100000;
auto dataset = DataGen(schema, N);
const auto std_vec = dataset.get_col<int64_t>(1);
const auto std_vfloat_vec = dataset.get_col<float>(0);
const auto std_i32_vec = dataset.get_col<int32_t>(2);
const auto std_vec = dataset.get_col<int64_t>(1); // ids field
const auto std_vfloat_vec = dataset.get_col<float>(0); // vector field
const auto std_i32_vec = dataset.get_col<int32_t>(2); // scalar field
std::vector<std::unique_ptr<SegmentInternalInterface>> segments;
segments.emplace_back([&] {
......@@ -701,16 +701,20 @@ TEST(Query, FillSegment) {
result->result_offsets_.resize(topk * num_queries);
segment->FillTargetEntry(plan.get(), *result);
auto ans = result->row_data_;
ASSERT_EQ(ans.size(), topk * num_queries);
int64_t std_index = 0;
auto fields_data = result->output_fields_data_;
auto fields_meta = result->output_fields_meta_;
ASSERT_EQ(fields_data.size(), 2);
ASSERT_EQ(fields_data.size(), 2);
ASSERT_EQ(fields_meta[0].get_sizeof(), sizeof(float) * dim);
ASSERT_EQ(fields_meta[1].get_sizeof(), sizeof(int32_t));
ASSERT_EQ(fields_data[0].size(), fields_meta[0].get_sizeof() * topk * num_queries);
ASSERT_EQ(fields_data[1].size(), fields_meta[1].get_sizeof() * topk * num_queries);
for (auto& vec : ans) {
ASSERT_EQ(vec.size(), sizeof(int64_t) + sizeof(float) * dim + sizeof(int32_t));
for (int i = 0; i < topk * num_queries; i++) {
int64_t val;
memcpy(&val, vec.data(), sizeof(int64_t));
memcpy(&val, &result->ids_data_[i * sizeof(int64_t)], sizeof(int64_t));
auto internal_offset = result->ids_[std_index];
auto internal_offset = result->ids_[i];
auto std_val = std_vec[internal_offset];
auto std_i32 = std_i32_vec[internal_offset];
std::vector<float> std_vfloat(dim);
......@@ -718,14 +722,16 @@ TEST(Query, FillSegment) {
ASSERT_EQ(val, std_val) << "io:" << internal_offset;
if (val != -1) {
// check vector field
std::vector<float> vfloat(dim);
memcpy(vfloat.data(), &fields_data[0][i * sizeof(float) * dim], dim * sizeof(float));
ASSERT_EQ(vfloat, std_vfloat);
// check int32 field
int i32;
memcpy(vfloat.data(), vec.data() + sizeof(int64_t), dim * sizeof(float));
memcpy(&i32, vec.data() + sizeof(int64_t) + dim * sizeof(float), sizeof(int32_t));
ASSERT_EQ(vfloat, std_vfloat) << std_index;
ASSERT_EQ(i32, std_i32) << std_index;
memcpy(&i32, &fields_data[1][i * sizeof(int32_t)], sizeof(int32_t));
ASSERT_EQ(i32, std_i32);
}
++std_index;
}
}
}
......
......@@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/util/cgoconverter"
)
// HandleCStatus deals with the error returned from CGO
......@@ -54,3 +55,13 @@ func HandleCStatus(status *C.CStatus, extraInfo string) error {
log.Warn(logMsg)
return errors.New(finalMsg)
}
func CopyCProtoBlob(cProto *C.CProto) []byte {
blob := C.GoBytes(unsafe.Pointer(cProto.proto_blob), C.int32_t(cProto.proto_size))
return blob
}
func GetCProtoBlob(cProto *C.CProto) []byte {
_, blob := cgoconverter.UnsafeGoBytes(&cProto.proto_blob, int(cProto.proto_size))
return blob
}
......@@ -1617,6 +1617,62 @@ func produceSimpleRetrieveMsg(ctx context.Context, queryChannel Channel) error {
return nil
}
func checkSearchResult(nq int64, plan *SearchPlan, searchResult *SearchResult) error {
searchResults := make([]*SearchResult, 0)
searchResults = append(searchResults, searchResult)
err := reduceSearchResultsAndFillData(plan, searchResults, 1)
if err != nil {
return err
}
nqOfReqs := []int64{nq / 5, nq / 5, nq / 5, nq / 5, nq / 5}
nqPerSlice := nq / 5
reqSlices, err := getReqSlices(nqOfReqs, nqPerSlice)
if err != nil {
return err
}
res, err := marshal(defaultCollectionID, UniqueID(0), searchResults, 1, reqSlices)
if err != nil {
return err
}
for i := 0; i < len(reqSlices); i++ {
blob, err := getSearchResultDataBlob(res, i)
if err != nil {
return err
}
if len(blob) == 0 {
return fmt.Errorf("wrong search result data blobs when checkSearchResult")
}
result := &schemapb.SearchResultData{}
err = proto.Unmarshal(blob, result)
if err != nil {
return err
}
if result.TopK != defaultTopK {
return fmt.Errorf("unexpected topK when checkSearchResult")
}
if result.NumQueries != nq {
return fmt.Errorf("unexpected nq when checkSearchResult")
}
if len(result.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data) != int(defaultTopK*nq/5) {
return fmt.Errorf("unexpected Ids when checkSearchResult")
}
if len(result.Scores) != int(defaultTopK*nq/5) {
return fmt.Errorf("unexpected Scores when checkSearchResult")
}
}
deleteSearchResults(searchResults)
deleteSearchResultDataBlobs(res)
return nil
}
func initConsumer(ctx context.Context, queryResultChannel Channel) (msgstream.MsgStream, error) {
stream, err := genQueryMsgStream(ctx)
if err != nil {
......
......@@ -1005,11 +1005,6 @@ func (q *queryCollection) search(msg queryMsg) error {
return err
}
schema, err := typeutil.CreateSchemaHelper(collection.schema)
if err != nil {
return err
}
var plan *SearchPlan
if searchMsg.GetDslType() == commonpb.DslType_BoolExprV1 {
expr := searchMsg.SerializedExprPlan
......@@ -1041,21 +1036,21 @@ func (q *queryCollection) search(msg queryMsg) error {
}
defer searchReq.delete()
queryNum := searchReq.getNumOfQuery()
nq := searchReq.getNumOfQuery()
searchRequests := make([]*searchRequest, 0)
searchRequests = append(searchRequests, searchReq)
if searchMsg.GetDslType() == commonpb.DslType_BoolExprV1 {
sp.LogFields(oplog.String("statistical time", "stats start"),
oplog.Object("nq", queryNum),
oplog.Object("nq", nq),
oplog.Object("expr", searchMsg.SerializedExprPlan))
} else {
sp.LogFields(oplog.String("statistical time", "stats start"),
oplog.Object("nq", queryNum),
oplog.Object("nq", nq),
oplog.Object("dsl", searchMsg.Dsl))
}
tr := timerecord.NewTimeRecorder(fmt.Sprintf("search %d(nq=%d, k=%d), msgID = %d", searchMsg.CollectionID, queryNum, topK, searchMsg.ID()))
tr := timerecord.NewTimeRecorder(fmt.Sprintf("search %d(nq=%d, k=%d), msgID = %d", searchMsg.CollectionID, nq, topK, searchMsg.ID()))
// get global sealed segments
var globalSealedSegments []UniqueID
......@@ -1106,7 +1101,7 @@ func (q *queryCollection) search(msg queryMsg) error {
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
ResultChannelID: searchMsg.ResultChannelID,
MetricType: plan.getMetricType(),
NumQueries: queryNum,
NumQueries: nq,
TopK: topK,
SlicedBlob: nil,
SlicedOffset: 1,
......@@ -1136,7 +1131,7 @@ func (q *queryCollection) search(msg queryMsg) error {
}
numSegment := int64(len(searchResults))
var marshaledHits *MarshaledHits
log.Debug("QueryNode reduce data", zap.Int64("msgID", searchMsg.ID()), zap.Int64("numSegment", numSegment))
tr.RecordSpan()
err = reduceSearchResultsAndFillData(plan, searchResults, numSegment)
......@@ -1146,49 +1141,22 @@ func (q *queryCollection) search(msg queryMsg) error {
log.Error("QueryNode reduce data failed", zap.Int64("msgID", searchMsg.ID()), zap.Error(err))
return err
}
marshaledHits, err = reorganizeSearchResults(searchResults, numSegment)
sp.LogFields(oplog.String("statistical time", "reorganizeSearchResults end"))
nqOfReqs := []int64{nq}
nqPerSlice := nq
reqSlices, err := getReqSlices(nqOfReqs, nqPerSlice)
if err != nil {
return err
}
defer deleteMarshaledHits(marshaledHits)
hitsBlob, err := marshaledHits.getHitsBlob()
sp.LogFields(oplog.String("statistical time", "getHitsBlob end"))
blobs, err := marshal(collectionID, searchMsg.ID(), searchResults, int(numSegment), reqSlices)
defer deleteSearchResultDataBlobs(blobs)
sp.LogFields(oplog.String("statistical time", "reorganizeSearchResults end"))
if err != nil {
return err
}
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.QueryNodeID), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
var offset int64
for index := range searchRequests {
hitBlobSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
if err != nil {
return err
}
hits := make([][]byte, len(hitBlobSizePeerQuery))
for i, len := range hitBlobSizePeerQuery {
hits[i] = hitsBlob[offset : offset+len]
//test code to checkout marshaled hits
//marshaledHit := hitsBlob[offset:offset+len]
//unMarshaledHit := milvuspb.Hits{}
//err = proto.Unmarshal(marshaledHit, &unMarshaledHit)
//if err != nil {
// return err
//}
//log.Debug("hits msg = ", unMarshaledHit)
offset += len
}
// TODO: remove inefficient code in cgo and use SearchResultData directly
// TODO: Currently add a translate layer from hits to SearchResultData
// TODO: hits marshal and unmarshal is likely bottleneck
transformed, err := translateHits(schema, searchMsg.OutputFieldsId, hits)
if err != nil {
return err
}
byteBlobs, err := proto.Marshal(transformed)
for i := 0; i < len(reqSlices); i++ {
blob, err := getSearchResultDataBlob(blobs, i)
if err != nil {
return err
}
......@@ -1206,9 +1174,9 @@ func (q *queryCollection) search(msg queryMsg) error {
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
ResultChannelID: searchMsg.ResultChannelID,
MetricType: plan.getMetricType(),
NumQueries: queryNum,
NumQueries: nq,
TopK: topK,
SlicedBlob: byteBlobs,
SlicedBlob: blob,
SlicedOffset: 1,
SlicedNumCount: 1,
SealedSegmentIDsSearched: sealedSegmentSearched,
......
......@@ -28,7 +28,11 @@ package querynode
import "C"
import (
"errors"
"unsafe"
"fmt"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
)
// SearchResult contains a pointer to the search result in C++ memory
......@@ -36,10 +40,8 @@ type SearchResult struct {
cSearchResult C.CSearchResult
}
// MarshaledHits contains a pointer to the marshaled hits in C++ memory
type MarshaledHits struct {
cMarshaledHits C.CMarshaledHits
}
// searchResultDataBlobs is the CSearchResultsDataBlobs in C++
type searchResultDataBlobs = C.CSearchResultDataBlobs
// RetrieveResult contains a pointer to the retrieve result in C++ memory
type RetrieveResult struct {
......@@ -65,47 +67,59 @@ func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchRes
return nil
}
func reorganizeSearchResults(searchResults []*SearchResult, numSegments int64) (*MarshaledHits, error) {
func marshal(collectionID UniqueID, msgID UniqueID, searchResults []*SearchResult, numSegments int, reqSlices []int32) (searchResultDataBlobs, error) {
log.Debug("start marshal...",
zap.Int64("collectionID", collectionID),
zap.Int64("msgID", msgID),
zap.Int32s("reqSlices", reqSlices))
cSearchResults := make([]C.CSearchResult, 0)
for _, res := range searchResults {
cSearchResults = append(cSearchResults, res.cSearchResult)
}
cSearchResultPtr := (*C.CSearchResult)(&cSearchResults[0])
var cNumSegments = C.int64_t(numSegments)
var cMarshaledHits C.CMarshaledHits
var cNumSegments = C.int32_t(numSegments)
var cSlicesPtr = (*C.int32_t)(&reqSlices[0])
var cNumSlices = C.int32_t(len(reqSlices))
var cSearchResultDataBlobs searchResultDataBlobs
status := C.ReorganizeSearchResults(&cMarshaledHits, cSearchResultPtr, cNumSegments)
status := C.Marshal(&cSearchResultDataBlobs, cSearchResultPtr, cNumSegments, cSlicesPtr, cNumSlices)
if err := HandleCStatus(&status, "ReorganizeSearchResults failed"); err != nil {
return nil, err
}
return &MarshaledHits{cMarshaledHits: cMarshaledHits}, nil
return cSearchResultDataBlobs, nil
}
func (mh *MarshaledHits) getHitsBlobSize() int64 {
res := C.GetHitsBlobSize(mh.cMarshaledHits)
return int64(res)
}
func getReqSlices(nqOfReqs []int64, nqPerSlice int64) ([]int32, error) {
if nqPerSlice == 0 {
return nil, fmt.Errorf("zero nqPerSlice is not allowed")
}
func (mh *MarshaledHits) getHitsBlob() ([]byte, error) {
byteSize := mh.getHitsBlobSize()
result := make([]byte, byteSize)
cResultPtr := unsafe.Pointer(&result[0])
C.GetHitsBlob(mh.cMarshaledHits, cResultPtr)
return result, nil
slices := make([]int32, 0)
for i := 0; i < len(nqOfReqs); i++ {
for j := 0; j < int(nqOfReqs[i]/nqPerSlice); j++ {
slices = append(slices, int32(nqPerSlice))
}
if tailSliceSize := nqOfReqs[i] % nqPerSlice; tailSliceSize > 0 {
slices = append(slices, int32(tailSliceSize))
}
}
return slices, nil
}
func (mh *MarshaledHits) hitBlobSizeInGroup(groupOffset int64) ([]int64, error) {
cGroupOffset := (C.int64_t)(groupOffset)
numQueries := C.GetNumQueriesPerGroup(mh.cMarshaledHits, cGroupOffset)
result := make([]int64, int64(numQueries))
cResult := (*C.int64_t)(&result[0])
C.GetHitSizePerQueries(mh.cMarshaledHits, cGroupOffset, cResult)
return result, nil
func getSearchResultDataBlob(cSearchResultDataBlobs searchResultDataBlobs, blobIndex int) ([]byte, error) {
var blob C.CProto
status := C.GetSearchResultDataBlob(&blob, cSearchResultDataBlobs, C.int32_t(blobIndex))
if err := HandleCStatus(&status, "marshal failed"); err != nil {
return nil, err
}
return GetCProtoBlob(&blob), nil
}
func deleteMarshaledHits(hits *MarshaledHits) {
C.DeleteMarshaledHits(hits.cMarshaledHits)
func deleteSearchResultDataBlobs(cSearchResultDataBlobs searchResultDataBlobs) {
C.DeleteSearchResultDataBlobs(cSearchResultDataBlobs)
}
func deleteSearchResults(results []*SearchResult) {
......
......@@ -17,6 +17,7 @@
package querynode
import (
"context"
"log"
"math"
"testing"
......@@ -29,35 +30,36 @@ import (
)
func TestReduce_AllFunc(t *testing.T) {
collectionID := UniqueID(0)
segmentID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionID, false)
nq := int64(10)
collection := newCollection(collectionMeta.ID, collectionMeta.Schema)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true)
assert.Nil(t, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
const DIM = 16
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
collection, err := node.historical.replica.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
// start search service
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
var searchRawData1 []byte
var searchRawData2 []byte
segment, err := node.historical.replica.getSegmentByID(defaultSegmentID)
assert.NoError(t, err)
// TODO: replace below by genPlaceholderGroup(nq)
vec := genSimpleFloatVectors()
var searchRawData []byte
for i, ele := range vec {
buf := make([]byte, 4)
common.Endian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
searchRawData1 = append(searchRawData1, buf...)
}
for i, ele := range vec {
buf := make([]byte, 4)
common.Endian.PutUint32(buf, math.Float32bits(ele+float32(i*4)))
searchRawData2 = append(searchRawData2, buf...)
searchRawData = append(searchRawData, buf...)
}
placeholderValue := milvuspb.PlaceholderValue{
Tag: "$0",
Type: milvuspb.PlaceholderType_FloatVector,
Values: [][]byte{searchRawData1, searchRawData2},
Values: [][]byte{},
}
for i := 0; i < int(nq); i++ {
placeholderValue.Values = append(placeholderValue.Values, searchRawData)
}
placeholderGroup := milvuspb.PlaceholderGroup{
......@@ -69,46 +71,24 @@ func TestReduce_AllFunc(t *testing.T) {
log.Print("marshal placeholderGroup failed")
}
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
plan, err := createSearchPlan(collection, dslString)
assert.NoError(t, err)
holder, err := parseSearchRequest(plan, placeGroupByte)
assert.NoError(t, err)
placeholderGroups := make([]*searchRequest, 0)
placeholderGroups = append(placeholderGroups, holder)
searchResults := make([]*SearchResult, 0)
searchResult, err := segment.search(plan, placeholderGroups, []Timestamp{0})
assert.Nil(t, err)
searchResults = append(searchResults, searchResult)
err = reduceSearchResultsAndFillData(plan, searchResults, 1)
assert.Nil(t, err)
marshaledHits, err := reorganizeSearchResults(searchResults, 1)
assert.NotNil(t, marshaledHits)
assert.Nil(t, err)
hitsBlob, err := marshaledHits.getHitsBlob()
assert.Nil(t, err)
var offset int64
for index := range placeholderGroups {
hitBolbSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
assert.Nil(t, err)
for _, len := range hitBolbSizePeerQuery {
marshaledHit := hitsBlob[offset : offset+len]
unMarshaledHit := milvuspb.Hits{}
err = proto.Unmarshal(marshaledHit, &unMarshaledHit)
assert.Nil(t, err)
log.Println("hits msg = ", unMarshaledHit)
offset += len
}
}
assert.NoError(t, err)
err = checkSearchResult(nq, plan, searchResult)
assert.NoError(t, err)
plan.delete()
holder.delete()
deleteSearchResults(searchResults)
deleteMarshaledHits(marshaledHits)
deleteSegment(segment)
deleteCollection(collection)
}
......
......@@ -474,104 +474,61 @@ func TestSegment_segmentDelete(t *testing.T) {
}
func TestSegment_segmentSearch(t *testing.T) {
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionID, false)
collection := newCollection(collectionMeta.ID, collectionMeta.Schema)
assert.Equal(t, collection.ID(), collectionID)
segmentID := UniqueID(0)
segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true)
assert.Equal(t, segmentID, segment.segmentID)
assert.Nil(t, err)
ids := []int64{1, 2, 3}
timestamps := []uint64{0, 0, 0}
const DIM = 16
const N = 3
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
var rawData []byte
for _, ele := range vec {
buf := make([]byte, 4)
common.Endian.PutUint32(buf, math.Float32bits(ele))
rawData = append(rawData, buf...)
}
bs := make([]byte, 4)
common.Endian.PutUint32(bs, 1)
rawData = append(rawData, bs...)
var records []*commonpb.Blob
for i := 0; i < N; i++ {
blob := &commonpb.Blob{
Value: rawData,
}
records = append(records, blob)
}
nq := int64(10)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
offset, err := segment.segmentPreInsert(N)
assert.Nil(t, err)
assert.GreaterOrEqual(t, offset, int64(0))
collection, err := node.historical.replica.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
err = segment.segmentInsert(offset, &ids, &timestamps, &records)
segment, err := node.historical.replica.getSegmentByID(defaultSegmentID)
assert.NoError(t, err)
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
// TODO: replace below by genPlaceholderGroup(nq)
vec := genSimpleFloatVectors()
var searchRawData []byte
for _, ele := range vec {
for i, ele := range vec {
buf := make([]byte, 4)
common.Endian.PutUint32(buf, math.Float32bits(ele))
common.Endian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
searchRawData = append(searchRawData, buf...)
}
placeholderValue := milvuspb.PlaceholderValue{
Tag: "$0",
Type: milvuspb.PlaceholderType_FloatVector,
Values: [][]byte{searchRawData},
Values: [][]byte{},
}
for i := 0; i < int(nq); i++ {
placeholderValue.Values = append(placeholderValue.Values, searchRawData)
}
placeholderGroup := milvuspb.PlaceholderGroup{
Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue},
}
placeHolderGroupBlob, err := proto.Marshal(&placeholderGroup)
placeGroupByte, err := proto.Marshal(&placeholderGroup)
if err != nil {
log.Print("marshal placeholderGroup failed")
}
travelTimestamp := Timestamp(1020)
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }"
plan, err := createSearchPlan(collection, dslString)
assert.NoError(t, err)
holder, err := parseSearchRequest(plan, placeHolderGroupBlob)
holder, err := parseSearchRequest(plan, placeGroupByte)
assert.NoError(t, err)
placeholderGroups := make([]*searchRequest, 0)
placeholderGroups = append(placeholderGroups, holder)
searchResults := make([]*SearchResult, 0)
searchResult, err := segment.search(plan, placeholderGroups, []Timestamp{travelTimestamp})
assert.Nil(t, err)
searchResults = append(searchResults, searchResult)
///////////////////////////////////
numSegment := int64(len(searchResults))
err = reduceSearchResultsAndFillData(plan, searchResults, numSegment)
assert.NoError(t, err)
marshaledHits, err := reorganizeSearchResults(searchResults, numSegment)
assert.NoError(t, err)
hitsBlob, err := marshaledHits.getHitsBlob()
searchResult, err := segment.search(plan, placeholderGroups, []Timestamp{0})
assert.NoError(t, err)
var placeHolderOffset int64
for index := range placeholderGroups {
hitBlobSizePeerQuery, err := marshaledHits.hitBlobSizeInGroup(int64(index))
assert.NoError(t, err)
hits := make([][]byte, 0)
for _, len := range hitBlobSizePeerQuery {
hits = append(hits, hitsBlob[placeHolderOffset:placeHolderOffset+len])
placeHolderOffset += len
}
}
deleteSearchResults(searchResults)
deleteMarshaledHits(marshaledHits)
///////////////////////////////////
err = checkSearchResult(nq, plan, searchResult)
assert.NoError(t, err)
plan.delete()
holder.delete()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册