未验证 提交 92e2a26a 编写于 作者: Y yukun 提交者: GitHub

Fix test_query_range_valid_ranges bug (#3224)

* Fix TestSearchDSL level 2 bugs
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix QueryTest
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add annotation in milvus.proto
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix CreateIndex in C++ sdk
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix C++ sdk range test
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix test_search_ip_index_partitions
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix test_query_range_valid_ranges
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix GetCollectionInfo
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>
Co-authored-by: Nquicksilver <zhifeng.zhang@zilliz.com>
上级 3735e3d1
......@@ -462,32 +462,35 @@ ExecutionEngineImpl::IndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, const
segment_reader_->GetSegment(segment_ptr);
knowhere::IndexPtr index_ptr = nullptr;
auto attr_index = segment_ptr->GetStructuredIndex(field_name, index_ptr);
if (!index_ptr) {
return Status(DB_ERROR, "Get field: " + field_name + " structured index failed");
}
switch (data_type) {
case DataType::INT8: {
ProcessIndexedTermQuery<int8_t>(bitset, index_ptr, term_values_json);
STATUS_CHECK(ProcessIndexedTermQuery<int8_t>(bitset, index_ptr, term_values_json));
break;
}
case DataType::INT16: {
ProcessIndexedTermQuery<int16_t>(bitset, index_ptr, term_values_json);
STATUS_CHECK(ProcessIndexedTermQuery<int16_t>(bitset, index_ptr, term_values_json));
break;
}
case DataType::INT32: {
ProcessIndexedTermQuery<int32_t>(bitset, index_ptr, term_values_json);
STATUS_CHECK(ProcessIndexedTermQuery<int32_t>(bitset, index_ptr, term_values_json));
break;
}
case DataType::INT64: {
ProcessIndexedTermQuery<int64_t>(bitset, index_ptr, term_values_json);
STATUS_CHECK(ProcessIndexedTermQuery<int64_t>(bitset, index_ptr, term_values_json));
break;
}
case DataType::FLOAT: {
ProcessIndexedTermQuery<float>(bitset, index_ptr, term_values_json);
STATUS_CHECK(ProcessIndexedTermQuery<float>(bitset, index_ptr, term_values_json));
break;
}
case DataType::DOUBLE: {
ProcessIndexedTermQuery<double>(bitset, index_ptr, term_values_json);
STATUS_CHECK(ProcessIndexedTermQuery<double>(bitset, index_ptr, term_values_json));
break;
}
default: { return Status{SERVER_INVALID_ARGUMENT, "Attribute:" + field_name + " type is wrong"}; }
default: { return Status(SERVER_INVALID_ARGUMENT, "Attribute:" + field_name + " type is wrong"); }
}
return Status::OK();
}
......@@ -544,27 +547,27 @@ ExecutionEngineImpl::IndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, const
auto status = Status::OK();
switch (data_type) {
case DataType::INT8: {
ProcessIndexedRangeQuery<int8_t>(bitset, index_ptr, range_values_json);
STATUS_CHECK(ProcessIndexedRangeQuery<int8_t>(bitset, index_ptr, range_values_json));
break;
}
case DataType::INT16: {
ProcessIndexedRangeQuery<int16_t>(bitset, index_ptr, range_values_json);
STATUS_CHECK(ProcessIndexedRangeQuery<int16_t>(bitset, index_ptr, range_values_json));
break;
}
case DataType::INT32: {
ProcessIndexedRangeQuery<int32_t>(bitset, index_ptr, range_values_json);
STATUS_CHECK(ProcessIndexedRangeQuery<int32_t>(bitset, index_ptr, range_values_json));
break;
}
case DataType::INT64: {
ProcessIndexedRangeQuery<int64_t>(bitset, index_ptr, range_values_json);
STATUS_CHECK(ProcessIndexedRangeQuery<int64_t>(bitset, index_ptr, range_values_json));
break;
}
case DataType::FLOAT: {
ProcessIndexedRangeQuery<float>(bitset, index_ptr, range_values_json);
STATUS_CHECK(ProcessIndexedRangeQuery<float>(bitset, index_ptr, range_values_json));
break;
}
case DataType::DOUBLE: {
ProcessIndexedRangeQuery<double>(bitset, index_ptr, range_values_json);
STATUS_CHECK(ProcessIndexedRangeQuery<double>(bitset, index_ptr, range_values_json));
break;
}
default:
......
......@@ -105,7 +105,7 @@ ConcurrentBitset::operator|=(ConcurrentBitset& bitset) {
size_t n64 = n8 / 8;
for (size_t i = 0; i < n64; i++) {
u64_1[i] &= u64_2[i];
u64_1[i] |= u64_2[i];
}
size_t remain = n8 % 8;
......@@ -134,7 +134,7 @@ ConcurrentBitset::operator|(const std::shared_ptr<ConcurrentBitset>& bitset) {
size_t n64 = n8 / 8;
for (size_t i = 0; i < n64; i++) {
result_64[i] = u64_1[i] & u64_2[i];
result_64[i] = u64_1[i] | u64_2[i];
}
size_t remain = n8 % 8;
......@@ -150,7 +150,7 @@ ConcurrentBitset::operator|(const std::shared_ptr<ConcurrentBitset>& bitset) {
ConcurrentBitset&
ConcurrentBitset::operator^=(ConcurrentBitset& bitset) {
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
// bitset_[i].fetch_xor(bitset.bitset()[i].load());
// }
......
......@@ -51,6 +51,10 @@ Job::TaskDone(Task* task) {
return;
}
auto json = task->Dump();
std::string task_desc = json.dump();
LOG_SERVER_DEBUG_ << LogOut("scheduler job [%ld] task %s finish", id(), task_desc.c_str());
std::unique_lock<std::mutex> lock(mutex_);
for (JobTasks::iterator iter = tasks_.begin(); iter != tasks_.end(); ++iter) {
if (task == (*iter).get()) {
......@@ -61,10 +65,6 @@ Job::TaskDone(Task* task) {
if (tasks_.empty()) {
cv_.notify_all();
}
auto json = task->Dump();
std::string task_desc = json.dump();
LOG_SERVER_DEBUG_ << LogOut("scheduler job [%ld] task %s finish", id(), task_desc.c_str());
}
void
......
......@@ -24,7 +24,7 @@ Task::Load(LoadType type, uint8_t device_id) {
if (job_) {
if (!status.ok()) {
job_->status() = status;
job_->TaskDone(this);
// job_->TaskDone(this);
}
} else {
LOG_ENGINE_ERROR_ << "Scheduler task's parent job not specified!";
......
......@@ -52,20 +52,22 @@ GetCollectionInfoReq::OnExecute() {
for (auto& field_kv : field_mappings) {
auto field = field_kv.first;
FieldSchema field_schema;
milvus::json field_index_param;
auto field_elements = field_kv.second;
for (const auto& element : field_elements) {
if (element->GetFtype() == (engine::snapshot::FTYPE_TYPE)engine::FieldElementType::FET_INDEX) {
field_index_param = element->GetParams();
auto type = element->GetTypeName();
field_schema.index_params_ = field_index_param;
field_schema.index_params_[engine::PARAM_INDEX_TYPE] = element->GetTypeName();
break;
}
}
auto field_name = field->GetName();
FieldSchema field_schema;
field_schema.field_type_ = (engine::DataType)field->GetFtype();
field_schema.field_params_ = field->GetParams();
field_schema.index_params_ = field_index_param;
collection_schema_.fields_.insert(std::make_pair(field_name, field_schema));
}
......
......@@ -1038,7 +1038,11 @@ GrpcRequestHandler::DescribeCollection(::grpc::ServerContext* context, const ::m
for (auto& item : field_schema.index_params_.items()) {
auto grpc_index_param = field->add_index_params();
grpc_index_param->set_key(item.key());
grpc_index_param->set_value(item.value());
if (item.value().is_object()) {
grpc_index_param->set_value(item.value().dump());
} else {
grpc_index_param->set_value(item.value());
}
}
}
......
......@@ -326,7 +326,7 @@ ClientTest::Test() {
ListCollections(table_array);
CreateCollection(collection_name);
GetCollectionInfo(collection_name);
// GetCollectionInfo(collection_name);
GetCollectionStats(collection_name);
ListCollections(table_array);
......@@ -336,6 +336,7 @@ ClientTest::Test() {
Flush(collection_name);
CountEntities(collection_name);
CreateIndex(collection_name, 1024);
GetCollectionInfo(collection_name);
// GetCollectionStats(collection_name);
//
BuildVectors(NQ, COLLECTION_DIMENSION);
......
......@@ -553,7 +553,7 @@ class TestSearchBase:
if min_distance > tmp_dis:
min_distance = tmp_dis
res = connect.search(collection, query)
assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= gen_inaccuracy(res[0]._distances[0])
assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= epsilon
# TODO
@pytest.mark.level(2)
......@@ -1133,7 +1133,7 @@ class TestSearchDSL(object):
return request.param
# TODO:
def _test_query_range_valid_ranges(self, connect, collection, get_valid_ranges):
def test_query_range_valid_ranges(self, connect, collection, get_valid_ranges):
'''
method: build query with valid ranges
expected: pass
......
......@@ -44,7 +44,7 @@ default_index_params = [
{"nlist": 1024, "m": 16},
{"M": 48, "efConstruction": 500},
# {"search_length": 50, "out_degree": 40, "candidate_pool_size": 100, "knng": 50},
{"n_trees": 4},
{"n_trees": 50},
{"nlist": 1024},
{"nlist": 1024}
]
......@@ -314,7 +314,7 @@ def gen_default_term_expr(keyword="term", values=None):
def gen_default_range_expr(keyword="range", ranges=None):
if ranges is None:
ranges = {"GT": 1, "LT": nb // 2}
expr = {keyword: {"int64": {"ranges": ranges}}}
expr = {keyword: {"int64": ranges}}
return expr
......@@ -341,7 +341,7 @@ def gen_invalid_ranges():
def gen_valid_ranges():
ranges = [
{"GT": 0, "LT": nb//2},
{"GT": nb, "LT": nb*2},
{"GT": nb // 2, "LT": nb*2},
{"GT": 0},
{"LT": nb},
{"GT": -1, "LT": top_k},
......@@ -766,7 +766,7 @@ def get_search_param(index_type):
elif index_type == "NSG":
search_params.update({"search_length": 100})
elif index_type == "ANNOY":
search_params.update({"search_k": 100})
search_params.update({"search_k": 1000})
else:
logging.getLogger().error("Invalid index_type.")
raise Exception("Invalid index_type.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册