utils.py 8.0 KB
Newer Older
1 2 3 4 5
# -*- coding: utf-8 -*-
import time
import logging
import string
import random
W
wt 已提交
6
import json
W
wt 已提交
7
import os
8
from yaml.representer import SafeRepresenter
9
# from yaml import full_load, dump
10 11
import yaml
import tableprint as tp
12
import config
13 14 15 16 17

logger = logging.getLogger("milvus_benchmark.utils")


def timestr_to_int(time_str):
18
    """ Parse the test time set in the yaml configuration file and convert it to int type """
19
    # time_int = 0
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
    if isinstance(time_str, int) or time_str.isdigit():
        time_int = int(time_str)
    elif time_str.endswith("s"):
        time_int = int(time_str.split("s")[0])
    elif time_str.endswith("m"):
        time_int = int(time_str.split("m")[0]) * 60
    elif time_str.endswith("h"):
        time_int = int(time_str.split("h")[0]) * 60 * 60
    else:
        raise Exception("%s not support" % time_str)
    return time_int


class literal_str(str): pass


def change_style(style, representer):
    def new_representer(dumper, data):
        scalar = representer(dumper, data)
        scalar.style = style
        return scalar

    return new_representer


45
# from yaml.representer import SafeRepresenter
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71

# represent_str does handle some corner cases, so use that
# instead of calling represent_scalar directly
represent_literal_str = change_style('|', SafeRepresenter.represent_str)

yaml.add_representer(literal_str, represent_literal_str)


def retry(times):
    """
    This decorator prints the execution time for the decorated function.
    """
    def wrapper(func):
        def newfn(*args, **kwargs):
            attempt = 0
            while attempt < times:
                try:
                    result = func(*args, **kwargs)
                    if result:
                        break
                    else:
                        raise Exception("Result false")
                except Exception as e:
                    logger.info(str(e))
                    time.sleep(3)
                    attempt += 1
D
del-zhenwu 已提交
72
            return result
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
        return newfn
    return wrapper


def convert_nested(dct):
    def insert(dct, lst):
        for x in lst[:-2]:
            dct[x] = dct = dct.get(x, dict())
        dct.update({lst[-2]: lst[-1]})

        # empty dict to store the result

    result = dict()

    # create an iterator of lists  
    # representing nested or hierarchial flow 
    lsts = ([*k.split("."), v] for k, v in dct.items())

    # insert each list into the result 
    for lst in lsts:
        insert(result, lst)
    return result


def get_unique_name(prefix=None):
    if prefix is None:
D
del-zhenwu 已提交
99
        prefix = "distributed-benchmark-test-"
100 101 102 103
    return prefix + "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8)).lower()


def get_current_time():
104
    """ return current time"""
105 106 107 108 109 110 111 112 113 114 115 116
    return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())


def print_table(headers, columns, data):
    bodys = []
    for index, value in enumerate(columns):
        tmp = [value]
        tmp.extend(data[index])
        bodys.append(tmp)
    tp.table(bodys, headers)


117
def get_deploy_mode(deploy_params):
118 119 120 121
    """
    Get the server deployment mode set in the yaml configuration file
    single, cluster, cluster_3rd
    """
122 123 124 125 126 127 128 129 130
    deploy_mode = None
    if deploy_params:
        milvus_params = None
        if "milvus" in deploy_params:
            milvus_params = deploy_params["milvus"]
        if not milvus_params:
            deploy_mode = config.DEFUALT_DEPLOY_MODE
        elif "deploy_mode" in milvus_params:
            deploy_mode = milvus_params["deploy_mode"]
131
            if deploy_mode not in [config.SINGLE_DEPLOY_MODE, config.CLUSTER_DEPLOY_MODE, config.CLUSTER_3RD_DEPLOY_MODE]:
132 133 134 135 136
                raise Exception("Invalid deploy mode: %s" % deploy_mode)
    return deploy_mode


def get_server_tag(deploy_params):
137 138 139 140 141 142
    """
    Get service deployment configuration
    e.g.:
        server:
          server_tag: "8c16m"
    """
143 144 145 146
    server_tag = ""
    if deploy_params and "server" in deploy_params:
        server = deploy_params["server"]
        server_tag = server["server_tag"] if "server_tag" in server else ""
147 148
    return server_tag

149

150 151 152 153 154 155 156
def get_server_resource(deploy_params):
    server_resource = {}
    if deploy_params and "server_resource" in deploy_params:
        server_resource = deploy_params["server_resource"]
    return server_resource


W
wt 已提交
157 158 159 160 161 162 163 164 165
def dict_update(source, target):
    for key, value in source.items():
        if isinstance(value, dict) and key in target:
            dict_update(source[key], target[key])
        else:
            target[key] = value
    return target


166 167 168 169 170 171 172 173 174
def update_dict_value(server_resource, values_dict):
    if not isinstance(server_resource, dict) or not isinstance(values_dict, dict):
        return values_dict

    target = dict_update(server_resource, values_dict)

    return target


175
def search_param_analysis(vector_query, filter_query):
176
    """ Search parameter adjustment, applicable pymilvus version >= 2.0.0rc7.dev24 """
177 178 179 180

    if "vector" in vector_query:
        vector = vector_query["vector"]
    else:
181
        logger.error("[search_param_analysis] vector not in vector_query")
182 183 184 185 186 187 188 189 190 191 192 193 194 195
        return False

    data = []
    anns_field = ""
    param = {}
    limit = 1
    if isinstance(vector, dict) and len(vector) == 1:
        for key in vector:
            anns_field = key
            data = vector[key]["query"]
            param = {"metric_type": vector[key]["metric_type"],
                     "params": vector[key]["params"]}
            limit = vector[key]["topk"]
    else:
196
        logger.error("[search_param_analysis] vector not dict or len != 1: %s" % str(vector))
197 198
        return False

199
    expression = None
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
    if isinstance(filter_query, list) and len(filter_query) != 0 and "range" in filter_query[0]:
        filter_range = filter_query[0]["range"]
        if isinstance(filter_range, dict) and len(filter_range) == 1:
            for key in filter_range:
                field_name = filter_range[key]
                expression = None
                if 'GT' in filter_range[key]:
                    exp1 = "%s > %s" % (field_name, str(filter_range[key]['GT']))
                    expression = exp1
                if 'LT' in filter_range[key]:
                    exp2 = "%s < %s" % (field_name, str(filter_range[key]['LT']))
                    if expression:
                        expression = expression + ' && ' + exp2
                    else:
                        expression = exp2
        else:
216
            logger.error("[search_param_analysis] filter_range not dict or len != 1: %s" % str(filter_range))
217
            return False
218
    # else:
219
        # logger.debug("[search_param_analysis] range not in filter_query: %s" % str(filter_query))
220
        # expression = None
221 222 223 224 225 226 227 228 229

    result = {
        "data": data,
        "anns_field": anns_field,
        "param": param,
        "limit": limit,
        "expression": expression
    }
    return result
W
wt 已提交
230 231


W
wt 已提交
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
def modify_file(file_path_list, is_modify=False, input_content=""):
    """
    file_path_list : file list -> list[<file_path>]
    is_modify : does the file need to be reset
    input_content :the content that need to insert to the file
    """
    if not isinstance(file_path_list, list):
        print("[modify_file] file is not a list.")

    for file_path in file_path_list:
        folder_path, file_name = os.path.split(file_path)
        if not os.path.isdir(folder_path):
            print("[modify_file] folder(%s) is not exist." % folder_path)
            os.makedirs(folder_path)

        if not os.path.isfile(file_path):
            print("[modify_file] file(%s) is not exist." % file_path)
            os.mknod(file_path)
        else:
            if is_modify is True:
                print("[modify_file] start modifying file(%s)..." % file_path)
                with open(file_path, "r+") as f:
                    f.seek(0)
                    f.truncate()
                    f.write(input_content)
                    f.close()
                print("[modify_file] file(%s) modification is complete." % file_path_list)


W
wt 已提交
261
def read_json_file(file_name):
262
    """ return content of json file """
W
wt 已提交
263 264 265
    with open(file_name) as f:
        file_dict = json.load(f)
    return file_dict