提交 c4e8d504 编写于 作者: F fishpenguin

Count entities got wrong result with binary vectors

Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>
上级 e9e5ca66
......@@ -20,6 +20,7 @@ Please mark all changes in change log and use the issue from GitHub
- \#2695 The number of fields should be limited
- \#2696 Check the validity of the parameters of creating collection: segment_size
- \#2697 Index can not be created
- \#2698 Count entities got wrong result with binary vectors
## Feature
- \#2319 Redo metadata to support MVCC
......
......@@ -974,24 +974,23 @@ DBImpl::InsertEntities(const std::string& collection_id, const std::string& part
record.partition_tag = partition_tag;
record.ids = entity.id_array_.data();
record.length = entity.entity_count_;
record.attr_data = attr_data;
record.attr_nbytes = attr_nbytes;
record.attr_data_size = attr_data_size;
auto vector_it = entity.vector_data_.begin();
if (vector_it->second.binary_data_.empty()) {
record.type = wal::MXLogType::Entity;
record.type = wal::MXLogType::InsertVector;
record.data = vector_it->second.float_data_.data();
record.data_size = vector_it->second.float_data_.size() * sizeof(float);
record.attr_data = attr_data;
record.attr_nbytes = attr_nbytes;
record.attr_data_size = attr_data_size;
} else {
// record.type = wal::MXLogType::InsertBinary;
// record.data = entities.vector_data_[0].binary_data_.data();
// record.length = entities.vector_data_[0].binary_data_.size() * sizeof(uint8_t);
record.type = wal::MXLogType::InsertBinary;
record.data = vector_it->second.binary_data_.data();
record.data_size = vector_it->second.binary_data_.size() * sizeof(uint8_t);
}
status = ExecWalRecord(record);
}
return status;
}
......@@ -3111,9 +3110,16 @@ DBImpl::ExecWalRecord(const wal::MXLogRecord& record) {
return status;
}
status = mem_mgr_->InsertVectors(target_collection_name, record.length, record.ids,
(record.data_size / record.length / sizeof(uint8_t)),
(const u_int8_t*)record.data, record.lsn);
Vectors vectors;
vectors.vector_type_ = Vectors::BINARY;
vectors.binary_vector = (const uint8_t*)record.data;
status = mem_mgr_->InsertEntities(target_collection_name, record.length, record.ids,
(record.data_size / record.length / sizeof(uint8_t)), vectors,
record.attr_nbytes, record.attr_data_size, record.attr_data, record.lsn);
// status = mem_mgr_->InsertVectors(target_collection_name, record.length, record.ids,
// (record.data_size / record.length / sizeof(uint8_t)),
// (const u_int8_t*)record.data, record.lsn);
force_flush_if_mem_full();
// metrics
......@@ -3129,9 +3135,16 @@ DBImpl::ExecWalRecord(const wal::MXLogRecord& record) {
return status;
}
status = mem_mgr_->InsertVectors(target_collection_name, record.length, record.ids,
(record.data_size / record.length / sizeof(float)),
(const float*)record.data, record.lsn);
Vectors vectors;
vectors.vector_type_ = Vectors::FLOAT;
vectors.float_vector = (const float*)record.data;
status = mem_mgr_->InsertEntities(target_collection_name, record.length, record.ids,
(record.data_size / record.length / sizeof(uint8_t)), vectors,
record.attr_nbytes, record.attr_data_size, record.attr_data, record.lsn);
// status = mem_mgr_->InsertVectors(target_collection_name, record.length, record.ids,
// (record.data_size / record.length / sizeof(float)),
// (const float*)record.data, record.lsn);
force_flush_if_mem_full();
// metrics
......
......@@ -948,7 +948,12 @@ ExecutionEngineImpl::HybridSearch(scheduler::SearchJobPtr search_job,
auto vector_query = search_job->query_ptr()->vectors.at(vector_placeholder);
int64_t topk = vector_query->topk;
int64_t nq = vector_query->query_vector.float_data.size() / dim_;
int64_t nq = 0;
if (!vector_query->query_vector.float_data.empty()) {
nq = vector_query->query_vector.float_data.size() / dim_;
} else if (!vector_query->query_vector.binary_data.empty()) {
nq = vector_query->query_vector.binary_data.size() * 8 / dim_;
}
engine::VectorsData vectors;
vectors.vector_count_ = nq;
......
......@@ -23,6 +23,17 @@
namespace milvus {
namespace engine {
struct Vectors {
typedef enum {
FLOAT,
BINARY,
} VECTOR_TYPE;
VECTOR_TYPE vector_type_;
const float* float_vector;
const uint8_t* binary_vector;
};
class MemManager {
public:
virtual Status
......@@ -39,6 +50,12 @@ class MemManager {
const std::unordered_map<std::string, uint64_t>& attr_size,
const std::unordered_map<std::string, std::vector<uint8_t>>& attr_data, uint64_t lsn) = 0;
virtual Status
InsertEntities(const std::string& collection_id, int64_t length, const IDNumber* vector_ids, int64_t dim,
const Vectors vectors, const std::unordered_map<std::string, uint64_t>& attr_nbytes,
const std::unordered_map<std::string, uint64_t>& attr_size,
const std::unordered_map<std::string, std::vector<uint8_t>>& attr_data, uint64_t lsn) = 0;
virtual Status
DeleteVector(const std::string& collection_id, IDNumber vector_id, uint64_t lsn) = 0;
......
......@@ -84,6 +84,33 @@ MemManagerImpl::InsertEntities(const std::string& collection_id, int64_t length,
return InsertEntitiesNoLock(collection_id, source, lsn);
}
Status
MemManagerImpl::InsertEntities(const std::string& collection_id, int64_t length,
const milvus::engine::IDNumber* vector_ids, int64_t dim, const Vectors vectors,
const std::unordered_map<std::string, uint64_t>& attr_nbytes,
const std::unordered_map<std::string, uint64_t>& attr_size,
const std::unordered_map<std::string, std::vector<uint8_t>>& attr_data, uint64_t lsn) {
VectorsData vectors_data;
if (vectors.vector_type_ == Vectors::FLOAT) {
vectors_data.vector_count_ = length;
vectors_data.float_data_.resize(length * dim);
memcpy(vectors_data.float_data_.data(), vectors.float_vector, length * dim * sizeof(float));
vectors_data.id_array_.resize(length);
memcpy(vectors_data.id_array_.data(), vector_ids, length * sizeof(IDNumber));
} else if (vectors.vector_type_ == Vectors::BINARY) {
vectors_data.vector_count_ = length;
vectors_data.binary_data_.resize(length * dim);
memcpy(vectors_data.binary_data_.data(), vectors.binary_vector, length * dim * sizeof(uint8_t));
vectors_data.id_array_.resize(length);
memcpy(vectors_data.id_array_.data(), vector_ids, length * sizeof(IDNumber));
}
VectorSourcePtr source = std::make_shared<VectorSource>(vectors_data, attr_nbytes, attr_size, attr_data);
std::unique_lock<std::mutex> lock(mutex_);
return InsertEntitiesNoLock(collection_id, source, lsn);
}
Status
MemManagerImpl::InsertVectorsNoLock(const std::string& collection_id, const VectorSourcePtr& source, uint64_t lsn) {
MemTablePtr mem = GetMemByTable(collection_id);
......
......@@ -55,6 +55,12 @@ class MemManagerImpl : public MemManager, public server::CacheConfigHandler {
const std::unordered_map<std::string, uint64_t>& attr_size,
const std::unordered_map<std::string, std::vector<uint8_t>>& attr_data, uint64_t lsn) override;
Status
InsertEntities(const std::string& collection_id, int64_t length, const IDNumber* vector_ids, int64_t dim,
const Vectors vectors, const std::unordered_map<std::string, uint64_t>& attr_nbytes,
const std::unordered_map<std::string, uint64_t>& attr_size,
const std::unordered_map<std::string, std::vector<uint8_t>>& attr_data, uint64_t lsn) override;
Status
DeleteVector(const std::string& collection_id, IDNumber vector_id, uint64_t lsn) override;
......
......@@ -120,13 +120,29 @@ VectorSource::AddEntities(const milvus::segment::SegmentWriterPtr& segment_write
return status;
}
std::vector<uint8_t> vectors;
auto size = num_entities_added * collection_file_schema.dimension_ * sizeof(float);
vectors.resize(size);
memcpy(vectors.data(), vectors_.float_data_.data() + current_num_vectors_added * collection_file_schema.dimension_,
size);
LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld]", "insert", 0) << "Insert into segment";
status = segment_writer_ptr->AddVectors(collection_file_schema.file_id_, vectors, vector_ids_to_add);
if (!vectors_.float_data_.empty()) {
LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld]", "insert", 0) << "Insert float data into segment";
auto size = num_entities_added * collection_file_schema.dimension_ * sizeof(float);
float* ptr = vectors_.float_data_.data() + current_num_vectors_added * collection_file_schema.dimension_;
status =
segment_writer_ptr->AddVectors(collection_file_schema.file_id_, (uint8_t*)ptr, size, vector_ids_to_add);
} else if (!vectors_.binary_data_.empty()) {
LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld]", "insert", 0) << "Insert binary data into segment";
std::vector<uint8_t> vectors;
auto size = num_entities_added * SingleVectorSize(collection_file_schema.dimension_) * sizeof(uint8_t);
uint8_t* ptr = vectors_.binary_data_.data() +
current_num_vectors_added * SingleVectorSize(collection_file_schema.dimension_);
status = segment_writer_ptr->AddVectors(collection_file_schema.file_id_, ptr, size, vector_ids_to_add);
}
// std::vector<uint8_t> vectors;
// auto size = num_entities_added * collection_file_schema.dimension_ * sizeof(float);
// vectors.resize(size);
// memcpy(vectors.data(), vectors_.float_data_.data() + current_num_vectors_added *
// collection_file_schema.dimension_,
// size);
// LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld]", "insert", 0) << "Insert into segment";
// status = segment_writer_ptr->AddVectors(collection_file_schema.file_id_, vectors, vector_ids_to_add);
if (status.ok()) {
current_num_vectors_added += num_entities_added;
vector_ids_.insert(vector_ids_.end(), std::make_move_iterator(vector_ids_to_add.begin()),
......
......@@ -42,8 +42,6 @@ struct MXLogRecord {
uint32_t data_size;
const void* data;
std::vector<std::string> field_names;
// std::vector<uint32_t> attrs_size;
// std::vector<const void* > attrs_data;
std::unordered_map<std::string, uint64_t> attr_nbytes;
std::unordered_map<std::string, uint64_t> attr_data_size;
std::unordered_map<std::string, std::vector<uint8_t>> attr_data;
......
......@@ -495,9 +495,9 @@ WalManager::InsertEntities(const std::string& collection_id, const std::string&
const std::unordered_map<std::string, std::vector<uint8_t>>& attrs) {
MXLogType log_type;
if (std::is_same<T, float>::value) {
log_type = MXLogType::Entity;
} else {
return false;
log_type = MXLogType::InsertVector;
} else if (std::is_same<T, uint8_t>::value) {
log_type = MXLogType::InsertBinary;
}
size_t entity_num = entity_ids.size();
......
......@@ -247,7 +247,11 @@ XSearchTask::Execute() {
s = index_engine_->HybridSearch(search_job, types, output_distance, output_ids, hybrid);
auto vector_query = query_ptr->vectors.begin()->second;
topk = vector_query->topk;
nq = vector_query->query_vector.float_data.size() / file_->dimension_;
if (!vector_query->query_vector.float_data.empty()) {
nq = vector_query->query_vector.float_data.size() / file_->dimension_;
} else if (!vector_query->query_vector.binary_data.empty()) {
nq = vector_query->query_vector.binary_data.size() * 5 / file_->dimension_;
}
search_job->vector_count() = nq;
} else {
s = index_engine_->Search(output_ids, output_distance, search_job, hybrid);
......
......@@ -36,7 +36,8 @@ CreateHybridCollectionRequest::CreateHybridCollectionRequest(
collection_name_(collection_name),
field_types_(field_types),
field_index_params_(field_index_params),
field_params_(field_params) {
field_params_(field_params),
extra_params_(extra_params) {
}
BaseRequestPtr
......
......@@ -618,7 +618,7 @@ GrpcRequestHandler::CreateCollection(::grpc::ServerContext* context, const ::mil
field_types, field_index_params, field_params, json_params);
LOG_SERVER_INFO_ << LogOut("Request [%s] %s end.", GetContext(context)->RequestID().c_str(), __func__);
SET_RESPONSE(response, status, context);
SET_RESPONSE(response, status, context)
return ::grpc::Status::OK;
}
......@@ -667,8 +667,6 @@ GrpcRequestHandler::CreateIndex(::grpc::ServerContext* context, const ::milvus::
}
}
std::string param = json_params.dump();
Status status = request_handler_.CreateIndex(GetContext(context), request->collection_name(), request->field_name(),
request->index_name(), json_params);
......
......@@ -9,28 +9,29 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "include/MilvusApi.h"
#include "examples/binary_vector/src/ClientTest.h"
#include "examples/utils/TimeRecorder.h"
#include "examples/utils/Utils.h"
#include "include/MilvusApi.h"
#include <iostream>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include <random>
namespace {
constexpr int64_t BATCH_ENTITY_COUNT = 100000;
constexpr int64_t BATCH_ENTITY_COUNT = 10000;
constexpr int64_t NQ = 5;
constexpr int64_t TOP_K = 10;
constexpr int64_t NPROBE = 32;
constexpr int64_t SEARCH_TARGET = 5000; // change this value, result is different, ensure less than BATCH_ENTITY_COUNT
constexpr int64_t ADD_ENTITY_LOOP = 10;
constexpr int64_t DIMENSION = 128;
void
BuildBinaryVectors(int64_t from, int64_t to, std::vector<milvus::Entity>& entity_array,
BuildBinaryVectors(int64_t from, int64_t to, std::vector<milvus::VectorData>& entity_array,
std::vector<int64_t>& entity_ids, int64_t dimension) {
if (to <= from) {
return;
......@@ -44,7 +45,7 @@ BuildBinaryVectors(int64_t from, int64_t to, std::vector<milvus::Entity>& entity
dim_byte++;
}
for (int64_t k = from; k < to; k++) {
milvus::Entity entity;
milvus::VectorData entity;
entity.binary_data.resize(dim_byte);
for (int64_t i = 0; i < dim_byte; i++) {
entity.binary_data[i] = (uint8_t)lrand48();
......@@ -56,46 +57,52 @@ BuildBinaryVectors(int64_t from, int64_t to, std::vector<milvus::Entity>& entity
}
void
TestProcess(std::shared_ptr<milvus::Connection> connection,
const milvus::CollectionParam& collection_param,
TestProcess(std::shared_ptr<milvus::Connection> connection, const milvus::Mapping& mapping,
const milvus::IndexParam& index_param) {
milvus::Status stat;
{ // create collection
stat = connection->CreateCollection(collection_param);
JSON extra_params;
extra_params["segment_size"] = 1024;
stat = connection->CreateCollection(mapping, extra_params.dump());
std::cout << "CreateCollection function call status: " << stat.message() << std::endl;
milvus_sdk::Utils::PrintCollectionParam(collection_param);
milvus_sdk::Utils::PrintCollectionParam(mapping);
}
std::vector<std::pair<int64_t, milvus::Entity>> search_entity_array;
std::vector<std::pair<int64_t, milvus::VectorData>> search_entity_array;
{ // insert vectors
for (int i = 0; i < ADD_ENTITY_LOOP; i++) {
std::vector<milvus::Entity> entity_array;
milvus::FieldValue field_value;
std::vector<milvus::VectorData> entity_array;
std::vector<int64_t> entity_ids;
int64_t begin_index = i * BATCH_ENTITY_COUNT;
{ // generate vectors
milvus_sdk::TimeRecorder rc("Build entities No." + std::to_string(i));
BuildBinaryVectors(begin_index,
begin_index + BATCH_ENTITY_COUNT,
entity_array,
entity_ids,
collection_param.dimension);
BuildBinaryVectors(begin_index, begin_index + BATCH_ENTITY_COUNT, entity_array, entity_ids, DIMENSION);
}
if (search_entity_array.size() < NQ) {
search_entity_array.push_back(std::make_pair(entity_ids[SEARCH_TARGET], entity_array[SEARCH_TARGET]));
}
std::vector<int64_t> int64_data(BATCH_ENTITY_COUNT);
for (int j = begin_index; j < begin_index + BATCH_ENTITY_COUNT; j++) {
int64_data[j - begin_index] = j - begin_index;
}
field_value.int64_value.insert(std::make_pair("field_1", int64_data));
field_value.vector_value.insert(std::make_pair("field_vec", entity_array));
std::string title = "Insert " + std::to_string(entity_array.size()) + " entities No." + std::to_string(i);
milvus_sdk::TimeRecorder rc(title);
stat = connection->Insert(collection_param.collection_name, "", entity_array, entity_ids);
stat = connection->Insert(mapping.collection_name, "", field_value, entity_ids);
std::cout << "Insert function call status: " << stat.message() << std::endl;
std::cout << "Returned id array count: " << entity_ids.size() << std::endl;
}
}
{ // flush buffer
std::vector<std::string> collections = {collection_param.collection_name};
std::vector<std::string> collections = {mapping.collection_name};
stat = connection->Flush(collections);
std::cout << "Flush function call status: " << stat.message() << std::endl;
}
......@@ -103,13 +110,8 @@ TestProcess(std::shared_ptr<milvus::Connection> connection,
{ // search vectors
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
milvus_sdk::Utils::DoSearch(connection,
collection_param.collection_name,
partition_tags,
TOP_K,
NPROBE,
search_entity_array,
topk_query_result);
milvus_sdk::Utils::DoSearch(connection, mapping.collection_name, partition_tags, TOP_K, NPROBE,
search_entity_array, topk_query_result);
}
{ // wait unit build index finish
......@@ -123,17 +125,12 @@ TestProcess(std::shared_ptr<milvus::Connection> connection,
{ // search vectors
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
milvus_sdk::Utils::DoSearch(connection,
collection_param.collection_name,
partition_tags,
TOP_K,
NPROBE,
search_entity_array,
topk_query_result);
milvus_sdk::Utils::DoSearch(connection, mapping.collection_name, partition_tags, TOP_K, NPROBE,
search_entity_array, topk_query_result);
}
{ // drop collection
stat = connection->DropCollection(collection_param.collection_name);
stat = connection->DropCollection(mapping.collection_name);
std::cout << "DropCollection function call status: " << stat.message() << std::endl;
}
}
......@@ -153,58 +150,58 @@ ClientTest::Test(const std::string& address, const std::string& port) {
}
{
milvus::CollectionParam collection_param = {
"collection_1",
512, // dimension
256, // index file size
milvus::MetricType::TANIMOTO
};
milvus::FieldPtr field_ptr1 = std::make_shared<milvus::Field>();
field_ptr1->field_name = "field_1";
field_ptr1->field_type = milvus::DataType::INT64;
JSON index_param_1;
index_param_1["name"] = "index_1";
field_ptr1->index_params = index_param_1.dump();
milvus::FieldPtr field_ptr2 = std::make_shared<milvus::Field>();
field_ptr2->field_type = milvus::DataType::BINARY_VECTOR;
field_ptr2->field_name = "field_vec";
JSON index_param_2;
index_param_2["name"] = "index_3";
index_param_2["index_type"] = "IVFFLAT";
field_ptr2->index_params = index_param_2.dump();
JSON extra_params;
extra_params["dimension"] = 128;
extra_params["metric_type"] = "TANIMOTO";
field_ptr2->extra_params = extra_params.dump();
milvus::Mapping mapping = {"collection_1", {field_ptr1, field_ptr2}};
JSON json_params = {{"nlist", 1024}};
milvus::IndexParam index_param = {
collection_param.collection_name,
milvus::IndexType::IVFFLAT,
json_params.dump()
};
TestProcess(connection, collection_param, index_param);
}
milvus::IndexParam index_param = {mapping.collection_name, "field_2", "index_3", json_params.dump()};
{
milvus::CollectionParam collection_param = {
"collection_2",
512, // dimension
512, // index file size
milvus::MetricType::SUBSTRUCTURE
};
JSON json_params = {};
milvus::IndexParam index_param = {
collection_param.collection_name,
milvus::IndexType::FLAT,
json_params.dump()
};
TestProcess(connection, collection_param, index_param);
TestProcess(connection, mapping, index_param);
}
{
milvus::CollectionParam collection_param = {
"collection_3",
128, // dimension
1024, // index file size
milvus::MetricType::SUPERSTRUCTURE
};
JSON json_params = {};
milvus::IndexParam index_param = {
collection_param.collection_name,
milvus::IndexType::FLAT,
json_params.dump()
};
TestProcess(connection, collection_param, index_param);
}
// {
// milvus::Mapping collection_param = {"collection_2",
// 512, // dimension
// 512, // index file size
// milvus::MetricType::SUBSTRUCTURE};
//
// JSON json_params = {};
// milvus::IndexParam index_param = {collection_param.collection_name, milvus::IndexType::FLAT,
// json_params.dump()};
//
// TestProcess(connection, collection_param, index_param);
// }
//
// {
// milvus::Mapping mapping = {"collection_3",
// 128, // dimension
// 1024, // index file size
// milvus::MetricType::SUPERSTRUCTURE};
//
// JSON json_params = {};
// milvus::IndexParam index_param = {collection_param.collection_name, milvus::IndexType::FLAT,
// json_params.dump()};
//
// TestProcess(connection, collection_param, index_param);
// }
milvus::Connection::Destroy(connection);
}
......@@ -103,53 +103,26 @@ ClientTest::CreateCollection(const std::string& collection_name) {
index_param_2["name"] = "index_2";
field_ptr2->index_params = index_param_2.dump();
field_ptr3->field_name = "field_3";
field_ptr3->field_name = "field_vec";
field_ptr3->field_type = milvus::DataType::FLOAT_VECTOR;
JSON index_param_3;
index_param_3["name"] = "index_3";
index_param_3["index_type"] = "IVFFLAT";
field_ptr3->index_params = index_param_3.dump();
JSON extra_params;
extra_params["dimension"] = COLLECTION_DIMENSION;
field_ptr3->extra_params = extra_params.dump();
JSON extra_params_3;
extra_params_3["dimension"] = COLLECTION_DIMENSION;
field_ptr3->extra_params = extra_params_3.dump();
JSON extra_params;
extra_params["segment_size"] = " ";
milvus::Mapping mapping = {collection_name, {field_ptr1, field_ptr2, field_ptr3}};
milvus::Status stat = conn_->CreateCollection(mapping);
milvus::Status stat = conn_->CreateCollection(mapping, extra_params.dump());
std::cout << "CreateCollection function call status: " << stat.message() << std::endl;
}
void
ClientTest::GetCollectionInfo(const std::string& collection_name) {
// milvus::FieldPtr field_ptr1 = std::make_shared<milvus::Field>();
// milvus::FieldPtr field_ptr2 = std::make_shared<milvus::Field>();
// milvus::FieldPtr field_ptr3 = std::make_shared<milvus::Field>();
// field_ptr1->field_name = "field_1";
// field_ptr1->field_type = milvus::DataType::INT64;
// JSON index_param_1;
// index_param_1["name"] = "index_1";
// field_ptr1->index_params = index_param_1.dump();
//
// field_ptr2->field_name = "field_2";
// field_ptr2->field_type = milvus::DataType::FLOAT;
// JSON index_param_2;
// index_param_2["name"] = "index_2";
// field_ptr2->index_params = index_param_2.dump();
//
// field_ptr3->field_name = "field_3";
// field_ptr3->field_type = milvus::DataType::FLOAT_VECTOR;
// JSON index_param_3;
// index_param_3["name"] = "index_3";
// index_param_3["index_type"] = "IVFFLAT";
// field_ptr3->index_params = index_param_3.dump();
// JSON extra_params;
// extra_params["dimension"] = COLLECTION_DIMENSION;
// field_ptr3->extra_params = extra_params.dump();
//
// milvus::Mapping mapping = {collection_name, {field_ptr1, field_ptr2, field_ptr3}};
//
// milvus::Status stat = conn_->CreateCollection(mapping);
// std::cout << "CreateCollection function call status: " << stat.message() << std::endl;
}
void
......@@ -174,25 +147,11 @@ ClientTest::InsertEntities(const std::string& collection_name, int64_t row_num)
milvus_sdk::Utils::BuildEntities(0, row_num, entity_array, record_ids, COLLECTION_DIMENSION);
}
field_value.vector_value.insert(std::make_pair("field_3", entity_array));
field_value.vector_value.insert(std::make_pair("field_vec", entity_array));
milvus::Status status = conn_->Insert(collection_name, "", field_value, record_ids);
std::cout << "InsertEntities function call status: " << status.message() << std::endl;
}
// void
// ClientTest::BuildSearchEntities(int64_t nq, int64_t dim) {
// search_entity_array_.clear();
// search_id_array_.clear();
// for (int64_t i = 0; i < nq; i++) {
// std::vector<milvus::Entity> entity_array;
// std::vector<int64_t> record_ids;
// int64_t index = i * BATCH_ENTITY_COUNT + SEARCH_TARGET;
// milvus_sdk::Utils::BuildEntities(index, index + 1, entity_array, record_ids, dim);
// search_entity_array_.push_back(std::make_pair(record_ids[0], entity_array[0]));
// search_id_array_.push_back(record_ids[0]);
// }
//}
void
ClientTest::Flush(const std::string& collection_name) {
milvus_sdk::TimeRecorder rc("Flush");
......@@ -282,7 +241,7 @@ ClientTest::CreateIndex(const std::string& collection_name, int64_t nlist) {
milvus_sdk::TimeRecorder rc("Create index");
std::cout << "Wait until create all index done" << std::endl;
JSON json_params = {{"nlist", nlist}, {"index_type", "IVFFLAT"}};
milvus::IndexParam index1 = {collection_name, "field_3", "index_3", json_params.dump()};
milvus::IndexParam index1 = {collection_name, "field_vec", "index_3", json_params.dump()};
milvus_sdk::Utils::PrintIndexParam(index1);
milvus::Status stat = conn_->CreateIndex(index1);
std::cout << "CreateIndex function call status: " << stat.message() << std::endl;
......@@ -371,6 +330,6 @@ ClientTest::Test() {
LoadCollection(collection_name);
SearchEntities(collection_name, TOP_K, NPROBE); // this line get two search error since we delete two entities
DropIndex(collection_name, "field_3", "index_3");
DropIndex(collection_name, "field_vec", "index_3");
DropCollection(collection_name);
}
......@@ -126,9 +126,9 @@ Utils::PrintCollectionParam(const milvus::Mapping& mapping) {
std::cout << "Collection name: " << mapping.collection_name << std::endl;
for (const auto& field : mapping.fields) {
std::cout << "field_name: " << field->field_name;
std::cout << "field_type: " << std::to_string((int)field->field_type);
std::cout << "index_param: " << field->index_params;
std::cout << "extra_param:" << field->extra_params;
std::cout << "\tfield_type: " << std::to_string((int)field->field_type);
std::cout << "\tindex_param: " << field->index_params;
std::cout << "\textra_param:" << field->extra_params << std::endl;
}
BLOCK_SPLITER
}
......@@ -227,24 +227,26 @@ Utils::CheckSearchResult(const std::vector<std::pair<int64_t, milvus::VectorData
void
Utils::DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& collection_name,
const std::vector<std::string>& partition_tags, int64_t top_k, int64_t nprobe,
const std::vector<std::pair<int64_t, milvus::VectorData>>& entity_array,
std::vector<std::pair<int64_t, milvus::VectorData>> entity_array,
milvus::TopKQueryResult& topk_query_result) {
topk_query_result.clear();
nlohmann::json dsl_json, vector_param_json;
GenDSLJson(dsl_json, vector_param_json);
std::vector<milvus::VectorData> temp_entity_array;
for (auto& pair : entity_array) {
temp_entity_array.push_back(pair.second);
}
milvus::VectorParam vector_param = {vector_param_json.dump(), temp_entity_array};
{
BLOCK_SPLITER
JSON json_params = {{"nprobe", nprobe}};
milvus_sdk::TimeRecorder rc("Search");
BLOCK_SPLITER
}
JSON json_params = {{"nprobe", nprobe}};
milvus_sdk::TimeRecorder rc("Search");
auto status = conn->Search(collection_name, partition_tags, dsl_json.dump(), vector_param, topk_query_result);
PrintSearchResult(entity_array, topk_query_result);
CheckSearchResult(entity_array, topk_query_result);
PrintTopKQueryResult(topk_query_result);
// PrintSearchResult(entity_array, topk_query_result);
}
void
......@@ -286,7 +288,7 @@ Utils::GenLeafQuery() {
uint64_t NPROBE = 32;
milvus::VectorQueryPtr vq = std::make_shared<milvus::VectorQuery>();
ConstructVector(NQ, DIMENSION, vq->query_vector);
vq->field_name = "field_3";
vq->field_name = "field_vec";
vq->topk = 10;
JSON json_params = {{"nprobe", NPROBE}};
vq->extra_params = json_params.dump();
......@@ -340,7 +342,7 @@ Utils::GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json) {
query_vector_json["topk"] = topk;
vector_extra_params["nprobe"] = 64;
query_vector_json["params"] = vector_extra_params;
vector_param_json[placeholder]["field_3"] = query_vector_json;
vector_param_json[placeholder]["field_vec"] = query_vector_json;
}
void
......
......@@ -54,8 +54,8 @@ class Utils {
PrintIndexParam(const milvus::IndexParam& index_param);
static void
BuildEntities(int64_t from, int64_t to, std::vector<milvus::VectorData>& entity_array, std::vector<int64_t>& entity_ids,
int64_t dimension);
BuildEntities(int64_t from, int64_t to, std::vector<milvus::VectorData>& entity_array,
std::vector<int64_t>& entity_ids, int64_t dimension);
static void
PrintSearchResult(const std::vector<std::pair<int64_t, milvus::VectorData>>& entity_array,
......@@ -68,7 +68,7 @@ class Utils {
static void
DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& collection_name,
const std::vector<std::string>& partition_tags, int64_t top_k, int64_t nprobe,
const std::vector<std::pair<int64_t, milvus::VectorData>>& entity_array,
std::vector<std::pair<int64_t, milvus::VectorData>> search_entity_array,
milvus::TopKQueryResult& topk_query_result);
static void
......
......@@ -451,7 +451,7 @@ ClientProxy::SetConfig(const std::string& node_name, const std::string& value) c
}
Status
ClientProxy::CreateCollection(const Mapping& mapping) {
ClientProxy::CreateCollection(const Mapping& mapping, const std::string& extra_params) {
try {
::milvus::grpc::Mapping grpc_mapping;
grpc_mapping.set_collection_name(mapping.collection_name);
......@@ -470,6 +470,9 @@ ClientProxy::CreateCollection(const Mapping& mapping) {
grpc_extra_param->set_key(EXTRA_PARAM_KEY);
grpc_extra_param->set_value(field->extra_params);
}
auto grpc_param = grpc_mapping.add_extra_params();
grpc_param->set_key(EXTRA_PARAM_KEY);
grpc_param->set_value(extra_params);
return client_ptr_->CreateCollection(grpc_mapping);
} catch (std::exception& ex) {
......
......@@ -51,7 +51,7 @@ class ClientProxy : public Connection {
SetConfig(const std::string& node_name, const std::string& value) const override;
Status
CreateCollection(const Mapping& mapping) override;
CreateCollection(const Mapping& mapping, const std::string& extra_params) override;
bool
HasCollection(const std::string& collection_name) override;
......
......@@ -282,7 +282,7 @@ class Connection {
* @return Indicate if collection is created successfully
*/
virtual Status
CreateCollection(const Mapping& mapping) = 0;
CreateCollection(const Mapping& mapping, const std::string& extra_params) = 0;
/**
* @brief Test collection existence method
......
......@@ -77,8 +77,8 @@ ConnectionImpl::SetConfig(const std::string& node_name, const std::string& value
}
Status
ConnectionImpl::CreateCollection(const Mapping& mapping) {
return client_proxy_->CreateCollection(mapping);
ConnectionImpl::CreateCollection(const Mapping& mapping, const std::string& extra_params) {
return client_proxy_->CreateCollection(mapping, extra_params);
}
bool
......
......@@ -53,7 +53,7 @@ class ConnectionImpl : public Connection {
SetConfig(const std::string& node_name, const std::string& value) const override;
Status
CreateCollection(const Mapping& mapping) override;
CreateCollection(const Mapping& mapping, const std::string& extra_params) override;
bool
HasCollection(const std::string& collection_name) override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册