benchmark.py 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wangjiawei04 已提交
15 16 17 18 19 20 21
import sys
import os
import base64
import yaml
import requests
import time
import json
T
TeslaZhao 已提交
22 23

from paddle_serving_server.pipeline import PipelineClient
W
wangjiawei04 已提交
24 25 26
import numpy as np
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args, show_latency
27 28


W
wangjiawei04 已提交
29 30
def parse_benchmark(filein, fileout):
    with open(filein, "r") as fin:
31
        res = yaml.load(fin, yaml.FullLoader)
W
wangjiawei04 已提交
32 33 34 35 36 37 38 39 40
        del_list = []
        for key in res["DAG"].keys():
            if "call" in key:
                del_list.append(key)
        for key in del_list:
            del res["DAG"][key]
    with open(fileout, "w") as fout:
        yaml.dump(res, fout, default_flow_style=False)

41

W
wangjiawei04 已提交
42 43
def gen_yml(device):
    fin = open("config.yml", "r")
44
    config = yaml.load(fin, yaml.FullLoader)
W
wangjiawei04 已提交
45 46 47 48 49 50
    fin.close()
    config["dag"]["tracer"] = {"interval_s": 10}
    if device == "gpu":
        config["op"]["det"]["local_service_conf"]["device_type"] = 1
        config["op"]["det"]["local_service_conf"]["devices"] = "2"
        config["op"]["rec"]["local_service_conf"]["device_type"] = 1
51 52
        config["op"]["rec"]["local_service_conf"]["devices"] = "2"
    with open("config2.yml", "w") as fout:
W
wangjiawei04 已提交
53 54
        yaml.dump(config, fout, default_flow_style=False)

55

W
wangjiawei04 已提交
56 57 58
def cv2_to_base64(image):
    return base64.b64encode(image).decode('utf8')

59

W
wangjiawei04 已提交
60 61
def run_http(idx, batch_size):
    print("start thread ({})".format(idx))
62
    url = "http://127.0.0.1:9999/ocr/prediction"
W
wangjiawei04 已提交
63
    start = time.time()
64 65
    test_img_dir = "imgs/"
    #test_img_dir = "rctw_test/images/"
66 67
    latency_list = []
    total_number = 0
W
wangjiawei04 已提交
68
    for img_file in os.listdir(test_img_dir):
69
        l_start = time.time()
W
wangjiawei04 已提交
70 71 72 73
        with open(os.path.join(test_img_dir, img_file), 'rb') as file:
            image_data1 = file.read()
        image = cv2_to_base64(image_data1)
        data = {"key": ["image"], "value": [image]}
74 75 76
        #for i in range(100):
        r = requests.post(url=url, data=json.dumps(data))
        print(r.json())
W
wangjiawei04 已提交
77
        end = time.time()
78 79 80 81 82
        l_end = time.time()
        latency_list.append(l_end * 1000 - l_start * 1000)
        total_number = total_number + 1
    return [[end - start], latency_list, [total_number]]

W
wangjiawei04 已提交
83 84 85

def multithread_http(thread, batch_size):
    multi_thread_runner = MultiThreadRunner()
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
    start = time.time()
    result = multi_thread_runner.run(run_http, thread, batch_size)
    end = time.time()
    total_cost = end - start
    avg_cost = 0
    total_number = 0
    for i in range(thread):
        avg_cost += result[0][i]
        total_number += result[2][i]
    avg_cost = avg_cost / thread
    print("Total cost: {}s".format(total_cost))
    print("Each thread cost: {}s. ".format(avg_cost))
    print("Total count: {}. ".format(total_number))
    print("AVG QPS: {} samples/s".format(batch_size * total_number /
                                         total_cost))
    show_latency(result[1])

W
wangjiawei04 已提交
103 104 105 106 107

def run_rpc(thread, batch_size):
    client = PipelineClient()
    client.connect(['127.0.0.1:18090'])
    start = time.time()
108 109
    test_img_dir = "imgs/"
    #test_img_dir = "rctw_test/images/"
110 111
    latency_list = []
    total_number = 0
W
wangjiawei04 已提交
112
    for img_file in os.listdir(test_img_dir):
113
        l_start = time.time()
W
wangjiawei04 已提交
114 115 116
        with open(os.path.join(test_img_dir, img_file), 'rb') as file:
            image_data = file.read()
        image = cv2_to_base64(image_data)
117 118 119 120 121
        ret = client.predict(feed_dict={"image": image}, fetch=["res"])
        print(ret)
        l_end = time.time()
        latency_list.append(l_end * 1000 - l_start * 1000)
        total_number = total_number + 1
W
wangjiawei04 已提交
122
    end = time.time()
123
    return [[end - start], latency_list, [total_number]]
W
wangjiawei04 已提交
124 125 126 127


def multithread_rpc(thraed, batch_size):
    multi_thread_runner = MultiThreadRunner()
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
    start = time.time()
    result = multi_thread_runner.run(run_rpc, thread, batch_size)
    end = time.time()
    total_cost = end - start
    avg_cost = 0
    total_number = 0
    for i in range(thread):
        avg_cost += result[0][i]
        total_number += result[2][i]
    avg_cost = avg_cost / thread
    print("Total cost: {}s".format(total_cost))
    print("Each thread cost: {}s. ".format(avg_cost))
    print("Total count: {}. ".format(total_number))
    print("AVG QPS: {} samples/s".format(batch_size * total_number /
                                         total_cost))
    show_latency(result[1])

W
wangjiawei04 已提交
145 146 147

if __name__ == "__main__":
    if sys.argv[1] == "yaml":
148
        mode = sys.argv[2]  # brpc/  local predictor
W
wangjiawei04 已提交
149 150 151 152
        thread = int(sys.argv[3])
        device = sys.argv[4]
        gen_yml(device)
    elif sys.argv[1] == "run":
153
        mode = sys.argv[2]  # http/ rpc
W
wangjiawei04 已提交
154 155 156 157 158 159 160 161 162 163
        thread = int(sys.argv[3])
        batch_size = int(sys.argv[4])
        if mode == "http":
            multithread_http(thread, batch_size)
        elif mode == "rpc":
            multithread_rpc(thread, batch_size)
    elif sys.argv[1] == "dump":
        filein = sys.argv[2]
        fileout = sys.argv[3]
        parse_benchmark(filein, fileout)