MainClass.java 5.1 KB
Newer Older
J
JinHai-CN 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
package com;

import io.milvus.client.*;
import org.apache.commons.cli.*;
import org.apache.commons.lang3.RandomStringUtils;
import org.testng.SkipException;
import org.testng.TestNG;
import org.testng.annotations.DataProvider;
import org.testng.xml.XmlClass;
import org.testng.xml.XmlSuite;
import org.testng.xml.XmlTest;

import java.util.ArrayList;
import java.util.List;

public class MainClass {
    private static String host = "127.0.0.1";
    private static String port = "19530";
    public Integer index_file_size = 50;
    public Integer dimension = 128;

    public static void setHost(String host) {
        MainClass.host = host;
    }

    public static void setPort(String port) {
        MainClass.port = port;
    }

    @DataProvider(name="DefaultConnectArgs")
    public static Object[][] defaultConnectArgs(){
        return new Object[][]{{host, port}};
    }

    @DataProvider(name="ConnectInstance")
    public Object[][] connectInstance(){
        MilvusClient client = new MilvusGrpcClient();
        ConnectParam connectParam = new ConnectParam.Builder()
                .withHost(host)
                .withPort(port)
                .build();
        client.connect(connectParam);
        String tableName = RandomStringUtils.randomAlphabetic(10);
        return new Object[][]{{client, tableName}};
    }

    @DataProvider(name="DisConnectInstance")
    public Object[][] disConnectInstance(){
        // Generate connection instance
        MilvusClient client = new MilvusGrpcClient();
        ConnectParam connectParam = new ConnectParam.Builder()
                .withHost(host)
                .withPort(port)
                .build();
        client.connect(connectParam);
        try {
            client.disconnect();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        String tableName = RandomStringUtils.randomAlphabetic(10);
        return new Object[][]{{client, tableName}};
    }

    @DataProvider(name="Table")
    public Object[][] provideTable(){
        Object[][] tables = new Object[2][2];
        MetricType metricTypes[] = { MetricType.L2, MetricType.IP };
        for (Integer i = 0; i < metricTypes.length; ++i) {
            String tableName = metricTypes[i].toString()+"_"+RandomStringUtils.randomAlphabetic(10);
            // Generate connection instance
            MilvusClient client = new MilvusGrpcClient();
            ConnectParam connectParam = new ConnectParam.Builder()
                    .withHost(host)
                    .withPort(port)
                    .build();
            client.connect(connectParam);
            TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
                    .withIndexFileSize(index_file_size)
                    .withMetricType(metricTypes[i])
                    .build();
            TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
            Response res = client.createTable(tableSchemaParam);
            if (!res.ok()) {
                System.out.println(res.getMessage());
                throw new SkipException("Table created failed");
            }
            tables[i] = new Object[]{client, tableName};
        }
        return tables;
    }

    public static void main(String[] args) {
        CommandLineParser parser = new DefaultParser();
        Options options = new Options();
        options.addOption("h", "host", true, "milvus-server hostname/ip");
        options.addOption("p", "port", true, "milvus-server port");
        try {
            CommandLine cmd = parser.parse(options, args);
            String host = cmd.getOptionValue("host");
            if (host != null) {
                setHost(host);
            }
            String port = cmd.getOptionValue("port");
            if (port != null) {
                setPort(port);
            }
            System.out.println("Host: "+host+", Port: "+port);
        }
        catch(ParseException exp) {
            System.err.println("Parsing failed.  Reason: " + exp.getMessage() );
        }

//        TestListenerAdapter tla = new TestListenerAdapter();
//        TestNG testng = new TestNG();
//        testng.setTestClasses(new Class[] { TestPing.class });
//        testng.setTestClasses(new Class[] { TestConnect.class });
//        testng.addListener(tla);
//        testng.run();

        XmlSuite suite = new XmlSuite();
        suite.setName("TmpSuite");

        XmlTest test = new XmlTest(suite);
        test.setName("TmpTest");
        List<XmlClass> classes = new ArrayList<XmlClass>();

        classes.add(new XmlClass("com.TestPing"));
        classes.add(new XmlClass("com.TestAddVectors"));
        classes.add(new XmlClass("com.TestConnect"));
        classes.add(new XmlClass("com.TestDeleteVectors"));
        classes.add(new XmlClass("com.TestIndex"));
        classes.add(new XmlClass("com.TestSearchVectors"));
        classes.add(new XmlClass("com.TestTable"));
        classes.add(new XmlClass("com.TestTableCount"));

        test.setXmlClasses(classes) ;

        List<XmlSuite> suites = new ArrayList<XmlSuite>();
        suites.add(suite);
        TestNG tng = new TestNG();
        tng.setXmlSuites(suites);
        tng.run();

    }

}