未验证 提交 ab8339ce 编写于 作者: J Jin Hai 提交者: GitHub

Merge pull request #1555 from yhmo/master

fix sqlite bug
......@@ -69,7 +69,7 @@ StoragePrototype(const std::string& path) {
make_column("flag", &TableSchema::flag_, default_value(0)),
make_column("index_file_size", &TableSchema::index_file_size_),
make_column("engine_type", &TableSchema::engine_type_),
make_column("index_params", &TableSchema::index_params_, default_value("")),
make_column("index_params", &TableSchema::index_params_),
make_column("metric_type", &TableSchema::metric_type_),
make_column("owner_table", &TableSchema::owner_table_, default_value("")),
make_column("partition_tag", &TableSchema::partition_tag_, default_value("")),
......
......@@ -159,13 +159,6 @@ ClientTest::SearchVectors(const std::string& table_name, int64_t topk, int64_t n
topk_query_result);
}
void
ClientTest::SearchVectorsByIds(const std::string& table_name, int64_t topk, int64_t nprobe) {
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
milvus_sdk::Utils::DoSearch(conn_, table_name, partition_tags, topk, nprobe, search_id_array_, topk_query_result);
}
void
ClientTest::CreateIndex(const std::string& table_name, milvus::IndexType type, int64_t nlist) {
milvus_sdk::TimeRecorder rc("Create index");
......@@ -245,7 +238,6 @@ ClientTest::Test() {
GetVectorById(table_name, search_id_array_[0]);
SearchVectors(table_name, TOP_K, NPROBE);
SearchVectorsByIds(table_name, TOP_K, NPROBE);
CreateIndex(table_name, INDEX_TYPE, NLIST);
ShowTableInfo(table_name);
......
......@@ -29,36 +29,49 @@ class ClientTest {
private:
void
ShowServerVersion();
void
ShowSdkVersion();
void
ShowTables(std::vector<std::string>&);
void
CreateTable(const std::string&, int64_t, milvus::MetricType);
void
DescribeTable(const std::string&);
void
InsertVectors(const std::string&, int64_t);
void
BuildSearchVectors(int64_t, int64_t);
void
Flush(const std::string&);
void
ShowTableInfo(const std::string&);
void
GetVectorById(const std::string&, int64_t);
void
SearchVectors(const std::string&, int64_t, int64_t);
void
SearchVectorsByIds(const std::string&, int64_t, int64_t);
void
CreateIndex(const std::string&, milvus::IndexType, int64_t);
void
PreloadTable(const std::string&);
void
DeleteByIds(const std::string&, const std::vector<int64_t>&);
void
DropIndex(const std::string&);
void
DropTable(const std::string&);
......
......@@ -220,68 +220,6 @@ Utils::DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& tab
CheckSearchResult(search_record_array, topk_query_result);
}
void
Utils::DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& table_name,
const std::vector<std::string>& partition_tags, int64_t top_k, int64_t nprobe,
const std::vector<int64_t>& search_id_array, milvus::TopKQueryResult& topk_query_result) {
topk_query_result.clear();
{
BLOCK_SPLITER
JSON json_params = {{"nprobe", nprobe}};
for (auto& search_id : search_id_array) {
milvus_sdk::TimeRecorder rc("search by id " + std::to_string(search_id));
milvus::TopKQueryResult result;
milvus::Status
stat = conn->SearchByID(table_name, partition_tags, search_id, top_k, json_params.dump(), result);
topk_query_result.insert(topk_query_result.end(), std::make_move_iterator(result.begin()),
std::make_move_iterator(result.end()));
std::cout << "SearchByID function call status: " << stat.message() << std::endl;
}
BLOCK_SPLITER
}
if (topk_query_result.size() != search_id_array.size()) {
std::cout << "ERROR: Returned result count does not equal nq" << std::endl;
return;
}
BLOCK_SPLITER
for (size_t i = 0; i < topk_query_result.size(); i++) {
const milvus::QueryResult& one_result = topk_query_result[i];
size_t topk = one_result.ids.size();
auto search_id = search_id_array[i];
std::cout << "No." << i << " vector " << search_id << " top " << topk << " search result:" << std::endl;
for (size_t j = 0; j < topk; j++) {
std::cout << "\t" << one_result.ids[j] << "\t" << one_result.distances[j] << std::endl;
}
}
BLOCK_SPLITER
BLOCK_SPLITER
size_t nq = topk_query_result.size();
for (size_t i = 0; i < nq; i++) {
const milvus::QueryResult& one_result = topk_query_result[i];
auto search_id = search_id_array[i];
uint64_t match_index = one_result.ids.size();
for (uint64_t index = 0; index < one_result.ids.size(); index++) {
if (search_id == one_result.ids[index]) {
match_index = index;
break;
}
}
if (match_index >= one_result.ids.size()) {
std::cout << "The topk result is wrong: not return search target in result set" << std::endl;
} else {
std::cout << "No." << i << " Check result successfully for target: " << search_id << " at top "
<< match_index << std::endl;
}
}
BLOCK_SPLITER
}
void
PrintPartitionStat(const milvus::PartitionStat& partition_stat) {
std::cout << "\tPartition " << partition_stat.tag << " row count: " << partition_stat.row_count << std::endl;
......
......@@ -70,12 +70,6 @@ class Utils {
const std::vector<std::pair<int64_t, milvus::RowRecord>>& search_record_array,
milvus::TopKQueryResult& topk_query_result);
static void
DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& table_name,
const std::vector<std::string>& partition_tags, int64_t top_k, int64_t nprobe,
const std::vector<int64_t>& search_id_array,
milvus::TopKQueryResult& topk_query_result);
static void
PrintTableInfo(const milvus::TableInfo& info);
};
......
......@@ -314,49 +314,6 @@ ClientProxy::Search(const std::string& table_name, const std::vector<std::string
}
}
Status
ClientProxy::SearchByID(const std::string& table_name,
const std::vector<std::string>& partition_tag_array,
int64_t query_id,
int64_t topk,
const std::string& extra_params,
TopKQueryResult& topk_query_result) {
try {
// step 1: convert vector id array
::milvus::grpc::SearchByIDParam search_param;
ConstructSearchParam(table_name,
partition_tag_array,
topk,
extra_params,
search_param);
search_param.set_id(query_id);
// step 2: search vectors
::milvus::grpc::TopKQueryResult result;
Status status = client_ptr_->SearchByID(search_param, result);
if (result.row_num() == 0) {
return status;
}
// step 4: convert result array
topk_query_result.reserve(result.row_num());
int64_t nq = result.row_num();
int64_t topk = result.ids().size() / nq;
for (int64_t i = 0; i < result.row_num(); i++) {
milvus::QueryResult one_result;
one_result.ids.resize(topk);
one_result.distances.resize(topk);
memcpy(one_result.ids.data(), result.ids().data() + topk * i, topk * sizeof(int64_t));
memcpy(one_result.distances.data(), result.distances().data() + topk * i, topk * sizeof(float));
topk_query_result.emplace_back(one_result);
}
return status;
} catch (std::exception& ex) {
return Status(StatusCode::UnknownError, "Failed to search vectors: " + std::string(ex.what()));
}
}
Status
ClientProxy::DescribeTable(const std::string& table_name, TableSchema& table_schema) {
try {
......
......@@ -63,11 +63,6 @@ class ClientProxy : public Connection {
const std::vector<RowRecord>& query_record_array, int64_t topk, const std::string& extra_params,
TopKQueryResult& topk_query_result) override;
Status
SearchByID(const std::string& table_name, const std::vector<std::string>& partition_tag_array,
int64_t query_id, int64_t topk,
const std::string& extra_params, TopKQueryResult& topk_query_result) override;
Status
DescribeTable(const std::string& table_name, TableSchema& table_schema) override;
......
......@@ -178,26 +178,6 @@ GrpcClient::Search(
return Status::OK();
}
Status
GrpcClient::SearchByID(const ::milvus::grpc::SearchByIDParam& search_param,
::milvus::grpc::TopKQueryResult& topk_query_result) {
::milvus::grpc::TopKQueryResult query_result;
ClientContext context;
::grpc::Status grpc_status = stub_->SearchByID(&context, search_param, &topk_query_result);
if (!grpc_status.ok()) {
std::cerr << "SearchByID rpc failed!" << std::endl;
std::cerr << grpc_status.error_message() << std::endl;
return Status(StatusCode::RPCFailed, grpc_status.error_message());
}
if (topk_query_result.status().error_code() != grpc::SUCCESS) {
std::cerr << topk_query_result.status().reason() << std::endl;
return Status(StatusCode::ServerFailed, topk_query_result.status().reason());
}
return Status::OK();
}
Status
GrpcClient::DescribeTable(const std::string& table_name, ::milvus::grpc::TableSchema& grpc_schema) {
ClientContext context;
......
......@@ -59,9 +59,6 @@ class GrpcClient {
Status
Search(const grpc::SearchParam& search_param, ::milvus::grpc::TopKQueryResult& topk_query_result);
Status
SearchByID(const grpc::SearchByIDParam& search_param, ::milvus::grpc::TopKQueryResult& topk_query_result);
Status
DescribeTable(const std::string& table_name, grpc::TableSchema& grpc_schema);
......
......@@ -334,24 +334,6 @@ class Connection {
const std::vector<RowRecord>& query_record_array, int64_t topk,
const std::string& extra_params, TopKQueryResult& topk_query_result) = 0;
/**
* @brief Search vector by ID
*
* This method is used to query vector in table.
*
* @param table_name, target table's name.
* @param partition_tag_array, target partitions, keep empty if no partition.
* @param query_id, vector id to be queried.
* @param topk, how many similarity vectors will be returned.
* @param extra_params, extra search parameters according to different index type, must be json format.
* @param topk_query_result, result array.
*
* @return Indicate if query is successful.
*/
virtual Status
SearchByID(const std::string& table_name, const PartitionTagList& partition_tag_array, int64_t query_id,
int64_t topk, const std::string& extra_params, TopKQueryResult& topk_query_result) = 0;
/**
* @brief Show table description
*
......
......@@ -100,16 +100,6 @@ ConnectionImpl::Search(const std::string& table_name, const std::vector<std::str
return client_proxy_->Search(table_name, partition_tags, query_record_array, topk, extra_params, topk_query_result);
}
Status
ConnectionImpl::SearchByID(const std::string& table_name,
const std::vector<std::string>& partition_tags,
int64_t query_id,
int64_t topk,
const std::string& extra_params,
TopKQueryResult& topk_query_result) {
return client_proxy_->SearchByID(table_name, partition_tags, query_id, topk, extra_params, topk_query_result);
}
Status
ConnectionImpl::DescribeTable(const std::string& table_name, TableSchema& table_schema) {
return client_proxy_->DescribeTable(table_name, table_schema);
......
......@@ -65,10 +65,6 @@ class ConnectionImpl : public Connection {
const std::vector<RowRecord>& query_record_array, int64_t topk,
const std::string& extra_params, TopKQueryResult& topk_query_result) override;
Status
SearchByID(const std::string& table_name, const std::vector<std::string>& partition_tag_array, int64_t query_id,
int64_t topk, const std::string& extra_params, TopKQueryResult& topk_query_result) override;
Status
DescribeTable(const std::string& table_name, TableSchema& table_schema) override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册