Utils.java 7.3 KB
Newer Older
D
del-zhenwu 已提交
1 2
package com;

3
import io.milvus.client.*;
D
del-zhenwu 已提交
4
import com.alibaba.fastjson.JSONObject;
5
import org.apache.commons.lang3.RandomStringUtils;
D
del-zhenwu 已提交
6 7

import java.nio.ByteBuffer;
8
import java.util.*;
D
del-zhenwu 已提交
9 10 11 12 13 14 15 16 17 18
import java.util.stream.Collectors;

public class Utils {

    public static List<Float> normalize(List<Float> w2v){
        float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum);
        final float norm = (float) Math.sqrt(squareSum);
        w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList());
        return w2v;
    }
19

20 21 22 23 24
    public static String genUniqueStr(String str_value){
        String prefix = "_"+RandomStringUtils.randomAlphabetic(10);
        String str = str_value == null || str_value.trim().isEmpty() ? "test" : str_value;
        return str.trim()+prefix;
    }
25

26 27
    public static List<List<Float>> genVectors(int vectorCount, int dimension, boolean norm) {
        List<List<Float>> vectors = new ArrayList<>();
D
del-zhenwu 已提交
28
        Random random = new Random();
29
        for (int i = 0; i < vectorCount; ++i) {
D
del-zhenwu 已提交
30
            List<Float> vector = new ArrayList<>();
31
            for (int j = 0; j < dimension; ++j) {
D
del-zhenwu 已提交
32 33 34 35 36
                vector.add(random.nextFloat());
            }
            if (norm == true) {
                vector = normalize(vector);
            }
37
            vectors.add(vector);
D
del-zhenwu 已提交
38
        }
39
        return vectors;
D
del-zhenwu 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52
    }

    static List<ByteBuffer> genBinaryVectors(long vectorCount, long dimension) {
        Random random = new Random();
        List<ByteBuffer> vectors = new ArrayList<>();
        final long dimensionInByte = dimension / 8;
        for (long i = 0; i < vectorCount; ++i) {
            ByteBuffer byteBuffer = ByteBuffer.allocate((int) dimensionInByte);
            random.nextBytes(byteBuffer.array());
            vectors.add(byteBuffer);
        }
        return vectors;
    }
53

54 55 56 57 58 59 60 61 62 63 64 65 66
    private static List<Map<String, Object>> genBaseFieldsWithoutVector(){
        List<Map<String,Object>> fieldsList = new ArrayList<>();
        Map<String, Object> intFields = new HashMap<>();
        intFields.put("field","int64");
        intFields.put("type",DataType.INT64);
        Map<String, Object> floatField = new HashMap<>();
        floatField.put("field","float");
        floatField.put("type",DataType.FLOAT);
        fieldsList.add(intFields);
        fieldsList.add(floatField);
        return fieldsList;

    }
67
    
68 69 70 71 72 73 74 75 76 77 78 79 80
    public static List<Map<String, Object>> genDefaultFields(int dimension, boolean isBinary){
        List<Map<String, Object>> defaultFieldList = genBaseFieldsWithoutVector();
        Map<String, Object> vectorField = new HashMap<>();
        if (isBinary){
            vectorField.put("field","binary_vector");
            vectorField.put("type",DataType.VECTOR_BINARY);
        }else {
            vectorField.put("field","float_vector");
            vectorField.put("type",DataType.VECTOR_FLOAT);
        }
        JSONObject jsonObject = new JSONObject();
        jsonObject.put("dim", dimension);
        vectorField.put("params", jsonObject.toString());
D
del-zhenwu 已提交
81

82 83 84 85
        defaultFieldList.add(vectorField);
        return defaultFieldList;
    }

86 87
    public static List<Map<String,Object>> genDefaultEntities(int dimension, int vectorCount, List<List<Float>> vectors){
        List<Map<String,Object>> fieldsMap = genDefaultFields(dimension, false);
88 89
        List<Long> intValues = new ArrayList<>(vectorCount);
        List<Float> floatValues = new ArrayList<>(vectorCount);
90 91
//        List<List<Float>> vectors = genVectors(vectorCount,dimension,false);
//        List<ByteBuffer> binaryVectors = genBinaryVectors(vectorCount,dimension);
92 93 94 95
        for (int i = 0; i < vectorCount; ++i) {
            intValues.add((long) i);
            floatValues.add((float) i);
        }
96
        for(Map<String,Object> field: fieldsMap){
97 98 99 100 101 102 103 104 105 106 107
            String fieldType = field.get("field").toString();
            switch (fieldType){
                case "int64":
                    field.put("values",intValues);
                    break;
                case "float":
                    field.put("values",floatValues);
                    break;
                case "float_vector":
                    field.put("values",vectors);
                    break;
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
            }
        }
        return fieldsMap;
    }

    public static List<Map<String,Object>> genDefaultBinaryEntities(int dimension, int vectorCount, List<ByteBuffer> vectorsBinary){
        List<Map<String,Object>> binaryFieldsMap = genDefaultFields(dimension, true);
        List<Long> intValues = new ArrayList<>(vectorCount);
        List<Float> floatValues = new ArrayList<>(vectorCount);
//        List<List<Float>> vectors = genVectors(vectorCount,dimension,false);
//        List<ByteBuffer> binaryVectors = genBinaryVectors(vectorCount,dimension);
        for (int i = 0; i < vectorCount; ++i) {
            intValues.add((long) i);
            floatValues.add((float) i);
        }
        for(Map<String,Object> field: binaryFieldsMap){
            String fieldType = field.get("field").toString();
            switch (fieldType){
                case "int64":
                    field.put("values",intValues);
                    break;
                case "float":
                    field.put("values",floatValues);
                    break;
132
                case "binary_vector":
133 134
                    field.put("values",vectorsBinary);
                    break;
135 136
            }
        }
137
        return binaryFieldsMap;
138 139 140 141 142 143 144 145 146
    }

    public static String setIndexParam(String indexType, String metricType, int nlist) {
//        ("{\"index_type\": \"IVF_SQ8\", \"metric_type\": \"L2\", \"\"params\": {\"nlist\": 2048}}")
//        JSONObject indexParam = new JSONObject();
//        indexParam.put("nlist", nlist);
//        return JSONObject.toJSONString(indexParam);
        String indexParams = String.format("{\"index_type\": %s, \"metric_type\": %s, \"params\": {\"nlist\": %s}}", indexType, metricType, nlist);
        return indexParams;
D
del-zhenwu 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
    }

    public static String setSearchParam(int nprobe) {
        JSONObject searchParam = new JSONObject();
        searchParam.put("nprobe", nprobe);
        return JSONObject.toJSONString(searchParam);
    }

    public static int getIndexParamValue(String indexParam, String key) {
        return JSONObject.parseObject(indexParam).getIntValue(key);
    }

    public static JSONObject getCollectionInfo(String collectionInfo) {
        return JSONObject.parseObject(collectionInfo);
    }

    public static List<Long> toListIds(int id) {
        List<Long> ids = new ArrayList<>();
        ids.add((long)id);
        return ids;
    }

    public static List<Long> toListIds(long id) {
        List<Long> ids = new ArrayList<>();
        ids.add(id);
        return ids;
    }
174 175 176 177 178 179 180

    public static int getParam(String params, String key){
        JSONObject jsonObject = JSONObject.parseObject(params);
        System.out.println(jsonObject.toString());
        Integer value = jsonObject.getInteger(key);
        return value;
    }
181

182 183 184 185 186 187 188 189 190
    public static List<Float> getVector(List<Map<String,Object>> entities, int i){
       List<Float> vector = new ArrayList<>();
        entities.forEach(entity -> {
            if("float_vector".equals(entity.get("field")) && Objects.nonNull(entity.get("values"))){
                vector.add(((List<Float>)entity.get("values")).get(i));
            }
        });
        return vector;
    }
D
del-zhenwu 已提交
191
}