提交 54cd9016 编写于 作者: F fishpenguin

Fix C++ sdk

Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>
上级 a78980e4
......@@ -39,16 +39,6 @@ constexpr milvus::IndexType INDEX_TYPE = milvus::IndexType::IVFSQ8;
constexpr int32_t NLIST = 16384;
constexpr uint64_t FIELD_NUM = 3;
void
PrintHybridQueryResult(const std::vector<int64_t>& id_array, const milvus::HybridQueryResult& result) {
for (size_t i = 0; i < id_array.size(); i++) {
std::string prefix = "No." + std::to_string(i) + " id:" + std::to_string(id_array[i]);
std::cout << prefix << "\t[";
for (size_t j = 0; j < result.attr_records.size(); i++) {
}
}
}
} // namespace
ClientTest::ClientTest(const std::string& address, const std::string& port) {
......@@ -108,38 +98,33 @@ void
ClientTest::Insert(std::string& collection_name, int64_t row_num) {
milvus::FieldValue field_value;
std::unordered_map<std::string, std::vector<int64_t>> int64_value;
std::unordered_map<std::string, std::vector<float>> float_value;
std::vector<int64_t> value1;
std::vector<double> value2;
std::vector<float> value2;
value1.resize(row_num);
value2.resize(row_num);
for (uint64_t i = 0; i < row_num; ++i) {
value1[i] = i;
value2[i] = (double)(i + row_num);
value2[i] = (float)(i + row_num);
}
field_value.int64_value.insert(std::make_pair("field_1", value1));
field_value.float_value.insert(std::make_pair("field_2", value2));
numerica_int_value.insert(std::make_pair("field_1", value1));
numerica_double_value.insert(std::make_pair("field_2", value2));
std::unordered_map<std::string, std::vector<milvus::Entity>> vector_value;
std::vector<milvus::Entity> entity_array;
std::unordered_map<std::string, std::vector<milvus::VectorData>> vector_value;
std::vector<milvus::VectorData> entity_array;
std::vector<int64_t> record_ids;
{ // generate vectors
milvus_sdk::Utils::BuildEntities(0, row_num, entity_array, record_ids, 128);
}
vector_value.insert(std::make_pair("field_3", entity_array));
milvus::HEntity entity = {row_num, numerica_int_value, numerica_double_value, vector_value};
std::vector<uint64_t> id_array;
milvus::Status status = conn_->InsertEntity(collection_name, "", entity, id_array);
field_value.vector_value.insert(std::make_pair("field_3", entity_array));
milvus::Status status = conn_->Insert(collection_name, "", field_value, record_ids);
std::cout << "InsertHybridEntities function call status: " << status.message() << std::endl;
}
void
ClientTest::HybridSearchPB(std::string& collection_name) {
ClientTest::SearchPB(std::string& collection_name) {
std::vector<std::string> partition_tags;
milvus::TopKHybridQueryResult topk_query_result;
milvus::TopKQueryResult topk_query_result;
auto leaf_queries = milvus_sdk::Utils::GenLeafQuery();
......@@ -154,51 +139,52 @@ ClientTest::HybridSearchPB(std::string& collection_name) {
std::string extra_params;
milvus::Status status =
conn_->HybridSearchPB(collection_name, partition_tags, query_clause, extra_params, topk_query_result);
conn_->SearchPB(collection_name, partition_tags, query_clause, extra_params, topk_query_result);
milvus_sdk::Utils::PrintTopKHybridQueryResult(topk_query_result);
std::cout << "HybridSearch function call status: " << status.message() << std::endl;
}
void
ClientTest::HybridSearch(std::string& collection_name) {
ClientTest::Search(std::string& collection_name) {
nlohmann::json dsl_json, vector_param_json;
milvus_sdk::Utils::GenDSLJson(dsl_json, vector_param_json);
std::vector<milvus::Entity> entity_array;
std::vector<milvus::VectorData> entity_array;
std::vector<int64_t> record_ids;
{ // generate vectors
milvus_sdk::Utils::ConstructVector(NQ, COLLECTION_DIMENSION, entity_array);
}
milvus::VectorParam vector_param = {vector_param_json.dump(), entity_array};
std::vector<std::string> partition_tags;
milvus::TopKHybridQueryResult topk_query_result;
auto status = conn_->HybridSearch(collection_name, partition_tags, dsl_json.dump(), vector_param_json.dump(),
entity_array, topk_query_result);
milvus::TopKQueryResult topk_query_result;
auto status = conn_->Search(collection_name, partition_tags, dsl_json.dump(), vector_param, topk_query_result);
milvus_sdk::Utils::PrintTopKHybridQueryResult(topk_query_result);
std::cout << "HybridSearch function call status: " << status.message() << std::endl;
}
void
ClientTest::GetHEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array) {
milvus::HybridQueryResult result;
ClientTest::GetEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array) {
std::string result;
{
milvus_sdk::TimeRecorder rc("GetHybridEntityByID");
milvus::Status stat = conn_->GetHEntityByID(collection_name, id_array, result);
milvus::Status stat = conn_->GetEntityByID(collection_name, id_array, result);
std::cout << "GetEntitiesByID function call status: " << stat.message() << std::endl;
}
PrintHybridQueryResult(id_array, result);
std::cout << "GetEntityByID function result: " << result;
}
void
ClientTest::TestHybrid() {
std::string collection_name = "HYBRID_TEST";
CreateHybridCollection(collection_name);
InsertHybridEntities(collection_name, 10000);
CreateCollection(collection_name);
Insert(collection_name, 10000);
Flush(collection_name);
// sleep(2);
// sleep(2);
// HybridSearchPB(collection_name);
HybridSearch(collection_name);
Search(collection_name);
}
......@@ -37,13 +37,13 @@ class ClientTest {
Insert(std::string&, int64_t);
void
HybridSearchPB(std::string&);
SearchPB(std::string&);
void
HybridSearch(std::string&);
Search(std::string&);
void
GetHEntityByID(const std::string&, const std::vector<int64_t>&);
GetEntityByID(const std::string&, const std::vector<int64_t>&);
private:
std::shared_ptr<milvus::Connection> conn_;
......
......@@ -9,11 +9,11 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "include/MilvusApi.h"
#include "include/BooleanQuery.h"
#include "examples/simple/src/ClientTest.h"
#include "examples/utils/TimeRecorder.h"
#include "examples/utils/Utils.h"
#include "examples/simple/src/ClientTest.h"
#include "include/BooleanQuery.h"
#include "include/MilvusApi.h"
#include <iostream>
#include <memory>
......@@ -90,7 +90,9 @@ void
ClientTest::CreateCollection(const std::string& collection_name, int64_t dim, milvus::MetricType type) {
std::vector<milvus::FieldPtr> fields;
milvus::Mapping mapping = {collection_name,};
milvus::Mapping mapping = {
collection_name,
};
milvus::CollectionParam collection_param = {collection_name, dim, COLLECTION_INDEX_FILE_SIZE, type};
milvus::Status stat = conn_->CreateCollection(collection_param);
......@@ -105,33 +107,62 @@ ClientTest::CreateCollection(const std::string& collection_name, int64_t dim, mi
void
ClientTest::GetCollectionInfo(const std::string& collection_name) {
milvus::CollectionParam collection_param;
milvus::Status stat = conn_->GetCollectionInfo(collection_name, collection_param);
std::cout << "DescribeCollection function call status: " << stat.message() << std::endl;
milvus_sdk::Utils::PrintCollectionParam(collection_param);
milvus::FieldPtr field_ptr1 = std::make_shared<milvus::Field>();
milvus::FieldPtr field_ptr2 = std::make_shared<milvus::Field>();
milvus::FieldPtr field_ptr3 = std::make_shared<milvus::Field>();
field_ptr1->field_name = "field_1";
field_ptr1->field_type = milvus::DataType::INT64;
JSON index_param_1;
index_param_1["name"] = "index_1";
field_ptr1->index_params = index_param_1.dump();
field_ptr2->field_name = "field_2";
field_ptr2->field_type = milvus::DataType::FLOAT;
JSON index_param_2;
index_param_2["name"] = "index_2";
field_ptr2->index_params = index_param_2.dump();
field_ptr3->field_name = "field_3";
field_ptr3->field_type = milvus::DataType::FLOAT_VECTOR;
JSON index_param_3;
index_param_3["name"] = "index_3";
index_param_3["index_type"] = "IVFFLAT";
field_ptr3->index_params = index_param_3;
JSON extra_params;
extra_params["dimension"] = COLLECTION_DIMENSION;
field_ptr3->extram_params = extra_params.dump();
milvus::Mapping mapping = {collection_name, {field_ptr1, field_ptr2, field_ptr3}};
milvus::Status stat = conn_->CreateCollection(mapping);
std::cout << "CreateCollection function call status: " << stat.message() << std::endl;
}
void
ClientTest::InsertEntities(const std::string& collection_name, int64_t dim) {
for (int i = 0; i < ADD_ENTITY_LOOP; i++) {
std::vector<milvus::Entity> entity_array;
std::vector<int64_t> record_ids;
int64_t begin_index = i * BATCH_ENTITY_COUNT;
{ // generate vectors
milvus_sdk::TimeRecorder rc("Build entities No." + std::to_string(i));
milvus_sdk::Utils::BuildEntities(begin_index,
begin_index + BATCH_ENTITY_COUNT,
entity_array,
record_ids,
dim);
}
std::string title = "Insert " + std::to_string(entity_array.size()) + " entities No." + std::to_string(i);
milvus_sdk::TimeRecorder rc(title);
milvus::Status stat = conn_->Insert(collection_name, "", entity_array, record_ids);
std::cout << "InsertEntities function call status: " << stat.message() << std::endl;
std::cout << "Returned id array count: " << record_ids.size() << std::endl;
ClientTest::InsertEntities(const std::string& collection_name, int64_t row_num) {
milvus::FieldValue field_value;
std::vector<int64_t> value1;
std::vector<float> value2;
value1.resize(row_num);
value2.resize(row_num);
for (uint64_t i = 0; i < row_num; ++i) {
value1[i] = i;
value2[i] = (float)(i + row_num);
}
field_value.int64_value.insert(std::make_pair("field_1", value1));
field_value.float_value.insert(std::make_pair("field_2", value2));
std::unordered_map<std::string, std::vector<milvus::VectorData>> vector_value;
std::vector<milvus::VectorData> entity_array;
std::vector<int64_t> record_ids;
{ // generate vectors
milvus_sdk::Utils::BuildEntities(0, row_num, entity_array, record_ids, 128);
}
field_value.vector_value.insert(std::make_pair("field_3", entity_array));
milvus::Status status = conn_->Insert(collection_name, "", field_value, record_ids);
std::cout << "InsertHybridEntities function call status: " << status.message() << std::endl;
}
void
......@@ -166,25 +197,35 @@ ClientTest::GetCollectionStats(const std::string& collection_name) {
void
ClientTest::GetEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array) {
std::vector<milvus::Entity> entities;
std::string result;
{
milvus_sdk::TimeRecorder rc("GetEntityByID");
milvus::Status stat = conn_->GetEntityByID(collection_name, id_array, entities);
std::cout << "GetEntityByID function call status: " << stat.message() << std::endl;
milvus_sdk::TimeRecorder rc("GetHybridEntityByID");
milvus::Status stat = conn_->GetEntityByID(collection_name, id_array, result);
std::cout << "GetEntitiesByID function call status: " << stat.message() << std::endl;
}
for (size_t i = 0; i < entities.size(); i++) {
std::string prefix = "No." + std::to_string(i) + " id:" + std::to_string(id_array[i]);
PrintEntity(prefix, entities[i]);
}
std::cout << "GetEntityByID function result: " << result;
}
void
ClientTest::SearchEntities(const std::string& collection_name, int64_t topk, int64_t nprobe) {
nlohmann::json dsl_json, vector_param_json;
milvus_sdk::Utils::GenDSLJson(dsl_json, vector_param_json);
std::vector<milvus::VectorData> entity_array;
std::vector<int64_t> record_ids;
{ // generate vectors
milvus_sdk::Utils::ConstructVector(NQ, COLLECTION_DIMENSION, entity_array);
}
milvus::VectorParam vector_param = {vector_param_json.dump(), entity_array};
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
milvus_sdk::Utils::DoSearch(conn_, collection_name, partition_tags, topk, nprobe, search_entity_array_,
topk_query_result);
auto status = conn_->Search(collection_name, partition_tags, dsl_json.dump(), vector_param, topk_query_result);
milvus_sdk::Utils::PrintTopKHybridQueryResult(topk_query_result);
std::cout << "HybridSearch function call status: " << status.message() << std::endl;
}
void
......@@ -205,12 +246,7 @@ ClientTest::SearchEntitiesByID(const std::string& collection_name, int64_t topk,
JSON json_params = {{"nprobe", nprobe}};
milvus_sdk::TimeRecorder rc("Search");
stat = conn_->Search(collection_name,
partition_tags,
entities,
topk,
json_params.dump(),
topk_query_result);
stat = conn_->Search(collection_name, partition_tags, entities, topk, json_params.dump(), topk_query_result);
std::cout << "Search function call status: " << stat.message() << std::endl;
if (topk_query_result.size() != id_array.size()) {
......@@ -308,7 +344,7 @@ ClientTest::Test() {
BuildSearchEntities(NQ, dim);
GetEntityByID(collection_name, search_id_array_);
// SearchEntities(collection_name, TOP_K, NPROBE);
// SearchEntities(collection_name, TOP_K, NPROBE);
SearchEntitiesByID(collection_name, TOP_K, NPROBE);
CreateIndex(collection_name, INDEX_TYPE, NLIST);
......@@ -319,7 +355,7 @@ ClientTest::Test() {
CompactCollection(collection_name);
LoadCollection(collection_name);
SearchEntities(collection_name, TOP_K, NPROBE); // this line get two search error since we delete two entities
SearchEntities(collection_name, TOP_K, NPROBE); // this line get two search error since we delete two entities
DropIndex(collection_name);
DropCollection(collection_name);
......
......@@ -37,7 +37,7 @@ class ClientTest {
ShowCollections(std::vector<std::string>&);
void
CreateCollection(const std::string&, int64_t, milvus::MetricType);
CreateCollection(const std::string& collection_name);
void
GetCollectionInfo(const std::string&);
......
......@@ -148,7 +148,7 @@ Utils::PrintIndexParam(const milvus::IndexParam& index_param) {
}
void
Utils::BuildEntities(int64_t from, int64_t to, std::vector<milvus::Entity>& entity_array,
Utils::BuildEntities(int64_t from, int64_t to, std::vector<milvus::VectorData>& entity_array,
std::vector<int64_t>& entity_ids, int64_t dimension) {
if (to <= from) {
return;
......@@ -159,13 +159,13 @@ Utils::BuildEntities(int64_t from, int64_t to, std::vector<milvus::Entity>& enti
std::default_random_engine e;
std::uniform_real_distribution<float> u(0, 1);
for (int64_t k = from; k < to; k++) {
milvus::Entity entity;
entity.float_data.resize(dimension);
milvus::VectorData vector_data;
vector_data.float_data.resize(dimension);
for (int64_t i = 0; i < dimension; i++) {
entity.float_data[i] = (float)((k + 100) % (i + 1));
vector_data.float_data[i] = (float)((k + 100) % (i + 1));
}
entity_array.emplace_back(entity);
entity_array.emplace_back(vector_data);
entity_ids.push_back(k);
}
}
......
......@@ -54,7 +54,7 @@ class Utils {
PrintIndexParam(const milvus::IndexParam& index_param);
static void
BuildEntities(int64_t from, int64_t to, std::vector<milvus::Entity>& entity_array, std::vector<int64_t>& entity_ids,
BuildEntities(int64_t from, int64_t to, std::vector<milvus::VectorData>& entity_array, std::vector<int64_t>& entity_ids,
int64_t dimension);
static void
......
......@@ -402,7 +402,7 @@ class Connection {
virtual Status
SearchPB(const std::string& collection_name, const std::vector<std::string>& partition_list,
BooleanQueryPtr& boolean_query, TopKQueryResult& query_result) = 0;
BooleanQueryPtr& boolean_query, const std::string& extra_params, TopKQueryResult& query_result) = 0;
/**
* @brief Get collection information
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册