未验证 提交 cbb4052f 编写于 作者: C Cai Yudong 提交者: GitHub

Support search and query output fields using wildcard (#6671)

* update wildcard polocy for search and query
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* mark xfail cases
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* mark xfail cases
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* optimize debug log
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>

* fix static-check
Signed-off-by: Nyudong.cai <yudong.cai@zilliz.com>
上级 faea66f8
......@@ -27,12 +27,6 @@ import (
"time"
"unsafe"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"go.uber.org/zap"
"github.com/golang/protobuf/proto"
......@@ -44,9 +38,12 @@ import (
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
......@@ -1265,6 +1262,49 @@ func (dct *DropCollectionTask) PostExecute(ctx context.Context) error {
return nil
}
func translateOutputFields(outputFields []string, schema *schemapb.CollectionSchema, addPrimary bool) ([]string, error) {
var primaryFieldName string
scalarFieldNameMap := make(map[string]bool)
vectorFieldNameMap := make(map[string]bool)
resultFieldNameMap := make(map[string]bool)
resultFieldNames := make([]string, 0)
for _, field := range schema.Fields {
if field.IsPrimaryKey {
primaryFieldName = field.Name
}
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector {
vectorFieldNameMap[field.Name] = true
} else {
scalarFieldNameMap[field.Name] = true
}
}
for _, outputFieldName := range outputFields {
outputFieldName = strings.TrimSpace(outputFieldName)
if outputFieldName == "*" {
for fieldName := range scalarFieldNameMap {
resultFieldNameMap[fieldName] = true
}
} else if outputFieldName == "%" {
for fieldName := range vectorFieldNameMap {
resultFieldNameMap[fieldName] = true
}
} else {
resultFieldNameMap[outputFieldName] = true
}
}
if addPrimary {
resultFieldNameMap[primaryFieldName] = true
}
for fieldName := range resultFieldNameMap {
resultFieldNames = append(resultFieldNames, fieldName)
}
return resultFieldNames, nil
}
type SearchTask struct {
Condition
*internalpb.SearchRequest
......@@ -1339,23 +1379,6 @@ func (st *SearchTask) getVChannels() ([]vChan, error) {
return st.chMgr.getVChannels(collID)
}
// https://github.com/milvus-io/milvus/issues/6411
// Support wildcard match
func translateOutputFields(outputFields []string, schema *schemapb.CollectionSchema) ([]string, error) {
if len(outputFields) == 1 && strings.TrimSpace(outputFields[0]) == "*" {
ret := make([]string, 0)
// fill all fields except vector fields
for _, field := range schema.Fields {
if field.DataType != schemapb.DataType_BinaryVector && field.DataType != schemapb.DataType_FloatVector {
ret = append(ret, field.Name)
}
}
return ret, nil
}
return outputFields, nil
}
func (st *SearchTask) PreExecute(ctx context.Context) error {
st.Base.MsgType = commonpb.MsgType_Search
st.Base.SourceID = Params.ProxyID
......@@ -1416,7 +1439,7 @@ func (st *SearchTask) PreExecute(ctx context.Context) error {
return err
}
outputFields, err := translateOutputFields(st.query.OutputFields, schema)
outputFields, err := translateOutputFields(st.query.OutputFields, schema, false)
if err != nil {
return err
}
......@@ -2099,7 +2122,7 @@ func (rt *RetrieveTask) PreExecute(ctx context.Context) error {
if err != nil {
return err
}
rt.retrieve.OutputFields, err = translateOutputFields(rt.retrieve.OutputFields, schema)
rt.retrieve.OutputFields, err = translateOutputFields(rt.retrieve.OutputFields, schema, true)
if err != nil {
return err
}
......
......@@ -390,102 +390,121 @@ func TestInsertTask_checkRowNums(t *testing.T) {
}
func TestTranslateOutputFields(t *testing.T) {
f1 := "field1"
f2 := "field2"
fvec := "fvec"
bvec := "bvec"
all := "*"
allWithWhiteSpace := " * "
allWithLeftWhiteSpace := " *"
allWithRightWhiteSpace := "* "
const (
idFieldName = "id"
tsFieldName = "timestamp"
floatVectorFieldName = "float_vector"
binaryVectorFieldName = "binary_vector"
)
var outputFields []string
var err error
// schema has no vector fields
schema1 := &schemapb.CollectionSchema{
schema := &schemapb.CollectionSchema{
Name: "TestTranslateOutputFields",
Description: "TestTranslateOutputFields",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{Name: f1, DataType: schemapb.DataType_Int64},
{Name: f2, DataType: schemapb.DataType_Int64},
{Name: idFieldName, DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{Name: tsFieldName, DataType: schemapb.DataType_Int64},
{Name: floatVectorFieldName, DataType: schemapb.DataType_FloatVector},
{Name: binaryVectorFieldName, DataType: schemapb.DataType_BinaryVector},
},
}
outputFields, err = translateOutputFields([]string{}, schema1)
outputFields, err = translateOutputFields([]string{}, schema, false)
assert.Equal(t, nil, err)
assert.Equal(t, []string{}, outputFields)
assert.ElementsMatch(t, []string{}, outputFields)
outputFields, err = translateOutputFields([]string{f1}, schema1)
outputFields, err = translateOutputFields([]string{idFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.Equal(t, []string{f1}, outputFields)
assert.ElementsMatch(t, []string{idFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{f2}, schema1)
outputFields, err = translateOutputFields([]string{idFieldName, tsFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.Equal(t, []string{f2}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{f1, f2}, schema1)
outputFields, err = translateOutputFields([]string{idFieldName, tsFieldName, floatVectorFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.Equal(t, []string{f1, f2}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{all}, schema1)
outputFields, err = translateOutputFields([]string{"*"}, schema, false)
assert.Equal(t, nil, err)
assert.Equal(t, []string{f1, f2}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{allWithWhiteSpace}, schema1)
outputFields, err = translateOutputFields([]string{" * "}, schema, false)
assert.Equal(t, nil, err)
assert.Equal(t, []string{f1, f2}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{allWithLeftWhiteSpace}, schema1)
outputFields, err = translateOutputFields([]string{"%"}, schema, false)
assert.Equal(t, nil, err)
assert.Equal(t, []string{f1, f2}, outputFields)
assert.ElementsMatch(t, []string{floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{allWithRightWhiteSpace}, schema1)
outputFields, err = translateOutputFields([]string{" % "}, schema, false)
assert.Equal(t, nil, err)
assert.Equal(t, []string{f1, f2}, outputFields)
assert.ElementsMatch(t, []string{floatVectorFieldName, binaryVectorFieldName}, outputFields)
// schema has vector fields
schema2 := &schemapb.CollectionSchema{
Name: "TestTranslateOutputFields",
Description: "TestTranslateOutputFields",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{Name: f1, DataType: schemapb.DataType_Int64},
{Name: f2, DataType: schemapb.DataType_Int64},
{Name: fvec, DataType: schemapb.DataType_FloatVector},
{Name: bvec, DataType: schemapb.DataType_BinaryVector},
},
}
outputFields, err = translateOutputFields([]string{"*", "%"}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"*", tsFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"*", floatVectorFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"%", floatVectorFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{"%", idFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
//=========================================================================
outputFields, err = translateOutputFields([]string{}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{idFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{idFieldName, tsFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{}, schema2)
outputFields, err = translateOutputFields([]string{idFieldName, tsFieldName, floatVectorFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.Equal(t, []string{}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{f1}, schema2)
outputFields, err = translateOutputFields([]string{"*"}, schema, true)
assert.Equal(t, nil, err)
assert.Equal(t, []string{f1}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{f2}, schema2)
outputFields, err = translateOutputFields([]string{"%"}, schema, true)
assert.Equal(t, nil, err)
assert.Equal(t, []string{f2}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{f1, f2}, schema2)
outputFields, err = translateOutputFields([]string{"*", "%"}, schema, true)
assert.Equal(t, nil, err)
assert.Equal(t, []string{f1, f2}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{all}, schema2)
outputFields, err = translateOutputFields([]string{"*", tsFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.Equal(t, []string{f1, f2}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{allWithWhiteSpace}, schema2)
outputFields, err = translateOutputFields([]string{"*", floatVectorFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.Equal(t, []string{f1, f2}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{allWithLeftWhiteSpace}, schema2)
outputFields, err = translateOutputFields([]string{"%", floatVectorFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.Equal(t, []string{f1, f2}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
outputFields, err = translateOutputFields([]string{allWithRightWhiteSpace}, schema2)
outputFields, err = translateOutputFields([]string{"%", idFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.Equal(t, []string{f1, f2}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields)
}
......@@ -402,10 +402,10 @@ func (q *queryCollection) doUnsolvedQueryMsg() {
default:
//time.Sleep(10 * time.Millisecond)
serviceTime := q.waitNewTSafe()
st, _ := tsoutil.ParseTS(serviceTime)
log.Debug("get tSafe from flow graph",
zap.Int64("collectionID", q.collectionID),
zap.Any("tSafe", st))
//st, _ := tsoutil.ParseTS(serviceTime)
//log.Debug("get tSafe from flow graph",
// zap.Int64("collectionID", q.collectionID),
// zap.Any("tSafe", st))
q.setServiceableTime(serviceTime)
//log.Debug("query node::doUnsolvedMsg: setServiceableTime", zap.Any("serviceTime", st))
......@@ -416,7 +416,7 @@ func (q *queryCollection) doUnsolvedQueryMsg() {
for _, m := range tempMsg {
guaranteeTs := m.GuaranteeTs()
gt, _ := tsoutil.ParseTS(guaranteeTs)
st, _ = tsoutil.ParseTS(serviceTime)
st, _ := tsoutil.ParseTS(serviceTime)
log.Debug("get query message from unsolvedMsg",
zap.Int64("collectionID", q.collectionID),
zap.Int64("msgID", m.ID()),
......
......@@ -1504,8 +1504,7 @@ func (c *Core) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRe
}, nil
}
log.Debug("ShowPartitions succeed", zap.String("role", Params.RoleName), zap.Int64("msgID", t.Req.Base.MsgID),
zap.String("collection name", in.CollectionName), zap.Strings("partition names", t.Rsp.PartitionNames),
zap.Int64s("partition ids", t.Rsp.PartitionIDs))
zap.String("collection name", in.CollectionName), zap.Int("num of partitions", len(t.Rsp.PartitionNames)))
metrics.RootCoordShowPartitionsCounter.WithLabelValues(metricProxy(in.Base.SourceID), MetricRequestsSuccess).Inc()
t.Rsp.Status = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
......
......@@ -24,6 +24,7 @@ class TestQueryBase(TestcaseBase):
"""
@pytest.mark.tags(CaseLabel.L0)
@pytest.mark.xfail(reason="issue #6650")
def test_query(self):
"""
target: test query
......@@ -52,6 +53,7 @@ class TestQueryBase(TestcaseBase):
assert len(res) == 0
@pytest.mark.tags(CaseLabel.L0)
@pytest.mark.xfail(reason="issue #6650")
def test_query_auto_id_collection(self):
"""
target: test query with auto_id=True collection
......@@ -135,6 +137,7 @@ class TestQueryBase(TestcaseBase):
collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6650")
def _test_query_expr_term(self):
"""
target: test query with TermExpr
......@@ -264,6 +267,7 @@ class TestQueryBase(TestcaseBase):
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6650")
def test_query_output_field_none(self):
"""
target: test query with none output field
......@@ -366,6 +370,7 @@ class TestQueryBase(TestcaseBase):
check_items=error)
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6650")
def test_query_empty_output_fields(self):
"""
target: test query with empty output fields
......@@ -393,6 +398,7 @@ class TestQueryBase(TestcaseBase):
check_items=error)
@pytest.mark.tags(CaseLabel.L0)
@pytest.mark.xfail(reason="issue #6650")
def test_query_partition(self):
"""
target: test query on partition
......@@ -426,6 +432,7 @@ class TestQueryBase(TestcaseBase):
check_task=CheckTasks.err_res, check_items=error)
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6650")
def test_query_default_partition(self):
"""
target: test query on default partition
......@@ -519,6 +526,7 @@ class TestQueryOperation(TestcaseBase):
check_items={ct.err_code: 1, ct.err_msg: clem.CollNotLoaded % collection_name})
@pytest.mark.tags(ct.CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6650")
@pytest.mark.parametrize("term_expr", [f'{ct.default_int64_field_name} in [0]'])
def test_query_expr_single_term_array(self, term_expr):
"""
......@@ -535,6 +543,7 @@ class TestQueryOperation(TestcaseBase):
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
@pytest.mark.tags(ct.CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6650")
@pytest.mark.parametrize("term_expr", [f'{ct.default_int64_field_name} in [0]'])
def test_query_binary_expr_single_term_array(self, term_expr, check_content):
"""
......@@ -552,6 +561,7 @@ class TestQueryOperation(TestcaseBase):
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
@pytest.mark.tags(ct.CaseLabel.L2)
@pytest.mark.xfail(reason="issue #6650")
def test_query_expr_all_term_array(self):
"""
target: test query with all array term expr
......@@ -621,6 +631,7 @@ class TestQueryOperation(TestcaseBase):
log.debug(res)
@pytest.mark.tags(ct.CaseLabel.L0)
@pytest.mark.xfail(reason="issue #6650")
def test_query_after_index(self):
"""
target: test query after creating index
......@@ -640,6 +651,7 @@ class TestQueryOperation(TestcaseBase):
collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec})
@pytest.mark.tags(ct.CaseLabel.L1)
@pytest.mark.xfail(reason="issue #6650")
def test_query_after_search(self):
"""
target: test query after search
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册