提交 dbe4e672 编写于 作者: T TeslaZhao

update

上级 209aef88
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import os import os
import base64 import base64
...@@ -12,6 +26,8 @@ except ImportError: ...@@ -12,6 +26,8 @@ except ImportError:
import numpy as np import numpy as np
from paddle_serving_client.utils import MultiThreadRunner from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args, show_latency from paddle_serving_client.utils import benchmark_args, show_latency
def parse_benchmark(filein, fileout): def parse_benchmark(filein, fileout):
with open(filein, "r") as fin: with open(filein, "r") as fin:
res = yaml.load(fin) res = yaml.load(fin)
...@@ -24,6 +40,7 @@ def parse_benchmark(filein, fileout): ...@@ -24,6 +40,7 @@ def parse_benchmark(filein, fileout):
with open(fileout, "w") as fout: with open(fileout, "w") as fout:
yaml.dump(res, fout, default_flow_style=False) yaml.dump(res, fout, default_flow_style=False)
def gen_yml(device): def gen_yml(device):
fin = open("config.yml", "r") fin = open("config.yml", "r")
config = yaml.load(fin) config = yaml.load(fin)
...@@ -33,19 +50,24 @@ def gen_yml(device): ...@@ -33,19 +50,24 @@ def gen_yml(device):
config["op"]["det"]["local_service_conf"]["device_type"] = 1 config["op"]["det"]["local_service_conf"]["device_type"] = 1
config["op"]["det"]["local_service_conf"]["devices"] = "2" config["op"]["det"]["local_service_conf"]["devices"] = "2"
config["op"]["rec"]["local_service_conf"]["device_type"] = 1 config["op"]["rec"]["local_service_conf"]["device_type"] = 1
config["op"]["rec"]["local_service_conf"]["devices"] = "2" config["op"]["rec"]["local_service_conf"]["devices"] = "2"
with open("config2.yml", "w") as fout: with open("config2.yml", "w") as fout:
yaml.dump(config, fout, default_flow_style=False) yaml.dump(config, fout, default_flow_style=False)
def cv2_to_base64(image): def cv2_to_base64(image):
return base64.b64encode(image).decode('utf8') return base64.b64encode(image).decode('utf8')
def run_http(idx, batch_size): def run_http(idx, batch_size):
print("start thread ({})".format(idx)) print("start thread ({})".format(idx))
url = "http://127.0.0.1:9999/ocr/prediction" url = "http://127.0.0.1:9999/ocr/prediction"
start = time.time() start = time.time()
test_img_dir = "imgs/" test_img_dir = "imgs/"
#test_img_dir = "rctw_test/images/"
latency_list = []
total_number = 0
for img_file in os.listdir(test_img_dir): for img_file in os.listdir(test_img_dir):
with open(os.path.join(test_img_dir, img_file), 'rb') as file: with open(os.path.join(test_img_dir, img_file), 'rb') as file:
image_data1 = file.read() image_data1 = file.read()
...@@ -56,15 +78,20 @@ def run_http(idx, batch_size): ...@@ -56,15 +78,20 @@ def run_http(idx, batch_size):
end = time.time() end = time.time()
return [[end - start]] return [[end - start]]
def multithread_http(thread, batch_size): def multithread_http(thread, batch_size):
multi_thread_runner = MultiThreadRunner() multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(run_http , thread, batch_size) result = multi_thread_runner.run(run_http, thread, batch_size)
def run_rpc(thread, batch_size): def run_rpc(thread, batch_size):
client = PipelineClient() client = PipelineClient()
client.connect(['127.0.0.1:18090']) client.connect(['127.0.0.1:18090'])
start = time.time() start = time.time()
test_img_dir = "imgs/" test_img_dir = "imgs/"
#test_img_dir = "rctw_test/images/"
latency_list = []
total_number = 0
for img_file in os.listdir(test_img_dir): for img_file in os.listdir(test_img_dir):
with open(os.path.join(test_img_dir, img_file), 'rb') as file: with open(os.path.join(test_img_dir, img_file), 'rb') as file:
image_data = file.read() image_data = file.read()
...@@ -78,16 +105,17 @@ def run_rpc(thread, batch_size): ...@@ -78,16 +105,17 @@ def run_rpc(thread, batch_size):
def multithread_rpc(thraed, batch_size): def multithread_rpc(thraed, batch_size):
multi_thread_runner = MultiThreadRunner() multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(run_rpc , thread, batch_size) result = multi_thread_runner.run(run_rpc, thread, batch_size)
if __name__ == "__main__": if __name__ == "__main__":
if sys.argv[1] == "yaml": if sys.argv[1] == "yaml":
mode = sys.argv[2] # brpc/ local predictor mode = sys.argv[2] # brpc/ local predictor
thread = int(sys.argv[3]) thread = int(sys.argv[3])
device = sys.argv[4] device = sys.argv[4]
gen_yml(device) gen_yml(device)
elif sys.argv[1] == "run": elif sys.argv[1] == "run":
mode = sys.argv[2] # http/ rpc mode = sys.argv[2] # http/ rpc
thread = int(sys.argv[3]) thread = int(sys.argv[3])
batch_size = int(sys.argv[4]) batch_size = int(sys.argv[4])
if mode == "http": if mode == "http":
...@@ -98,4 +126,3 @@ if __name__ == "__main__": ...@@ -98,4 +126,3 @@ if __name__ == "__main__":
filein = sys.argv[2] filein = sys.argv[2]
fileout = sys.argv[3] fileout = sys.argv[3]
parse_benchmark(filein, fileout) parse_benchmark(filein, fileout)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册