diff --git a/internal/util/importutil/json_handler.go b/internal/util/importutil/json_handler.go index e5b7e499bb7d80981ac8b2590ea105282d616641..ce6b786fb485972eaa48afad5c6137c35490902d 100644 --- a/internal/util/importutil/json_handler.go +++ b/internal/util/importutil/json_handler.go @@ -47,6 +47,22 @@ type Validator struct { fieldName string // field name } +func getPrimaryKey(obj interface{}, fieldName string, isString bool) (string, error) { + // varchar type primary field, the value must be a string + if isString { + if value, ok := obj.(string); ok { + return value, nil + } + return "", fmt.Errorf("illegal value '%v' for varchar type primary key field '%s'", obj, fieldName) + } + + // int64 type primary field, the value must be json.Number + if num, ok := obj.(json.Number); ok { + return string(num), nil + } + return "", fmt.Errorf("illegal value '%v' for int64 type primary key field '%s'", obj, fieldName) +} + // JSONRowConsumer is row-based json format consumer class type JSONRowConsumer struct { collectionSchema *schemapb.CollectionSchema // collection schema @@ -225,6 +241,7 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { // consume rows for i := 0; i < len(rows); i++ { row := rows[i] + rowNumber := v.rowCounter + int64(i) // hash to a shard number var shard uint32 @@ -235,7 +252,13 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { } value := row[v.primaryKey] - pk := string(value.(string)) + pk, err := getPrimaryKey(value, primaryValidator.fieldName, primaryValidator.isString) + if err != nil { + log.Error("JSON row consumer: failed to parse primary key at the row", + zap.Int64("rowNumber", rowNumber), zap.Error(err)) + return fmt.Errorf("failed to parse primary key at the row %d, error: %w", rowNumber, err) + } + hash := typeutil.HashString2Uint32(pk) shard = hash % uint32(v.shardNum) pkArray := v.segmentsData[shard][v.primaryKey].(*storage.StringFieldData) @@ -247,22 +270,28 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { pk = rowIDBegin + int64(i) } else { value := row[v.primaryKey] + strValue, err := getPrimaryKey(value, primaryValidator.fieldName, primaryValidator.isString) + if err != nil { + log.Error("JSON row consumer: failed to parse primary key at the row", + zap.Int64("rowNumber", rowNumber), zap.Error(err)) + return fmt.Errorf("failed to parse primary key at the row %d, error: %w", rowNumber, err) + } + // parse the pk from a string - strValue := string(value.(json.Number)) pk, err = strconv.ParseInt(strValue, 10, 64) if err != nil { log.Error("JSON row consumer: failed to parse primary key at the row", - zap.String("value", strValue), zap.Int64("rowNumber", v.rowCounter+int64(i)), zap.Error(err)) + zap.String("value", strValue), zap.Int64("rowNumber", rowNumber), zap.Error(err)) return fmt.Errorf("failed to parse primary key '%s' at the row %d, error: %w", - strValue, v.rowCounter+int64(i), err) + strValue, rowNumber, err) } } hash, err := typeutil.Hash32Int64(pk) if err != nil { log.Error("JSON row consumer: failed to hash primary key at the row", - zap.Int64("key", pk), zap.Int64("rowNumber", v.rowCounter+int64(i)), zap.Error(err)) - return fmt.Errorf("failed to hash primary key %d at the row %d, error: %w", pk, v.rowCounter+int64(i), err) + zap.Int64("key", pk), zap.Int64("rowNumber", rowNumber), zap.Error(err)) + return fmt.Errorf("failed to hash primary key %d at the row %d, error: %w", pk, rowNumber, err) } shard = hash % uint32(v.shardNum) @@ -282,9 +311,9 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { value := row[name] if err := validator.convertFunc(value, v.segmentsData[shard][name]); err != nil { log.Error("JSON row consumer: failed to convert value for field at the row", - zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", v.rowCounter+int64(i)), zap.Error(err)) + zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", rowNumber), zap.Error(err)) return fmt.Errorf("failed to convert value for field '%s' at the row %d, error: %w", - validator.fieldName, v.rowCounter+int64(i), err) + validator.fieldName, rowNumber, err) } } } diff --git a/internal/util/importutil/json_handler_test.go b/internal/util/importutil/json_handler_test.go index 3aa495733b18b2f4c3c450eaeaa688f636f11944..aac860464a363180fd4e701979e5cd523d9b352e 100644 --- a/internal/util/importutil/json_handler_test.go +++ b/internal/util/importutil/json_handler_test.go @@ -18,6 +18,7 @@ package importutil import ( "context" + "encoding/json" "errors" "testing" @@ -243,10 +244,27 @@ func Test_JSONRowConsumerHandle(t *testing.T) { assert.Equal(t, int64(1), consumer.autoIDRange[0]) assert.Equal(t, int64(1+rowCount), consumer.autoIDRange[1]) - // pk is auto-generated byt IDAllocator is nil + // pk is auto-generated but IDAllocator is nil consumer.rowIDAllocator = nil err = consumer.Handle(input) assert.Error(t, err) + + // pk is not auto-generated, pk is not numeric value + input = make([]map[storage.FieldID]interface{}, 1) + input[0] = make(map[int64]interface{}) + input[0][101] = "1" + + schema.Fields[0].AutoID = false + consumer, err = NewJSONRowConsumer(schema, idAllocator, 1, 1, flushFunc) + assert.NotNil(t, consumer) + assert.Nil(t, err) + err = consumer.Handle(input) + assert.Error(t, err) + + // pk is numeric value, but cannot parsed + input[0][101] = json.Number("A1") + err = consumer.Handle(input) + assert.Error(t, err) }) t.Run("handle varchar pk", func(t *testing.T) { @@ -279,7 +297,7 @@ func Test_JSONRowConsumerHandle(t *testing.T) { input[j][101] = "abc" } - // varchar pk cannot be autoid + // varchar pk cannot be auto-generated err = consumer.Handle(input) assert.NotNil(t, err) @@ -293,5 +311,33 @@ func Test_JSONRowConsumerHandle(t *testing.T) { assert.Nil(t, err) assert.Equal(t, int64(rowCount), consumer.RowCount()) assert.Equal(t, 0, len(consumer.autoIDRange)) + + // pk is not string value + input = make([]map[storage.FieldID]interface{}, 1) + input[0] = make(map[int64]interface{}) + input[0][101] = false + err = consumer.Handle(input) + assert.Error(t, err) }) } + +func Test_GetPrimaryKey(t *testing.T) { + fieldName := "dummy" + var obj1 interface{} = "aa" + val, err := getPrimaryKey(obj1, fieldName, true) + assert.Equal(t, val, "aa") + assert.NoError(t, err) + + val, err = getPrimaryKey(obj1, fieldName, false) + assert.Empty(t, val) + assert.Error(t, err) + + var obj2 interface{} = json.Number("10") + val, err = getPrimaryKey(obj2, fieldName, false) + assert.Equal(t, val, "10") + assert.NoError(t, err) + + val, err = getPrimaryKey(obj2, fieldName, true) + assert.Empty(t, val) + assert.Error(t, err) +}