benchmark.py 3.7 KB
Newer Older
B
bjjwwang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
import sys
import os
import yaml
import requests
import time
import json
import cv2
import base64
try:
    from paddle_serving_server_gpu.pipeline import PipelineClient
except ImportError:
    from paddle_serving_server.pipeline import PipelineClient
import numpy as np
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args, show_latency

def cv2_to_base64(image):
    return base64.b64encode(image).decode('utf8')

def parse_benchmark(filein, fileout):
    with open(filein, "r") as fin:
        res = yaml.load(fin)
        del_list = []
        for key in res["DAG"].keys():
            if "call" in key:
                del_list.append(key)
        for key in del_list:
            del res["DAG"][key]
    with open(fileout, "w") as fout:
        yaml.dump(res, fout, default_flow_style=False)

def gen_yml(device, gpu_id):
    fin = open("config.yml", "r")
    config = yaml.load(fin)
    fin.close()
    config["dag"]["tracer"] = {"interval_s": 30}
    if device == "gpu":
        config["op"]["ppyolo_mbv3"]["local_service_conf"]["device_type"] = 1
B
bjjwwang 已提交
39
        config["op"]["ppyolo_mbv3"]["local_service_conf"]["devices"] = gpu_id
B
bjjwwang 已提交
40 41 42 43 44 45 46 47 48
    with open("config2.yml", "w") as fout: 
        yaml.dump(config, fout, default_flow_style=False)

def run_http(idx, batch_size):
    print("start thread ({})".format(idx))
    url = "http://127.0.0.1:18082/ppyolo_mbv3/prediction"
    with open(os.path.join(".", "000000570688.jpg"), 'rb') as file:
        image_data1 = file.read()
    image = cv2_to_base64(image_data1)
B
bjjwwang 已提交
49
    latency_list = []
B
bjjwwang 已提交
50
    start = time.time()
B
bjjwwang 已提交
51
    total_num = 0
B
bjjwwang 已提交
52
    while True:
B
bjjwwang 已提交
53
        l_start = time.time()
B
bjjwwang 已提交
54 55 56 57 58
        data = {"key": [], "value": []}
        for j in range(batch_size):
            data["key"].append("image_" + str(j))
            data["value"].append(image)
        r = requests.post(url=url, data=json.dumps(data))
B
bjjwwang 已提交
59 60
        l_end = time.time()
        total_num += 1
B
bjjwwang 已提交
61
        end = time.time()
B
bjjwwang 已提交
62
        latency_list.append(l_end * 1000 - l_start * 1000)
B
bjjwwang 已提交
63
        if end - start > 70:
B
bjjwwang 已提交
64
            #print("70s end")
B
bjjwwang 已提交
65
            break
B
bjjwwang 已提交
66
    return [[end - start], latency_list, [total_num]]
B
bjjwwang 已提交
67 68 69

def multithread_http(thread, batch_size):
    multi_thread_runner = MultiThreadRunner()
B
bjjwwang 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
    start = time.time()
    result = multi_thread_runner.run(run_http, thread, batch_size)
    end = time.time()
    total_cost = end - start
    avg_cost = 0
    total_number = 0
    for i in range(thread):
        avg_cost += result[0][i]
        total_number += result[2][i]
    avg_cost = avg_cost / thread
    print("Total cost: {}s".format(total_cost))
    print("Each thread cost: {}s. ".format(avg_cost))
    print("Total count: {}. ".format(total_number))
    print("AVG QPS: {} samples/s".format(batch_size * total_number /
                                         total_cost))
    show_latency(result[1])
B
bjjwwang 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113

def run_rpc(thread, batch_size):
    pass

def multithread_rpc(thraed, batch_size):
    multi_thread_runner = MultiThreadRunner()
    result = multi_thread_runner.run(run_rpc , thread, batch_size)

if __name__ == "__main__":
    if sys.argv[1] == "yaml":
        mode = sys.argv[2] # brpc/  local predictor
        thread = int(sys.argv[3])
        device = sys.argv[4]
        gpu_id = sys.argv[5]
        gen_yml(device, gpu_id)
    elif sys.argv[1] == "run":
        mode = sys.argv[2] # http/ rpc
        thread = int(sys.argv[3])
        batch_size = int(sys.argv[4])
        if mode == "http":
            multithread_http(thread, batch_size)
        elif mode == "rpc":
            multithread_rpc(thread, batch_size)
    elif sys.argv[1] == "dump":
        filein = sys.argv[2]
        fileout = sys.argv[3]
        parse_benchmark(filein, fileout)