From c4e8d50446314532b5fcd9328ccebe3acc918190 Mon Sep 17 00:00:00 2001 From: fishpenguin Date: Fri, 3 Jul 2020 17:21:11 +0800 Subject: [PATCH] Count entities got wrong result with binary vectors Signed-off-by: fishpenguin --- CHANGELOG.md | 1 + core/src/db/DBImpl.cpp | 41 +++-- core/src/db/engine/ExecutionEngineImpl.cpp | 7 +- core/src/db/insert/MemManager.h | 17 ++ core/src/db/insert/MemManagerImpl.cpp | 27 +++ core/src/db/insert/MemManagerImpl.h | 6 + core/src/db/insert/VectorSource.cpp | 30 +++- core/src/db/wal/WalDefinations.h | 2 - core/src/db/wal/WalManager.cpp | 6 +- core/src/scheduler/task/SearchTask.cpp | 6 +- .../CreateHybridCollectionRequest.cpp | 3 +- .../server/grpc_impl/GrpcRequestHandler.cpp | 4 +- sdk/examples/binary_vector/src/ClientTest.cpp | 157 +++++++++--------- sdk/examples/simple/src/ClientTest.cpp | 61 ++----- sdk/examples/utils/Utils.cpp | 30 ++-- sdk/examples/utils/Utils.h | 6 +- sdk/grpc/ClientProxy.cpp | 5 +- sdk/grpc/ClientProxy.h | 2 +- sdk/include/MilvusApi.h | 2 +- sdk/interface/ConnectionImpl.cpp | 4 +- sdk/interface/ConnectionImpl.h | 2 +- 21 files changed, 233 insertions(+), 186 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 87d6d4b7..296ab33b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ Please mark all changes in change log and use the issue from GitHub - \#2695 The number of fields should be limited - \#2696 Check the validity of the parameters of creating collection: segment_size - \#2697 Index can not be created +- \#2698 Count entities got wrong result with binary vectors ## Feature - \#2319 Redo metadata to support MVCC diff --git a/core/src/db/DBImpl.cpp b/core/src/db/DBImpl.cpp index 14b1f6af..9e9748fd 100644 --- a/core/src/db/DBImpl.cpp +++ b/core/src/db/DBImpl.cpp @@ -974,24 +974,23 @@ DBImpl::InsertEntities(const std::string& collection_id, const std::string& part record.partition_tag = partition_tag; record.ids = entity.id_array_.data(); record.length = entity.entity_count_; + record.attr_data = attr_data; + record.attr_nbytes = attr_nbytes; + record.attr_data_size = attr_data_size; auto vector_it = entity.vector_data_.begin(); if (vector_it->second.binary_data_.empty()) { - record.type = wal::MXLogType::Entity; + record.type = wal::MXLogType::InsertVector; record.data = vector_it->second.float_data_.data(); record.data_size = vector_it->second.float_data_.size() * sizeof(float); - record.attr_data = attr_data; - record.attr_nbytes = attr_nbytes; - record.attr_data_size = attr_data_size; } else { - // record.type = wal::MXLogType::InsertBinary; - // record.data = entities.vector_data_[0].binary_data_.data(); - // record.length = entities.vector_data_[0].binary_data_.size() * sizeof(uint8_t); + record.type = wal::MXLogType::InsertBinary; + record.data = vector_it->second.binary_data_.data(); + record.data_size = vector_it->second.binary_data_.size() * sizeof(uint8_t); } status = ExecWalRecord(record); } - return status; } @@ -3111,9 +3110,16 @@ DBImpl::ExecWalRecord(const wal::MXLogRecord& record) { return status; } - status = mem_mgr_->InsertVectors(target_collection_name, record.length, record.ids, - (record.data_size / record.length / sizeof(uint8_t)), - (const u_int8_t*)record.data, record.lsn); + Vectors vectors; + vectors.vector_type_ = Vectors::BINARY; + vectors.binary_vector = (const uint8_t*)record.data; + status = mem_mgr_->InsertEntities(target_collection_name, record.length, record.ids, + (record.data_size / record.length / sizeof(uint8_t)), vectors, + record.attr_nbytes, record.attr_data_size, record.attr_data, record.lsn); + + // status = mem_mgr_->InsertVectors(target_collection_name, record.length, record.ids, + // (record.data_size / record.length / sizeof(uint8_t)), + // (const u_int8_t*)record.data, record.lsn); force_flush_if_mem_full(); // metrics @@ -3129,9 +3135,16 @@ DBImpl::ExecWalRecord(const wal::MXLogRecord& record) { return status; } - status = mem_mgr_->InsertVectors(target_collection_name, record.length, record.ids, - (record.data_size / record.length / sizeof(float)), - (const float*)record.data, record.lsn); + Vectors vectors; + vectors.vector_type_ = Vectors::FLOAT; + vectors.float_vector = (const float*)record.data; + status = mem_mgr_->InsertEntities(target_collection_name, record.length, record.ids, + (record.data_size / record.length / sizeof(uint8_t)), vectors, + record.attr_nbytes, record.attr_data_size, record.attr_data, record.lsn); + + // status = mem_mgr_->InsertVectors(target_collection_name, record.length, record.ids, + // (record.data_size / record.length / sizeof(float)), + // (const float*)record.data, record.lsn); force_flush_if_mem_full(); // metrics diff --git a/core/src/db/engine/ExecutionEngineImpl.cpp b/core/src/db/engine/ExecutionEngineImpl.cpp index 3b882772..cb04016e 100644 --- a/core/src/db/engine/ExecutionEngineImpl.cpp +++ b/core/src/db/engine/ExecutionEngineImpl.cpp @@ -948,7 +948,12 @@ ExecutionEngineImpl::HybridSearch(scheduler::SearchJobPtr search_job, auto vector_query = search_job->query_ptr()->vectors.at(vector_placeholder); int64_t topk = vector_query->topk; - int64_t nq = vector_query->query_vector.float_data.size() / dim_; + int64_t nq = 0; + if (!vector_query->query_vector.float_data.empty()) { + nq = vector_query->query_vector.float_data.size() / dim_; + } else if (!vector_query->query_vector.binary_data.empty()) { + nq = vector_query->query_vector.binary_data.size() * 8 / dim_; + } engine::VectorsData vectors; vectors.vector_count_ = nq; diff --git a/core/src/db/insert/MemManager.h b/core/src/db/insert/MemManager.h index 63b0c963..3c35dab7 100644 --- a/core/src/db/insert/MemManager.h +++ b/core/src/db/insert/MemManager.h @@ -23,6 +23,17 @@ namespace milvus { namespace engine { +struct Vectors { + typedef enum { + FLOAT, + BINARY, + } VECTOR_TYPE; + + VECTOR_TYPE vector_type_; + const float* float_vector; + const uint8_t* binary_vector; +}; + class MemManager { public: virtual Status @@ -39,6 +50,12 @@ class MemManager { const std::unordered_map& attr_size, const std::unordered_map>& attr_data, uint64_t lsn) = 0; + virtual Status + InsertEntities(const std::string& collection_id, int64_t length, const IDNumber* vector_ids, int64_t dim, + const Vectors vectors, const std::unordered_map& attr_nbytes, + const std::unordered_map& attr_size, + const std::unordered_map>& attr_data, uint64_t lsn) = 0; + virtual Status DeleteVector(const std::string& collection_id, IDNumber vector_id, uint64_t lsn) = 0; diff --git a/core/src/db/insert/MemManagerImpl.cpp b/core/src/db/insert/MemManagerImpl.cpp index 437a0313..b29be378 100644 --- a/core/src/db/insert/MemManagerImpl.cpp +++ b/core/src/db/insert/MemManagerImpl.cpp @@ -84,6 +84,33 @@ MemManagerImpl::InsertEntities(const std::string& collection_id, int64_t length, return InsertEntitiesNoLock(collection_id, source, lsn); } +Status +MemManagerImpl::InsertEntities(const std::string& collection_id, int64_t length, + const milvus::engine::IDNumber* vector_ids, int64_t dim, const Vectors vectors, + const std::unordered_map& attr_nbytes, + const std::unordered_map& attr_size, + const std::unordered_map>& attr_data, uint64_t lsn) { + VectorsData vectors_data; + if (vectors.vector_type_ == Vectors::FLOAT) { + vectors_data.vector_count_ = length; + vectors_data.float_data_.resize(length * dim); + memcpy(vectors_data.float_data_.data(), vectors.float_vector, length * dim * sizeof(float)); + vectors_data.id_array_.resize(length); + memcpy(vectors_data.id_array_.data(), vector_ids, length * sizeof(IDNumber)); + } else if (vectors.vector_type_ == Vectors::BINARY) { + vectors_data.vector_count_ = length; + vectors_data.binary_data_.resize(length * dim); + memcpy(vectors_data.binary_data_.data(), vectors.binary_vector, length * dim * sizeof(uint8_t)); + vectors_data.id_array_.resize(length); + memcpy(vectors_data.id_array_.data(), vector_ids, length * sizeof(IDNumber)); + } + VectorSourcePtr source = std::make_shared(vectors_data, attr_nbytes, attr_size, attr_data); + + std::unique_lock lock(mutex_); + + return InsertEntitiesNoLock(collection_id, source, lsn); +} + Status MemManagerImpl::InsertVectorsNoLock(const std::string& collection_id, const VectorSourcePtr& source, uint64_t lsn) { MemTablePtr mem = GetMemByTable(collection_id); diff --git a/core/src/db/insert/MemManagerImpl.h b/core/src/db/insert/MemManagerImpl.h index 0eb957f4..088d3567 100644 --- a/core/src/db/insert/MemManagerImpl.h +++ b/core/src/db/insert/MemManagerImpl.h @@ -55,6 +55,12 @@ class MemManagerImpl : public MemManager, public server::CacheConfigHandler { const std::unordered_map& attr_size, const std::unordered_map>& attr_data, uint64_t lsn) override; + Status + InsertEntities(const std::string& collection_id, int64_t length, const IDNumber* vector_ids, int64_t dim, + const Vectors vectors, const std::unordered_map& attr_nbytes, + const std::unordered_map& attr_size, + const std::unordered_map>& attr_data, uint64_t lsn) override; + Status DeleteVector(const std::string& collection_id, IDNumber vector_id, uint64_t lsn) override; diff --git a/core/src/db/insert/VectorSource.cpp b/core/src/db/insert/VectorSource.cpp index ad659c19..bac70f5d 100644 --- a/core/src/db/insert/VectorSource.cpp +++ b/core/src/db/insert/VectorSource.cpp @@ -120,13 +120,29 @@ VectorSource::AddEntities(const milvus::segment::SegmentWriterPtr& segment_write return status; } - std::vector vectors; - auto size = num_entities_added * collection_file_schema.dimension_ * sizeof(float); - vectors.resize(size); - memcpy(vectors.data(), vectors_.float_data_.data() + current_num_vectors_added * collection_file_schema.dimension_, - size); - LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld]", "insert", 0) << "Insert into segment"; - status = segment_writer_ptr->AddVectors(collection_file_schema.file_id_, vectors, vector_ids_to_add); + if (!vectors_.float_data_.empty()) { + LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld]", "insert", 0) << "Insert float data into segment"; + auto size = num_entities_added * collection_file_schema.dimension_ * sizeof(float); + float* ptr = vectors_.float_data_.data() + current_num_vectors_added * collection_file_schema.dimension_; + status = + segment_writer_ptr->AddVectors(collection_file_schema.file_id_, (uint8_t*)ptr, size, vector_ids_to_add); + } else if (!vectors_.binary_data_.empty()) { + LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld]", "insert", 0) << "Insert binary data into segment"; + std::vector vectors; + auto size = num_entities_added * SingleVectorSize(collection_file_schema.dimension_) * sizeof(uint8_t); + uint8_t* ptr = vectors_.binary_data_.data() + + current_num_vectors_added * SingleVectorSize(collection_file_schema.dimension_); + status = segment_writer_ptr->AddVectors(collection_file_schema.file_id_, ptr, size, vector_ids_to_add); + } + + // std::vector vectors; + // auto size = num_entities_added * collection_file_schema.dimension_ * sizeof(float); + // vectors.resize(size); + // memcpy(vectors.data(), vectors_.float_data_.data() + current_num_vectors_added * + // collection_file_schema.dimension_, + // size); + // LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld]", "insert", 0) << "Insert into segment"; + // status = segment_writer_ptr->AddVectors(collection_file_schema.file_id_, vectors, vector_ids_to_add); if (status.ok()) { current_num_vectors_added += num_entities_added; vector_ids_.insert(vector_ids_.end(), std::make_move_iterator(vector_ids_to_add.begin()), diff --git a/core/src/db/wal/WalDefinations.h b/core/src/db/wal/WalDefinations.h index 6808e902..385c8302 100644 --- a/core/src/db/wal/WalDefinations.h +++ b/core/src/db/wal/WalDefinations.h @@ -42,8 +42,6 @@ struct MXLogRecord { uint32_t data_size; const void* data; std::vector field_names; - // std::vector attrs_size; - // std::vector attrs_data; std::unordered_map attr_nbytes; std::unordered_map attr_data_size; std::unordered_map> attr_data; diff --git a/core/src/db/wal/WalManager.cpp b/core/src/db/wal/WalManager.cpp index 61db243f..eaba69e9 100644 --- a/core/src/db/wal/WalManager.cpp +++ b/core/src/db/wal/WalManager.cpp @@ -495,9 +495,9 @@ WalManager::InsertEntities(const std::string& collection_id, const std::string& const std::unordered_map>& attrs) { MXLogType log_type; if (std::is_same::value) { - log_type = MXLogType::Entity; - } else { - return false; + log_type = MXLogType::InsertVector; + } else if (std::is_same::value) { + log_type = MXLogType::InsertBinary; } size_t entity_num = entity_ids.size(); diff --git a/core/src/scheduler/task/SearchTask.cpp b/core/src/scheduler/task/SearchTask.cpp index 35b623f2..d224f07f 100644 --- a/core/src/scheduler/task/SearchTask.cpp +++ b/core/src/scheduler/task/SearchTask.cpp @@ -247,7 +247,11 @@ XSearchTask::Execute() { s = index_engine_->HybridSearch(search_job, types, output_distance, output_ids, hybrid); auto vector_query = query_ptr->vectors.begin()->second; topk = vector_query->topk; - nq = vector_query->query_vector.float_data.size() / file_->dimension_; + if (!vector_query->query_vector.float_data.empty()) { + nq = vector_query->query_vector.float_data.size() / file_->dimension_; + } else if (!vector_query->query_vector.binary_data.empty()) { + nq = vector_query->query_vector.binary_data.size() * 5 / file_->dimension_; + } search_job->vector_count() = nq; } else { s = index_engine_->Search(output_ids, output_distance, search_job, hybrid); diff --git a/core/src/server/delivery/hybrid_request/CreateHybridCollectionRequest.cpp b/core/src/server/delivery/hybrid_request/CreateHybridCollectionRequest.cpp index f37aeb9a..8f9d00fa 100644 --- a/core/src/server/delivery/hybrid_request/CreateHybridCollectionRequest.cpp +++ b/core/src/server/delivery/hybrid_request/CreateHybridCollectionRequest.cpp @@ -36,7 +36,8 @@ CreateHybridCollectionRequest::CreateHybridCollectionRequest( collection_name_(collection_name), field_types_(field_types), field_index_params_(field_index_params), - field_params_(field_params) { + field_params_(field_params), + extra_params_(extra_params) { } BaseRequestPtr diff --git a/core/src/server/grpc_impl/GrpcRequestHandler.cpp b/core/src/server/grpc_impl/GrpcRequestHandler.cpp index d247923c..8efe097d 100644 --- a/core/src/server/grpc_impl/GrpcRequestHandler.cpp +++ b/core/src/server/grpc_impl/GrpcRequestHandler.cpp @@ -618,7 +618,7 @@ GrpcRequestHandler::CreateCollection(::grpc::ServerContext* context, const ::mil field_types, field_index_params, field_params, json_params); LOG_SERVER_INFO_ << LogOut("Request [%s] %s end.", GetContext(context)->RequestID().c_str(), __func__); - SET_RESPONSE(response, status, context); + SET_RESPONSE(response, status, context) return ::grpc::Status::OK; } @@ -667,8 +667,6 @@ GrpcRequestHandler::CreateIndex(::grpc::ServerContext* context, const ::milvus:: } } - std::string param = json_params.dump(); - Status status = request_handler_.CreateIndex(GetContext(context), request->collection_name(), request->field_name(), request->index_name(), json_params); diff --git a/sdk/examples/binary_vector/src/ClientTest.cpp b/sdk/examples/binary_vector/src/ClientTest.cpp index 60792375..d8b9ba3e 100644 --- a/sdk/examples/binary_vector/src/ClientTest.cpp +++ b/sdk/examples/binary_vector/src/ClientTest.cpp @@ -9,28 +9,29 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License. -#include "include/MilvusApi.h" #include "examples/binary_vector/src/ClientTest.h" #include "examples/utils/TimeRecorder.h" #include "examples/utils/Utils.h" +#include "include/MilvusApi.h" #include #include +#include #include #include -#include namespace { -constexpr int64_t BATCH_ENTITY_COUNT = 100000; +constexpr int64_t BATCH_ENTITY_COUNT = 10000; constexpr int64_t NQ = 5; constexpr int64_t TOP_K = 10; constexpr int64_t NPROBE = 32; constexpr int64_t SEARCH_TARGET = 5000; // change this value, result is different, ensure less than BATCH_ENTITY_COUNT constexpr int64_t ADD_ENTITY_LOOP = 10; +constexpr int64_t DIMENSION = 128; void -BuildBinaryVectors(int64_t from, int64_t to, std::vector& entity_array, +BuildBinaryVectors(int64_t from, int64_t to, std::vector& entity_array, std::vector& entity_ids, int64_t dimension) { if (to <= from) { return; @@ -44,7 +45,7 @@ BuildBinaryVectors(int64_t from, int64_t to, std::vector& entity dim_byte++; } for (int64_t k = from; k < to; k++) { - milvus::Entity entity; + milvus::VectorData entity; entity.binary_data.resize(dim_byte); for (int64_t i = 0; i < dim_byte; i++) { entity.binary_data[i] = (uint8_t)lrand48(); @@ -56,46 +57,52 @@ BuildBinaryVectors(int64_t from, int64_t to, std::vector& entity } void -TestProcess(std::shared_ptr connection, - const milvus::CollectionParam& collection_param, +TestProcess(std::shared_ptr connection, const milvus::Mapping& mapping, const milvus::IndexParam& index_param) { milvus::Status stat; { // create collection - stat = connection->CreateCollection(collection_param); + JSON extra_params; + extra_params["segment_size"] = 1024; + stat = connection->CreateCollection(mapping, extra_params.dump()); std::cout << "CreateCollection function call status: " << stat.message() << std::endl; - milvus_sdk::Utils::PrintCollectionParam(collection_param); + milvus_sdk::Utils::PrintCollectionParam(mapping); } - std::vector> search_entity_array; + std::vector> search_entity_array; { // insert vectors for (int i = 0; i < ADD_ENTITY_LOOP; i++) { - std::vector entity_array; + milvus::FieldValue field_value; + + std::vector entity_array; std::vector entity_ids; int64_t begin_index = i * BATCH_ENTITY_COUNT; { // generate vectors milvus_sdk::TimeRecorder rc("Build entities No." + std::to_string(i)); - BuildBinaryVectors(begin_index, - begin_index + BATCH_ENTITY_COUNT, - entity_array, - entity_ids, - collection_param.dimension); + BuildBinaryVectors(begin_index, begin_index + BATCH_ENTITY_COUNT, entity_array, entity_ids, DIMENSION); } if (search_entity_array.size() < NQ) { search_entity_array.push_back(std::make_pair(entity_ids[SEARCH_TARGET], entity_array[SEARCH_TARGET])); } + std::vector int64_data(BATCH_ENTITY_COUNT); + for (int j = begin_index; j < begin_index + BATCH_ENTITY_COUNT; j++) { + int64_data[j - begin_index] = j - begin_index; + } + field_value.int64_value.insert(std::make_pair("field_1", int64_data)); + field_value.vector_value.insert(std::make_pair("field_vec", entity_array)); + std::string title = "Insert " + std::to_string(entity_array.size()) + " entities No." + std::to_string(i); milvus_sdk::TimeRecorder rc(title); - stat = connection->Insert(collection_param.collection_name, "", entity_array, entity_ids); + stat = connection->Insert(mapping.collection_name, "", field_value, entity_ids); std::cout << "Insert function call status: " << stat.message() << std::endl; std::cout << "Returned id array count: " << entity_ids.size() << std::endl; } } { // flush buffer - std::vector collections = {collection_param.collection_name}; + std::vector collections = {mapping.collection_name}; stat = connection->Flush(collections); std::cout << "Flush function call status: " << stat.message() << std::endl; } @@ -103,13 +110,8 @@ TestProcess(std::shared_ptr connection, { // search vectors std::vector partition_tags; milvus::TopKQueryResult topk_query_result; - milvus_sdk::Utils::DoSearch(connection, - collection_param.collection_name, - partition_tags, - TOP_K, - NPROBE, - search_entity_array, - topk_query_result); + milvus_sdk::Utils::DoSearch(connection, mapping.collection_name, partition_tags, TOP_K, NPROBE, + search_entity_array, topk_query_result); } { // wait unit build index finish @@ -123,17 +125,12 @@ TestProcess(std::shared_ptr connection, { // search vectors std::vector partition_tags; milvus::TopKQueryResult topk_query_result; - milvus_sdk::Utils::DoSearch(connection, - collection_param.collection_name, - partition_tags, - TOP_K, - NPROBE, - search_entity_array, - topk_query_result); + milvus_sdk::Utils::DoSearch(connection, mapping.collection_name, partition_tags, TOP_K, NPROBE, + search_entity_array, topk_query_result); } { // drop collection - stat = connection->DropCollection(collection_param.collection_name); + stat = connection->DropCollection(mapping.collection_name); std::cout << "DropCollection function call status: " << stat.message() << std::endl; } } @@ -153,58 +150,58 @@ ClientTest::Test(const std::string& address, const std::string& port) { } { - milvus::CollectionParam collection_param = { - "collection_1", - 512, // dimension - 256, // index file size - milvus::MetricType::TANIMOTO - }; + milvus::FieldPtr field_ptr1 = std::make_shared(); + field_ptr1->field_name = "field_1"; + field_ptr1->field_type = milvus::DataType::INT64; + JSON index_param_1; + index_param_1["name"] = "index_1"; + field_ptr1->index_params = index_param_1.dump(); + + milvus::FieldPtr field_ptr2 = std::make_shared(); + field_ptr2->field_type = milvus::DataType::BINARY_VECTOR; + field_ptr2->field_name = "field_vec"; + JSON index_param_2; + index_param_2["name"] = "index_3"; + index_param_2["index_type"] = "IVFFLAT"; + field_ptr2->index_params = index_param_2.dump(); + JSON extra_params; + extra_params["dimension"] = 128; + extra_params["metric_type"] = "TANIMOTO"; + field_ptr2->extra_params = extra_params.dump(); + + milvus::Mapping mapping = {"collection_1", {field_ptr1, field_ptr2}}; JSON json_params = {{"nlist", 1024}}; - milvus::IndexParam index_param = { - collection_param.collection_name, - milvus::IndexType::IVFFLAT, - json_params.dump() - }; - - TestProcess(connection, collection_param, index_param); - } + milvus::IndexParam index_param = {mapping.collection_name, "field_2", "index_3", json_params.dump()}; - { - milvus::CollectionParam collection_param = { - "collection_2", - 512, // dimension - 512, // index file size - milvus::MetricType::SUBSTRUCTURE - }; - - JSON json_params = {}; - milvus::IndexParam index_param = { - collection_param.collection_name, - milvus::IndexType::FLAT, - json_params.dump() - }; - - TestProcess(connection, collection_param, index_param); + TestProcess(connection, mapping, index_param); } - { - milvus::CollectionParam collection_param = { - "collection_3", - 128, // dimension - 1024, // index file size - milvus::MetricType::SUPERSTRUCTURE - }; - - JSON json_params = {}; - milvus::IndexParam index_param = { - collection_param.collection_name, - milvus::IndexType::FLAT, - json_params.dump() - }; - - TestProcess(connection, collection_param, index_param); - } + // { + // milvus::Mapping collection_param = {"collection_2", + // 512, // dimension + // 512, // index file size + // milvus::MetricType::SUBSTRUCTURE}; + // + // JSON json_params = {}; + // milvus::IndexParam index_param = {collection_param.collection_name, milvus::IndexType::FLAT, + // json_params.dump()}; + // + // TestProcess(connection, collection_param, index_param); + // } + // + // { + // milvus::Mapping mapping = {"collection_3", + // 128, // dimension + // 1024, // index file size + // milvus::MetricType::SUPERSTRUCTURE}; + // + // JSON json_params = {}; + // milvus::IndexParam index_param = {collection_param.collection_name, milvus::IndexType::FLAT, + // json_params.dump()}; + // + // TestProcess(connection, collection_param, index_param); + // } milvus::Connection::Destroy(connection); } diff --git a/sdk/examples/simple/src/ClientTest.cpp b/sdk/examples/simple/src/ClientTest.cpp index 58fcc76a..d88ce691 100644 --- a/sdk/examples/simple/src/ClientTest.cpp +++ b/sdk/examples/simple/src/ClientTest.cpp @@ -103,53 +103,26 @@ ClientTest::CreateCollection(const std::string& collection_name) { index_param_2["name"] = "index_2"; field_ptr2->index_params = index_param_2.dump(); - field_ptr3->field_name = "field_3"; + field_ptr3->field_name = "field_vec"; field_ptr3->field_type = milvus::DataType::FLOAT_VECTOR; JSON index_param_3; index_param_3["name"] = "index_3"; index_param_3["index_type"] = "IVFFLAT"; field_ptr3->index_params = index_param_3.dump(); - JSON extra_params; - extra_params["dimension"] = COLLECTION_DIMENSION; - field_ptr3->extra_params = extra_params.dump(); + JSON extra_params_3; + extra_params_3["dimension"] = COLLECTION_DIMENSION; + field_ptr3->extra_params = extra_params_3.dump(); + JSON extra_params; + extra_params["segment_size"] = " "; milvus::Mapping mapping = {collection_name, {field_ptr1, field_ptr2, field_ptr3}}; - milvus::Status stat = conn_->CreateCollection(mapping); + milvus::Status stat = conn_->CreateCollection(mapping, extra_params.dump()); std::cout << "CreateCollection function call status: " << stat.message() << std::endl; } void ClientTest::GetCollectionInfo(const std::string& collection_name) { - // milvus::FieldPtr field_ptr1 = std::make_shared(); - // milvus::FieldPtr field_ptr2 = std::make_shared(); - // milvus::FieldPtr field_ptr3 = std::make_shared(); - // field_ptr1->field_name = "field_1"; - // field_ptr1->field_type = milvus::DataType::INT64; - // JSON index_param_1; - // index_param_1["name"] = "index_1"; - // field_ptr1->index_params = index_param_1.dump(); - // - // field_ptr2->field_name = "field_2"; - // field_ptr2->field_type = milvus::DataType::FLOAT; - // JSON index_param_2; - // index_param_2["name"] = "index_2"; - // field_ptr2->index_params = index_param_2.dump(); - // - // field_ptr3->field_name = "field_3"; - // field_ptr3->field_type = milvus::DataType::FLOAT_VECTOR; - // JSON index_param_3; - // index_param_3["name"] = "index_3"; - // index_param_3["index_type"] = "IVFFLAT"; - // field_ptr3->index_params = index_param_3.dump(); - // JSON extra_params; - // extra_params["dimension"] = COLLECTION_DIMENSION; - // field_ptr3->extra_params = extra_params.dump(); - // - // milvus::Mapping mapping = {collection_name, {field_ptr1, field_ptr2, field_ptr3}}; - // - // milvus::Status stat = conn_->CreateCollection(mapping); - // std::cout << "CreateCollection function call status: " << stat.message() << std::endl; } void @@ -174,25 +147,11 @@ ClientTest::InsertEntities(const std::string& collection_name, int64_t row_num) milvus_sdk::Utils::BuildEntities(0, row_num, entity_array, record_ids, COLLECTION_DIMENSION); } - field_value.vector_value.insert(std::make_pair("field_3", entity_array)); + field_value.vector_value.insert(std::make_pair("field_vec", entity_array)); milvus::Status status = conn_->Insert(collection_name, "", field_value, record_ids); std::cout << "InsertEntities function call status: " << status.message() << std::endl; } -// void -// ClientTest::BuildSearchEntities(int64_t nq, int64_t dim) { -// search_entity_array_.clear(); -// search_id_array_.clear(); -// for (int64_t i = 0; i < nq; i++) { -// std::vector entity_array; -// std::vector record_ids; -// int64_t index = i * BATCH_ENTITY_COUNT + SEARCH_TARGET; -// milvus_sdk::Utils::BuildEntities(index, index + 1, entity_array, record_ids, dim); -// search_entity_array_.push_back(std::make_pair(record_ids[0], entity_array[0])); -// search_id_array_.push_back(record_ids[0]); -// } -//} - void ClientTest::Flush(const std::string& collection_name) { milvus_sdk::TimeRecorder rc("Flush"); @@ -282,7 +241,7 @@ ClientTest::CreateIndex(const std::string& collection_name, int64_t nlist) { milvus_sdk::TimeRecorder rc("Create index"); std::cout << "Wait until create all index done" << std::endl; JSON json_params = {{"nlist", nlist}, {"index_type", "IVFFLAT"}}; - milvus::IndexParam index1 = {collection_name, "field_3", "index_3", json_params.dump()}; + milvus::IndexParam index1 = {collection_name, "field_vec", "index_3", json_params.dump()}; milvus_sdk::Utils::PrintIndexParam(index1); milvus::Status stat = conn_->CreateIndex(index1); std::cout << "CreateIndex function call status: " << stat.message() << std::endl; @@ -371,6 +330,6 @@ ClientTest::Test() { LoadCollection(collection_name); SearchEntities(collection_name, TOP_K, NPROBE); // this line get two search error since we delete two entities - DropIndex(collection_name, "field_3", "index_3"); + DropIndex(collection_name, "field_vec", "index_3"); DropCollection(collection_name); } diff --git a/sdk/examples/utils/Utils.cpp b/sdk/examples/utils/Utils.cpp index f7119a7c..ed0c2cbb 100644 --- a/sdk/examples/utils/Utils.cpp +++ b/sdk/examples/utils/Utils.cpp @@ -126,9 +126,9 @@ Utils::PrintCollectionParam(const milvus::Mapping& mapping) { std::cout << "Collection name: " << mapping.collection_name << std::endl; for (const auto& field : mapping.fields) { std::cout << "field_name: " << field->field_name; - std::cout << "field_type: " << std::to_string((int)field->field_type); - std::cout << "index_param: " << field->index_params; - std::cout << "extra_param:" << field->extra_params; + std::cout << "\tfield_type: " << std::to_string((int)field->field_type); + std::cout << "\tindex_param: " << field->index_params; + std::cout << "\textra_param:" << field->extra_params << std::endl; } BLOCK_SPLITER } @@ -227,24 +227,26 @@ Utils::CheckSearchResult(const std::vector conn, const std::string& collection_name, const std::vector& partition_tags, int64_t top_k, int64_t nprobe, - const std::vector>& entity_array, + std::vector> entity_array, milvus::TopKQueryResult& topk_query_result) { topk_query_result.clear(); + nlohmann::json dsl_json, vector_param_json; + GenDSLJson(dsl_json, vector_param_json); + std::vector temp_entity_array; for (auto& pair : entity_array) { temp_entity_array.push_back(pair.second); } + milvus::VectorParam vector_param = {vector_param_json.dump(), temp_entity_array}; - { - BLOCK_SPLITER - JSON json_params = {{"nprobe", nprobe}}; - milvus_sdk::TimeRecorder rc("Search"); - BLOCK_SPLITER - } + JSON json_params = {{"nprobe", nprobe}}; + milvus_sdk::TimeRecorder rc("Search"); + + auto status = conn->Search(collection_name, partition_tags, dsl_json.dump(), vector_param, topk_query_result); - PrintSearchResult(entity_array, topk_query_result); - CheckSearchResult(entity_array, topk_query_result); + PrintTopKQueryResult(topk_query_result); + // PrintSearchResult(entity_array, topk_query_result); } void @@ -286,7 +288,7 @@ Utils::GenLeafQuery() { uint64_t NPROBE = 32; milvus::VectorQueryPtr vq = std::make_shared(); ConstructVector(NQ, DIMENSION, vq->query_vector); - vq->field_name = "field_3"; + vq->field_name = "field_vec"; vq->topk = 10; JSON json_params = {{"nprobe", NPROBE}}; vq->extra_params = json_params.dump(); @@ -340,7 +342,7 @@ Utils::GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json) { query_vector_json["topk"] = topk; vector_extra_params["nprobe"] = 64; query_vector_json["params"] = vector_extra_params; - vector_param_json[placeholder]["field_3"] = query_vector_json; + vector_param_json[placeholder]["field_vec"] = query_vector_json; } void diff --git a/sdk/examples/utils/Utils.h b/sdk/examples/utils/Utils.h index 51a22413..b27c4e24 100644 --- a/sdk/examples/utils/Utils.h +++ b/sdk/examples/utils/Utils.h @@ -54,8 +54,8 @@ class Utils { PrintIndexParam(const milvus::IndexParam& index_param); static void - BuildEntities(int64_t from, int64_t to, std::vector& entity_array, std::vector& entity_ids, - int64_t dimension); + BuildEntities(int64_t from, int64_t to, std::vector& entity_array, + std::vector& entity_ids, int64_t dimension); static void PrintSearchResult(const std::vector>& entity_array, @@ -68,7 +68,7 @@ class Utils { static void DoSearch(std::shared_ptr conn, const std::string& collection_name, const std::vector& partition_tags, int64_t top_k, int64_t nprobe, - const std::vector>& entity_array, + std::vector> search_entity_array, milvus::TopKQueryResult& topk_query_result); static void diff --git a/sdk/grpc/ClientProxy.cpp b/sdk/grpc/ClientProxy.cpp index 5abc0b50..fcd90b1f 100644 --- a/sdk/grpc/ClientProxy.cpp +++ b/sdk/grpc/ClientProxy.cpp @@ -451,7 +451,7 @@ ClientProxy::SetConfig(const std::string& node_name, const std::string& value) c } Status -ClientProxy::CreateCollection(const Mapping& mapping) { +ClientProxy::CreateCollection(const Mapping& mapping, const std::string& extra_params) { try { ::milvus::grpc::Mapping grpc_mapping; grpc_mapping.set_collection_name(mapping.collection_name); @@ -470,6 +470,9 @@ ClientProxy::CreateCollection(const Mapping& mapping) { grpc_extra_param->set_key(EXTRA_PARAM_KEY); grpc_extra_param->set_value(field->extra_params); } + auto grpc_param = grpc_mapping.add_extra_params(); + grpc_param->set_key(EXTRA_PARAM_KEY); + grpc_param->set_value(extra_params); return client_ptr_->CreateCollection(grpc_mapping); } catch (std::exception& ex) { diff --git a/sdk/grpc/ClientProxy.h b/sdk/grpc/ClientProxy.h index 84d09992..4fd148d1 100644 --- a/sdk/grpc/ClientProxy.h +++ b/sdk/grpc/ClientProxy.h @@ -51,7 +51,7 @@ class ClientProxy : public Connection { SetConfig(const std::string& node_name, const std::string& value) const override; Status - CreateCollection(const Mapping& mapping) override; + CreateCollection(const Mapping& mapping, const std::string& extra_params) override; bool HasCollection(const std::string& collection_name) override; diff --git a/sdk/include/MilvusApi.h b/sdk/include/MilvusApi.h index ac2f9a9e..2bbc3c65 100644 --- a/sdk/include/MilvusApi.h +++ b/sdk/include/MilvusApi.h @@ -282,7 +282,7 @@ class Connection { * @return Indicate if collection is created successfully */ virtual Status - CreateCollection(const Mapping& mapping) = 0; + CreateCollection(const Mapping& mapping, const std::string& extra_params) = 0; /** * @brief Test collection existence method diff --git a/sdk/interface/ConnectionImpl.cpp b/sdk/interface/ConnectionImpl.cpp index d2e45339..6fd9293f 100644 --- a/sdk/interface/ConnectionImpl.cpp +++ b/sdk/interface/ConnectionImpl.cpp @@ -77,8 +77,8 @@ ConnectionImpl::SetConfig(const std::string& node_name, const std::string& value } Status -ConnectionImpl::CreateCollection(const Mapping& mapping) { - return client_proxy_->CreateCollection(mapping); +ConnectionImpl::CreateCollection(const Mapping& mapping, const std::string& extra_params) { + return client_proxy_->CreateCollection(mapping, extra_params); } bool diff --git a/sdk/interface/ConnectionImpl.h b/sdk/interface/ConnectionImpl.h index 5cd2e4ae..0f8eadf8 100644 --- a/sdk/interface/ConnectionImpl.h +++ b/sdk/interface/ConnectionImpl.h @@ -53,7 +53,7 @@ class ConnectionImpl : public Connection { SetConfig(const std::string& node_name, const std::string& value) const override; Status - CreateCollection(const Mapping& mapping) override; + CreateCollection(const Mapping& mapping, const std::string& extra_params) override; bool HasCollection(const std::string& collection_name) override; -- GitLab