utils.py 8.3 KB
Newer Older
1 2 3 4 5
# -*- coding: utf-8 -*-
import time
import logging
import string
import random
6
import requests
W
wt 已提交
7
import json
W
wt 已提交
8
import os
9
from yaml.representer import SafeRepresenter
10
# from yaml import full_load, dump
11 12
import yaml
import tableprint as tp
13
import config
14 15 16 17 18

logger = logging.getLogger("milvus_benchmark.utils")


def timestr_to_int(time_str):
19
    """ Parse the test time set in the yaml configuration file and convert it to int type """
20
    # time_int = 0
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
    if isinstance(time_str, int) or time_str.isdigit():
        time_int = int(time_str)
    elif time_str.endswith("s"):
        time_int = int(time_str.split("s")[0])
    elif time_str.endswith("m"):
        time_int = int(time_str.split("m")[0]) * 60
    elif time_str.endswith("h"):
        time_int = int(time_str.split("h")[0]) * 60 * 60
    else:
        raise Exception("%s not support" % time_str)
    return time_int


class literal_str(str): pass


def change_style(style, representer):
    def new_representer(dumper, data):
        scalar = representer(dumper, data)
        scalar.style = style
        return scalar

    return new_representer


46
# from yaml.representer import SafeRepresenter
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72

# represent_str does handle some corner cases, so use that
# instead of calling represent_scalar directly
represent_literal_str = change_style('|', SafeRepresenter.represent_str)

yaml.add_representer(literal_str, represent_literal_str)


def retry(times):
    """
    This decorator prints the execution time for the decorated function.
    """
    def wrapper(func):
        def newfn(*args, **kwargs):
            attempt = 0
            while attempt < times:
                try:
                    result = func(*args, **kwargs)
                    if result:
                        break
                    else:
                        raise Exception("Result false")
                except Exception as e:
                    logger.info(str(e))
                    time.sleep(3)
                    attempt += 1
D
del-zhenwu 已提交
73
            return result
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
        return newfn
    return wrapper


def convert_nested(dct):
    def insert(dct, lst):
        for x in lst[:-2]:
            dct[x] = dct = dct.get(x, dict())
        dct.update({lst[-2]: lst[-1]})

        # empty dict to store the result

    result = dict()

    # create an iterator of lists  
    # representing nested or hierarchial flow 
    lsts = ([*k.split("."), v] for k, v in dct.items())

    # insert each list into the result 
    for lst in lsts:
        insert(result, lst)
    return result


def get_unique_name(prefix=None):
    if prefix is None:
D
del-zhenwu 已提交
100
        prefix = "distributed-benchmark-test-"
101 102 103 104
    return prefix + "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8)).lower()


def get_current_time():
105
    """ Return current time"""
106 107 108 109 110 111 112 113 114 115 116 117
    return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())


def print_table(headers, columns, data):
    bodys = []
    for index, value in enumerate(columns):
        tmp = [value]
        tmp.extend(data[index])
        bodys.append(tmp)
    tp.table(bodys, headers)


118
def get_deploy_mode(deploy_params):
119 120 121 122
    """
    Get the server deployment mode set in the yaml configuration file
    single, cluster, cluster_3rd
    """
123 124 125 126 127 128 129 130 131
    deploy_mode = None
    if deploy_params:
        milvus_params = None
        if "milvus" in deploy_params:
            milvus_params = deploy_params["milvus"]
        if not milvus_params:
            deploy_mode = config.DEFUALT_DEPLOY_MODE
        elif "deploy_mode" in milvus_params:
            deploy_mode = milvus_params["deploy_mode"]
132
            if deploy_mode not in [config.SINGLE_DEPLOY_MODE, config.CLUSTER_DEPLOY_MODE, config.CLUSTER_3RD_DEPLOY_MODE]:
133 134 135 136 137
                raise Exception("Invalid deploy mode: %s" % deploy_mode)
    return deploy_mode


def get_server_tag(deploy_params):
138 139 140 141 142 143
    """
    Get service deployment configuration
    e.g.:
        server:
          server_tag: "8c16m"
    """
144 145 146 147
    server_tag = ""
    if deploy_params and "server" in deploy_params:
        server = deploy_params["server"]
        server_tag = server["server_tag"] if "server_tag" in server else ""
148 149
    return server_tag

150

151 152 153 154 155 156 157
def get_server_resource(deploy_params):
    server_resource = {}
    if deploy_params and "server_resource" in deploy_params:
        server_resource = deploy_params["server_resource"]
    return server_resource


W
wt 已提交
158 159 160 161 162 163 164 165 166
def dict_update(source, target):
    for key, value in source.items():
        if isinstance(value, dict) and key in target:
            dict_update(source[key], target[key])
        else:
            target[key] = value
    return target


167 168 169 170 171 172 173 174 175
def update_dict_value(server_resource, values_dict):
    if not isinstance(server_resource, dict) or not isinstance(values_dict, dict):
        return values_dict

    target = dict_update(server_resource, values_dict)

    return target


176
def search_param_analysis(vector_query, filter_query):
177
    """ Search parameter adjustment, applicable pymilvus version >= 2.0.0rc7.dev24 """
178 179 180 181

    if "vector" in vector_query:
        vector = vector_query["vector"]
    else:
182
        logger.error("[search_param_analysis] vector not in vector_query")
183 184 185 186 187 188 189 190 191 192 193 194 195 196
        return False

    data = []
    anns_field = ""
    param = {}
    limit = 1
    if isinstance(vector, dict) and len(vector) == 1:
        for key in vector:
            anns_field = key
            data = vector[key]["query"]
            param = {"metric_type": vector[key]["metric_type"],
                     "params": vector[key]["params"]}
            limit = vector[key]["topk"]
    else:
197
        logger.error("[search_param_analysis] vector not dict or len != 1: %s" % str(vector))
198 199
        return False

200
    expression = None
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
    if isinstance(filter_query, list) and len(filter_query) != 0 and "range" in filter_query[0]:
        filter_range = filter_query[0]["range"]
        if isinstance(filter_range, dict) and len(filter_range) == 1:
            for key in filter_range:
                field_name = filter_range[key]
                expression = None
                if 'GT' in filter_range[key]:
                    exp1 = "%s > %s" % (field_name, str(filter_range[key]['GT']))
                    expression = exp1
                if 'LT' in filter_range[key]:
                    exp2 = "%s < %s" % (field_name, str(filter_range[key]['LT']))
                    if expression:
                        expression = expression + ' && ' + exp2
                    else:
                        expression = exp2
        else:
217
            logger.error("[search_param_analysis] filter_range not dict or len != 1: %s" % str(filter_range))
218
            return False
219
    # else:
220
        # logger.debug("[search_param_analysis] range not in filter_query: %s" % str(filter_query))
221
        # expression = None
222 223 224 225 226 227 228 229 230

    result = {
        "data": data,
        "anns_field": anns_field,
        "param": param,
        "limit": limit,
        "expression": expression
    }
    return result
W
wt 已提交
231 232


W
wt 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
def modify_file(file_path_list, is_modify=False, input_content=""):
    """
    file_path_list : file list -> list[<file_path>]
    is_modify : does the file need to be reset
    input_content :the content that need to insert to the file
    """
    if not isinstance(file_path_list, list):
        print("[modify_file] file is not a list.")

    for file_path in file_path_list:
        folder_path, file_name = os.path.split(file_path)
        if not os.path.isdir(folder_path):
            print("[modify_file] folder(%s) is not exist." % folder_path)
            os.makedirs(folder_path)

        if not os.path.isfile(file_path):
            print("[modify_file] file(%s) is not exist." % file_path)
            os.mknod(file_path)
        else:
            if is_modify is True:
                print("[modify_file] start modifying file(%s)..." % file_path)
                with open(file_path, "r+") as f:
                    f.seek(0)
                    f.truncate()
                    f.write(input_content)
                    f.close()
                print("[modify_file] file(%s) modification is complete." % file_path_list)


W
wt 已提交
262
def read_json_file(file_name):
263
    """ Return content of json file """
W
wt 已提交
264 265 266
    with open(file_name) as f:
        file_dict = json.load(f)
    return file_dict
267 268 269 270 271 272 273 274 275 276 277 278


def get_token(url):
    """ get the request token and return the value """
    rep = requests.get(url)
    data = json.loads(rep.text)
    if 'token' in data:
        token = data['token']
    else:
        token = ''
        print("Can not get token.")
    return token