未验证 提交 d6471e9c 编写于 作者: D del-zhenwu 提交者: GitHub

update java tests (#2246)

Signed-off-by: Nzw <zw@milvus.io>
Co-authored-by: Nzw <zw@milvus.io>
上级 80751436
<?xml version="1.0" encoding="UTF-8"?>
<module org.jetbrains.idea.maven.project.MavenProjectsManager.isMavenModule="true" type="JAVA_MODULE" version="4">
<component name="NewModuleRootManager" LANGUAGE_LEVEL="JDK_1_8">
<output url="file://$MODULE_DIR$/target/classes" />
<output-test url="file://$MODULE_DIR$/target/test-classes" />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="library" name="Maven: io.milvus:milvus-sdk-java:0.6.0-SNAPSHOT" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-lang3:3.9" level="project" />
<orderEntry type="library" name="Maven: commons-cli:commons-cli:1.4" level="project" />
<orderEntry type="library" name="Maven: org.testng:testng:6.14.3" level="project" />
<orderEntry type="library" name="Maven: com.beust:jcommander:1.72" level="project" />
<orderEntry type="library" name="Maven: org.apache-extras.beanshell:bsh:2.0b6" level="project" />
<orderEntry type="library" name="Maven: com.alibaba:fastjson:1.2.47" level="project" />
<orderEntry type="library" name="Maven: junit:junit:4.13" level="project" />
<orderEntry type="library" name="Maven: org.hamcrest:hamcrest-core:1.3" level="project" />
<orderEntry type="library" name="Maven: io.milvus:milvus-sdk-java:0.8.0-SNAPSHOT" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven.plugins:maven-gpg-plugin:1.6" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-plugin-api:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-project:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-settings:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-profile:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-artifact-manager:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven.wagon:wagon-provider-api:1.0-beta-6" level="project" />
<orderEntry type="library" name="Maven: backport-util-concurrent:backport-util-concurrent:3.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-plugin-registry:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.codehaus.plexus:plexus-interpolation:1.11" level="project" />
<orderEntry type="library" name="Maven: org.codehaus.plexus:plexus-container-default:1.0-alpha-9-stable-1" level="project" />
<orderEntry type="library" name="Maven: classworlds:classworlds:1.1-alpha-2" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-artifact:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-repository-metadata:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-model:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.codehaus.plexus:plexus-utils:3.0.20" level="project" />
<orderEntry type="library" name="Maven: org.sonatype.plexus:plexus-sec-dispatcher:1.4" level="project" />
<orderEntry type="library" name="Maven: org.sonatype.plexus:plexus-cipher:1.4" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-netty-shaded:1.27.2" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: io.grpc:grpc-core:1.27.2" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: com.google.android:annotations:4.1.1.4" level="project" />
<orderEntry type="library" scope="RUNTIME" name="Maven: io.perfmark:perfmark-api:0.19.0" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf:1.27.2" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-api:1.27.2" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-context:1.27.2" level="project" />
<orderEntry type="library" name="Maven: com.google.code.findbugs:jsr305:3.0.2" level="project" />
<orderEntry type="library" name="Maven: org.codehaus.mojo:animal-sniffer-annotations:1.18" level="project" />
<orderEntry type="library" name="Maven: com.google.protobuf:protobuf-java:3.11.0" level="project" />
<orderEntry type="library" name="Maven: com.google.guava:guava:28.1-android" level="project" />
<orderEntry type="library" name="Maven: com.google.guava:failureaccess:1.0.1" level="project" />
<orderEntry type="library" name="Maven: com.google.guava:listenablefuture:9999.0-empty-to-avoid-conflict-with-guava" level="project" />
<orderEntry type="library" name="Maven: org.checkerframework:checker-compat-qual:2.5.5" level="project" />
<orderEntry type="library" name="Maven: com.google.j2objc:j2objc-annotations:1.3" level="project" />
<orderEntry type="library" name="Maven: com.google.api.grpc:proto-google-common-protos:1.17.0" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf-lite:1.27.2" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-stub:1.27.2" level="project" />
<orderEntry type="library" name="Maven: com.google.protobuf:protobuf-java-util:3.11.0" level="project" />
<orderEntry type="library" name="Maven: com.google.code.gson:gson:2.8.6" level="project" />
<orderEntry type="library" name="Maven: com.google.errorprone:error_prone_annotations:2.3.4" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-text:1.6" level="project" />
<orderEntry type="library" name="Maven: org.apache.commons:commons-collections4:4.4" level="project" />
<orderEntry type="library" name="Maven: org.json:json:20190722" level="project" />
</component>
</module>
\ No newline at end of file
......@@ -28,6 +28,16 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>8</source>
<target>8</target>
</configuration>
<version>3.8.1</version>
</plugin>
</plugins>
</build>
......@@ -69,7 +79,7 @@
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.10</version>
<version>3.9</version>
</dependency>
<dependency>
......@@ -84,6 +94,13 @@
<version>6.14.3</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.47</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
......@@ -99,7 +116,7 @@
<dependency>
<groupId>io.milvus</groupId>
<artifactId>milvus-sdk-java</artifactId>
<version>0.4.1-SNAPSHOT</version>
<version>0.8.0-SNAPSHOT</version>
</dependency>
<!-- <dependency>-->
......
......@@ -15,7 +15,7 @@ import java.util.List;
public class MainClass {
private static String host = "127.0.0.1";
private static String port = "19532";
private static int port = 19530;
private int index_file_size = 50;
public int dimension = 128;
......@@ -23,7 +23,7 @@ public class MainClass {
MainClass.host = host;
}
public static void setPort(String port) {
public static void setPort(int port) {
MainClass.port = port;
}
......@@ -40,8 +40,8 @@ public class MainClass {
.withPort(port)
.build();
client.connect(connectParam);
String tableName = RandomStringUtils.randomAlphabetic(10);
return new Object[][]{{client, tableName}};
String collectionName = RandomStringUtils.randomAlphabetic(10);
return new Object[][]{{client, collectionName}};
}
@DataProvider(name="DisConnectInstance")
......@@ -58,16 +58,16 @@ public class MainClass {
} catch (InterruptedException e) {
e.printStackTrace();
}
String tableName = RandomStringUtils.randomAlphabetic(10);
return new Object[][]{{client, tableName}};
String collectionName = RandomStringUtils.randomAlphabetic(10);
return new Object[][]{{client, collectionName}};
}
@DataProvider(name="Table")
public Object[][] provideTable() throws ConnectFailedException {
Object[][] tables = new Object[2][2];
@DataProvider(name="Collection")
public Object[][] provideCollection() throws ConnectFailedException, InterruptedException {
Object[][] collections = new Object[2][2];
MetricType[] metricTypes = { MetricType.L2, MetricType.IP };
for (int i = 0; i < metricTypes.length; ++i) {
String tableName = metricTypes[i].toString()+"_"+RandomStringUtils.randomAlphabetic(10);
String collectionName = metricTypes[i].toString()+"_"+RandomStringUtils.randomAlphabetic(10);
// Generate connection instance
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
......@@ -75,20 +75,60 @@ public class MainClass {
.withPort(port)
.build();
client.connect(connectParam);
TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
// List<String> tableNames = client.showCollections().getCollectionNames();
// for (int j = 0; j < tableNames.size(); ++j
// ) {
// client.dropCollection(tableNames.get(j));
// }
// Thread.currentThread().sleep(2000);
CollectionMapping cm = new CollectionMapping.Builder(collectionName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(metricTypes[i])
.build();
Response res = client.createTable(tableSchema);
Response res = client.createCollection(cm);
if (!res.ok()) {
System.out.println(res.getMessage());
throw new SkipException("Table created failed");
throw new SkipException("Collection created failed");
}
tables[i] = new Object[]{client, tableName};
collections[i] = new Object[]{client, collectionName};
}
return tables;
return collections;
}
@DataProvider(name="BinaryCollection")
public Object[][] provideBinaryCollection() throws ConnectFailedException, InterruptedException {
Object[][] collections = new Object[3][2];
MetricType[] metricTypes = { MetricType.JACCARD, MetricType.HAMMING, MetricType.TANIMOTO };
for (int i = 0; i < metricTypes.length; ++i) {
String collectionName = metricTypes[i].toString()+"_"+RandomStringUtils.randomAlphabetic(10);
// Generate connection instance
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.build();
client.connect(connectParam);
// List<String> tableNames = client.showCollections().getCollectionNames();
// for (int j = 0; j < tableNames.size(); ++j
// ) {
// client.dropCollection(tableNames.get(j));
// }
// Thread.currentThread().sleep(2000);
CollectionMapping cm = new CollectionMapping.Builder(collectionName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(metricTypes[i])
.build();
Response res = client.createCollection(cm);
if (!res.ok()) {
System.out.println(res.getMessage());
throw new SkipException("Collection created failed");
}
collections[i] = new Object[]{client, collectionName};
}
return collections;
}
public static void main(String[] args) {
CommandLineParser parser = new DefaultParser();
Options options = new Options();
......@@ -102,7 +142,7 @@ public class MainClass {
}
String port = cmd.getOptionValue("port");
if (port != null) {
setPort(port);
setPort(Integer.parseInt(port));
}
System.out.println("Host: "+host+", Port: "+port);
}
......@@ -110,13 +150,6 @@ public class MainClass {
System.err.println("Parsing failed. Reason: " + exp.getMessage() );
}
// TestListenerAdapter tla = new TestListenerAdapter();
// TestNG testng = new TestNG();
// testng.setTestClasses(new Class[] { TestPing.class });
// testng.setTestClasses(new Class[] { TestConnect.class });
// testng.addListener(tla);
// testng.run();
XmlSuite suite = new XmlSuite();
suite.setName("TmpSuite");
......@@ -129,9 +162,15 @@ public class MainClass {
classes.add(new XmlClass("com.TestConnect"));
classes.add(new XmlClass("com.TestDeleteVectors"));
classes.add(new XmlClass("com.TestIndex"));
classes.add(new XmlClass("com.TestCompact"));
classes.add(new XmlClass("com.TestSearchVectors"));
classes.add(new XmlClass("com.TestTable"));
classes.add(new XmlClass("com.TestTableCount"));
classes.add(new XmlClass("com.TestCollection"));
classes.add(new XmlClass("com.TestCollectionCount"));
classes.add(new XmlClass("com.TestFlush"));
classes.add(new XmlClass("com.TestPartition"));
classes.add(new XmlClass("com.TestGetVectorByID"));
classes.add(new XmlClass("com.TestCollectionInfo"));
classes.add(new XmlClass("com.TestSearchByIds"));
test.setXmlClasses(classes) ;
......
......@@ -5,6 +5,7 @@ import org.apache.commons.lang3.RandomStringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
......@@ -12,188 +13,164 @@ import java.util.stream.Stream;
public class TestAddVectors {
int dimension = 128;
String tag = "tag";
public List<List<Float>> gen_vectors(Integer nb) {
List<List<Float>> xb = new LinkedList<>();
Random random = new Random();
for (int i = 0; i < nb; ++i) {
LinkedList<Float> vector = new LinkedList<>();
for (int j = 0; j < dimension; j++) {
vector.add(random.nextFloat());
}
xb.add(vector);
}
return xb;
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
int nb = 1000;
List<List<Float>> vectors = gen_vectors(nb);
String tableNameNew = tableName + "_";
InsertParam insertParam = new InsertParam.Builder(tableNameNew, vectors).build();
int nb = 8000;
List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_vectors_collection_not_existed(MilvusClient client, String collectionName) throws InterruptedException {
String collectionNameNew = collectionName + "_";
InsertParam insertParam = new InsertParam.Builder(collectionNameNew).withFloatVectors(vectors).build();
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_add_vectors_without_connect(MilvusClient client, String tableName) throws InterruptedException {
int nb = 100;
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
public void test_add_vectors_without_connect(MilvusClient client, String collectionName) throws InterruptedException {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors(MilvusClient client, String tableName) throws InterruptedException {
int nb = 1000;
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_vectors(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
Thread.currentThread().sleep(1000);
// Assert table row count
Assert.assertEquals(client.getTableRowCount(tableName).getTableRowCount(), nb);
Response res_flush = client.flush(collectionName);
assert(res_flush.ok());
// Assert collection row count
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), nb);
}
// @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
// public void test_add_vectors_timeout(MilvusClient client, String tableName) throws InterruptedException {
// int nb = 200000;
// List<List<Float>> vectors = gen_vectors(nb);
// System.out.println(new Date());
// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withTimeout(1).build();
// InsertResponse res = client.insert(insertParam);
// assert(!res.getResponse().ok());
// }
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_big_data(MilvusClient client, String tableName) throws InterruptedException {
int nb = 500000;
List<List<Float>> vectors = gen_vectors(nb);
System.out.println(new Date());
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_with_ids(MilvusClient client, String tableName) throws InterruptedException {
int nb = 1000;
List<List<Float>> vectors = gen_vectors(nb);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_vectors_with_ids(MilvusClient client, String collectionName) {
// Add vectors with ids
List<Long> vectorIds;
vectorIds = Stream.iterate(0L, n -> n)
.limit(nb)
.collect(Collectors.toList());
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build();
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
Thread.currentThread().sleep(2000);
// Assert table row count
Assert.assertEquals(client.getTableRowCount(tableName).getTableRowCount(), nb);
Response res_flush = client.flush(collectionName);
assert(res_flush.ok());
// Assert collection row count
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), nb);
}
// TODO: MS-628
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_with_invalid_ids(MilvusClient client, String tableName) {
int nb = 10;
List<List<Float>> vectors = gen_vectors(nb);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_vectors_with_invalid_ids(MilvusClient client, String collectionName) {
// Add vectors with ids
List<Long> vectorIds;
vectorIds = Stream.iterate(0L, n -> n)
.limit(nb+1)
.collect(Collectors.toList());
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build();
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_with_invalid_dimension(MilvusClient client, String tableName) {
int nb = 1000;
List<List<Float>> vectors = gen_vectors(nb);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_vectors_with_invalid_dimension(MilvusClient client, String collectionName) {
vectors.get(0).add((float) 0);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_with_invalid_vectors(MilvusClient client, String tableName) {
int nb = 1000;
List<List<Float>> vectors = gen_vectors(nb);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_vectors_with_invalid_vectors(MilvusClient client, String collectionName) {
vectors.set(0, new ArrayList<>());
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_repeatably(MilvusClient client, String tableName) throws InterruptedException {
int nb = 100000;
int loops = 10;
List<List<Float>> vectors = gen_vectors(nb);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertResponse res = null;
for (int i = 0; i < loops; ++i ) {
long startTime = System.currentTimeMillis();
res = client.insert(insertParam);
long endTime = System.currentTimeMillis();
System.out.println("Total execution time: " + (endTime-startTime) + "ms");
}
Thread.currentThread().sleep(1000);
// Assert table row count
Assert.assertEquals(client.getTableRowCount(tableName).getTableRowCount(), nb * loops);
}
// ----------------------------- partition cases in Insert ---------------------------------
// Add vectors into table with given tag
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_partition(MilvusClient client, String tableName) throws InterruptedException {
int nb = 1000;
List<List<Float>> vectors = gen_vectors(nb);
String partitionName = RandomStringUtils.randomAlphabetic(10);
io.milvus.client.Partition partition = new io.milvus.client.Partition.Builder(tableName, partitionName, tag).build();
Response createpResponse = client.createPartition(partition);
// Add vectors into collection with given tag
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_vectors_partition(MilvusClient client, String collectionName) {
Response createpResponse = client.createPartition(collectionName, tag);
assert(createpResponse.ok());
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withPartitionTag(tag).build();
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).withPartitionTag(tag).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
Thread.currentThread().sleep(1000);
// Assert table row count
Assert.assertEquals(client.getTableRowCount(tableName).getTableRowCount(), nb);
Response res_flush = client.flush(collectionName);
assert(res_flush.ok());
// Assert collection row count
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), nb);
}
// Add vectors into table, which tag not existed
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_partition_tag_not_existed(MilvusClient client, String tableName) {
int nb = 1000;
String newTag = RandomStringUtils.randomAlphabetic(10);
List<List<Float>> vectors = gen_vectors(nb);
String partitionName = RandomStringUtils.randomAlphabetic(10);
io.milvus.client.Partition partition = new io.milvus.client.Partition.Builder(tableName, partitionName, tag).build();
Response createpResponse = client.createPartition(partition);
// Add vectors into collection, which tag not existed
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_vectors_partition_tag_not_existed(MilvusClient client, String collectionName) {
Response createpResponse = client.createPartition(collectionName, tag);
assert(createpResponse.ok());
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withPartitionTag(newTag).build();
String tag = RandomStringUtils.randomAlphabetic(10);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).withPartitionTag(tag).build();
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
// Create table, add vectors into table
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_partition_A(MilvusClient client, String tableName) throws InterruptedException {
int nb = 1000;
List<List<Float>> vectors = gen_vectors(nb);
String partitionName = RandomStringUtils.randomAlphabetic(10);
io.milvus.client.Partition partition = new io.milvus.client.Partition.Builder(tableName, partitionName, tag).build();
Response createpResponse = client.createPartition(partition);
assert(createpResponse.ok());
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
// Binary tests
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_add_vectors_partition_A_binary(MilvusClient client, String collectionName) {
Response createpResponse = client.createPartition(collectionName, tag);
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).withPartitionTag(tag).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
Response res_flush = client.flush(collectionName);
assert(res_flush.ok());
// Assert collection row count
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), nb);
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_add_vectors_binary(MilvusClient client, String collectionName) {
System.out.println(collectionName);
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
Response res_flush = client.flush(collectionName);
assert(res_flush.ok());
// Assert collection row count
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), nb);
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_add_vectors_with_ids_binary(MilvusClient client, String collectionName) {
// Add vectors with ids
List<Long> vectorIds;
vectorIds = Stream.iterate(0L, n -> n)
.limit(nb)
.collect(Collectors.toList());
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
Thread.currentThread().sleep(1000);
// Assert table row count
Assert.assertEquals(client.getTableRowCount(tableName).getTableRowCount(), nb);
Response res_flush = client.flush(collectionName);
assert(res_flush.ok());
// Assert collection row count
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), nb);
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_add_vectors_with_invalid_ids_binary(MilvusClient client, String collectionName) {
// Add vectors with ids
List<Long> vectorIds;
vectorIds = Stream.iterate(0L, n -> n)
.limit(nb+1)
.collect(Collectors.toList());
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).withVectorIds(vectorIds).build();
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_add_vectors_with_invalid_dimension_binary(MilvusClient client, String collectionName) {
List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension-1);
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
InsertResponse res = client.insert(insertParam);
assert(!res.getResponse().ok());
}
}
package com;
import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.List;
public class TestCollection {
int index_file_size = 50;
int dimension = 128;
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void test_create_table(MilvusClient client, String collectionName){
CollectionMapping tableSchema = new CollectionMapping.Builder(collectionName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.L2)
.build();
Response res = client.createCollection(tableSchema);
assert(res.ok());
Assert.assertEquals(res.ok(), true);
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_create_table_disconnect(MilvusClient client, String collectionName){
CollectionMapping tableSchema = new CollectionMapping.Builder(collectionName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.L2)
.build();
Response res = client.createCollection(tableSchema);
assert(!res.ok());
}
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void test_create_table_repeatably(MilvusClient client, String collectionName){
CollectionMapping tableSchema = new CollectionMapping.Builder(collectionName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.L2)
.build();
Response res = client.createCollection(tableSchema);
Assert.assertEquals(res.ok(), true);
Response res_new = client.createCollection(tableSchema);
Assert.assertEquals(res_new.ok(), false);
}
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void test_create_table_wrong_params(MilvusClient client, String collectionName){
Integer dimension = 0;
CollectionMapping tableSchema = new CollectionMapping.Builder(collectionName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.L2)
.build();
Response res = client.createCollection(tableSchema);
System.out.println(res.toString());
Assert.assertEquals(res.ok(), false);
}
@Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
public void test_show_tables(MilvusClient client, String collectionName){
Integer tableNum = 10;
ShowCollectionsResponse res = null;
for (int i = 0; i < tableNum; ++i) {
String collectionNameNew = collectionName+"_"+Integer.toString(i);
CollectionMapping tableSchema = new CollectionMapping.Builder(collectionNameNew, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.L2)
.build();
client.createCollection(tableSchema);
List<String> collectionNames = client.showCollections().getCollectionNames();
Assert.assertTrue(collectionNames.contains(collectionNameNew));
}
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_show_tables_without_connect(MilvusClient client, String collectionName){
ShowCollectionsResponse res = client.showCollections();
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_drop_table(MilvusClient client, String collectionName) throws InterruptedException {
Response res = client.dropCollection(collectionName);
assert(res.ok());
Thread.currentThread().sleep(1000);
List<String> collectionNames = client.showCollections().getCollectionNames();
Assert.assertFalse(collectionNames.contains(collectionName));
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_drop_table_not_existed(MilvusClient client, String collectionName) {
Response res = client.dropCollection(collectionName+"_");
assert(!res.ok());
List<String> collectionNames = client.showCollections().getCollectionNames();
Assert.assertTrue(collectionNames.contains(collectionName));
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_drop_table_without_connect(MilvusClient client, String collectionName) {
Response res = client.dropCollection(collectionName);
assert(!res.ok());
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_describe_table(MilvusClient client, String collectionName) {
DescribeCollectionResponse res = client.describeCollection(collectionName);
assert(res.getResponse().ok());
CollectionMapping tableSchema = res.getCollectionMapping().get();
Assert.assertEquals(tableSchema.getDimension(), dimension);
Assert.assertEquals(tableSchema.getCollectionName(), collectionName);
Assert.assertEquals(tableSchema.getIndexFileSize(), index_file_size);
Assert.assertEquals(tableSchema.getMetricType().name(), collectionName.substring(0,2));
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_describe_table_without_connect(MilvusClient client, String collectionName) {
DescribeCollectionResponse res = client.describeCollection(collectionName);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_has_table_not_existed(MilvusClient client, String collectionName) {
HasCollectionResponse res = client.hasCollection(collectionName+"_");
assert(res.getResponse().ok());
Assert.assertFalse(res.hasCollection());
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_has_table_without_connect(MilvusClient client, String collectionName) {
HasCollectionResponse res = client.hasCollection(collectionName);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_has_table(MilvusClient client, String collectionName) {
HasCollectionResponse res = client.hasCollection(collectionName);
assert(res.getResponse().ok());
Assert.assertTrue(res.hasCollection());
}
}
package com;
import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.nio.ByteBuffer;
import java.util.List;
public class TestCollectionCount {
int index_file_size = 50;
int dimension = 128;
int nb = 10000;
List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_collection_count_no_vectors(MilvusClient client, String collectionName) {
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), 0);
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_collection_count_collection_not_existed(MilvusClient client, String collectionName) {
GetCollectionRowCountResponse res = client.getCollectionRowCount(collectionName+"_");
assert(!res.getResponse().ok());
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_collection_count_without_connect(MilvusClient client, String collectionName) {
GetCollectionRowCountResponse res = client.getCollectionRowCount(collectionName+"_");
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_collection_count(MilvusClient client, String collectionName) throws InterruptedException {
// Add vectors
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
client.insert(insertParam);
client.flush(collectionName);
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), nb);
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_collection_count_binary(MilvusClient client, String collectionName) throws InterruptedException {
// Add vectors
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
client.insert(insertParam);
client.flush(collectionName);
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), nb);
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_collection_count_multi_collections(MilvusClient client, String collectionName) throws InterruptedException {
Integer collectionNum = 10;
GetCollectionRowCountResponse res;
for (int i = 0; i < collectionNum; ++i) {
String collectionNameNew = collectionName + "_" + i;
CollectionMapping collectionSchema = new CollectionMapping.Builder(collectionNameNew, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.L2)
.build();
client.createCollection(collectionSchema);
// Add vectors
InsertParam insertParam = new InsertParam.Builder(collectionNameNew).withFloatVectors(vectors).build();
client.insert(insertParam);
client.flush(collectionNameNew);
}
for (int i = 0; i < collectionNum; ++i) {
String collectionNameNew = collectionName + "_" + i;
res = client.getCollectionRowCount(collectionNameNew);
Assert.assertEquals(res.getCollectionRowCount(), nb);
}
}
}
package com;
import com.alibaba.fastjson.JSONObject;
import io.milvus.client.*;
import org.testng.annotations.Test;
import java.nio.ByteBuffer;
import java.util.List;
public class TestCollectionInfo {
int dimension = 128;
int nb = 8000;
int n_list = 1024;
int default_n_list = 16384;
IndexType indexType = IndexType.IVF_SQ8;
IndexType defaultIndexType = IndexType.FLAT;
String indexParam = Utils.setIndexParam(n_list);
List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_get_vector_ids_after_delete_vectors(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse resInsert = client.insert(insertParam);
client.flush(collectionName);
List<Long> idsBefore = resInsert.getVectorIds();
client.deleteById(collectionName, idsBefore.get(0));
client.flush(collectionName);
Response res = client.showCollectionInfo(collectionName);
System.out.println(res.getMessage());
JSONObject collectionInfo = Utils.getCollectionInfo(res.getMessage());
int row_count = collectionInfo.getIntValue("row_count");
assert(row_count == nb-1);
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_get_vector_ids_after_delete_vectors_indexed(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse resInsert = client.insert(insertParam);
client.flush(collectionName);
Index index = new Index.Builder(collectionName, indexType).withParamsInJson(indexParam).build();
client.createIndex(index);
List<Long> idsBefore = resInsert.getVectorIds();
client.deleteById(collectionName, idsBefore.get(0));
client.flush(collectionName);
Response res = client.showCollectionInfo(collectionName);
System.out.println(res.getMessage());
JSONObject collectionInfo = Utils.getCollectionInfo(res.getMessage());
int row_count = collectionInfo.getIntValue("row_count");
assert(row_count == nb-1);
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_get_vector_ids_after_delete_vectors_binary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
InsertResponse resInsert = client.insert(insertParam);
client.flush(collectionName);
List<Long> idsBefore = resInsert.getVectorIds();
client.deleteById(collectionName, idsBefore.get(0));
client.flush(collectionName);
Response res = client.showCollectionInfo(collectionName);
System.out.println(res.getMessage());
JSONObject collectionInfo = Utils.getCollectionInfo(res.getMessage());
int row_count = collectionInfo.getIntValue("row_count");
assert(row_count == nb-1);
}
}
\ No newline at end of file
package com;
import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.nio.ByteBuffer;
import java.util.List;
public class TestCompact {
int dimension = 128;
int nb = 8000;
List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_compact_after_delete(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
List<Long> ids = res.getVectorIds();
client.flush(collectionName);
Response res_delete = client.deleteByIds(collectionName, ids);
assert(res_delete.ok());
client.flush(collectionName);
Response res_compact = client.compact(collectionName);
assert(res_compact.ok());
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), 0);
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_compact_after_delete_binary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
List<Long> ids = res.getVectorIds();
client.flush(collectionName);
Response res_delete = client.deleteByIds(collectionName, ids);
assert(res_delete.ok());
client.flush(collectionName);
Response res_compact = client.compact(collectionName);
assert(res_compact.ok());
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), 0);
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_compact_no_table(MilvusClient client, String collectionName) {
String name = "";
Response res_compact = client.compact(name);
assert(!res_compact.ok());
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_compact_empty_table(MilvusClient client, String collectionName) {
Response res_compact = client.compact(collectionName);
assert(res_compact.ok());
}
}
......@@ -7,7 +7,7 @@ import org.testng.annotations.Test;
public class TestConnect {
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void test_connect(String host, String port) throws ConnectFailedException {
public void test_connect(String host, int port) throws ConnectFailedException {
System.out.println("Host: "+host+", Port: "+port);
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
......@@ -20,7 +20,7 @@ public class TestConnect {
}
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void test_connect_repeat(String host, String port) {
public void test_connect_repeat(String host, int port) {
MilvusGrpcClient client = new MilvusGrpcClient();
Response res = null;
......@@ -39,7 +39,7 @@ public class TestConnect {
}
@Test(dataProvider="InvalidConnectArgs")
public void test_connect_invalid_connect_args(String ip, String port) {
public void test_connect_invalid_connect_args(String ip, int port) {
MilvusClient client = new MilvusGrpcClient();
Response res = null;
try {
......@@ -55,29 +55,27 @@ public class TestConnect {
assert(!client.isConnected());
}
// TODO: MS-615
@DataProvider(name="InvalidConnectArgs")
public Object[][] generate_invalid_connect_args() {
String port = "19530";
String ip = "";
int port = 19530;
return new Object[][]{
{"1.1.1.1", port},
{"255.255.0.0", port},
{"1.2.2", port},
{"中文", port},
{"www.baidu.com", "100000"},
{"127.0.0.1", "100000"},
{"www.baidu.com", "80"},
{"www.baidu.com", 100000},
{"127.0.0.1", 100000},
{"www.baidu.com", 80},
};
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_disconnect(MilvusClient client, String tableName){
public void test_disconnect(MilvusClient client, String collectionName){
assert(!client.isConnected());
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_disconnect_repeatably(MilvusClient client, String tableName){
public void test_disconnect_repeatably(MilvusClient client, String collectionName){
Response res = null;
try {
res = client.disconnect();
......
package com;
import java.util.*;
import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
public class TestDeleteVectors {
int index_file_size = 50;
int dimension = 128;
int nb = 8000;
List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension);
public List<List<Float>> gen_vectors(Integer nb) {
List<List<Float>> xb = new LinkedList<>();
Random random = new Random();
for (int i = 0; i < nb; ++i) {
LinkedList<Float> vector = new LinkedList<>();
for (int j = 0; j < dimension; j++) {
vector.add(random.nextFloat());
}
xb.add(vector);
}
return xb;
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_delete_vectors(MilvusClient client, String collectionName) {
// Add vectors
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
List<Long> ids = res.getVectorIds();
client.flush(collectionName);
Response res_delete = client.deleteByIds(collectionName, ids);
assert(res_delete.ok());
client.flush(collectionName);
// Assert collection row count
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), 0);
}
public static Date getDeltaDate(int delta) {
Date today = new Date();
Calendar c = Calendar.getInstance();
c.setTime(today);
c.add(Calendar.DAY_OF_MONTH, delta);
return c.getTime();
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_delete_single_vector(MilvusClient client, String collectionName) {
List<List<Float>> del_vector = new ArrayList<>();
del_vector.add(vectors.get(0));
List<Long> del_ids = new ArrayList<>();
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
List<Long> ids = res.getVectorIds();
del_ids.add(ids.get(0));
client.flush(collectionName);
Response res_delete = client.deleteById(collectionName, ids.get(0));
assert(res_delete.ok());
client.flush(collectionName);
// Assert collection row count
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), nb - 1);
GetVectorsByIdsResponse res_get = client.getVectorsByIds(collectionName, del_ids);
assert(res_get.getResponse().ok());
assert(res_get.getFloatVectors().get(0).size() == 0);
}
// @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
// public void test_delete_vectors(MilvusClient client, String tableName) throws InterruptedException {
// int nb = 10000;
// List<List<Float>> vectors = gen_vectors(nb);
// // Add vectors
// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
// InsertResponse res = client.insert(insertParam);
// assert(res.getResponse().ok());
// Thread.sleep(1000);
// DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
// DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
// Response res_delete = client.deleteByRange(param);
// assert(res_delete.ok());
// Thread.sleep(1000);
// // Assert table row count
// Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), 0);
// }
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_delete_vectors_collection_not_existed(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
client.flush(collectionName);
List<Long> ids = res.getVectorIds();
Response res_delete = client.deleteByIds(collectionName + "_not_existed", ids);
assert(!res_delete.ok());
}
// @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
// public void test_delete_vectors_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
// String tableNameNew = tableName + "_";
// DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
// DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableNameNew).build();
// Response res_delete = client.deleteByRange(param);
// assert(!res_delete.ok());
// }
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_delete_vector_id_not_existed(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
List<Long> ids = new ArrayList<Long>();
ids.add((long)123456);
ids.add((long)1234561);
client.flush(collectionName);
Response res_delete = client.deleteByIds(collectionName, ids);
assert(res_delete.ok());
client.flush(collectionName);
// Assert collection row count
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), nb);
}
// @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
// public void test_delete_vectors_without_connect(MilvusClient client, String tableName) throws InterruptedException {
// DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
// DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
// Response res_delete = client.deleteByRange(param);
// assert(!res_delete.ok());
// }
//
// @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
// public void test_delete_vectors_table_empty(MilvusClient client, String tableName) throws InterruptedException {
// DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
// DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
// Response res_delete = client.deleteByRange(param);
// assert(res_delete.ok());
// }
// @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
// public void test_delete_vectors_invalid_date_range(MilvusClient client, String tableName) throws InterruptedException {
// int nb = 100;
// List<List<Float>> vectors = gen_vectors(nb);
// // Add vectors
// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
// InsertResponse res = client.insert(insertParam);
// assert(res.getResponse().ok());
// Thread.sleep(1000);
// DateRange dateRange = new DateRange(getDeltaDate(1), getDeltaDate(0));
// DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
// Response res_delete = client.deleteByRange(param);
// assert(!res_delete.ok());
// }
// Below tests binary vectors
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_delete_vectors_binary(MilvusClient client, String collectionName) {
// Add vectors
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
List<Long> ids = res.getVectorIds();
client.flush(collectionName);
Response res_delete = client.deleteByIds(collectionName, ids);
assert(res_delete.ok());
client.flush(collectionName);
// Assert collection row count
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), 0);
}
// @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
// public void test_delete_vectors_invalid_date_range_1(MilvusClient client, String tableName) throws InterruptedException {
// int nb = 100;
// List<List<Float>> vectors = gen_vectors(nb);
// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
// InsertResponse res = client.insert(insertParam);
// assert(res.getResponse().ok());
// DateRange dateRange = new DateRange(getDeltaDate(2), getDeltaDate(-1));
// DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
// Response res_delete = client.deleteByRange(param);
// assert(!res_delete.ok());
// }
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_delete_single_vector_binary(MilvusClient client, String collectionName) {
List<ByteBuffer> del_vector = new ArrayList<>();
del_vector.add(vectorsBinary.get(0));
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
List<Long> ids = res.getVectorIds();
client.flush(collectionName);
Response res_delete = client.deleteById(collectionName, ids.get(0));
assert(res_delete.ok());
client.flush(collectionName);
// Assert collection row count
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), nb - 1);
// Cannot search for the vector
SearchParam searchParam = new SearchParam.Builder(collectionName)
.withBinaryVectors(del_vector)
.withTopK(1)
.withParamsInJson("{\"nprobe\": 20}")
.build();
SearchResponse res_search = client.search(searchParam);
assert(res_search.getResultIdsList().size() == 1);
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_delete_vectors_collection_not_existed_binary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
List<Long> ids = res.getVectorIds();
client.flush(collectionName);
Response res_delete = client.deleteByIds(collectionName + "_not_existed", ids);
assert(!res_delete.ok());
}
// @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
// public void test_delete_vectors_no_result(MilvusClient client, String tableName) throws InterruptedException {
// int nb = 100;
// List<List<Float>> vectors = gen_vectors(nb);
// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
// InsertResponse res = client.insert(insertParam);
// assert(res.getResponse().ok());
// Thread.sleep(1000);
// DateRange dateRange = new DateRange(getDeltaDate(-3), getDeltaDate(-2));
// DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
// Response res_delete = client.deleteByRange(param);
// assert(res_delete.ok());
// Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb);
// }
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_delete_vector_id_not_existed_binary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
List<Long> ids = new ArrayList<Long>();
ids.add((long)123456);
ids.add((long)1234561);
client.flush(collectionName);
Response res_delete = client.deleteByIds(collectionName, ids);
assert(res_delete.ok());
client.flush(collectionName);
// Assert collection row count
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), nb);
}
}
package com;
import com.google.common.util.concurrent.ListenableFuture;
import java.util.concurrent.ExecutionException;
import io.milvus.client.*;
import org.apache.commons.lang3.RandomStringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
public class TestFlush {
int index_file_size = 50;
int dimension = 128;
int nb = 8000;
List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_flush_collection_not_existed(MilvusClient client, String collectionName) {
String newCollection = "not_existed";
Response res = client.flush(newCollection);
assert(!res.ok());
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_flush_empty_collection(MilvusClient client, String collectionName) {
Response res = client.flush(collectionName);
assert(res.ok());
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_collections_flush(MilvusClient client, String collectionName) {
List<String> names = new ArrayList<>();
for (int i = 0; i < 10; i++) {
names.add(RandomStringUtils.randomAlphabetic(10));
CollectionMapping tableSchema = new CollectionMapping.Builder(names.get(i), dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.IP)
.build();
client.createCollection(tableSchema);
InsertParam insertParam = new InsertParam.Builder(names.get(i)).withFloatVectors(vectors).build();
client.insert(insertParam);
System.out.println("Table " + names.get(i) + " created.");
}
Response res = client.flush(names);
assert(res.ok());
for (int i = 0; i < 10; i++) {
// check row count
Assert.assertEquals(client.getCollectionRowCount(names.get(i)).getCollectionRowCount(), nb);
}
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_collections_flush_async(MilvusClient client, String collectionName) throws ExecutionException, InterruptedException {
List<String> names = new ArrayList<>();
for (int i = 0; i < 100; i++) {
names.add(RandomStringUtils.randomAlphabetic(10));
CollectionMapping tableSchema = new CollectionMapping.Builder(names.get(i), dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.IP)
.build();
client.createCollection(tableSchema);
InsertParam insertParam = new InsertParam.Builder(names.get(i)).withFloatVectors(vectors).build();
client.insert(insertParam);
System.out.println("Collection " + names.get(i) + " created.");
}
ListenableFuture<Response> flushResponseFuture = client.flushAsync(names);
flushResponseFuture.get();
for (int i = 0; i < 100; i++) {
// check row count
Assert.assertEquals(client.getCollectionRowCount(names.get(i)).getCollectionRowCount(), nb);
}
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_flush_multiple_times(MilvusClient client, String collectionName) {
for (int i = 0; i < 10; i++) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
client.insert(insertParam);
Response res = client.flush(collectionName);
assert(res.ok());
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), nb * (i+1));
}
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_add_flush_multiple_times_binary(MilvusClient client, String collectionName) {
for (int i = 0; i < 10; i++) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
client.insert(insertParam);
Response res = client.flush(collectionName);
assert(res.ok());
Assert.assertEquals(client.getCollectionRowCount(collectionName).getCollectionRowCount(), nb * (i+1));
}
}
}
\ No newline at end of file
package com;
import io.milvus.client.*;
import org.testng.annotations.Test;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
public class TestGetVectorByID {
int dimension = 128;
int nb = 8000;
public List<Long> get_ids = Utils.toListIds(1111);
List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_get_vector_by_id_valid(MilvusClient client, String collectionName) {
int get_length = 100;
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse resInsert = client.insert(insertParam);
List<Long> ids = resInsert.getVectorIds();
client.flush(collectionName);
GetVectorsByIdsResponse res = client.getVectorsByIds(collectionName, ids.subList(0, get_length));
assert (res.getResponse().ok());
for (int i = 0; i < get_length; i++) {
assert (res.getFloatVectors().get(i).equals(vectors.get(i)));
}
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_get_vector_by_id_after_delete(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse resInsert = client.insert(insertParam);
List<Long> ids = resInsert.getVectorIds();
Response res_delete = client.deleteById(collectionName, ids.get(0));
assert(res_delete.ok());
client.flush(collectionName);
GetVectorsByIdsResponse res = client.getVectorsByIds(collectionName, ids.subList(0, 1));
assert (res.getResponse().ok());
assert (res.getFloatVectors().get(0).size() == 0);
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_get_vector_by_id_collection_name_not_existed(MilvusClient client, String collectionName) {
String newCollection = "not_existed";
GetVectorsByIdsResponse res = client.getVectorsByIds(newCollection, get_ids);
assert(!res.getResponse().ok());
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_get_vector_id_not_existed(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
client.insert(insertParam);
client.flush(collectionName);
GetVectorsByIdsResponse res = client.getVectorsByIds(collectionName, get_ids);
assert (res.getFloatVectors().get(0).size() == 0);
}
// Binary tests
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_get_vector_by_id_valid_binary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
InsertResponse resInsert = client.insert(insertParam);
List<Long> ids = resInsert.getVectorIds();
client.flush(collectionName);
GetVectorsByIdsResponse res = client.getVectorsByIds(collectionName, ids.subList(0, 1));
assert res.getBinaryVectors().get(0).equals(vectorsBinary.get(0).rewind());
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_get_vector_by_id_after_delete_binary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
InsertResponse resInsert = client.insert(insertParam);
List<Long> ids = resInsert.getVectorIds();
Response res_delete = client.deleteById(collectionName, ids.get(0));
assert(res_delete.ok());
client.flush(collectionName);
GetVectorsByIdsResponse res = client.getVectorsByIds(collectionName, ids.subList(0, 1));
assert (res.getFloatVectors().get(0).size() == 0);
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_get_vector_id_not_existed_binary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
client.insert(insertParam);
client.flush(collectionName);
GetVectorsByIdsResponse res = client.getVectorsByIds(collectionName, get_ids);
assert (res.getFloatVectors().get(0).size() == 0);
}
}
\ No newline at end of file
......@@ -15,53 +15,36 @@ import java.util.stream.Collectors;
public class TestMix {
private int dimension = 128;
private int nb = 100000;
int nq = 10;
int n_list = 8192;
int n_probe = 20;
int top_k = 10;
double epsilon = 0.001;
int index_file_size = 20;
String indexParam = "{\"nlist\":\"1024\"}";
public List<Float> normalize(List<Float> w2v){
float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum);
final float norm = (float) Math.sqrt(squareSum);
w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList());
return w2v;
}
List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
List<List<Float>> queryVectors = vectors.subList(0,nq);
public List<List<Float>> gen_vectors(int nb, boolean norm) {
List<List<Float>> xb = new ArrayList<>();
Random random = new Random();
for (int i = 0; i < nb; ++i) {
List<Float> vector = new ArrayList<>();
for (int j = 0; j < dimension; j++) {
vector.add(random.nextFloat());
}
if (norm == true) {
vector = normalize(vector);
}
xb.add(vector);
}
return xb;
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_search_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_search_vectors_threads(MilvusClient client, String collectionName) throws InterruptedException {
int thread_num = 10;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, false);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
client.insert(insertParam);
Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8)
.withNList(n_list)
Index index = new Index.Builder(collectionName, IndexType.IVF_SQ8)
.withParamsInJson(indexParam)
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
client.createIndex(index);
ForkJoinPool executor = new ForkJoinPool();
for (int i = 0; i < thread_num; i++) {
executor.execute(
() -> {
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
String params = "{\"nprobe\":\"1\"}";
SearchParam searchParam = new SearchParam.Builder(collectionName)
.withFloatVectors(queryVectors)
.withParamsInJson(params)
.withTopK(top_k).build();
SearchResponse res_search = client.search(searchParam);
assert (res_search.getResponse().ok());
});
......@@ -71,7 +54,7 @@ public class TestMix {
}
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void test_connect_threads(String host, String port) throws ConnectFailedException {
public void test_connect_threads(String host, int port) throws ConnectFailedException {
int thread_num = 100;
ForkJoinPool executor = new ForkJoinPool();
for (int i = 0; i < thread_num; i++) {
......@@ -99,11 +82,10 @@ public class TestMix {
executor.shutdown();
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_vectors_threads(MilvusClient client, String collectionName) throws InterruptedException {
int thread_num = 10;
List<List<Float>> vectors = gen_vectors(nb,false);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
ForkJoinPool executor = new ForkJoinPool();
for (int i = 0; i < thread_num; i++) {
executor.execute(
......@@ -116,19 +98,16 @@ public class TestMix {
executor.shutdown();
Thread.sleep(2000);
GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableName);
Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb);
GetCollectionRowCountResponse getCollectionRowCountResponse = client.getCollectionRowCount(collectionName);
Assert.assertEquals(getCollectionRowCountResponse.getCollectionRowCount(), thread_num * nb);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_vectors_partition_threads(MilvusClient client, String tableName) throws InterruptedException {
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_vectors_partition_threads(MilvusClient client, String collectionName) throws InterruptedException {
int thread_num = 10;
String tag = RandomStringUtils.randomAlphabetic(10);
String partitionName = RandomStringUtils.randomAlphabetic(10);
io.milvus.client.Partition partition = new io.milvus.client.Partition.Builder(tableName, partitionName, tag).build();
client.createPartition(partition);
List<List<Float>> vectors = gen_vectors(nb,false);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withPartitionTag(tag).build();
client.createPartition(collectionName, tag);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).withPartitionTag(tag).build();
ForkJoinPool executor = new ForkJoinPool();
for (int i = 0; i < thread_num; i++) {
executor.execute(
......@@ -141,42 +120,39 @@ public class TestMix {
executor.shutdown();
Thread.sleep(2000);
GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableName);
Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb);
GetCollectionRowCountResponse getCollectionRowCountResponse = client.getCollectionRowCount(collectionName);
Assert.assertEquals(getCollectionRowCountResponse.getCollectionRowCount(), thread_num * nb);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_index_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_index_vectors_threads(MilvusClient client, String collectionName) throws InterruptedException {
int thread_num = 50;
List<List<Float>> vectors = gen_vectors(nb,false);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
ForkJoinPool executor = new ForkJoinPool();
for (int i = 0; i < thread_num; i++) {
executor.execute(
() -> {
InsertResponse res_insert = client.insert(insertParam);
Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8)
.withNList(n_list)
Index index = new Index.Builder(collectionName, IndexType.IVF_SQ8)
.withParamsInJson("{\"nlist\":\"1024\"}")
.build();
CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
client.createIndex(createIndexParam);
client.createIndex(index);
assert (res_insert.getResponse().ok());
});
}
executor.awaitQuiescence(300, TimeUnit.SECONDS);
executor.shutdown();
Thread.sleep(2000);
GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableName);
Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb);
GetCollectionRowCountResponse getCollectionRowCountResponse = client.getCollectionRowCount(collectionName);
Assert.assertEquals(getCollectionRowCountResponse.getCollectionRowCount(), thread_num * nb);
}
@Test(dataProvider = "Table", dataProviderClass = MainClass.class)
public void test_add_search_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_add_search_vectors_threads(MilvusClient client, String collectionName) throws InterruptedException {
int thread_num = 50;
int nq = 5;
List<List<Float>> vectors = gen_vectors(nb, true);
List<List<Float>> queryVectors = vectors.subList(0,nq);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
ForkJoinPool executor = new ForkJoinPool();
for (int i = 0; i < thread_num; i++) {
executor.execute(
......@@ -188,14 +164,17 @@ public class TestMix {
} catch (InterruptedException e) {
e.printStackTrace();
}
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
SearchParam searchParam = new SearchParam.Builder(collectionName)
.withFloatVectors(queryVectors)
.withParamsInJson("{\"nlist\":\"1024\"}")
.withTopK(top_k).build();
SearchResponse res_search = client.search(searchParam);
assert (res_search.getResponse().ok());
List<List<SearchResponse.QueryResult>> res = client.search(searchParam).getQueryResultsList();
double distance = res.get(0).get(0).getDistance();
if (tableName.startsWith("L2")) {
if (collectionName.startsWith("L2")) {
Assert.assertEquals(distance, 0.0, epsilon);
}else if (tableName.startsWith("IP")) {
}else if (collectionName.startsWith("IP")) {
Assert.assertEquals(distance, 1.0, epsilon);
}
});
......@@ -203,14 +182,13 @@ public class TestMix {
executor.awaitQuiescence(300, TimeUnit.SECONDS);
executor.shutdown();
Thread.sleep(2000);
GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableName);
Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb);
GetCollectionRowCountResponse getCollectionRowCountResponse = client.getCollectionRowCount(collectionName);
Assert.assertEquals(getCollectionRowCountResponse.getCollectionRowCount(), thread_num * nb);
}
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void test_create_insert_delete_threads(String host, String port) {
public void test_create_insert_delete_threads(String host, int port) {
int thread_num = 100;
List<List<Float>> vectors = gen_vectors(nb,false);
ForkJoinPool executor = new ForkJoinPool();
for (int i = 0; i < thread_num; i++) {
executor.execute(
......@@ -226,15 +204,15 @@ public class TestMix {
e.printStackTrace();
}
assert(client.isConnected());
String tableName = RandomStringUtils.randomAlphabetic(10);
TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
String collectionName = RandomStringUtils.randomAlphabetic(10);
CollectionMapping tableSchema = new CollectionMapping.Builder(collectionName, dimension)
.withIndexFileSize(index_file_size)
.withMetricType(MetricType.IP)
.build();
client.createTable(tableSchema);
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
client.createCollection(tableSchema);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
client.insert(insertParam);
Response response = client.dropTable(tableName);
Response response = client.dropCollection(collectionName);
Assert.assertTrue(response.ok());
try {
client.disconnect();
......
......@@ -3,9 +3,7 @@ package com;
import io.milvus.client.*;
import org.apache.commons.cli.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
......@@ -13,7 +11,7 @@ import java.util.stream.Stream;
public class TestPS {
private static int dimension = 128;
private static String host = "192.168.1.101";
private static String host = "127.0.0.1";
private static String port = "19530";
public static void setHost(String host) {
......@@ -24,37 +22,20 @@ public class TestPS {
TestPS.port = port;
}
public static List<Float> normalize(List<Float> w2v){
float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum);
final float norm = (float) Math.sqrt(squareSum);
w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList());
return w2v;
}
public static List<List<Float>> gen_vectors(int nb, boolean norm) {
List<List<Float>> xb = new ArrayList<>();
Random random = new Random();
for (int i = 0; i < nb; ++i) {
List<Float> vector = new ArrayList<>();
for (int j = 0; j < dimension; j++) {
vector.add(random.nextFloat());
}
if (norm == true) {
vector = normalize(vector);
}
xb.add(vector);
}
return xb;
}
public static void main(String[] args) throws ConnectFailedException {
int nb = 10000;
int nq = 5;
int nb = 1;
int nprobe = 32;
int top_k = 10;
int loops = 100000;
// int index_file_size = 1024;
String tableName = "sift_1b_2048_128_l2";
String collectionName = "sift_1b_2048_128_l2";
List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
List<List<Float>> queryVectors = vectors.subList(0, nq);
CommandLineParser parser = new DefaultParser();
Options options = new Options();
......@@ -76,16 +57,14 @@ public class TestPS {
System.err.println("Parsing failed. Reason: " + exp.getMessage() );
}
List<List<Float>> vectors = gen_vectors(nb, true);
List<List<Float>> queryVectors = gen_vectors(nq, true);
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
.withHost(host)
.withPort(port)
.withPort(Integer.parseInt(port))
.build();
client.connect(connectParam);
// String tableName = RandomStringUtils.randomAlphabetic(10);
// TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
// String collectionName = RandomStringUtils.randomAlphabetic(10);
// TableSchema tableSchema = new TableSchema.Builder(collectionName, dimension)
// .withIndexFileSize(index_file_size)
// .withMetricType(MetricType.IP)
// .build();
......@@ -94,7 +73,7 @@ public class TestPS {
vectorIds = Stream.iterate(0L, n -> n)
.limit(nb)
.collect(Collectors.toList());
InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build();
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).withVectorIds(vectorIds).build();
ForkJoinPool executor_search = new ForkJoinPool();
for (int i = 0; i < loops; i++) {
executor_search.execute(
......@@ -102,14 +81,14 @@ public class TestPS {
InsertResponse res_insert = client.insert(insertParam);
assert (res_insert.getResponse().ok());
System.out.println("In insert");
SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(nprobe).withTopK(top_k).build();
SearchParam searchParam = new SearchParam.Builder(collectionName).withFloatVectors(queryVectors).withTopK(top_k).build();
SearchResponse res_search = client.search(searchParam);
assert (res_search.getResponse().ok());
});
}
executor_search.awaitQuiescence(300, TimeUnit.SECONDS);
executor_search.shutdown();
GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableName);
System.out.println(getTableRowCountResponse.getTableRowCount());
GetCollectionRowCountResponse getTableRowCountResponse = client.getCollectionRowCount(collectionName);
System.out.println(getTableRowCountResponse.getCollectionRowCount());
}
}
package com;
import io.milvus.client.HasPartitionResponse;
import io.milvus.client.MilvusClient;
import io.milvus.client.Response;
import io.milvus.client.ShowPartitionsResponse;
import org.apache.commons.lang3.RandomStringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.List;
public class TestPartition {
int dimension = 128;
// ----------------------------- create partition cases in ---------------------------------
// create partition
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_create_partition(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
// show partitions
List<String> partitions = client.showPartitions(collectionName).getPartitionList();
System.out.println(partitions);
Assert.assertTrue(partitions.contains(tag));
}
// create partition, tag name existed
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_create_partition_tag_name_existed(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
Response createpResponseNew = client.createPartition(collectionName, tag);
assert (!createpResponseNew.ok());
}
// ----------------------------- has partition cases in ---------------------------------
// has partition, tag name not existed
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_has_partition_tag_name_not_existed(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
String tagNew = RandomStringUtils.randomAlphabetic(10);
HasPartitionResponse haspResponse = client.hasPartition(collectionName, tagNew);
assert (haspResponse.ok());
Assert.assertFalse(haspResponse.hasPartition());
}
// has partition, tag name existed
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_has_partition_tag_name_existed(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
HasPartitionResponse haspResponse = client.hasPartition(collectionName, tag);
assert (haspResponse.ok());
Assert.assertTrue(haspResponse.hasPartition());
}
// ----------------------------- drop partition cases in ---------------------------------
// drop a partition created before, drop by partition name
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_drop_partition(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponseNew = client.createPartition(collectionName, tag);
assert (createpResponseNew.ok());
Response response = client.dropPartition(collectionName, tag);
assert (response.ok());
// show partitions
System.out.println(client.showPartitions(collectionName).getPartitionList());
int length = client.showPartitions(collectionName).getPartitionList().size();
// _default
Assert.assertEquals(length, 1);
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_drop_partition_default(MilvusClient client, String collectionName) {
String tag = "_default";
Response createpResponseNew = client.createPartition(collectionName, tag);
assert (!createpResponseNew.ok());
// show partitions
// System.out.println(client.showPartitions(collectionName).getPartitionList());
// int length = client.showPartitions(collectionName).getPartitionList().size();
// // _default
// Assert.assertEquals(length, 1);
}
// drop a partition repeat created before, drop by partition name
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_drop_partition_repeat(MilvusClient client, String collectionName) throws InterruptedException {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
Response response = client.dropPartition(collectionName, tag);
assert (response.ok());
Thread.currentThread().sleep(2000);
Response newResponse = client.dropPartition(collectionName, tag);
assert (!newResponse.ok());
}
// drop a partition not created before
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_drop_partition_not_existed(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
String tagNew = RandomStringUtils.randomAlphabetic(10);
Response response = client.dropPartition(collectionName, tagNew);
assert(!response.ok());
}
// drop a partition not created before
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_drop_partition_tag_not_existed(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert(createpResponse.ok());
String newTag = RandomStringUtils.randomAlphabetic(10);
Response response = client.dropPartition(collectionName, newTag);
assert(!response.ok());
}
// ----------------------------- show partitions cases in ---------------------------------
// create partition, then show partitions
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_show_partitions(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
ShowPartitionsResponse response = client.showPartitions(collectionName);
assert (response.getResponse().ok());
Assert.assertTrue(response.getPartitionList().contains(tag));
}
// create multi partition, then show partitions
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_show_partitions_multi(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
Response createpResponse = client.createPartition(collectionName, tag);
assert (createpResponse.ok());
String tagNew = RandomStringUtils.randomAlphabetic(10);
Response newCreatepResponse = client.createPartition(collectionName, tagNew);
assert (newCreatepResponse.ok());
ShowPartitionsResponse response = client.showPartitions(collectionName);
assert (response.getResponse().ok());
System.out.println(response.getPartitionList());
Assert.assertTrue(response.getPartitionList().contains(tag));
Assert.assertTrue(response.getPartitionList().contains(tagNew));
}
}
......@@ -5,7 +5,7 @@ import org.testng.annotations.Test;
public class TestPing {
@Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
public void test_server_status(String host, String port) throws ConnectFailedException {
public void test_server_status(String host, int port) throws ConnectFailedException {
System.out.println("Host: "+host+", Port: "+port);
MilvusClient client = new MilvusGrpcClient();
ConnectParam connectParam = new ConnectParam.Builder()
......@@ -18,7 +18,7 @@ public class TestPing {
}
@Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
public void test_server_status_without_connected(MilvusGrpcClient client, String tableName){
public void test_server_status_without_connected(MilvusClient client, String collectionName) throws ConnectFailedException {
Response res = client.getServerStatus();
assert (!res.ok());
}
......
package com;
import io.milvus.client.*;
import org.apache.commons.lang3.RandomStringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
public class TestSearchByIds {
int dimension = 128;
int n_list = 1024;
int default_n_list = 16384;
int nb = 10000;
int n_probe = 20;
int top_k = 10;
int nq = 5;
double epsilon = 0.001;
IndexType indexType = IndexType.IVF_SQ8;
IndexType defaultIndexType = IndexType.FLAT;
List<Long> default_ids = Utils.toListIds(1111);
List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension);
String indexParam = Utils.setIndexParam(n_list);
public String searchParamStr = Utils.setSearchParam(n_probe);
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_search_collection_not_existed(MilvusClient client, String collectionName) {
String collectionNameNew = collectionName + "_";
SearchByIdsParam searchParam = new SearchByIdsParam.Builder(collectionNameNew)
.withParamsInJson(searchParamStr)
.withTopK(top_k)
.withIDs(default_ids)
.build();
SearchResponse res_search = client.searchByIds(searchParam);
assert (!res_search.getResponse().ok());
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_search_collection_empty(MilvusClient client, String collectionName) {
SearchByIdsParam searchParam = new SearchByIdsParam.Builder(collectionName)
.withParamsInJson(searchParamStr)
.withTopK(top_k)
.withIDs(default_ids)
.build();
SearchResponse res_search = client.searchByIds(searchParam);
assert (!res_search.getResponse().ok());
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_search_no_result(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
client.insert(insertParam);
client.flush(collectionName);
SearchByIdsParam searchParam = new SearchByIdsParam.Builder(collectionName)
.withParamsInJson(searchParamStr)
.withTopK(top_k)
.withIDs(default_ids)
.build();
List<List<SearchResponse.QueryResult>> res_search = client.searchByIds(searchParam).getQueryResultsList();
assert (client.searchByIds(searchParam).getResponse().ok());
Assert.assertEquals(res_search.get(0).size(), top_k);
Assert.assertEquals(res_search.size(), default_ids.size());
Assert.assertEquals(res_search.get(0).get(0).getVectorId(), -1);
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_search_index_IVFLAT(MilvusClient client, String collectionName) {
IndexType indexType = IndexType.IVFLAT;
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
InsertResponse res_insert = client.insert(insertParam);
client.flush(collectionName);
Index index = new Index.Builder(collectionName, indexType).withParamsInJson(indexParam).build();
client.createIndex(index);
SearchByIdsParam searchParam = new SearchByIdsParam.Builder(collectionName)
.withParamsInJson(searchParamStr)
.withTopK(top_k)
.withIDs(res_insert.getVectorIds())
.build();
List<List<SearchResponse.QueryResult>> res_search = client.searchByIds(searchParam).getQueryResultsList();
for (int i=0; i<vectors.size(); ++i) {
Assert.assertEquals(res_search.get(i).size(), top_k);
long vectorId = res_search.get(i).get(0).getVectorId();
long insertId = res_insert.getVectorIds().get(i);
Assert.assertEquals(vectorId, insertId);
double distance = res_search.get(i).get(0).getDistance();
if (collectionName.startsWith("L2")) {
Assert.assertEquals(distance, 0.0, epsilon);
}else if (collectionName.startsWith("IP")) {
Assert.assertEquals(distance, 1.0, epsilon);
}
}
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_search_partition(MilvusClient client, String collectionName) {
IndexType indexType = IndexType.IVFLAT;
String tag = RandomStringUtils.randomAlphabetic(10);
client.createPartition(collectionName, tag);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).withPartitionTag(tag).build();
InsertResponse res_insert = client.insert(insertParam);
client.flush(collectionName);
Index index = new Index.Builder(collectionName, indexType).withParamsInJson(indexParam).build();
client.createIndex(index);
List<String> queryTags = new ArrayList<>();
queryTags.add(tag);
SearchByIdsParam searchParam = new SearchByIdsParam.Builder(collectionName)
.withParamsInJson(searchParamStr)
.withTopK(top_k)
.withIDs(Utils.toListIds(res_insert.getVectorIds().get(0)))
.withPartitionTags(queryTags)
.build();
List<List<SearchResponse.QueryResult>> res_search = client.searchByIds(searchParam).getQueryResultsList();
double distance = res_search.get(0).get(0).getDistance();
if (collectionName.startsWith("L2")) {
Assert.assertEquals(distance, 0.0, epsilon);
}else if (collectionName.startsWith("IP")) {
Assert.assertEquals(distance, 1.0, epsilon);
}
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_search_partition_not_exited(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
client.createPartition(collectionName, tag);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).withPartitionTag(tag).build();
client.insert(insertParam);
client.flush(collectionName);
String tagNew = RandomStringUtils.randomAlphabetic(10);
List<String> queryTags = new ArrayList<>();
queryTags.add(tagNew);
SearchByIdsParam searchParam = new SearchByIdsParam.Builder(collectionName)
.withParamsInJson(searchParamStr)
.withTopK(top_k)
.withIDs(default_ids)
.withPartitionTags(queryTags)
.build();
SearchResponse res_search = client.searchByIds(searchParam);
assert (!res_search.getResponse().ok());
Assert.assertEquals(res_search.getQueryResultsList().size(), 0);
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_search_partition_empty(MilvusClient client, String collectionName) {
String tag = RandomStringUtils.randomAlphabetic(10);
client.createPartition(collectionName, tag);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).withPartitionTag(tag).build();
client.insert(insertParam);
client.flush(collectionName);
List<String> queryTags = new ArrayList<>();
queryTags.add(tag);
SearchByIdsParam searchParam = new SearchByIdsParam.Builder(collectionName)
.withParamsInJson(searchParamStr)
.withTopK(top_k)
.withIDs(default_ids)
.withPartitionTags(queryTags)
.build();
SearchResponse res_search = client.searchByIds(searchParam);
assert (res_search.getResponse().ok());
Assert.assertEquals(res_search.getQueryResultsList().size(), 1);
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_search_invalid_n_probe(MilvusClient client, String collectionName) {
int n_probe_new = 0;
String searchParamStrNew = Utils.setSearchParam(n_probe_new);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
client.insert(insertParam);
SearchByIdsParam searchParam = new SearchByIdsParam.Builder(collectionName)
.withParamsInJson(searchParamStrNew)
.withTopK(top_k)
.withIDs(default_ids)
.build();
SearchResponse res_search = client.searchByIds(searchParam);
assert (!res_search.getResponse().ok());
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void test_search_invalid_top_k(MilvusClient client, String collectionName) {
int top_k_new = 0;
InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(vectors).build();
client.insert(insertParam);
SearchByIdsParam searchParam = new SearchByIdsParam.Builder(collectionName)
.withParamsInJson(searchParamStr)
.withTopK(top_k_new)
.withIDs(default_ids)
.build();
SearchResponse res_search = client.searchByIds(searchParam);
assert (!res_search.getResponse().ok());
}
// Binary tests
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_search_collection_not_existed_binary(MilvusClient client, String collectionName) {
String collectionNameNew = collectionName + "_";
SearchByIdsParam searchParam = new SearchByIdsParam.Builder(collectionNameNew)
.withParamsInJson(searchParamStr)
.withTopK(top_k)
.withIDs(default_ids)
.build();
SearchResponse res_search = client.searchByIds(searchParam);
assert (!res_search.getResponse().ok());
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_search_ids_binary(MilvusClient client, String collectionName) {
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
InsertResponse res_insert = client.insert(insertParam);
client.flush(collectionName);
SearchByIdsParam searchParam = new SearchByIdsParam.Builder(collectionName)
.withParamsInJson(searchParamStr)
.withTopK(top_k)
.withIDs(res_insert.getVectorIds())
.build();
List<List<SearchResponse.QueryResult>> res_search = client.searchByIds(searchParam).getQueryResultsList();
for (int i = 0; i < top_k; i++) {
long insert_id = res_insert.getVectorIds().get(i);
long get_id = res_search.get(i).get(0).getVectorId();
System.out.println(get_id);
Assert.assertEquals(insert_id, get_id);
}
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_search_invalid_n_probe_binary(MilvusClient client, String collectionName) {
int n_probe_new = 0;
String searchParamStrNew = Utils.setSearchParam(n_probe_new);
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
client.insert(insertParam);
SearchByIdsParam searchParam = new SearchByIdsParam.Builder(collectionName)
.withParamsInJson(searchParamStrNew)
.withTopK(top_k)
.withIDs(default_ids)
.build();
SearchResponse res_search = client.searchByIds(searchParam);
assert (!res_search.getResponse().ok());
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void test_search_invalid_top_k_binary(MilvusClient client, String collectionName) {
int top_k_new = 0;
InsertParam insertParam = new InsertParam.Builder(collectionName).withBinaryVectors(vectorsBinary).build();
client.insert(insertParam);
SearchByIdsParam searchParam = new SearchByIdsParam.Builder(collectionName)
.withParamsInJson(searchParamStr)
.withTopK(top_k_new)
.withIDs(default_ids)
.build();
SearchResponse res_search = client.searchByIds(searchParam);
assert (!res_search.getResponse().ok());
}
}
package com;
import com.alibaba.fastjson.JSONObject;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
public class Utils {
public static List<Float> normalize(List<Float> w2v){
float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum);
final float norm = (float) Math.sqrt(squareSum);
w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList());
return w2v;
}
public static List<List<Float>> genVectors(int nb, int dimension, boolean norm) {
List<List<Float>> xb = new ArrayList<>();
Random random = new Random();
for (int i = 0; i < nb; ++i) {
List<Float> vector = new ArrayList<>();
for (int j = 0; j < dimension; j++) {
vector.add(random.nextFloat());
}
if (norm == true) {
vector = normalize(vector);
}
xb.add(vector);
}
return xb;
}
static List<ByteBuffer> genBinaryVectors(long vectorCount, long dimension) {
Random random = new Random();
List<ByteBuffer> vectors = new ArrayList<>();
final long dimensionInByte = dimension / 8;
for (long i = 0; i < vectorCount; ++i) {
ByteBuffer byteBuffer = ByteBuffer.allocate((int) dimensionInByte);
random.nextBytes(byteBuffer.array());
vectors.add(byteBuffer);
}
return vectors;
}
public static String setIndexParam(int nlist) {
JSONObject indexParam = new JSONObject();
indexParam.put("nlist", nlist);
return JSONObject.toJSONString(indexParam);
}
public static String setSearchParam(int nprobe) {
JSONObject searchParam = new JSONObject();
searchParam.put("nprobe", nprobe);
return JSONObject.toJSONString(searchParam);
}
public static int getIndexParamValue(String indexParam, String key) {
return JSONObject.parseObject(indexParam).getIntValue(key);
}
public static JSONObject getCollectionInfo(String collectionInfo) {
return JSONObject.parseObject(collectionInfo);
}
public static List<Long> toListIds(int id) {
List<Long> ids = new ArrayList<>();
ids.add((long)id);
return ids;
}
public static List<Long> toListIds(long id) {
List<Long> ids = new ArrayList<>();
ids.add(id);
return ids;
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册