From 03c0415ec62dd33a8e846ee3d28ea2f3baac4592 Mon Sep 17 00:00:00 2001 From: ThreadDao Date: Fri, 21 Aug 2020 14:19:13 +0800 Subject: [PATCH] [skip ci] fix generate default entities (#3382) * java main class Signed-off-by: zongyufen * [skip ci] fix java sdk test Signed-off-by: zongyufen * [skip ci] fix java sdk test Signed-off-by: zongyufen * [skip ci] fix generate default entities Signed-off-by: zongyufen --- .../main/java/com/TestCollectionCount.java | 9 +-- .../src/main/java/com/TestCollectionInfo.java | 8 +-- .../src/main/java/com/TestGetEntityByID.java | 5 +- .../src/main/java/com/TestInsertEntities.java | 29 +++++----- .../src/main/java/com/Utils.java | 56 ++++++++++++++----- 5 files changed, 67 insertions(+), 40 deletions(-) diff --git a/tests/milvus-java-test/src/main/java/com/TestCollectionCount.java b/tests/milvus-java-test/src/main/java/com/TestCollectionCount.java index 1c4eb4195..f1fd6dc07 100644 --- a/tests/milvus-java-test/src/main/java/com/TestCollectionCount.java +++ b/tests/milvus-java-test/src/main/java/com/TestCollectionCount.java @@ -4,6 +4,7 @@ import io.milvus.client.*; import org.testng.Assert; import org.testng.annotations.Test; +import java.nio.ByteBuffer; import java.util.List; import java.util.Map; @@ -11,10 +12,10 @@ public class TestCollectionCount { int segmentRowCount = 5000; int dimension = 128; int nb = 10000; -// List> vectors = Utils.genVectors(nb, dimension, true); -// List vectorsBinary = Utils.genBinaryVectors(nb, dimension); - List> defaultEntities = Utils.genDefaultEntities(dimension,nb,false); - List> defaultBinaryEntities = Utils.genDefaultEntities(dimension,nb,true); + List> vectors = Utils.genVectors(nb, dimension, true); + List vectorsBinary = Utils.genBinaryVectors(nb, dimension); + List> defaultEntities = Utils.genDefaultEntities(dimension,nb,vectors); + List> defaultBinaryEntities = Utils.genDefaultBinaryEntities(dimension,nb,vectorsBinary); @Test(dataProvider = "Collection", dataProviderClass = MainClass.class) public void testCollectionCountNoVectors(MilvusClient client, String collectionName) { diff --git a/tests/milvus-java-test/src/main/java/com/TestCollectionInfo.java b/tests/milvus-java-test/src/main/java/com/TestCollectionInfo.java index 001e8eb2f..2d95d9494 100644 --- a/tests/milvus-java-test/src/main/java/com/TestCollectionInfo.java +++ b/tests/milvus-java-test/src/main/java/com/TestCollectionInfo.java @@ -20,10 +20,10 @@ public class TestCollectionInfo { String defaultIndexType = "FLAT"; String metricType = "L2"; String indexParam = Utils.setIndexParam(indexType,metricType,nList); -// List> vectors = Utils.genVectors(nb, dimension, true); -// List vectorsBinary = Utils.genBinaryVectors(nb, dimension); - List> defaultEntities = Utils.genDefaultEntities(dimension,nb,false); - List> defaultBinaryEntities = Utils.genDefaultEntities(dimension,nb,true); + List> vectors = Utils.genVectors(nb, dimension, true); + List vectorsBinary = Utils.genBinaryVectors(nb, dimension); + List> defaultEntities = Utils.genDefaultEntities(dimension,nb,vectors); + List> defaultBinaryEntities = Utils.genDefaultBinaryEntities(dimension,nb,vectorsBinary); @Test(dataProvider = "Collection", dataProviderClass = MainClass.class) public void testGetEntityIdsAfterDeleteEntities(MilvusClient client, String collectionName) { diff --git a/tests/milvus-java-test/src/main/java/com/TestGetEntityByID.java b/tests/milvus-java-test/src/main/java/com/TestGetEntityByID.java index 8d7702220..a039a6f7d 100644 --- a/tests/milvus-java-test/src/main/java/com/TestGetEntityByID.java +++ b/tests/milvus-java-test/src/main/java/com/TestGetEntityByID.java @@ -27,9 +27,10 @@ // client.flush(collectionName); // GetEntityByIDResponse res = client.getEntityByID(collectionName, ids.subList(0, get_length)); // assert (res.getResponse().ok()); +// assert (res.getValidIds(), ids.subList(0, get_length)); // for (int i = 0; i < get_length; i++) { -// List> fields = res.getFieldsMap(); -// assert (res.getFieldsMap().get(i).equals(vectors.get(i))); +// List> fieldsMap = res.getFieldsMap(); +// assert (fieldsMap.get(i).get("float_vector").equals(defaultEntities.get(defaultEntities.size()-1).get("values").get("float_vector"))); // } // } // diff --git a/tests/milvus-java-test/src/main/java/com/TestInsertEntities.java b/tests/milvus-java-test/src/main/java/com/TestInsertEntities.java index 80b10afe7..bad09a656 100644 --- a/tests/milvus-java-test/src/main/java/com/TestInsertEntities.java +++ b/tests/milvus-java-test/src/main/java/com/TestInsertEntities.java @@ -9,16 +9,15 @@ import java.nio.ByteBuffer; import java.util.*; import java.util.stream.Collectors; import java.util.stream.LongStream; -import java.util.stream.Stream; public class TestInsertEntities { int dimension = 128; String tag = "tag"; int nb = 8000; -// List> vectors = Utils.genVectors(nb, dimension, true); -// List vectorsBinary = Utils.genBinaryVectors(nb, dimension); - List> defaultEntities = Utils.genDefaultEntities(dimension,nb,false); - List> defaultBinaryEntities = Utils.genDefaultEntities(dimension,nb,true); + List> vectors = Utils.genVectors(nb, dimension, true); + List vectorsBinary = Utils.genBinaryVectors(nb, dimension); + List> defaultEntities = Utils.genDefaultEntities(dimension,nb,vectors); + List> defaultBinaryEntities = Utils.genDefaultBinaryEntities(dimension,nb,vectorsBinary); @Test(dataProvider = "Collection", dataProviderClass = MainClass.class) public void testInsertEntitiesCollectionNotExisted(MilvusClient client, String collectionName) throws InterruptedException { @@ -72,19 +71,19 @@ public class TestInsertEntities { assert(!res.getResponse().ok()); } - @Test(dataProvider = "Collection", dataProviderClass = MainClass.class) - public void testInsertEntityWithInvalidDimension(MilvusClient client, String collectionName) { -// vectors.get(0).add((float) 0); - List> entities = Utils.genDefaultEntities(dimension+1,nb,false); - InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(entities).build(); - InsertResponse res = client.insert(insertParam); - assert(!res.getResponse().ok()); - } +// @Test(dataProvider = "Collection", dataProviderClass = MainClass.class) +// public void testInsertEntityWithInvalidDimension(MilvusClient client, String collectionName) { +//// vectors.get(0).add((float) 0); +// List> entities = Utils.genDefaultEntities(dimension+1,nb,vectors); +// InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(entities).build(); +// InsertResponse res = client.insert(insertParam); +// assert(!res.getResponse().ok()); +// } @Test(dataProvider = "Collection", dataProviderClass = MainClass.class) public void testInsertEntityWithInvalidVectors(MilvusClient client, String collectionName) { // vectors.set(0, new ArrayList<>()); - List> invalidEntities = Utils.genDefaultEntities(dimension,nb,false); + List> invalidEntities = Utils.genDefaultEntities(dimension,nb,new ArrayList<>()); invalidEntities.forEach(entity ->{ if("float_vector".equals(entity.get("field"))){ entity.put("values",new ArrayList<>()); @@ -171,7 +170,7 @@ public class TestInsertEntities { @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class) public void testInsertBinaryEntityWithInvalidDimension(MilvusClient client, String collectionName) { List vectorsBinary = Utils.genBinaryVectors(nb, dimension-1); - List> binaryEntities = Utils.genDefaultEntities(dimension-1,nb,true); + List> binaryEntities = Utils.genDefaultBinaryEntities(dimension-1,nb,vectorsBinary); InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(binaryEntities).build(); InsertResponse res = client.insert(insertParam); assert(!res.getResponse().ok()); diff --git a/tests/milvus-java-test/src/main/java/com/Utils.java b/tests/milvus-java-test/src/main/java/com/Utils.java index f5ee4134d..6578d03de 100644 --- a/tests/milvus-java-test/src/main/java/com/Utils.java +++ b/tests/milvus-java-test/src/main/java/com/Utils.java @@ -79,17 +79,17 @@ public class Utils { return defaultFieldList; } - public static List> genDefaultEntities(int dimension, int vectorCount, boolean isBinary){ - List> fields = genDefaultFields(dimension, isBinary); + public static List> genDefaultEntities(int dimension, int vectorCount, List> vectors){ + List> fieldsMap = genDefaultFields(dimension, false); List intValues = new ArrayList<>(vectorCount); List floatValues = new ArrayList<>(vectorCount); - List> vectors = genVectors(vectorCount,dimension,false); - List binaryVectors = genBinaryVectors(vectorCount,dimension); +// List> vectors = genVectors(vectorCount,dimension,false); +// List binaryVectors = genBinaryVectors(vectorCount,dimension); for (int i = 0; i < vectorCount; ++i) { intValues.add((long) i); floatValues.add((float) i); } - for(Map field: fields){ + for(Map field: fieldsMap){ String fieldType = field.get("field").toString(); switch (fieldType){ case "int64": @@ -101,11 +101,36 @@ public class Utils { case "float_vector": field.put("values",vectors); break; + } + } + return fieldsMap; + } + + public static List> genDefaultBinaryEntities(int dimension, int vectorCount, List vectorsBinary){ + List> binaryFieldsMap = genDefaultFields(dimension, true); + List intValues = new ArrayList<>(vectorCount); + List floatValues = new ArrayList<>(vectorCount); +// List> vectors = genVectors(vectorCount,dimension,false); +// List binaryVectors = genBinaryVectors(vectorCount,dimension); + for (int i = 0; i < vectorCount; ++i) { + intValues.add((long) i); + floatValues.add((float) i); + } + for(Map field: binaryFieldsMap){ + String fieldType = field.get("field").toString(); + switch (fieldType){ + case "int64": + field.put("values",intValues); + break; + case "float": + field.put("values",floatValues); + break; case "binary_vector": - field.put("values",binaryVectors); + field.put("values",vectorsBinary); + break; } } - return fields; + return binaryFieldsMap; } public static String setIndexParam(String indexType, String metricType, int nlist) { @@ -149,12 +174,13 @@ public class Utils { Integer value = jsonObject.getInteger(key); return value; } -// public static List getVector(List> entities, int i){ -// List vector = new ArrayList<>(); -// entities.forEach(entity -> { -// if("float_vector".equals(entity.get("field"))){ -// vector.add(entity.get("values").get(i)); -// } -// }); -// } + public static List getVector(List> entities, int i){ + List vector = new ArrayList<>(); + entities.forEach(entity -> { + if("float_vector".equals(entity.get("field")) && Objects.nonNull(entity.get("values"))){ + vector.add(((List)entity.get("values")).get(i)); + } + }); + return vector; + } } -- GitLab