未验证 提交 110d4d06 编写于 作者: Y yukun 提交者: GitHub

TargetEntry implementation (#2391)

* Add GetEntitiesByID in DBImpl
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add operator overload functions in ConcurrentBitset
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add GetEntityByID interface
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Change format of Attributes
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Change hybrid search for new rules
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix C++ sdk for new format of hybrid interfaces
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix compile bugs
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix GetEntityByID
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix unittest bugs
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix WebControllerTest:test_hybrid bugs
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix field names bug in HybridSearch
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix HYBRID_SEARCH_TEST caused by const auto&
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add ConvertRowToColumnJson in WebRequestHandler
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Return target entry in WebServer
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Change ValidateBinaryQuery
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add GetEntityByID in WebServer
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* [skip ci]Removed unused code in C++ sdk
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>
上级 a147836b
......@@ -18,6 +18,7 @@
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "segment/Attrs.h"
......@@ -36,6 +37,10 @@ class AttrsFormat {
virtual void
read_uids(const storage::FSHandlerPtr& fs_ptr, std::vector<int64_t>& uids) = 0;
virtual void
read_attrs(const storage::FSHandlerPtr& fs_ptr, const std::string& field_name, off_t offset, size_t num_bytes,
std::vector<uint8_t>& raw_attrs) = 0;
};
using AttrsFormatPtr = std::shared_ptr<AttrsFormat>;
......
......@@ -185,6 +185,34 @@ DefaultAttrsFormat::write(const milvus::storage::FSHandlerPtr& fs_ptr, const mil
}
}
void
DefaultAttrsFormat::read_attrs(const milvus::storage::FSHandlerPtr& fs_ptr, const std::string& field_name, off_t offset,
size_t num_bytes, std::vector<uint8_t>& raw_attrs) {
const std::lock_guard<std::mutex> lock(mutex_);
std::string dir_path = fs_ptr->operation_ptr_->GetDirectory();
if (!boost::filesystem::is_directory(dir_path)) {
std::string err_msg = "Directory: " + dir_path + "does not exist";
LOG_ENGINE_ERROR_ << err_msg;
throw Exception(SERVER_INVALID_ARGUMENT, err_msg);
}
boost::filesystem::path target_path(dir_path);
typedef boost::filesystem::directory_iterator d_it;
d_it it_end;
d_it it(target_path);
for (; it != it_end; ++it) {
const auto& path = it->path();
std::string file_name = path.filename().string();
if (path.extension().string() == raw_attr_extension_ &&
file_name.substr(0, file_name.size() - 3) == field_name) {
size_t nbytes;
read_attrs_internal(fs_ptr, path.string(), offset, num_bytes, raw_attrs, nbytes);
}
}
}
void
DefaultAttrsFormat::read_uids(const milvus::storage::FSHandlerPtr& fs_ptr, std::vector<int64_t>& uids) {
const std::lock_guard<std::mutex> lock(mutex_);
......
......@@ -37,6 +37,10 @@ class DefaultAttrsFormat : public AttrsFormat {
void
write(const storage::FSHandlerPtr& fs_ptr, const segment::AttrsPtr& attr) override;
void
read_attrs(const storage::FSHandlerPtr& fs_ptr, const std::string& field_name, off_t offset, size_t num_bytes,
std::vector<uint8_t>& raw_attrs) override;
void
read_uids(const storage::FSHandlerPtr& fs_ptr, std::vector<int64_t>& uids) override;
......
......@@ -114,6 +114,10 @@ class DB {
GetVectorsByID(const std::string& collection_id, const IDNumbers& id_array,
std::vector<engine::VectorsData>& vectors) = 0;
virtual Status
GetEntitiesByID(const std::string& collection_id, const IDNumbers& id_array,
std::vector<engine::VectorsData>& vectors, std::vector<engine::AttrsData>& attrs) = 0;
virtual Status
GetVectorIDs(const std::string& collection_id, const std::string& segment_id, IDNumbers& vector_ids) = 0;
......@@ -165,9 +169,9 @@ class DB {
virtual Status
HybridQuery(const std::shared_ptr<server::Context>& context, const std::string& collection_id,
const std::vector<std::string>& partition_tags, context::HybridSearchContextPtr hybrid_search_context,
query::GeneralQueryPtr general_query,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type, uint64_t& nq,
engine::ResultIds& result_ids, engine::ResultDistances& result_distances) = 0;
query::GeneralQueryPtr general_query, std::vector<std::string>& field_name,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
engine::QueryResult& result) = 0;
}; // DB
using DBPtr = std::shared_ptr<DB>;
......
......@@ -1178,6 +1178,77 @@ DBImpl::GetVectorsByID(const std::string& collection_id, const IDNumbers& id_arr
return status;
}
Status
DBImpl::GetEntitiesByID(const std::string& collection_id, const milvus::engine::IDNumbers& id_array,
std::vector<engine::VectorsData>& vectors, std::vector<engine::AttrsData>& attrs) {
if (!initialized_.load(std::memory_order_acquire)) {
return SHUTDOWN_ERROR;
}
bool has_collection;
auto status = HasCollection(collection_id, has_collection);
if (!has_collection) {
LOG_ENGINE_ERROR_ << "Collection " << collection_id << " does not exist: ";
return Status(DB_NOT_FOUND, "Collection does not exist");
}
if (!status.ok()) {
return status;
}
engine::meta::CollectionSchema collection_schema;
engine::meta::hybrid::FieldsSchema fields_schema;
collection_schema.collection_id_ = collection_id;
status = meta_ptr_->DescribeHybridCollection(collection_schema, fields_schema);
if (!status.ok()) {
return status;
}
std::unordered_map<std::string, engine::meta::hybrid::DataType> attr_type;
for (auto schema : fields_schema.fields_schema_) {
if (schema.field_type_ == (int32_t)engine::meta::hybrid::DataType::VECTOR) {
continue;
}
attr_type.insert(std::make_pair(schema.field_name_, (engine::meta::hybrid::DataType)schema.field_type_));
}
meta::FilesHolder files_holder;
std::vector<int> file_types{meta::SegmentSchema::FILE_TYPE::RAW, meta::SegmentSchema::FILE_TYPE::TO_INDEX,
meta::SegmentSchema::FILE_TYPE::BACKUP};
status = meta_ptr_->FilesByType(collection_id, file_types, files_holder);
if (!status.ok()) {
std::string err_msg = "Failed to get files for GetEntitiesByID: " + status.message();
LOG_ENGINE_ERROR_ << err_msg;
return status;
}
std::vector<meta::CollectionSchema> partition_array;
status = meta_ptr_->ShowPartitions(collection_id, partition_array);
if (!status.ok()) {
std::string err_msg = "Failed to get partitions for GetEntitiesByID: " + status.message();
LOG_ENGINE_ERROR_ << err_msg;
return status;
}
for (auto& schema : partition_array) {
status = meta_ptr_->FilesByType(schema.collection_id_, file_types, files_holder);
if (!status.ok()) {
std::string err_msg = "Failed to get files for GetEntitiesByID: " + status.message();
LOG_ENGINE_ERROR_ << err_msg;
return status;
}
}
if (files_holder.HoldFiles().empty()) {
LOG_ENGINE_DEBUG_ << "No files to get vector by id from";
return Status(DB_NOT_FOUND, "Collection is empty");
}
cache::CpuCacheMgr::GetInstance()->PrintInfo();
status = GetEntitiesByIdHelper(collection_id, id_array, attr_type, vectors, attrs, files_holder);
cache::CpuCacheMgr::GetInstance()->PrintInfo();
return status;
}
Status
DBImpl::GetVectorIDs(const std::string& collection_id, const std::string& segment_id, IDNumbers& vector_ids) {
if (!initialized_.load(std::memory_order_acquire)) {
......@@ -1359,6 +1430,166 @@ DBImpl::GetVectorsByIdHelper(const std::string& collection_id, const IDNumbers&
return Status::OK();
}
Status
DBImpl::GetEntitiesByIdHelper(const std::string& collection_id, const milvus::engine::IDNumbers& id_array,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
std::vector<engine::VectorsData>& vectors, std::vector<engine::AttrsData>& attrs,
milvus::engine::meta::FilesHolder& files_holder) {
// attention: this is a copy, not a reference, since the files_holder.UnMarkFile will change the array internal
milvus::engine::meta::SegmentsSchema files = files_holder.HoldFiles();
LOG_ENGINE_DEBUG_ << "Getting vector by id in " << files.size() << " files, id count = " << id_array.size();
// sometimes not all of id_array can be found, we need to return empty vector for id not found
// for example:
// id_array = [1, -1, 2, -1, 3]
// vectors should return [valid_vector, empty_vector, valid_vector, empty_vector, valid_vector]
// the ID2RAW is to ensure returned vector sequence is consist with id_array
using ID2ATTR = std::map<int64_t, engine::AttrsData>;
using ID2VECTOR = std::map<int64_t, engine::VectorsData>;
ID2ATTR map_id2attr;
ID2VECTOR map_id2vector;
IDNumbers temp_ids = id_array;
for (auto& file : files) {
// Load bloom filter
std::string segment_dir;
engine::utils::GetParentPath(file.location_, segment_dir);
segment::SegmentReader segment_reader(segment_dir);
segment::IdBloomFilterPtr id_bloom_filter_ptr;
segment_reader.LoadBloomFilter(id_bloom_filter_ptr);
for (IDNumbers::iterator it = temp_ids.begin(); it != temp_ids.end();) {
int64_t vector_id = *it;
// each id must has a VectorsData
// if vector not found for an id, its VectorsData's vector_count = 0, else 1
AttrsData& attr_ref = map_id2attr[vector_id];
VectorsData& vector_ref = map_id2vector[vector_id];
// Check if the id is present in bloom filter.
if (id_bloom_filter_ptr->Check(vector_id)) {
// Load uids and check if the id is indeed present. If yes, find its offset.
std::vector<segment::doc_id_t> uids;
auto status = segment_reader.LoadUids(uids);
if (!status.ok()) {
return status;
}
auto found = std::find(uids.begin(), uids.end(), vector_id);
if (found != uids.end()) {
auto offset = std::distance(uids.begin(), found);
// Check whether the id has been deleted
segment::DeletedDocsPtr deleted_docs_ptr;
status = segment_reader.LoadDeletedDocs(deleted_docs_ptr);
if (!status.ok()) {
LOG_ENGINE_ERROR_ << status.message();
return status;
}
auto& deleted_docs = deleted_docs_ptr->GetDeletedDocs();
auto deleted = std::find(deleted_docs.begin(), deleted_docs.end(), offset);
if (deleted == deleted_docs.end()) {
// Load raw vector
bool is_binary = utils::IsBinaryMetricType(file.metric_type_);
size_t single_vector_bytes = is_binary ? file.dimension_ / 8 : file.dimension_ * sizeof(float);
std::vector<uint8_t> raw_vector;
status =
segment_reader.LoadVectors(offset * single_vector_bytes, single_vector_bytes, raw_vector);
if (!status.ok()) {
LOG_ENGINE_ERROR_ << status.message();
return status;
}
std::unordered_map<std::string, std::vector<uint8_t>> raw_attrs;
auto attr_it = attr_type.begin();
for (; attr_it != attr_type.end(); attr_it++) {
size_t num_bytes;
switch (attr_it->second) {
case engine::meta::hybrid::DataType::INT8: {
num_bytes = 1;
break;
}
case engine::meta::hybrid::DataType::INT16: {
num_bytes = 2;
break;
}
case engine::meta::hybrid::DataType::INT32: {
num_bytes = 4;
break;
}
case engine::meta::hybrid::DataType::INT64: {
num_bytes = 8;
break;
}
case engine::meta::hybrid::DataType::FLOAT: {
num_bytes = 4;
break;
}
case engine::meta::hybrid::DataType::DOUBLE: {
num_bytes = 8;
break;
}
default: {
std::string msg = "Field type of " + attr_it->first + " is wrong";
return Status{DB_ERROR, msg};
}
}
std::vector<uint8_t> raw_attr;
status = segment_reader.LoadAttrs(attr_it->first, offset * num_bytes, num_bytes, raw_attr);
if (!status.ok()) {
LOG_ENGINE_ERROR_ << status.message();
return status;
}
raw_attrs.insert(std::make_pair(attr_it->first, raw_attr));
}
vector_ref.vector_count_ = 1;
if (is_binary) {
vector_ref.binary_data_.swap(raw_vector);
} else {
std::vector<float> float_vector;
float_vector.resize(file.dimension_);
memcpy(float_vector.data(), raw_vector.data(), single_vector_bytes);
vector_ref.float_data_.swap(float_vector);
}
attr_ref.attr_count_ = 1;
attr_ref.attr_data_ = raw_attrs;
attr_ref.attr_type_ = attr_type;
temp_ids.erase(it);
continue;
}
}
}
it++;
}
// unmark file, allow the file to be deleted
files_holder.UnmarkFile(file);
}
for (auto id : id_array) {
VectorsData& vector_ref = map_id2vector[id];
VectorsData data;
data.vector_count_ = vector_ref.vector_count_;
if (data.vector_count_ > 0) {
data.float_data_ = vector_ref.float_data_; // copy data since there could be duplicated id
data.binary_data_ = vector_ref.binary_data_; // copy data since there could be duplicated id
}
vectors.emplace_back(data);
attrs.emplace_back(map_id2attr[id]);
}
if (vectors.empty()) {
std::string msg = "Vectors not found in collection " + collection_id;
LOG_ENGINE_DEBUG_ << msg;
}
return Status::OK();
}
Status
DBImpl::CreateIndex(const std::shared_ptr<server::Context>& context, const std::string& collection_id,
const CollectionIndex& index) {
......@@ -1554,8 +1785,9 @@ Status
DBImpl::HybridQuery(const std::shared_ptr<server::Context>& context, const std::string& collection_id,
const std::vector<std::string>& partition_tags,
context::HybridSearchContextPtr hybrid_search_context, query::GeneralQueryPtr general_query,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type, uint64_t& nq,
ResultIds& result_ids, ResultDistances& result_distances) {
std::vector<std::string>& field_names,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
engine::QueryResult& result) {
auto query_ctx = context->Child("Query");
if (!initialized_.load(std::memory_order_acquire)) {
......@@ -1605,8 +1837,8 @@ DBImpl::HybridQuery(const std::shared_ptr<server::Context>& context, const std::
}
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
status = HybridQueryAsync(query_ctx, collection_id, files_holder, hybrid_search_context, general_query, attr_type,
nq, result_ids, result_distances);
status = HybridQueryAsync(query_ctx, collection_id, files_holder, hybrid_search_context, general_query, field_names,
attr_type, result);
if (!status.ok()) {
return status;
}
......@@ -1766,11 +1998,11 @@ DBImpl::QueryAsync(const std::shared_ptr<server::Context>& context, meta::FilesH
}
Status
DBImpl::HybridQueryAsync(const std::shared_ptr<server::Context>& context, const std::string& table_id,
DBImpl::HybridQueryAsync(const std::shared_ptr<server::Context>& context, const std::string& collection_id,
meta::FilesHolder& files_holder, context::HybridSearchContextPtr hybrid_search_context,
query::GeneralQueryPtr general_query,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type, uint64_t& nq,
ResultIds& result_ids, ResultDistances& result_distances) {
query::GeneralQueryPtr general_query, std::vector<std::string>& field_names,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
engine::QueryResult& result) {
auto query_async_ctx = context->Child("Query Async");
#if 0
......@@ -1789,10 +2021,8 @@ DBImpl::HybridQueryAsync(const std::shared_ptr<server::Context>& context, const
search::TaskInst::GetInstance().load_cv().notify_one();
hybrid_search_context->tasks_.emplace_back(task);
}
#endif
//#if 0
TimeRecorder rc("");
// step 1: construct search job
......@@ -1816,13 +2046,37 @@ DBImpl::HybridQueryAsync(const std::shared_ptr<server::Context>& context, const
}
// step 3: construct results
nq = job->vector_count();
result_ids = job->GetResultIds();
result_distances = job->GetResultDistances();
result.row_num_ = job->vector_count();
result.result_ids_ = job->GetResultIds();
result.result_distances_ = job->GetResultDistances();
// step 4: get entities by result ids
auto status = GetEntitiesByID(collection_id, result.result_ids_, result.vectors_, result.attrs_);
if (!status.ok()) {
query_async_ctx->GetTraceContext()->GetSpan()->Finish();
return status;
}
// step 5: filter entities by field names
std::vector<engine::AttrsData> filter_attrs;
for (auto attr : result.attrs_) {
AttrsData attrs_data;
attrs_data.attr_type_ = attr.attr_type_;
attrs_data.attr_count_ = attr.attr_count_;
attrs_data.id_array_ = attr.id_array_;
for (auto& name : field_names) {
if (attr.attr_data_.find(name) != attr.attr_data_.end()) {
attrs_data.attr_data_.insert(std::make_pair(name, attr.attr_data_.at(name)));
}
}
filter_attrs.emplace_back(attrs_data);
}
result.attrs_ = filter_attrs;
rc.ElapseFromBegin("Engine query totally cost");
query_async_ctx->GetTraceContext()->GetSpan()->Finish();
//#endif
return Status::OK();
}
......
......@@ -125,6 +125,10 @@ class DBImpl : public DB, public server::CacheConfigHandler, public server::Engi
GetVectorsByID(const std::string& collection_id, const IDNumbers& id_array,
std::vector<engine::VectorsData>& vectors) override;
Status
GetEntitiesByID(const std::string& collection_id, const IDNumbers& id_array,
std::vector<engine::VectorsData>& vectors, std::vector<engine::AttrsData>& attrs) override;
Status
GetVectorIDs(const std::string& collection_id, const std::string& segment_id, IDNumbers& vector_ids) override;
......@@ -157,9 +161,9 @@ class DBImpl : public DB, public server::CacheConfigHandler, public server::Engi
Status
HybridQuery(const std::shared_ptr<server::Context>& context, const std::string& collection_id,
const std::vector<std::string>& partition_tags, context::HybridSearchContextPtr hybrid_search_context,
query::GeneralQueryPtr general_query,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type, uint64_t& nq,
ResultIds& result_ids, ResultDistances& result_distances) override;
query::GeneralQueryPtr general_query, std::vector<std::string>& field_names,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
engine::QueryResult& result) override;
Status
QueryByIDs(const std::shared_ptr<server::Context>& context, const std::string& collection_id,
......@@ -193,16 +197,22 @@ class DBImpl : public DB, public server::CacheConfigHandler, public server::Engi
ResultDistances& result_distances);
Status
HybridQueryAsync(const std::shared_ptr<server::Context>& context, const std::string& table_id,
HybridQueryAsync(const std::shared_ptr<server::Context>& context, const std::string& collection_id,
meta::FilesHolder& files_holder, context::HybridSearchContextPtr hybrid_search_context,
query::GeneralQueryPtr general_query,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type, uint64_t& nq,
ResultIds& result_ids, ResultDistances& result_distances);
query::GeneralQueryPtr general_query, std::vector<std::string>& field_names,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
engine::QueryResult& result);
Status
GetVectorsByIdHelper(const std::string& collection_id, const IDNumbers& id_array,
std::vector<engine::VectorsData>& vectors, meta::FilesHolder& files_holder);
Status
GetEntitiesByIdHelper(const std::string& collection_id, const IDNumbers& id_array,
std::unordered_map<std::string, engine::meta::hybrid::DataType>& attr_type,
std::vector<engine::VectorsData>& vectors, std::vector<engine::AttrsData>& attrs,
meta::FilesHolder& files_holder);
void
InternalFlush(const std::string& collection_id = "");
......
......@@ -22,6 +22,7 @@
#include <vector>
#include "db/engine/ExecutionEngine.h"
#include "db/meta/MetaTypes.h"
#include "segment/Types.h"
#include "utils/Json.h"
......@@ -51,11 +52,25 @@ struct VectorsData {
struct Entity {
uint64_t entity_count_ = 0;
std::vector<uint8_t> attr_value_;
std::unordered_map<std::string, std::vector<std::string>> attr_data_;
std::unordered_map<std::string, VectorsData> vector_data_;
IDNumbers id_array_;
};
struct AttrsData {
uint64_t attr_count_ = 0;
std::unordered_map<std::string, engine::meta::hybrid::DataType> attr_type_;
std::unordered_map<std::string, std::vector<uint8_t>> attr_data_;
IDNumbers id_array_;
};
struct QueryResult {
uint64_t row_num_;
engine::ResultIds result_ids_;
engine::ResultDistances result_distances_;
std::vector<engine::VectorsData> vectors_;
std::vector<engine::AttrsData> attrs_;
};
using File2ErrArray = std::map<std::string, std::vector<std::string>>;
using Table2FileErr = std::map<std::string, File2ErrArray>;
......
......@@ -116,9 +116,13 @@ class ExecutionEngine {
GetVectorByID(const int64_t& id, uint8_t* vector, bool hybrid) = 0;
virtual Status
ExecBinaryQuery(query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr bitset,
std::unordered_map<std::string, DataType>& attr_type, uint64_t& nq, uint64_t& topk,
std::vector<float>& distances, std::vector<int64_t>& labels) = 0;
ExecBinaryQuery(query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr& bitset,
std::unordered_map<std::string, DataType>& attr_type,
milvus::query::VectorQueryPtr& vector_query) = 0;
virtual Status
HybridSearch(query::GeneralQueryPtr general_query, std::unordered_map<std::string, DataType>& attr_type,
uint64_t& nq, uint64_t& topk, std::vector<float>& distances, std::vector<int64_t>& search_ids) = 0;
virtual Status
Search(int64_t n, const float* data, int64_t k, const milvus::json& extra_params, float* distances, int64_t* labels,
......
......@@ -733,63 +733,52 @@ MapAndCopyResult(const knowhere::DatasetPtr& dataset, const std::vector<milvus::
template <typename T>
void
ProcessRangeQuery(std::vector<T> data, T value, query::CompareOperator type, faiss::ConcurrentBitsetPtr& bitset) {
ExecutionEngineImpl::ProcessRangeQuery(std::vector<T> data, T value, query::CompareOperator type,
faiss::ConcurrentBitsetPtr& bitset) {
switch (type) {
case query::CompareOperator::LT: {
for (uint64_t i = 0; i < data.size(); ++i) {
if (data[i] >= value) {
if (!bitset->test(i)) {
bitset->set(i);
}
if (data[i] < value) {
bitset->set(i);
}
}
break;
}
case query::CompareOperator::LTE: {
for (uint64_t i = 0; i < data.size(); ++i) {
if (data[i] > value) {
if (!bitset->test(i)) {
bitset->set(i);
}
if (data[i] <= value) {
bitset->set(i);
}
}
break;
}
case query::CompareOperator::GT: {
for (uint64_t i = 0; i < data.size(); ++i) {
if (data[i] <= value) {
if (!bitset->test(i)) {
bitset->set(i);
}
if (data[i] > value) {
bitset->set(i);
}
}
break;
}
case query::CompareOperator::GTE: {
for (uint64_t i = 0; i < data.size(); ++i) {
if (data[i] < value) {
if (!bitset->test(i)) {
bitset->set(i);
}
if (data[i] >= value) {
bitset->set(i);
}
}
break;
}
case query::CompareOperator::EQ: {
for (uint64_t i = 0; i < data.size(); ++i) {
if (data[i] != value) {
if (!bitset->test(i)) {
bitset->set(i);
}
if (data[i] == value) {
bitset->set(i);
}
}
}
case query::CompareOperator::NE: {
for (uint64_t i = 0; i < data.size(); ++i) {
if (data[i] == value) {
if (!bitset->test(i)) {
bitset->set(i);
}
if (data[i] != value) {
bitset->set(i);
}
}
break;
......@@ -798,23 +787,90 @@ ProcessRangeQuery(std::vector<T> data, T value, query::CompareOperator type, fai
}
Status
ExecutionEngineImpl::ExecBinaryQuery(milvus::query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr bitset,
std::unordered_map<std::string, DataType>& attr_type, uint64_t& nq, uint64_t& topk,
std::vector<float>& distances, std::vector<int64_t>& labels) {
if (bitset == nullptr) {
bitset = std::make_shared<faiss::ConcurrentBitset>(vector_count_);
ExecutionEngineImpl::HybridSearch(milvus::query::GeneralQueryPtr general_query,
std::unordered_map<std::string, DataType>& attr_type, uint64_t& nq, uint64_t& topk,
std::vector<float>& distances, std::vector<int64_t>& search_ids) {
faiss::ConcurrentBitsetPtr bitset;
milvus::query::VectorQueryPtr vector_query;
auto status = ExecBinaryQuery(general_query, bitset, attr_type, vector_query);
// Do search
faiss::ConcurrentBitsetPtr list;
list = index_->GetBlacklist();
// Do AND
for (uint64_t i = 0; i < vector_count_; ++i) {
if (list->test(i) && !bitset->test(i)) {
list->clear(i);
}
}
index_->SetBlacklist(list);
topk = vector_query->topk;
nq = vector_query->query_vector.float_data.size() / dim_;
distances.resize(nq * topk);
search_ids.resize(nq * topk);
status = Search(nq, vector_query->query_vector.float_data.data(), topk, vector_query->extra_params,
distances.data(), search_ids.data());
if (!status.ok()) {
return status;
}
return Status::OK();
}
Status
ExecutionEngineImpl::ExecBinaryQuery(milvus::query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr& bitset,
std::unordered_map<std::string, DataType>& attr_type,
milvus::query::VectorQueryPtr& vector_query) {
if (general_query->leaf == nullptr) {
Status status;
faiss::ConcurrentBitsetPtr left_bitset, right_bitset;
if (general_query->bin->left_query != nullptr) {
status = ExecBinaryQuery(general_query->bin->left_query, bitset, attr_type, nq, topk, distances, labels);
status = ExecBinaryQuery(general_query->bin->left_query, left_bitset, attr_type, vector_query);
if (!status.ok()) {
return status;
}
}
if (general_query->bin->right_query != nullptr) {
status = ExecBinaryQuery(general_query->bin->right_query, bitset, attr_type, nq, topk, distances, labels);
status = ExecBinaryQuery(general_query->bin->right_query, right_bitset, attr_type, vector_query);
if (!status.ok()) {
return status;
}
}
if (left_bitset == nullptr || right_bitset == nullptr) {
bitset = left_bitset != nullptr ? left_bitset : right_bitset;
} else {
switch (general_query->bin->relation) {
case milvus::query::QueryRelation::AND:
case milvus::query::QueryRelation::R1: {
bitset = (*left_bitset) & right_bitset;
break;
}
case milvus::query::QueryRelation::OR:
case milvus::query::QueryRelation::R2:
case milvus::query::QueryRelation::R3: {
bitset = (*left_bitset) | right_bitset;
break;
}
case milvus::query::QueryRelation::R4: {
for (uint64_t i = 0; i < vector_count_; ++i) {
if (left_bitset->test(i) && !right_bitset->test(i)) {
bitset->set(i);
}
}
break;
}
default: {
std::string msg = "Invalid QueryRelation in RangeQuery";
return Status{SERVER_INVALID_ARGUMENT, msg};
}
}
}
return status;
} else {
bitset = std::make_shared<faiss::ConcurrentBitset>(vector_count_);
if (general_query->leaf->term_query != nullptr) {
// process attrs_data
auto field_name = general_query->leaf->term_query->field_name;
......@@ -841,10 +897,8 @@ ExecutionEngineImpl::ExecBinaryQuery(milvus::query::GeneralQueryPtr general_quer
break;
}
}
if (!value_in_term) {
if (!bitset->test(i)) {
bitset->set(i);
}
if (value_in_term) {
bitset->set(i);
}
}
break;
......@@ -868,10 +922,8 @@ ExecutionEngineImpl::ExecBinaryQuery(milvus::query::GeneralQueryPtr general_quer
break;
}
}
if (!value_in_term) {
if (!bitset->test(i)) {
bitset->set(i);
}
if (value_in_term) {
bitset->set(i);
}
}
break;
......@@ -896,10 +948,8 @@ ExecutionEngineImpl::ExecBinaryQuery(milvus::query::GeneralQueryPtr general_quer
break;
}
}
if (!value_in_term) {
if (!bitset->test(i)) {
bitset->set(i);
}
if (value_in_term) {
bitset->set(i);
}
}
break;
......@@ -924,10 +974,8 @@ ExecutionEngineImpl::ExecBinaryQuery(milvus::query::GeneralQueryPtr general_quer
break;
}
}
if (!value_in_term) {
if (!bitset->test(i)) {
bitset->set(i);
}
if (value_in_term) {
bitset->set(i);
}
}
break;
......@@ -952,10 +1000,8 @@ ExecutionEngineImpl::ExecBinaryQuery(milvus::query::GeneralQueryPtr general_quer
break;
}
}
if (!value_in_term) {
if (!bitset->test(i)) {
bitset->set(i);
}
if (value_in_term) {
bitset->set(i);
}
}
break;
......@@ -980,10 +1026,8 @@ ExecutionEngineImpl::ExecBinaryQuery(milvus::query::GeneralQueryPtr general_quer
break;
}
}
if (!value_in_term) {
if (!bitset->test(i)) {
bitset->set(i);
}
if (value_in_term) {
bitset->set(i);
}
}
break;
......@@ -1056,25 +1100,10 @@ ExecutionEngineImpl::ExecBinaryQuery(milvus::query::GeneralQueryPtr general_quer
return Status::OK();
}
if (general_query->leaf->vector_query != nullptr) {
// Do search
faiss::ConcurrentBitsetPtr list;
list = index_->GetBlacklist();
// Do OR
for (uint64_t i = 0; i < vector_count_; ++i) {
if (list->test(i) || bitset->test(i)) {
bitset->set(i);
}
}
index_->SetBlacklist(bitset);
auto vector_query = general_query->leaf->vector_query;
topk = vector_query->topk;
nq = vector_query->query_vector.float_data.size() / dim_;
distances.resize(nq * topk);
labels.resize(nq * topk);
return Search(nq, vector_query->query_vector.float_data.data(), topk, vector_query->extra_params,
distances.data(), labels.data());
// skip vector query
vector_query = general_query->leaf->vector_query;
bitset = nullptr;
return Status::OK();
}
}
}
......
......@@ -70,9 +70,14 @@ class ExecutionEngineImpl : public ExecutionEngine {
GetVectorByID(const int64_t& id, uint8_t* vector, bool hybrid) override;
Status
ExecBinaryQuery(query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr bitset,
std::unordered_map<std::string, DataType>& attr_type, uint64_t& nq, uint64_t& topk,
std::vector<float>& distances, std::vector<int64_t>& labels) override;
ExecBinaryQuery(query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr& bitset,
std::unordered_map<std::string, DataType>& attr_type,
milvus::query::VectorQueryPtr& vector_query) override;
Status
HybridSearch(query::GeneralQueryPtr general_query, std::unordered_map<std::string, DataType>& attr_type,
uint64_t& nq, uint64_t& topk, std::vector<float>& distances,
std::vector<int64_t>& search_ids) override;
Status
Search(int64_t n, const float* data, int64_t k, const milvus::json& extra_params, float* distances, int64_t* labels,
......@@ -113,6 +118,10 @@ class ExecutionEngineImpl : public ExecutionEngine {
knowhere::VecIndexPtr
Load(const std::string& location);
template <typename T>
void
ProcessRangeQuery(std::vector<T> data, T value, query::CompareOperator type, faiss::ConcurrentBitsetPtr& bitset);
void
HybridLoad() const;
......@@ -124,10 +133,9 @@ class ExecutionEngineImpl : public ExecutionEngine {
EngineType index_type_;
MetricType metric_type_;
std::unordered_map<std::string, DataType> attr_types_;
std::unordered_map<std::string, std::vector<uint8_t>> attr_data_;
std::unordered_map<std::string, size_t> attr_size_;
query::BinaryQueryPtr binary_query_;
std::vector<int64_t> entity_ids_;
int64_t vector_count_;
int64_t dim_;
......
......@@ -1061,32 +1061,32 @@ void MilvusService::Stub::experimental_async::InsertEntity(::grpc::ClientContext
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::HEntityIDs>::Create(channel_.get(), cq, rpcmethod_InsertEntity_, context, request, false);
}
::grpc::Status MilvusService::Stub::HybridSearch(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParam& request, ::milvus::grpc::TopKQueryResult* response) {
::grpc::Status MilvusService::Stub::HybridSearch(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParam& request, ::milvus::grpc::HQueryResult* response) {
return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_HybridSearch_, context, request, response);
}
void MilvusService::Stub::experimental_async::HybridSearch(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParam* request, ::milvus::grpc::TopKQueryResult* response, std::function<void(::grpc::Status)> f) {
void MilvusService::Stub::experimental_async::HybridSearch(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParam* request, ::milvus::grpc::HQueryResult* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_HybridSearch_, context, request, response, std::move(f));
}
void MilvusService::Stub::experimental_async::HybridSearch(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResult* response, std::function<void(::grpc::Status)> f) {
void MilvusService::Stub::experimental_async::HybridSearch(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::HQueryResult* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_HybridSearch_, context, request, response, std::move(f));
}
void MilvusService::Stub::experimental_async::HybridSearch(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParam* request, ::milvus::grpc::TopKQueryResult* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
void MilvusService::Stub::experimental_async::HybridSearch(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParam* request, ::milvus::grpc::HQueryResult* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_HybridSearch_, context, request, response, reactor);
}
void MilvusService::Stub::experimental_async::HybridSearch(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResult* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
void MilvusService::Stub::experimental_async::HybridSearch(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::HQueryResult* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_HybridSearch_, context, request, response, reactor);
}
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::AsyncHybridSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_HybridSearch_, context, request, true);
::grpc::ClientAsyncResponseReader< ::milvus::grpc::HQueryResult>* MilvusService::Stub::AsyncHybridSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::HQueryResult>::Create(channel_.get(), cq, rpcmethod_HybridSearch_, context, request, true);
}
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::PrepareAsyncHybridSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_HybridSearch_, context, request, false);
::grpc::ClientAsyncResponseReader< ::milvus::grpc::HQueryResult>* MilvusService::Stub::PrepareAsyncHybridSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::HSearchParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::HQueryResult>::Create(channel_.get(), cq, rpcmethod_HybridSearch_, context, request, false);
}
::grpc::Status MilvusService::Stub::HybridSearchInSegments(::grpc::ClientContext* context, const ::milvus::grpc::HSearchInSegmentsParam& request, ::milvus::grpc::TopKQueryResult* response) {
......@@ -1117,11 +1117,11 @@ void MilvusService::Stub::experimental_async::HybridSearchInSegments(::grpc::Cli
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_HybridSearchInSegments_, context, request, false);
}
::grpc::Status MilvusService::Stub::GetEntityByID(::grpc::ClientContext* context, const ::milvus::grpc::HEntityIdentity& request, ::milvus::grpc::HEntity* response) {
::grpc::Status MilvusService::Stub::GetEntityByID(::grpc::ClientContext* context, const ::milvus::grpc::VectorsIdentity& request, ::milvus::grpc::HEntity* response) {
return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_GetEntityByID_, context, request, response);
}
void MilvusService::Stub::experimental_async::GetEntityByID(::grpc::ClientContext* context, const ::milvus::grpc::HEntityIdentity* request, ::milvus::grpc::HEntity* response, std::function<void(::grpc::Status)> f) {
void MilvusService::Stub::experimental_async::GetEntityByID(::grpc::ClientContext* context, const ::milvus::grpc::VectorsIdentity* request, ::milvus::grpc::HEntity* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_GetEntityByID_, context, request, response, std::move(f));
}
......@@ -1129,7 +1129,7 @@ void MilvusService::Stub::experimental_async::GetEntityByID(::grpc::ClientContex
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_GetEntityByID_, context, request, response, std::move(f));
}
void MilvusService::Stub::experimental_async::GetEntityByID(::grpc::ClientContext* context, const ::milvus::grpc::HEntityIdentity* request, ::milvus::grpc::HEntity* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
void MilvusService::Stub::experimental_async::GetEntityByID(::grpc::ClientContext* context, const ::milvus::grpc::VectorsIdentity* request, ::milvus::grpc::HEntity* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_GetEntityByID_, context, request, response, reactor);
}
......@@ -1137,11 +1137,11 @@ void MilvusService::Stub::experimental_async::GetEntityByID(::grpc::ClientContex
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_GetEntityByID_, context, request, response, reactor);
}
::grpc::ClientAsyncResponseReader< ::milvus::grpc::HEntity>* MilvusService::Stub::AsyncGetEntityByIDRaw(::grpc::ClientContext* context, const ::milvus::grpc::HEntityIdentity& request, ::grpc::CompletionQueue* cq) {
::grpc::ClientAsyncResponseReader< ::milvus::grpc::HEntity>* MilvusService::Stub::AsyncGetEntityByIDRaw(::grpc::ClientContext* context, const ::milvus::grpc::VectorsIdentity& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::HEntity>::Create(channel_.get(), cq, rpcmethod_GetEntityByID_, context, request, true);
}
::grpc::ClientAsyncResponseReader< ::milvus::grpc::HEntity>* MilvusService::Stub::PrepareAsyncGetEntityByIDRaw(::grpc::ClientContext* context, const ::milvus::grpc::HEntityIdentity& request, ::grpc::CompletionQueue* cq) {
::grpc::ClientAsyncResponseReader< ::milvus::grpc::HEntity>* MilvusService::Stub::PrepareAsyncGetEntityByIDRaw(::grpc::ClientContext* context, const ::milvus::grpc::VectorsIdentity& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::HEntity>::Create(channel_.get(), cq, rpcmethod_GetEntityByID_, context, request, false);
}
......@@ -1375,7 +1375,7 @@ MilvusService::Service::Service() {
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[34],
::grpc::internal::RpcMethod::NORMAL_RPC,
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::HSearchParam, ::milvus::grpc::TopKQueryResult>(
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::HSearchParam, ::milvus::grpc::HQueryResult>(
std::mem_fn(&MilvusService::Service::HybridSearch), this)));
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[35],
......@@ -1385,7 +1385,7 @@ MilvusService::Service::Service() {
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[36],
::grpc::internal::RpcMethod::NORMAL_RPC,
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::HEntityIdentity, ::milvus::grpc::HEntity>(
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::VectorsIdentity, ::milvus::grpc::HEntity>(
std::mem_fn(&MilvusService::Service::GetEntityByID), this)));
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[37],
......@@ -1640,7 +1640,7 @@ MilvusService::Service::~Service() {
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
::grpc::Status MilvusService::Service::HybridSearch(::grpc::ServerContext* context, const ::milvus::grpc::HSearchParam* request, ::milvus::grpc::TopKQueryResult* response) {
::grpc::Status MilvusService::Service::HybridSearch(::grpc::ServerContext* context, const ::milvus::grpc::HSearchParam* request, ::milvus::grpc::HQueryResult* response) {
(void) context;
(void) request;
(void) response;
......@@ -1654,7 +1654,7 @@ MilvusService::Service::~Service() {
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
::grpc::Status MilvusService::Service::GetEntityByID(::grpc::ServerContext* context, const ::milvus::grpc::HEntityIdentity* request, ::milvus::grpc::HEntity* response) {
::grpc::Status MilvusService::Service::GetEntityByID(::grpc::ServerContext* context, const ::milvus::grpc::VectorsIdentity* request, ::milvus::grpc::HEntity* response) {
(void) context;
(void) request;
(void) response;
......
......@@ -253,19 +253,17 @@ message FieldParam {
repeated KeyValuePair extra_params = 4;
}
message VectorFieldValue {
message VectorFieldRecord {
repeated RowRecord value = 1;
}
message FieldValue {
oneof value {
int32 int32_value = 1;
int64 int64_value = 2;
float float_value = 3;
double double_value = 4;
string string_value = 5;
bool bool_value = 6;
VectorFieldValue vector_value = 7;
int64 int64_value = 1;
double double_value = 2;
string string_value = 3;
bool bool_value = 4;
VectorFieldRecord vector_value = 5;
}
}
......@@ -287,10 +285,11 @@ message MappingList {
message TermQuery {
string field_name = 1;
bytes values = 2;
int64 value_num = 3;
float boost = 4;
repeated KeyValuePair extra_params = 5;
repeated int64 int_value = 2;
repeated double double_value = 3;
int64 value_num = 4;
float boost = 5;
repeated KeyValuePair extra_params = 6;
}
enum CompareOperator {
......@@ -358,37 +357,40 @@ message HSearchInSegmentsParam {
///////////////////////////////////////////////////////////////////
message AttrRecord {
repeated string value = 1;
repeated int64 int_value = 1;
repeated double double_value = 2;
}
message HEntity {
Status status = 1;
int64 entity_id = 2;
repeated int64 entity_id = 2;
repeated string field_names = 3;
bytes attr_records = 4;
repeated DataType data_types = 4;
int64 row_num = 5;
repeated FieldValue result_values = 6;
repeated AttrRecord attr_data = 6;
repeated VectorFieldRecord vector_data = 7;
}
message HQueryResult {
Status status = 1;
repeated HEntity entities = 2;
HEntity entity = 2;
int64 row_num = 3;
repeated float score = 4;
repeated float distance = 5;
repeated KeyValuePair extra_params = 6;
}
message HInsertParam {
string collection_name = 1;
string partition_tag = 2;
HEntity entities = 3;
HEntity entity = 3;
repeated int64 entity_id_array = 4;
repeated KeyValuePair extra_params = 5;
}
message HEntityIdentity {
string collection_name = 1;
int64 id = 2;
repeated int64 id = 2;
}
message HEntityIDs {
......@@ -672,12 +674,11 @@ service MilvusService {
rpc InsertEntity(HInsertParam) returns (HEntityIDs) {}
// TODO(yukun): will change to HQueryResult
rpc HybridSearch(HSearchParam) returns (TopKQueryResult) {}
rpc HybridSearch(HSearchParam) returns (HQueryResult) {}
rpc HybridSearchInSegments(HSearchInSegmentsParam) returns (TopKQueryResult) {}
rpc GetEntityByID(HEntityIdentity) returns (HEntity) {}
rpc GetEntityByID(VectorsIdentity) returns (HEntity) {}
rpc GetEntityIDs(HGetEntityIDsParam) returns (HEntityIDs) {}
......
......@@ -22,6 +22,156 @@ namespace faiss {
ConcurrentBitset::ConcurrentBitset(id_type_t capacity) : capacity_(capacity), bitset_((capacity + 8 - 1) >> 3) {
}
std::vector<std::atomic<uint8_t>>&
ConcurrentBitset::bitset() {
return bitset_;
}
ConcurrentBitset&
ConcurrentBitset::operator&=(ConcurrentBitset& bitset) {
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
// bitset_[i].fetch_and(bitset.bitset()[i].load());
// }
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset.data());
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
for (size_t i = 0; i < n64; i++) {
u64_1[i] &= u64_2[i];
}
size_t remain = n8 % 8;
u8_1 += n64 * 8;
u8_2 += n64 * 8;
for (size_t i = 0; i < remain; i++) {
u8_1[i] &= u8_2[i];
}
return *this;
}
std::shared_ptr<ConcurrentBitset>
ConcurrentBitset::operator&(std::shared_ptr<ConcurrentBitset>& bitset) {
auto result_bitset = std::make_shared<ConcurrentBitset>(bitset->capacity());
auto result_8 = const_cast<uint8_t*>(result_bitset->data());
auto result_64 = reinterpret_cast<uint64_t*>(result_8);
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset->data());
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
for (size_t i = 0; i < n64; i++) {
result_64[i] = u64_1[i] & u64_2[i];
}
size_t remain = n8 % 8;
u8_1 += n64 * 8;
u8_2 += n64 * 8;
result_8 += n64 * 8;
for (size_t i = 0; i < remain; i++) {
result_8[i] = u8_1[i] & u8_2[i];
}
return result_bitset;
}
ConcurrentBitset&
ConcurrentBitset::operator|=(ConcurrentBitset& bitset) {
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
// bitset_[i].fetch_or(bitset.bitset()[i].load());
// }
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset.data());
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
for (size_t i = 0; i < n64; i++) {
u64_1[i] &= u64_2[i];
}
size_t remain = n8 % 8;
u8_1 += n64 * 8;
u8_2 += n64 * 8;
for (size_t i = 0; i < remain; i++) {
u8_1[i] |= u8_2[i];
}
return *this;
}
std::shared_ptr<ConcurrentBitset>
ConcurrentBitset::operator|(std::shared_ptr<ConcurrentBitset>& bitset) {
auto result_bitset = std::make_shared<ConcurrentBitset>(bitset->capacity());
auto result_8 = const_cast<uint8_t*>(result_bitset->data());
auto result_64 = reinterpret_cast<uint64_t*>(result_8);
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset->data());
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
for (size_t i = 0; i < n64; i++) {
result_64[i] = u64_1[i] & u64_2[i];
}
size_t remain = n8 % 8;
u8_1 += n64 * 8;
u8_2 += n64 * 8;
result_8 += n64 * 8;
for (size_t i = 0; i < remain; i++) {
result_8[i] = u8_1[i] | u8_2[i];
}
return result_bitset;
}
ConcurrentBitset&
ConcurrentBitset::operator^=(ConcurrentBitset& bitset) {
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
// bitset_[i].fetch_xor(bitset.bitset()[i].load());
// }
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset.data());
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
for (size_t i = 0; i < n64; i++) {
u64_1[i] &= u64_2[i];
}
size_t remain = n8 % 8;
u8_1 += n64 * 8;
u8_2 += n64 * 8;
for (size_t i = 0; i < remain; i++) {
u8_1[i] ^= u8_2[i];
}
return *this;
}
bool
ConcurrentBitset::test(id_type_t id) {
return bitset_[id >> 3].load() & (0x1 << (id & 0x7));
......
......@@ -33,6 +33,24 @@ class ConcurrentBitset {
// ConcurrentBitset&
// operator=(const ConcurrentBitset&) = delete;
std::vector<std::atomic<uint8_t>>&
bitset();
ConcurrentBitset&
operator&=(ConcurrentBitset& bitset);
std::shared_ptr<ConcurrentBitset>
operator&(std::shared_ptr<ConcurrentBitset>& bitset);
ConcurrentBitset&
operator|=(ConcurrentBitset& bitset);
std::shared_ptr<ConcurrentBitset>
operator|(std::shared_ptr<ConcurrentBitset>& bitset);
ConcurrentBitset&
operator^=(ConcurrentBitset& bitset);
bool
test(id_type_t id);
......@@ -54,7 +72,6 @@ class ConcurrentBitset {
private:
size_t capacity_;
std::vector<std::atomic<uint8_t>> bitset_;
};
using ConcurrentBitsetPtr = std::shared_ptr<ConcurrentBitset>;
......
......@@ -64,7 +64,7 @@ Status
GenBinaryQuery(BooleanQueryPtr query, BinaryQueryPtr& binary_query) {
if (query->getBooleanQuerys().size() == 0) {
if (binary_query->relation == QueryRelation::AND || binary_query->relation == QueryRelation::OR) {
// Put VectorQuery to the end of leafqueries
// Put VectorQuery to the end of leaf queries
auto query_size = query->getLeafQueries().size();
for (uint64_t i = 0; i < query_size; ++i) {
if (query->getLeafQueries()[i]->vector_query != nullptr) {
......@@ -211,9 +211,8 @@ BinaryQueryHeight(BinaryQueryPtr& binary_query) {
bool
ValidateBinaryQuery(BinaryQueryPtr& binary_query) {
// Only for one layer BooleanQuery
uint64_t height = BinaryQueryHeight(binary_query);
return height > 1 && height < 4;
return height > 1;
}
} // namespace query
......
......@@ -267,8 +267,7 @@ XSearchTask::Execute() {
for (; type_it != attr_type.end(); type_it++) {
types.insert(std::make_pair(type_it->first, (engine::DataType)(type_it->second)));
}
faiss::ConcurrentBitsetPtr bitset;
s = index_engine_->ExecBinaryQuery(general_query, bitset, types, nq, topk, output_distance, output_ids);
s = index_engine_->HybridSearch(general_query, types, nq, topk, output_distance, output_ids);
if (!s.ok()) {
search_job->GetStatus() = s;
......
......@@ -73,6 +73,21 @@ SegmentReader::LoadVectors(off_t offset, size_t num_bytes, std::vector<uint8_t>&
return Status::OK();
}
Status
SegmentReader::LoadAttrs(const std::string& field_name, off_t offset, size_t num_bytes,
std::vector<uint8_t>& raw_attrs) {
codec::DefaultCodec default_codec;
try {
fs_ptr_->operation_ptr_->CreateDirectory();
default_codec.GetAttrsFormat()->read_attrs(fs_ptr_, field_name, offset, num_bytes, raw_attrs);
} catch (std::exception& e) {
std::string err_msg = "Failed to load raw attributes: " + std::string(e.what());
LOG_ENGINE_ERROR_ << err_msg;
return Status(DB_ERROR, err_msg);
}
return Status::OK();
}
Status
SegmentReader::LoadUids(std::vector<doc_id_t>& uids) {
codec::DefaultCodec default_codec;
......
......@@ -42,6 +42,9 @@ class SegmentReader {
Status
LoadVectors(off_t offset, size_t num_bytes, std::vector<uint8_t>& raw_vectors);
Status
LoadAttrs(const std::string& field_name, off_t offset, size_t num_bytes, std::vector<uint8_t>& raw_attrs);
Status
LoadUids(std::vector<doc_id_t>& uids);
......
......@@ -42,6 +42,7 @@
#include "server/delivery/hybrid_request/CreateHybridCollectionRequest.h"
#include "server/delivery/hybrid_request/DescribeHybridCollectionRequest.h"
#include "server/delivery/hybrid_request/GetEntityByIDRequest.h"
#include "server/delivery/hybrid_request/HybridSearchRequest.h"
#include "server/delivery/hybrid_request/InsertEntityRequest.h"
......@@ -304,13 +305,26 @@ RequestHandler::InsertEntity(const std::shared_ptr<Context>& context, const std:
return request_ptr->status();
}
Status
RequestHandler::GetEntityByID(const std::shared_ptr<Context>& context, const std::string& collection_name,
const std::vector<int64_t>& ids, std::vector<engine::AttrsData>& attrs,
std::vector<engine::VectorsData>& vectors) {
BaseRequestPtr request_ptr = GetEntityByIDRequest::Create(context, collection_name, ids, attrs, vectors);
RequestScheduler::ExecRequest(request_ptr);
return request_ptr->status();
}
Status
RequestHandler::HybridSearch(const std::shared_ptr<Context>& context,
context::HybridSearchContextPtr hybrid_search_context, const std::string& collection_name,
std::vector<std::string>& partition_list, milvus::query::GeneralQueryPtr& general_query,
TopKQueryResult& result) {
BaseRequestPtr request_ptr = HybridSearchRequest::Create(context, hybrid_search_context, collection_name,
partition_list, general_query, result);
milvus::json& json_params, std::vector<std::string>& field_names,
engine::QueryResult& result) {
BaseRequestPtr request_ptr =
HybridSearchRequest::Create(context, hybrid_search_context, collection_name, partition_list, general_query,
json_params, field_names, result);
RequestScheduler::ExecRequest(request_ptr);
return request_ptr->status();
......
......@@ -140,10 +140,16 @@ class RequestHandler {
const std::string& partition_tag, uint64_t& row_num, std::vector<std::string>& field_names,
std::vector<uint8_t>& attr_values, std::unordered_map<std::string, engine::VectorsData>& vector_datas);
Status
GetEntityByID(const std::shared_ptr<Context>& context, const std::string& collection_name,
const std::vector<int64_t>& ids, std::vector<engine::AttrsData>& attrs,
std::vector<engine::VectorsData>& vectors);
Status
HybridSearch(const std::shared_ptr<Context>& context, context::HybridSearchContextPtr hybrid_search_context,
const std::string& collection_name, std::vector<std::string>& partition_list,
query::GeneralQueryPtr& boolean_query, TopKQueryResult& result);
query::GeneralQueryPtr& general_query, milvus::json& json_params,
std::vector<std::string>& field_names, engine::QueryResult& result);
};
} // namespace server
......
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "server/delivery/hybrid_request/GetEntityByIDRequest.h"
#include "server/DBWrapper.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include "utils/ValidationUtil.h"
#include <memory>
#include <vector>
namespace milvus {
namespace server {
constexpr uint64_t MAX_COUNT_RETURNED = 1000;
GetEntityByIDRequest::GetEntityByIDRequest(const std::shared_ptr<milvus::server::Context>& context,
const std::string& collection_name, const std::vector<int64_t>& ids,
std::vector<engine::AttrsData>& attrs,
std::vector<engine::VectorsData>& vectors)
: BaseRequest(context, BaseRequest::kGetVectorByID),
collection_name_(collection_name),
ids_(ids),
attrs_(attrs),
vectors_(vectors) {
}
BaseRequestPtr
GetEntityByIDRequest::Create(const std::shared_ptr<milvus::server::Context>& context,
const std::string& collection_name, const std::vector<int64_t>& ids,
std::vector<engine::AttrsData>& attrs, std::vector<engine::VectorsData>& vectors) {
return std::shared_ptr<BaseRequest>(new GetEntityByIDRequest(context, collection_name, ids, attrs, vectors));
}
Status
GetEntityByIDRequest::OnExecute() {
try {
std::string hdr = "GetEntitiesByIDRequest(collection=" + collection_name_ + ")";
TimeRecorderAuto rc(hdr);
// step 1: check arguments
if (ids_.empty()) {
return Status(SERVER_INVALID_ARGUMENT, "No entity id specified");
}
if (ids_.size() > MAX_COUNT_RETURNED) {
std::string msg = "Input id array size cannot exceed: " + std::to_string(MAX_COUNT_RETURNED);
return Status(SERVER_INVALID_ARGUMENT, msg);
}
auto status = ValidationUtil::ValidateCollectionName(collection_name_);
if (!status.ok()) {
return status;
}
// only process root collection, ignore partition collection
engine::meta::CollectionSchema collection_schema;
collection_schema.collection_id_ = collection_name_;
status = DBWrapper::DB()->DescribeCollection(collection_schema);
if (!status.ok()) {
if (status.code() == DB_NOT_FOUND) {
return Status(SERVER_COLLECTION_NOT_EXIST, CollectionNotExistMsg(collection_name_));
} else {
return status;
}
} else {
if (!collection_schema.owner_collection_.empty()) {
return Status(SERVER_INVALID_COLLECTION_NAME, CollectionNotExistMsg(collection_name_));
}
}
// step 2: get vector data, now only support get one id
return DBWrapper::DB()->GetEntitiesByID(collection_name_, ids_, vectors_, attrs_);
} catch (std::exception& ex) {
return Status(SERVER_UNEXPECTED_ERROR, ex.what());
}
return Status::OK();
}
} // namespace server
} // namespace milvus
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "server/delivery/request/BaseRequest.h"
#include <memory>
#include <string>
#include <vector>
namespace milvus {
namespace server {
class GetEntityByIDRequest : public BaseRequest {
public:
static BaseRequestPtr
Create(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
const std::vector<int64_t>& ids, std::vector<engine::AttrsData>& attrs,
std::vector<engine::VectorsData>& vectors);
protected:
GetEntityByIDRequest(const std::shared_ptr<milvus::server::Context>& context, const std::string& collection_name,
const std::vector<int64_t>& ids, std::vector<engine::AttrsData>& attrs,
std::vector<engine::VectorsData>& vectors);
Status
OnExecute() override;
private:
std::string collection_name_;
std::vector<int64_t> ids_;
std::vector<engine::AttrsData>& attrs_;
std::vector<engine::VectorsData>& vectors_;
};
} // namespace server
} // namespace milvus
......@@ -33,12 +33,14 @@ namespace server {
HybridSearchRequest::HybridSearchRequest(const std::shared_ptr<milvus::server::Context>& context,
context::HybridSearchContextPtr& hybrid_search_context,
const std::string& collection_name, std::vector<std::string>& partition_list,
milvus::query::GeneralQueryPtr& general_query, TopKQueryResult& result)
milvus::query::GeneralQueryPtr& general_query, milvus::json& json_params,
std::vector<std::string>& field_names, engine::QueryResult& result)
: BaseRequest(context, BaseRequest::kHybridSearch),
hybrid_search_contxt_(hybrid_search_context),
hybrid_search_context_(hybrid_search_context),
collection_name_(collection_name),
partition_list_(partition_list),
general_query_(general_query),
field_names_(field_names),
result_(result) {
}
......@@ -46,9 +48,11 @@ BaseRequestPtr
HybridSearchRequest::Create(const std::shared_ptr<milvus::server::Context>& context,
context::HybridSearchContextPtr& hybrid_search_context, const std::string& collection_name,
std::vector<std::string>& partition_list, milvus::query::GeneralQueryPtr& general_query,
TopKQueryResult& result) {
milvus::json& json_params, std::vector<std::string>& field_names,
engine::QueryResult& result) {
return std::shared_ptr<BaseRequest>(new HybridSearchRequest(context, hybrid_search_context, collection_name,
partition_list, general_query, result));
partition_list, general_query, json_params, field_names,
result));
}
Status
......@@ -85,18 +89,25 @@ HybridSearchRequest::OnExecute() {
}
std::unordered_map<std::string, engine::meta::hybrid::DataType> attr_type;
for (uint64_t i = 0; i < fields_schema.fields_schema_.size(); ++i) {
for (auto& field_schema : fields_schema.fields_schema_) {
attr_type.insert(
std::make_pair(fields_schema.fields_schema_[i].field_name_,
(engine::meta::hybrid::DataType)fields_schema.fields_schema_[i].field_type_));
std::make_pair(field_schema.field_name_, (engine::meta::hybrid::DataType)field_schema.field_type_));
}
engine::ResultIds result_ids;
engine::ResultDistances result_distances;
uint64_t nq;
if (json_params.contains("field_names")) {
if (json_params["field_names"].is_array()) {
for (auto& name : json_params["field_names"]) {
field_names_.emplace_back(name.get<std::string>());
}
}
} else {
for (auto& field_schema : fields_schema.fields_schema_) {
field_names_.emplace_back(field_schema.field_name_);
}
}
status = DBWrapper::DB()->HybridQuery(context_, collection_name_, partition_list_, hybrid_search_contxt_,
general_query_, attr_type, nq, result_ids, result_distances);
status = DBWrapper::DB()->HybridQuery(context_, collection_name_, partition_list_, hybrid_search_context_,
general_query_, field_names_, attr_type, result_);
#ifdef ENABLE_CPU_PROFILING
ProfilerStop();
......@@ -106,18 +117,14 @@ HybridSearchRequest::OnExecute() {
if (!status.ok()) {
return status;
}
fiu_do_on("SearchRequest.OnExecute.empty_result_ids", result_ids.clear());
if (result_ids.empty()) {
fiu_do_on("SearchRequest.OnExecute.empty_result_ids", result_.result_ids_.clear());
if (result_.result_ids_.empty()) {
return Status::OK(); // empty table
}
auto post_query_ctx = context_->Child("Constructing result");
// step 7: construct result array
result_.row_num_ = nq;
result_.distance_list_ = result_distances;
result_.id_list_ = result_ids;
post_query_ctx->GetTraceContext()->GetSpan()->Finish();
// step 8: print time cost percent
......
......@@ -27,23 +27,25 @@ class HybridSearchRequest : public BaseRequest {
Create(const std::shared_ptr<milvus::server::Context>& context,
context::HybridSearchContextPtr& hybrid_search_context, const std::string& collection_name,
std::vector<std::string>& partition_list, milvus::query::GeneralQueryPtr& general_query,
TopKQueryResult& result);
milvus::json& json_params, std::vector<std::string>& field_names, engine::QueryResult& result);
protected:
HybridSearchRequest(const std::shared_ptr<milvus::server::Context>& context,
context::HybridSearchContextPtr& hybrid_search_context, const std::string& collection_name,
std::vector<std::string>& partition_list, milvus::query::GeneralQueryPtr& general_query,
TopKQueryResult& result);
milvus::json& json_params, std::vector<std::string>& field_names, engine::QueryResult& result);
Status
OnExecute() override;
private:
context::HybridSearchContextPtr hybrid_search_contxt_;
context::HybridSearchContextPtr hybrid_search_context_;
const std::string collection_name_;
std::vector<std::string>& partition_list_;
milvus::query::GeneralQueryPtr& general_query_;
TopKQueryResult& result_;
std::vector<std::string> partition_list_;
milvus::query::GeneralQueryPtr general_query_;
milvus::json json_params;
std::vector<std::string>& field_names_;
engine::QueryResult& result_;
};
} // namespace server
......
......@@ -40,6 +40,7 @@ RequestGroup(BaseRequest::RequestType type) {
{BaseRequest::kGetVectorByID, INFO_REQUEST_GROUP},
{BaseRequest::kGetVectorIDs, INFO_REQUEST_GROUP},
{BaseRequest::kInsertEntity, DDL_DML_REQUEST_GROUP},
{BaseRequest::kGetEntityByID, INFO_REQUEST_GROUP},
// collection operations
{BaseRequest::kShowCollections, INFO_REQUEST_GROUP},
......
......@@ -72,7 +72,9 @@ struct HybridQueryResult {
int64_t row_num_;
engine::ResultIds id_list_;
engine::ResultDistances distance_list_;
engine::Entity entities_;
std::vector<engine::VectorsData> vectors_;
std::vector<engine::AttrsData> attrs_;
};
struct IndexParam {
......@@ -118,6 +120,7 @@ class BaseRequest {
kGetVectorByID,
kGetVectorIDs,
kInsertEntity,
kGetEntityByID,
// collection operations
kShowCollections = 300,
......
......@@ -362,7 +362,11 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service,
::grpc::Status
HybridSearch(::grpc::ServerContext* context, const ::milvus::grpc::HSearchParam* request,
::milvus::grpc::TopKQueryResult* response) override;
::milvus::grpc::HQueryResult* response) override;
::grpc::Status
GetEntityByID(::grpc::ServerContext* context, const ::milvus::grpc::VectorsIdentity* request,
::milvus::grpc::HEntity* response) override;
//
// ::grpc::Status
// HybridSearchInSegments(::grpc::ServerContext* context,
......@@ -370,10 +374,7 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service,
// grpc::HSearchInSegmentsParam* request,
// ::milvus::grpc::HQueryResult* response) override;
//
// ::grpc::Status
// GetEntityByID(::grpc::ServerContext* context,
// const ::milvus::grpc::HEntityIdentity* request,
// ::milvus::grpc::HEntity* response) override;
//
// ::grpc::Status
// GetEntityIDs(::grpc::ServerContext* context,
......
......@@ -577,8 +577,8 @@ class WebController : public oatpp::web::server::api::ApiController {
*
* GetVectorByID ?id=
*/
ENDPOINT("GET", "/collections/{collection_name}/vectors", GetVectors,
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
ENDPOINT("GET", "/collections/{collection_name}/vectors", GetVectors, PATH(String, collection_name),
QUERIES(const QueryParams&, query_params)) {
auto handler = WebRequestHandler();
String response;
auto status_dto = handler.GetVector(collection_name, query_params, response);
......@@ -653,6 +653,36 @@ class WebController : public oatpp::web::server::api::ApiController {
return response;
}
ADD_CORS(EntityOp)
ENDPOINT("PUT", "/hybrid_collections/{collection_name}/entities", EntityOp, PATH(String, collection_name),
BODY_STRING(String, body)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/hybrid_collections/" + collection_name->std_str() +
"/vectors\'");
tr.RecordSection("Received request.");
WebRequestHandler handler = WebRequestHandler();
OString result;
std::shared_ptr<OutgoingResponse> response;
auto status_dto = handler.VectorsOp(collection_name, body, result);
switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS:
response = createResponse(Status::CODE_200, result);
break;
case StatusCode::COLLECTION_NOT_EXISTS:
response = createDtoResponse(Status::CODE_404, status_dto);
break;
default:
response = createDtoResponse(Status::CODE_400, status_dto);
}
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
", reason = " + status_dto->message->std_str() + ". Total cost");
return response;
}
ADD_CORS(VectorsOp)
ENDPOINT("PUT", "/collections/{collection_name}/vectors", VectorsOp, PATH(String, collection_name),
......
......@@ -141,6 +141,9 @@ class WebRequestHandler {
Status
GetVectorsByIDs(const std::string& collection_name, const std::vector<int64_t>& ids, nlohmann::json& json_out);
Status
GetEntityByIDs(const std::string& collection_name, const std::vector<int64_t>& ids, nlohmann::json& json_out);
public:
WebRequestHandler() {
context_ptr_ = GenContextPtr("Web Handler");
......@@ -219,6 +222,9 @@ class WebRequestHandler {
StatusDto::ObjectWrapper
InsertEntity(const OString& collection_name, const OString& body, VectorIdsDto::ObjectWrapper& ids_dto);
StatusDto::ObjectWrapper
GetEntity(const OString& collection_name, const OQueryParams& query_params, OString& response);
StatusDto::ObjectWrapper
GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response);
......
此差异已折叠。
......@@ -988,30 +988,35 @@ TEST_F(RpcHandlerTest, HYBRID_TEST) {
milvus::grpc::HEntityIDs entity_ids;
insert_param.set_collection_name("test_hybrid");
auto entity = insert_param.mutable_entities();
auto entity = insert_param.mutable_entity();
auto field_name_0 = entity->add_field_names();
*field_name_0 = "field_0";
auto field_name_1 = entity->add_field_names();
*field_name_1 = "field_1";
auto attr_record = entity->add_attr_data();
entity->set_row_num(row_num);
std::vector<int64_t> field_value(row_num, 0);
for (uint64_t i = 0; i < row_num; i++) {
field_value[i] = i;
}
entity->set_attr_records(field_value.data(), row_num * sizeof(int64_t));
attr_record->mutable_int_value()->Resize(static_cast<int>(row_num), 0);
memcpy(attr_record->mutable_int_value()->mutable_data(), field_value.data(), row_num * sizeof(int64_t));
std::vector<std::vector<float>> vector_field;
vector_field.resize(row_num);
std::default_random_engine e;
std::uniform_real_distribution<float> u(0, 1);
for (uint64_t i = 0; i < row_num; ++i) {
vector_field[i].resize(dimension);
for (uint64_t j = 0; j < dimension; ++j) {
vector_field[i][j] = (float)((i + 10) / (j + 20));
vector_field[i][j] = u(e);
}
}
auto vector_record = entity->add_result_values();
auto vector_record = entity->add_vector_data();
for (uint64_t i = 0; i < row_num; ++i) {
auto record = vector_record->mutable_vector_value()->add_value();
auto record = vector_record->add_value();
auto vector_data = record->mutable_float_data();
vector_data->Resize(static_cast<int>(vector_field[i].size()), 0.0);
memcpy(vector_data->mutable_data(), vector_field[i].data(), vector_field[i].size() * sizeof(float));
......@@ -1034,7 +1039,8 @@ TEST_F(RpcHandlerTest, HYBRID_TEST) {
term_value[i] = i + nq;
}
term_query->set_value_num(nq);
term_query->set_values(term_value.data(), nq * sizeof(int64_t));
term_query->mutable_int_value()->Resize(static_cast<int>(nq), 0);
memcpy(term_query->mutable_int_value()->mutable_data(), term_value.data(), nq * sizeof(int64_t));
auto vector_query = boolean_query_2->add_general_query()->mutable_vector_query();
vector_query->set_field_name("field_1");
......@@ -1045,7 +1051,7 @@ TEST_F(RpcHandlerTest, HYBRID_TEST) {
for (uint64_t i = 0; i < nq; ++i) {
query_vector[i].resize(dimension);
for (uint64_t j = 0; j < dimension; ++j) {
query_vector[i][j] = (float)((j + 1) / (i + dimension));
query_vector[i][j] = u(e);
}
}
for (auto record : query_vector) {
......@@ -1062,7 +1068,7 @@ TEST_F(RpcHandlerTest, HYBRID_TEST) {
search_extra_param->set_key("params");
search_extra_param->set_value("");
milvus::grpc::TopKQueryResult topk_query_result;
milvus::grpc::HQueryResult topk_query_result;
handler->HybridSearch(&context, &search_param, &topk_query_result);
}
......
......@@ -39,6 +39,9 @@ class ClientTest {
void
HybridSearch(std::string&);
void
GetHEntityByID(const std::string&, const std::vector<int64_t>&);
private:
std::shared_ptr<milvus::Connection> conn_;
std::vector<std::pair<int64_t, milvus::Entity>> search_entity_array_;
......
......@@ -16,6 +16,7 @@
#include <iostream>
#include <memory>
#include <random>
#include <utility>
#include <vector>
......@@ -136,11 +137,13 @@ Utils::BuildEntities(int64_t from, int64_t to, std::vector<milvus::Entity>& enti
entity_array.clear();
entity_ids.clear();
std::default_random_engine e;
std::uniform_real_distribution<float> u(0, 1);
for (int64_t k = from; k < to; k++) {
milvus::Entity entity;
entity.float_data.resize(dimension);
for (int64_t i = 0; i < dimension; i++) {
entity.float_data[i] = (float)((k + 100) % (i + 1));
entity.float_data[i] = (u(e));
}
entity_array.emplace_back(entity);
......@@ -231,10 +234,12 @@ Utils::DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& col
void ConstructVector(uint64_t nq, uint64_t dimension, std::vector<milvus::Entity>& query_vector) {
query_vector.resize(nq);
std::default_random_engine e;
std::uniform_real_distribution<float> u(0, 1);
for (uint64_t i = 0; i < nq; ++i) {
query_vector[i].float_data.resize(dimension);
for (uint64_t j = 0; j < dimension; ++j) {
query_vector[i].float_data[j] = (float)((i + 100) / (j + 1));
query_vector[i].float_data[j] = u(e);
}
}
}
......@@ -242,20 +247,18 @@ void ConstructVector(uint64_t nq, uint64_t dimension, std::vector<milvus::Entity
std::vector<milvus::LeafQueryPtr>
Utils::GenLeafQuery() {
//Construct TermQuery
uint64_t row_num = 1000;
uint64_t row_num = 10000;
std::vector<int64_t> field_value;
field_value.resize(row_num);
for (uint64_t i = 0; i < row_num; ++i) {
field_value[i] = i;
}
std::vector<int8_t> term_value(row_num * sizeof(int64_t));
memcpy(term_value.data(), field_value.data(), row_num * sizeof(int64_t));
milvus::TermQueryPtr tq = std::make_shared<milvus::TermQuery>();
tq->field_name = "field_1";
tq->field_value = term_value;
tq->int_value = field_value;
//Construct RangeQuery
milvus::CompareExpr ce1 = {milvus::CompareOperator::LTE, "10000"}, ce2 = {milvus::CompareOperator::GTE, "1"};
milvus::CompareExpr ce1 = {milvus::CompareOperator::LTE, "100000"}, ce2 = {milvus::CompareOperator::GTE, "1"};
std::vector<milvus::CompareExpr> ces{ce1, ce2};
milvus::RangeQueryPtr rq = std::make_shared<milvus::RangeQuery>();
rq->field_name = "field_2";
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -115,7 +115,10 @@ class GrpcClient {
InsertEntities(milvus::grpc::HInsertParam& entities, milvus::grpc::HEntityIDs& ids);
Status
HybridSearch(milvus::grpc::HSearchParam& search_param, milvus::grpc::TopKQueryResult& result);
HybridSearch(milvus::grpc::HSearchParam& search_param, milvus::grpc::HQueryResult& result);
Status
GetHEntityByID(milvus::grpc::VectorsIdentity& vectors_identity, milvus::grpc::HEntity& entity);
private:
std::unique_ptr<grpc::MilvusService::Stub> stub_;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册