提交 5acc9681 编写于 作者: B bigsheeper 提交者: yefu.chen

Fix search error about metric type

Signed-off-by: Nbigsheeper <yihao.dai@zilliz.com>
上级 d599407e
......@@ -75,6 +75,12 @@ GetTopK(CPlan plan) {
return res;
}
const char*
GetMetricType(CPlan plan) {
auto query_plan = static_cast<milvus::query::Plan*>(plan);
return strdup(query_plan->plan_node_->query_info_.metric_type_.c_str());
}
void
DeletePlan(CPlan cPlan) {
auto plan = (milvus::query::Plan*)cPlan;
......
......@@ -35,6 +35,9 @@ GetNumOfQueries(CPlaceholderGroup placeholder_group);
int64_t
GetTopK(CPlan plan);
const char*
GetMetricType(CPlan plan);
void
DeletePlan(CPlan plan);
......
......@@ -64,6 +64,11 @@ struct SearchResultPair {
return (distance_ < pair.distance_);
}
bool
operator>(const SearchResultPair& pair) const {
return (distance_ > pair.distance_);
}
void
reset_distance() {
distance_ = search_result_->result_distances_[offset_];
......@@ -89,7 +94,7 @@ GetResultData(std::vector<std::vector<int64_t>>& search_records,
AssertInfo(topk > 0, "topK must greater than 0");
for (int i = 0; i < topk; ++i) {
result_pairs[0].reset_distance();
std::sort(result_pairs.begin(), result_pairs.end());
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
auto& result_pair = result_pairs[0];
auto index = result_pair.index_;
is_selected[index] = true;
......
......@@ -140,6 +140,11 @@ Search(CSegmentBase c_segment,
auto status = CStatus();
try {
auto res = segment->Search(plan, placeholder_groups.data(), timestamps, num_groups, *query_result);
if (plan->plan_node_->query_info_.metric_type_ != "IP") {
for (auto& dis : query_result->result_distances_) {
dis *= -1;
}
}
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
......
......@@ -259,6 +259,7 @@ message SearchResult {
uint64 timestamp = 6;
int64 result_channelID = 7;
repeated bytes hits = 8;
string metric_type = 9;
}
message TimeTickMsg {
......
......@@ -487,6 +487,7 @@ func (qt *QueryTask) PostExecute() error {
Hits: make([][]byte, 0),
}
const minFloat32 = -1 * float32(math.MaxFloat32)
for i := 0; i < nq; i++ {
locs := make([]int, availableQueryNodeNum)
reducedHits := &servicepb.Hits{
......@@ -496,18 +497,18 @@ func (qt *QueryTask) PostExecute() error {
}
for j := 0; j < topk; j++ {
choice, minDistance := 0, float32(math.MaxFloat32)
choice, maxDistance := 0, minFloat32
for q, loc := range locs { // query num, the number of ways to merge
distance := hits[q][i].Scores[loc]
if distance < minDistance {
if distance > maxDistance {
choice = q
minDistance = distance
maxDistance = distance
}
}
choiceOffset := locs[choice]
// check if distance is valid, `invalid` here means very very big,
// in this process, distance here is the smallest, so the rest of distance are all invalid
if hits[choice][i].Scores[choiceOffset] >= float32(math.MaxFloat32) {
if hits[choice][i].Scores[choiceOffset] <= minFloat32 {
break
}
reducedHits.IDs = append(reducedHits.IDs, hits[choice][i].IDs[choiceOffset])
......@@ -517,6 +518,11 @@ func (qt *QueryTask) PostExecute() error {
reducedHits.Scores = append(reducedHits.Scores, hits[choice][i].Scores[choiceOffset])
locs[choice]++
}
if searchResults[0].MetricType != "IP" {
for k := range reducedHits.Scores {
reducedHits.Scores[k] *= -1
}
}
reducedHitsBs, err := proto.Marshal(reducedHits)
if err != nil {
log.Println("marshal error")
......
......@@ -41,6 +41,13 @@ func (plan *Plan) getTopK() int64 {
return int64(topK)
}
func (plan *Plan) getMetricType() string {
cMetricType := C.GetMetricType(plan.cPlan)
defer C.free(unsafe.Pointer(cMetricType))
metricType := C.GoString(cMetricType)
return metricType
}
func (plan *Plan) delete() {
C.DeletePlan(plan.cPlan)
}
......
......@@ -27,6 +27,8 @@ func TestPlan_Plan(t *testing.T) {
assert.NotEqual(t, plan, nil)
topk := plan.getTopK()
assert.Equal(t, int(topk), 10)
metricType := plan.getMetricType()
assert.Equal(t, metricType, "L2")
plan.delete()
deleteCollection(collection)
}
......
......@@ -336,6 +336,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
Timestamp: searchTimestamp,
ResultChannelID: searchMsg.ResultChannelID,
Hits: hits,
MetricType: plan.getMetricType(),
}
searchResultMsg := &msgstream.SearchResultMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册