plan_parser.go 18.3 KB
Newer Older
1 2 3 4 5 6
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
7 8
// with the License. You may obtain a copy of the License at
//
9
//     http://www.apache.org/licenses/LICENSE-2.0
10
//
11 12 13 14 15
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
16

C
Cai Yudong 已提交
17
package proxy
18 19 20

import (
	"fmt"
21
	"math"
22
	"strings"
23 24 25 26 27 28 29 30

	ant_ast "github.com/antonmedv/expr/ast"
	ant_parser "github.com/antonmedv/expr/parser"
	"github.com/milvus-io/milvus/internal/proto/planpb"
	"github.com/milvus-io/milvus/internal/proto/schemapb"
	"github.com/milvus-io/milvus/internal/util/typeutil"
)

31
type parserContext struct {
32 33 34
	schema *typeutil.SchemaHelper
}

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
type optimizer struct {
	err error
}

func (*optimizer) Enter(*ant_ast.Node) {}

func (optimizer *optimizer) Exit(node *ant_ast.Node) {
	patch := func(newNode ant_ast.Node) {
		ant_ast.Patch(node, newNode)
	}

	switch node := (*node).(type) {
	case *ant_ast.UnaryNode:
		switch node.Operator {
		case "-":
			if i, ok := node.Node.(*ant_ast.IntegerNode); ok {
				patch(&ant_ast.IntegerNode{Value: -i.Value})
			} else if i, ok := node.Node.(*ant_ast.FloatNode); ok {
				patch(&ant_ast.FloatNode{Value: -i.Value})
			} else {
55
				optimizer.err = fmt.Errorf("invalid data type")
56 57 58 59 60 61 62 63
				return
			}
		case "+":
			if i, ok := node.Node.(*ant_ast.IntegerNode); ok {
				patch(&ant_ast.IntegerNode{Value: i.Value})
			} else if i, ok := node.Node.(*ant_ast.FloatNode); ok {
				patch(&ant_ast.FloatNode{Value: i.Value})
			} else {
64
				optimizer.err = fmt.Errorf("invalid data type")
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
				return
			}
		}

	case *ant_ast.BinaryNode:
		floatNodeLeft, leftFloat := node.Left.(*ant_ast.FloatNode)
		integerNodeLeft, leftInteger := node.Left.(*ant_ast.IntegerNode)
		floatNodeRight, rightFloat := node.Right.(*ant_ast.FloatNode)
		integerNodeRight, rightInteger := node.Right.(*ant_ast.IntegerNode)

		switch node.Operator {
		case "+":
			if leftFloat && rightFloat {
				patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value + floatNodeRight.Value})
			} else if leftFloat && rightInteger {
				patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value + float64(integerNodeRight.Value)})
			} else if leftInteger && rightFloat {
				patch(&ant_ast.FloatNode{Value: float64(integerNodeLeft.Value) + floatNodeRight.Value})
			} else if leftInteger && rightInteger {
				patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value + integerNodeRight.Value})
			} else {
86
				optimizer.err = fmt.Errorf("invalid data type")
87 88 89 90 91 92 93 94 95 96 97 98
				return
			}
		case "-":
			if leftFloat && rightFloat {
				patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value - floatNodeRight.Value})
			} else if leftFloat && rightInteger {
				patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value - float64(integerNodeRight.Value)})
			} else if leftInteger && rightFloat {
				patch(&ant_ast.FloatNode{Value: float64(integerNodeLeft.Value) - floatNodeRight.Value})
			} else if leftInteger && rightInteger {
				patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value - integerNodeRight.Value})
			} else {
99
				optimizer.err = fmt.Errorf("invalid data type")
100 101 102 103 104 105 106 107 108 109 110 111
				return
			}
		case "*":
			if leftFloat && rightFloat {
				patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value * floatNodeRight.Value})
			} else if leftFloat && rightInteger {
				patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value * float64(integerNodeRight.Value)})
			} else if leftInteger && rightFloat {
				patch(&ant_ast.FloatNode{Value: float64(integerNodeLeft.Value) * floatNodeRight.Value})
			} else if leftInteger && rightInteger {
				patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value * integerNodeRight.Value})
			} else {
112
				optimizer.err = fmt.Errorf("invalid data type")
113 114 115 116 117
				return
			}
		case "/":
			if leftFloat && rightFloat {
				if floatNodeRight.Value == 0 {
118
					optimizer.err = fmt.Errorf("divide by zero")
119 120 121 122 123
					return
				}
				patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value / floatNodeRight.Value})
			} else if leftFloat && rightInteger {
				if integerNodeRight.Value == 0 {
124
					optimizer.err = fmt.Errorf("divide by zero")
125 126 127 128 129
					return
				}
				patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value / float64(integerNodeRight.Value)})
			} else if leftInteger && rightFloat {
				if floatNodeRight.Value == 0 {
130
					optimizer.err = fmt.Errorf("divide by zero")
131 132 133 134 135
					return
				}
				patch(&ant_ast.FloatNode{Value: float64(integerNodeLeft.Value) / floatNodeRight.Value})
			} else if leftInteger && rightInteger {
				if integerNodeRight.Value == 0 {
136
					optimizer.err = fmt.Errorf("divide by zero")
137 138 139 140
					return
				}
				patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value / integerNodeRight.Value})
			} else {
141
				optimizer.err = fmt.Errorf("invalid data type")
142 143 144 145
				return
			}
		case "%":
			if leftInteger && rightInteger {
146 147 148 149
				if integerNodeRight.Value == 0 {
					optimizer.err = fmt.Errorf("modulo by zero")
					return
				}
150 151
				patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value % integerNodeRight.Value})
			} else {
152
				optimizer.err = fmt.Errorf("invalid data type")
153 154 155 156 157 158 159 160 161 162 163 164
				return
			}
		case "**":
			if leftFloat && rightFloat {
				patch(&ant_ast.FloatNode{Value: math.Pow(floatNodeLeft.Value, floatNodeRight.Value)})
			} else if leftFloat && rightInteger {
				patch(&ant_ast.FloatNode{Value: math.Pow(floatNodeLeft.Value, float64(integerNodeRight.Value))})
			} else if leftInteger && rightFloat {
				patch(&ant_ast.FloatNode{Value: math.Pow(float64(integerNodeLeft.Value), floatNodeRight.Value)})
			} else if leftInteger && rightInteger {
				patch(&ant_ast.IntegerNode{Value: int(math.Pow(float64(integerNodeLeft.Value), float64(integerNodeRight.Value)))})
			} else {
165
				optimizer.err = fmt.Errorf("invalid data type")
166 167 168 169 170 171
				return
			}
		}
	}
}

172 173 174 175
func parseExpr(schema *typeutil.SchemaHelper, exprStr string) (*planpb.Expr, error) {
	if exprStr == "" {
		return nil, nil
	}
176 177 178 179
	ast, err := ant_parser.Parse(exprStr)
	if err != nil {
		return nil, err
	}
180 181 182 183 184 185 186

	optimizer := &optimizer{}
	ant_ast.Walk(&ast.Node, optimizer)
	if optimizer.err != nil {
		return nil, optimizer.err
	}

187
	pc := parserContext{schema}
188
	expr, err := pc.handleExpr(&ast.Node)
189 190 191 192 193
	if err != nil {
		return nil, err
	}

	return expr, nil
194 195
}

196
func createColumnInfo(field *schemapb.FieldSchema) *planpb.ColumnInfo {
197
	return &planpb.ColumnInfo{
198 199 200
		FieldId:      field.FieldID,
		DataType:     field.DataType,
		IsPrimaryKey: field.IsPrimaryKey,
201 202 203
	}
}

204
func isSameOrder(opStr1, opStr2 string) bool {
205 206
	isLess1 := (opStr1 == "<") || (opStr1 == "<=")
	isLess2 := (opStr2 == "<") || (opStr2 == "<=")
207
	return isLess1 == isLess2
208 209
}

C
Cai Yudong 已提交
210 211 212 213
func getCompareOpType(opStr string, reverse bool) (op planpb.OpType) {
	switch opStr {
	case ">":
		if reverse {
214
			op = planpb.OpType_LessThan
C
Cai Yudong 已提交
215
		} else {
216
			op = planpb.OpType_GreaterThan
217
		}
C
Cai Yudong 已提交
218 219
	case "<":
		if reverse {
220
			op = planpb.OpType_GreaterThan
C
Cai Yudong 已提交
221 222 223 224 225
		} else {
			op = planpb.OpType_LessThan
		}
	case ">=":
		if reverse {
226
			op = planpb.OpType_LessEqual
C
Cai Yudong 已提交
227 228 229 230 231
		} else {
			op = planpb.OpType_GreaterEqual
		}
	case "<=":
		if reverse {
232
			op = planpb.OpType_GreaterEqual
C
Cai Yudong 已提交
233 234
		} else {
			op = planpb.OpType_LessEqual
235
		}
C
Cai Yudong 已提交
236 237 238 239 240 241
	case "==":
		op = planpb.OpType_Equal
	case "!=":
		op = planpb.OpType_NotEqual
	default:
		op = planpb.OpType_Invalid
242 243 244 245
	}
	return op
}

246 247 248 249 250 251 252 253 254 255 256
func getLogicalOpType(opStr string) planpb.BinaryExpr_BinaryOp {
	switch opStr {
	case "&&", "and":
		return planpb.BinaryExpr_LogicalAnd
	case "||", "or":
		return planpb.BinaryExpr_LogicalOr
	default:
		return planpb.BinaryExpr_Invalid
	}
}

257 258 259
func parseBoolNode(nodeRaw *ant_ast.Node) *ant_ast.BoolNode {
	switch node := (*nodeRaw).(type) {
	case *ant_ast.IdentifierNode:
260
		// bool node only accept value 'true' or 'false'
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
		val := strings.ToLower(node.Value)
		if val == "true" {
			return &ant_ast.BoolNode{
				Value: true,
			}
		} else if val == "false" {
			return &ant_ast.BoolNode{
				Value: false,
			}
		} else {
			return nil
		}
	default:
		return nil
	}
}

278
func (pc *parserContext) createCmpExpr(left, right ant_ast.Node, operator string) (*planpb.Expr, error) {
279 280 281 282 283 284
	if boolNode := parseBoolNode(&left); boolNode != nil {
		left = boolNode
	}
	if boolNode := parseBoolNode(&right); boolNode != nil {
		right = boolNode
	}
C
Cai Yudong 已提交
285 286
	idNodeLeft, okLeft := left.(*ant_ast.IdentifierNode)
	idNodeRight, okRight := right.(*ant_ast.IdentifierNode)
287

C
Cai Yudong 已提交
288
	if okLeft && okRight {
289
		leftField, err := pc.handleIdentifier(idNodeLeft)
290 291 292
		if err != nil {
			return nil, err
		}
293
		rightField, err := pc.handleIdentifier(idNodeRight)
294 295 296 297 298
		if err != nil {
			return nil, err
		}
		op := getCompareOpType(operator, false)
		if op == planpb.OpType_Invalid {
299
			return nil, fmt.Errorf("invalid binary operator(%s)", operator)
300 301 302 303
		}
		expr := &planpb.Expr{
			Expr: &planpb.Expr_CompareExpr{
				CompareExpr: &planpb.CompareExpr{
304 305
					LeftColumnInfo:  createColumnInfo(leftField),
					RightColumnInfo: createColumnInfo(rightField),
306
					Op:              op,
307 308 309 310 311 312
				},
			},
		}
		return expr, nil
	}

313
	var idNode *ant_ast.IdentifierNode
C
Cai Yudong 已提交
314
	var reverse bool
315
	var valueNode *ant_ast.Node
C
Cai Yudong 已提交
316
	if okLeft {
317
		idNode = idNodeLeft
C
Cai Yudong 已提交
318
		reverse = false
319
		valueNode = &right
C
Cai Yudong 已提交
320
	} else if okRight {
321
		idNode = idNodeRight
C
Cai Yudong 已提交
322
		reverse = true
323
		valueNode = &left
324 325
	} else {
		return nil, fmt.Errorf("compare expr has no identifier")
326
	}
327

328
	field, err := pc.handleIdentifier(idNode)
329 330 331 332
	if err != nil {
		return nil, err
	}

333
	val, err := pc.handleLeafValue(valueNode, field.DataType)
334 335 336 337
	if err != nil {
		return nil, err
	}

C
Cai Yudong 已提交
338
	op := getCompareOpType(operator, reverse)
339
	if op == planpb.OpType_Invalid {
340
		return nil, fmt.Errorf("invalid binary operator(%s)", operator)
341 342
	}

343
	expr := &planpb.Expr{
344 345
		Expr: &planpb.Expr_UnaryRangeExpr{
			UnaryRangeExpr: &planpb.UnaryRangeExpr{
346
				ColumnInfo: createColumnInfo(field),
347 348
				Op:         op,
				Value:      val,
349 350 351
			},
		},
	}
352
	return expr, nil
353 354
}

355
func (pc *parserContext) handleCmpExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) {
356
	return pc.createCmpExpr(node.Left, node.Right, node.Operator)
357 358
}

359
func (pc *parserContext) handleLogicalExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) {
360 361
	op := getLogicalOpType(node.Operator)
	if op == planpb.BinaryExpr_Invalid {
362
		return nil, fmt.Errorf("invalid logical operator(%s)", node.Operator)
363 364
	}

365
	leftExpr, err := pc.handleExpr(&node.Left)
366 367 368 369
	if err != nil {
		return nil, err
	}

370
	rightExpr, err := pc.handleExpr(&node.Right)
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
	if err != nil {
		return nil, err
	}

	expr := &planpb.Expr{
		Expr: &planpb.Expr_BinaryExpr{
			BinaryExpr: &planpb.BinaryExpr{
				Op:    op,
				Left:  leftExpr,
				Right: rightExpr,
			},
		},
	}
	return expr, nil
}

387
func (pc *parserContext) handleArrayExpr(node *ant_ast.Node, dataType schemapb.DataType) ([]*planpb.GenericValue, error) {
388 389 390 391 392 393
	arrayNode, ok2 := (*node).(*ant_ast.ArrayNode)
	if !ok2 {
		return nil, fmt.Errorf("right operand of the InExpr must be array")
	}
	var arr []*planpb.GenericValue
	for _, element := range arrayNode.Nodes {
C
congqixia 已提交
394 395
		// use value inside
		// #nosec G601
396
		val, err := pc.handleLeafValue(&element, dataType)
397 398 399 400 401 402
		if err != nil {
			return nil, err
		}
		arr = append(arr, val)
	}
	return arr, nil
403 404
}

405
func (pc *parserContext) handleInExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) {
F
FluorineDog 已提交
406
	if node.Operator != "in" && node.Operator != "not in" {
407
		return nil, fmt.Errorf("invalid operator(%s)", node.Operator)
F
FluorineDog 已提交
408
	}
409 410 411 412
	idNode, ok := node.Left.(*ant_ast.IdentifierNode)
	if !ok {
		return nil, fmt.Errorf("left operand of the InExpr must be identifier")
	}
413
	field, err := pc.handleIdentifier(idNode)
414 415 416
	if err != nil {
		return nil, err
	}
417
	arrayData, err := pc.handleArrayExpr(&node.Right, field.DataType)
418 419 420 421 422 423 424
	if err != nil {
		return nil, err
	}

	expr := &planpb.Expr{
		Expr: &planpb.Expr_TermExpr{
			TermExpr: &planpb.TermExpr{
425
				ColumnInfo: createColumnInfo(field),
426 427 428 429
				Values:     arrayData,
			},
		},
	}
F
FluorineDog 已提交
430 431

	if node.Operator == "not in" {
432
		return pc.createNotExpr(expr)
F
FluorineDog 已提交
433
	}
434
	return expr, nil
435 436
}

437
func (pc *parserContext) combineUnaryRangeExpr(a, b *planpb.UnaryRangeExpr) *planpb.Expr {
438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
	if a.Op == planpb.OpType_LessEqual || a.Op == planpb.OpType_LessThan {
		a, b = b, a
	}

	lowerInclusive := (a.Op == planpb.OpType_GreaterEqual)
	upperInclusive := (b.Op == planpb.OpType_LessEqual)

	expr := &planpb.Expr{
		Expr: &planpb.Expr_BinaryRangeExpr{
			BinaryRangeExpr: &planpb.BinaryRangeExpr{
				ColumnInfo:     a.ColumnInfo,
				LowerInclusive: lowerInclusive,
				UpperInclusive: upperInclusive,
				LowerValue:     a.Value,
				UpperValue:     b.Value,
			},
		},
	}
	return expr
}

459
func (pc *parserContext) handleMultiCmpExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) {
460 461 462
	exprs := []*planpb.Expr{}
	curNode := node

C
cxytz01 已提交
463
	// handle multiple relational operators
464 465 466
	for {
		binNodeLeft, LeftOk := curNode.Left.(*ant_ast.BinaryNode)
		if !LeftOk {
467
			expr, err := pc.handleCmpExpr(curNode)
468 469 470 471 472 473 474
			if err != nil {
				return nil, err
			}
			exprs = append(exprs, expr)
			break
		}
		if isSameOrder(node.Operator, binNodeLeft.Operator) {
475
			expr, err := pc.createCmpExpr(binNodeLeft.Right, curNode.Right, curNode.Operator)
476 477 478 479 480
			if err != nil {
				return nil, err
			}
			exprs = append(exprs, expr)
			curNode = binNodeLeft
481 482
		} else {
			return nil, fmt.Errorf("illegal multi-range expr")
483 484 485
		}
	}

486 487
	// combine UnaryRangeExpr to BinaryRangeExpr
	var lastExpr *planpb.UnaryRangeExpr
488
	for i := len(exprs) - 1; i >= 0; i-- {
489 490
		if expr, ok := exprs[i].Expr.(*planpb.Expr_UnaryRangeExpr); ok {
			if lastExpr != nil && expr.UnaryRangeExpr.ColumnInfo.FieldId == lastExpr.ColumnInfo.FieldId {
491
				binaryRangeExpr := pc.combineUnaryRangeExpr(expr.UnaryRangeExpr, lastExpr)
492 493 494 495
				exprs = append(exprs[0:i], append([]*planpb.Expr{binaryRangeExpr}, exprs[i+2:]...)...)
				lastExpr = nil
			} else {
				lastExpr = expr.UnaryRangeExpr
496 497 498 499 500 501
			}
		} else {
			lastExpr = nil
		}
	}

502
	// use `&&` to connect exprs
503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518
	combinedExpr := exprs[len(exprs)-1]
	for i := len(exprs) - 2; i >= 0; i-- {
		expr := exprs[i]
		combinedExpr = &planpb.Expr{
			Expr: &planpb.Expr_BinaryExpr{
				BinaryExpr: &planpb.BinaryExpr{
					Op:    planpb.BinaryExpr_LogicalAnd,
					Left:  combinedExpr,
					Right: expr,
				},
			},
		}
	}
	return combinedExpr, nil
}

519
func (pc *parserContext) handleBinaryExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) {
520
	switch node.Operator {
521
	case "<", "<=", ">", ">=":
522
		return pc.handleMultiCmpExpr(node)
523
	case "==", "!=":
524
		return pc.handleCmpExpr(node)
525
	case "and", "or", "&&", "||":
526
		return pc.handleLogicalExpr(node)
527
	case "in", "not in":
528
		return pc.handleInExpr(node)
529
	}
530
	return nil, fmt.Errorf("unsupported binary operator %s", node.Operator)
531 532
}

533
func (pc *parserContext) createNotExpr(childExpr *planpb.Expr) (*planpb.Expr, error) {
534 535 536 537 538 539 540 541 542
	expr := &planpb.Expr{
		Expr: &planpb.Expr_UnaryExpr{
			UnaryExpr: &planpb.UnaryExpr{
				Op:    planpb.UnaryExpr_Not,
				Child: childExpr,
			},
		},
	}
	return expr, nil
543 544
}

545
func (pc *parserContext) handleLeafValue(nodeRaw *ant_ast.Node, dataType schemapb.DataType) (gv *planpb.GenericValue, err error) {
546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
	switch node := (*nodeRaw).(type) {
	case *ant_ast.FloatNode:
		if typeutil.IsFloatingType(dataType) {
			gv = &planpb.GenericValue{
				Val: &planpb.GenericValue_FloatVal{
					FloatVal: node.Value,
				},
			}
		} else {
			return nil, fmt.Errorf("type mismatch")
		}
	case *ant_ast.IntegerNode:
		if typeutil.IsFloatingType(dataType) {
			gv = &planpb.GenericValue{
				Val: &planpb.GenericValue_FloatVal{
					FloatVal: float64(node.Value),
				},
			}
564
		} else if typeutil.IsIntegerType(dataType) {
565 566 567 568 569 570 571 572 573
			gv = &planpb.GenericValue{
				Val: &planpb.GenericValue_Int64Val{
					Int64Val: int64(node.Value),
				},
			}
		} else {
			return nil, fmt.Errorf("type mismatch")
		}
	case *ant_ast.BoolNode:
574
		if typeutil.IsBoolType(dataType) {
575 576 577 578 579 580 581 582 583 584 585 586 587 588 589
			gv = &planpb.GenericValue{
				Val: &planpb.GenericValue_BoolVal{
					BoolVal: node.Value,
				},
			}
		} else {
			return nil, fmt.Errorf("type mismatch")
		}
	default:
		return nil, fmt.Errorf("unsupported leaf node")
	}

	return gv, nil
}

590
func (pc *parserContext) handleIdentifier(node *ant_ast.IdentifierNode) (*schemapb.FieldSchema, error) {
591
	fieldName := node.Value
592
	field, err := pc.schema.GetFieldFromName(fieldName)
593 594 595
	return field, err
}

596
func (pc *parserContext) handleUnaryExpr(node *ant_ast.UnaryNode) (*planpb.Expr, error) {
597 598
	switch node.Operator {
	case "!", "not":
599
		subExpr, err := pc.handleExpr(&node.Node)
600 601 602
		if err != nil {
			return nil, err
		}
603
		return pc.createNotExpr(subExpr)
604 605 606 607 608
	default:
		return nil, fmt.Errorf("invalid unary operator(%s)", node.Operator)
	}
}

609
func (pc *parserContext) handleExpr(nodeRaw *ant_ast.Node) (*planpb.Expr, error) {
610 611 612 613 614 615 616
	switch node := (*nodeRaw).(type) {
	case *ant_ast.IdentifierNode,
		*ant_ast.FloatNode,
		*ant_ast.IntegerNode,
		*ant_ast.BoolNode:
		return nil, fmt.Errorf("scalar expr is not supported yet")
	case *ant_ast.UnaryNode:
617
		expr, err := pc.handleUnaryExpr(node)
618 619 620 621 622
		if err != nil {
			return nil, err
		}
		return expr, nil
	case *ant_ast.BinaryNode:
623
		return pc.handleBinaryExpr(node)
624 625 626 627 628
	default:
		return nil, fmt.Errorf("unsupported node (%s)", node.Type().String())
	}
}

629
func createQueryPlan(schemaPb *schemapb.CollectionSchema, exprStr string, vectorFieldName string, queryInfo *planpb.QueryInfo) (*planpb.PlanNode, error) {
630 631 632 633 634
	schema, err := typeutil.CreateSchemaHelper(schemaPb)
	if err != nil {
		return nil, err
	}

635
	expr, err := parseExpr(schema, exprStr)
636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662
	if err != nil {
		return nil, err
	}
	vectorField, err := schema.GetFieldFromName(vectorFieldName)
	if err != nil {
		return nil, err
	}
	fieldID := vectorField.FieldID
	dataType := vectorField.DataType

	if !typeutil.IsVectorType(dataType) {
		return nil, fmt.Errorf("field (%s) to search is not of vector data type", vectorFieldName)
	}

	planNode := &planpb.PlanNode{
		Node: &planpb.PlanNode_VectorAnns{
			VectorAnns: &planpb.VectorANNS{
				IsBinary:       dataType == schemapb.DataType_BinaryVector,
				Predicates:     expr,
				QueryInfo:      queryInfo,
				PlaceholderTag: "$0",
				FieldId:        fieldID,
			},
		},
	}
	return planNode, nil
}
Y
yukun 已提交
663

664
func createExprPlan(schemaPb *schemapb.CollectionSchema, exprStr string) (*planpb.PlanNode, error) {
Y
yukun 已提交
665 666 667 668 669
	schema, err := typeutil.CreateSchemaHelper(schemaPb)
	if err != nil {
		return nil, err
	}

670
	expr, err := parseExpr(schema, exprStr)
Y
yukun 已提交
671 672 673 674 675 676 677 678 679 680 681
	if err != nil {
		return nil, err
	}

	planNode := &planpb.PlanNode{
		Node: &planpb.PlanNode_Predicates{
			Predicates: expr,
		},
	}
	return planNode, nil
}