提交 b8972b36 编写于 作者: L LDOUBLEV

add python benchmark for ocr

上级 5d24736a
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
import logging
import paddle
import paddle.inference as paddle_infer
from pathlib import Path
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
class PaddleInferBenchmark(object):
def __init__(self,
config,
model_info: dict={},
data_info: dict={},
perf_info: dict={},
resource_info: dict={},
save_log_path: str="",
**kwargs):
"""
Construct PaddleInferBenchmark Class to format logs.
args:
config(paddle.inference.Config): paddle inference config
model_info(dict): basic model info
{'model_name': 'resnet50'
'precision': 'fp32'}
data_info(dict): input data info
{'batch_size': 1
'shape': '3,224,224'
'data_num': 1000}
perf_info(dict): performance result
{'preprocess_time_s': 1.0
'inference_time_s': 2.0
'postprocess_time_s': 1.0
'total_time_s': 4.0}
resource_info(dict):
cpu and gpu resources
{'cpu_rss': 100
'gpu_rss': 100
'gpu_util': 60}
"""
# PaddleInferBenchmark Log Version
self.log_version = 1.0
# Paddle Version
self.paddle_version = paddle.__version__
self.paddle_commit = paddle.__git_commit__
paddle_infer_info = paddle_infer.get_version()
self.paddle_branch = paddle_infer_info.strip().split(': ')[-1]
# model info
self.model_info = model_info
# data info
self.data_info = data_info
# perf info
self.perf_info = perf_info
try:
self.model_name = model_info['model_name']
self.precision = model_info['precision']
self.batch_size = data_info['batch_size']
self.shape = data_info['shape']
self.data_num = data_info['data_num']
self.preprocess_time_s = round(perf_info['preprocess_time_s'], 4)
self.inference_time_s = round(perf_info['inference_time_s'], 4)
self.postprocess_time_s = round(perf_info['postprocess_time_s'], 4)
self.total_time_s = round(perf_info['total_time_s'], 4)
except:
self.print_help()
raise ValueError(
"Set argument wrong, please check input argument and its type")
# conf info
self.config_status = self.parse_config(config)
self.save_log_path = save_log_path
# mem info
if isinstance(resource_info, dict):
self.cpu_rss_mb = int(resource_info.get('cpu_rss_mb', 0))
self.gpu_rss_mb = int(resource_info.get('gpu_rss_mb', 0))
self.gpu_util = round(resource_info.get('gpu_util', 0), 2)
else:
self.cpu_rss_mb = 0
self.gpu_rss_mb = 0
self.gpu_util = 0
# init benchmark logger
self.benchmark_logger()
def benchmark_logger(self):
"""
benchmark logger
"""
# Init logger
FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
log_output = f"{self.save_log_path}/{self.model_name}.log"
Path(f"{self.save_log_path}").mkdir(parents=True, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format=FORMAT,
handlers=[
logging.FileHandler(
filename=log_output, mode='w'),
logging.StreamHandler(),
])
self.logger = logging.getLogger(__name__)
self.logger.info(
f"Paddle Inference benchmark log will be saved to {log_output}")
def parse_config(self, config) -> dict:
"""
parse paddle predictor config
args:
config(paddle.inference.Config): paddle inference config
return:
config_status(dict): dict style config info
"""
config_status = {}
config_status['runtime_device'] = "gpu" if config.use_gpu() else "cpu"
config_status['ir_optim'] = config.ir_optim()
config_status['enable_tensorrt'] = config.tensorrt_engine_enabled()
config_status['precision'] = self.precision
config_status['enable_mkldnn'] = config.mkldnn_enabled()
config_status[
'cpu_math_library_num_threads'] = config.cpu_math_library_num_threads(
)
return config_status
def report(self, identifier=None):
"""
print log report
args:
identifier(string): identify log
"""
if identifier:
identifier = f"[{identifier}]"
else:
identifier = ""
self.logger.info("\n")
self.logger.info(
"---------------------- Paddle info ----------------------")
self.logger.info(f"{identifier} paddle_version: {self.paddle_version}")
self.logger.info(f"{identifier} paddle_commit: {self.paddle_commit}")
self.logger.info(f"{identifier} paddle_branch: {self.paddle_branch}")
self.logger.info(f"{identifier} log_api_version: {self.log_version}")
self.logger.info(
"----------------------- Conf info -----------------------")
self.logger.info(
f"{identifier} runtime_device: {self.config_status['runtime_device']}"
)
self.logger.info(
f"{identifier} ir_optim: {self.config_status['ir_optim']}")
self.logger.info(f"{identifier} enable_memory_optim: {True}")
self.logger.info(
f"{identifier} enable_tensorrt: {self.config_status['enable_tensorrt']}"
)
self.logger.info(
f"{identifier} enable_mkldnn: {self.config_status['enable_mkldnn']}")
self.logger.info(
f"{identifier} cpu_math_library_num_threads: {self.config_status['cpu_math_library_num_threads']}"
)
self.logger.info(
"----------------------- Model info ----------------------")
self.logger.info(f"{identifier} model_name: {self.model_name}")
self.logger.info(f"{identifier} precision: {self.precision}")
self.logger.info(
"----------------------- Data info -----------------------")
self.logger.info(f"{identifier} batch_size: {self.batch_size}")
self.logger.info(f"{identifier} input_shape: {self.shape}")
self.logger.info(f"{identifier} data_num: {self.data_num}")
self.logger.info(
"----------------------- Perf info -----------------------")
self.logger.info(
f"{identifier} cpu_rss(MB): {self.cpu_rss_mb}, gpu_rss(MB): {self.gpu_rss_mb}, gpu_util: {self.gpu_util}%"
)
self.logger.info(
f"{identifier} total time spent(s): {self.total_time_s}")
self.logger.info(
f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, inference_time(ms): {round(self.inference_time_s*1000, 1)}, postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}"
)
def print_help(self):
"""
print function help
"""
print("""Usage:
==== Print inference benchmark logs. ====
config = paddle.inference.Config()
model_info = {'model_name': 'resnet50'
'precision': 'fp32'}
data_info = {'batch_size': 1
'shape': '3,224,224'
'data_num': 1000}
perf_info = {'preprocess_time_s': 1.0
'inference_time_s': 2.0
'postprocess_time_s': 1.0
'total_time_s': 4.0}
resource_info = {'cpu_rss_mb': 100
'gpu_rss_mb': 100
'gpu_util': 60}
log = PaddleInferBenchmark(config, model_info, data_info, perf_info, resource_info)
log('Test')
""")
def __call__(self, identifier=None):
"""
__call__
args:
identifier(string): identify log
"""
self.report(identifier)
......@@ -45,9 +45,11 @@ class TextClassifier(object):
"label_list": args.label_list,
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = \
self.predictor, self.input_tensor, self.output_tensors, _ = \
utility.create_predictor(args, 'cls', logger)
self.cls_times = utility.Timer()
def resize_norm_img(self, img):
imgC, imgH, imgW = self.cls_image_shape
h = img.shape[0]
......@@ -83,7 +85,9 @@ class TextClassifier(object):
cls_res = [['', 0.0]] * img_num
batch_num = self.cls_batch_num
elapse = 0
self.cls_times.total_time.start()
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
......@@ -91,6 +95,7 @@ class TextClassifier(object):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
self.cls_times.preprocess_time.start()
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
......@@ -98,11 +103,17 @@ class TextClassifier(object):
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
starttime = time.time()
self.cls_times.preprocess_time.end()
self.cls_times.inference_time.start()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
prob_out = self.output_tensors[0].copy_to_cpu()
self.cls_times.inference_time.end()
self.cls_times.postprocess_time.start()
self.predictor.try_shrink_memory()
cls_result = self.postprocess_op(prob_out)
self.cls_times.postprocess_time.end()
elapse += time.time() - starttime
for rno in range(len(cls_result)):
label, score = cls_result[rno]
......@@ -110,6 +121,9 @@ class TextClassifier(object):
if '180' in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1)
self.cls_times.total_time.end()
self.cls_times.img_num += img_num
elapse = self.cls_times.total_time.value()
return img_list, cls_res, elapse
......@@ -141,8 +155,9 @@ def main(args):
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
cls_res[ino]))
logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time))
logger.info(
"The predict time about text angle classify module is as follows: ")
text_classifier.cls_times.info(average=False)
if __name__ == "__main__":
......
......@@ -31,6 +31,8 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
import tools.infer.benchmark_utils as benchmark_utils
logger = get_logger()
......@@ -95,9 +97,10 @@ class TextDetector(object):
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
args, 'det', logger) # paddle.jit.load(args.det_model_dir)
# self.predictor.eval()
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
args, 'det', logger)
self.det_times = utility.Timer()
def order_points_clockwise(self, pts):
"""
......@@ -155,6 +158,8 @@ class TextDetector(object):
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
self.det_times.total_time.start()
self.det_times.preprocess_time.start()
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
......@@ -162,7 +167,9 @@ class TextDetector(object):
img = np.expand_dims(img, axis=0)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
starttime = time.time()
self.det_times.preprocess_time.end()
self.det_times.inference_time.start()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
......@@ -170,6 +177,7 @@ class TextDetector(object):
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
self.det_times.inference_time.end()
preds = {}
if self.det_algorithm == "EAST":
......@@ -184,6 +192,9 @@ class TextDetector(object):
preds['maps'] = outputs[0]
else:
raise NotImplementedError
self.det_times.postprocess_time.start()
self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
......@@ -191,8 +202,11 @@ class TextDetector(object):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
elapse = time.time() - starttime
return dt_boxes, elapse
self.det_times.postprocess_time.end()
self.det_times.total_time.end()
self.det_times.img_num += 1
return dt_boxes, self.det_times.total_time.value()
if __name__ == "__main__":
......@@ -202,6 +216,13 @@ if __name__ == "__main__":
count = 0
total_time = 0
draw_img_save = "./inference_results"
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
# warmup 10 times
fake_img = np.random.uniform(-1, 1, [640, 640, 3]).astype(np.float32)
for i in range(10):
dt_boxes, _ = text_detector(fake_img)
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
for image_file in image_file_list:
......@@ -211,16 +232,56 @@ if __name__ == "__main__":
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
dt_boxes, elapse = text_detector(img)
st = time.time()
dt_boxes, _ = text_detector(img)
elapse = time.time() - st
if count > 0:
total_time += elapse
count += 1
if args.benchmark:
cm, gm, gu = utility.get_current_memory_mb(0)
cpu_mem += cm
gpu_mem += gm
gpu_util += gu
logger.info("Predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save,
"det_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
if count > 1:
logger.info("Avg Time: {}".format(total_time / (count - 1)))
# print the information about memory and time-spent
if args.benchmark:
mems = {
'cpu_rss_mb': cpu_mem / count,
'gpu_rss_mb': gpu_mem / count,
'gpu_util': gpu_util * 100 / count
}
else:
mems = None
logger.info("The predict time about detection module is as follows: ")
det_time_dict = text_detector.det_times.report(average=True)
det_model_name = args.det_model_dir
if args.benchmark:
# construct log information
model_info = {
'model_name': args.det_model_dir.split('/')[-1],
'precision': args.precision
}
data_info = {
'batch_size': 1,
'shape': 'dynamic_shape',
'data_num': det_time_dict['img_num']
}
perf_info = {
'preprocess_time_s': det_time_dict['preprocess_time'],
'inference_time_s': det_time_dict['inference_time'],
'postprocess_time_s': det_time_dict['postprocess_time'],
'total_time_s': det_time_dict['total_time']
}
benchmark_log = benchmark_utils.PaddleInferBenchmark(
text_detector.config, model_info, data_info, perf_info, mems)
benchmark_log("Det")
......@@ -28,6 +28,7 @@ import traceback
import paddle
import tools.infer.utility as utility
import tools.infer.benchmark_utils as benchmark_utils
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
......@@ -41,7 +42,6 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
self.max_text_length = args.max_text_length
postprocess_params = {
'name': 'CTCLabelDecode',
"character_type": args.rec_char_type,
......@@ -63,9 +63,11 @@ class TextRecognizer(object):
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = \
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
self.rec_times = utility.Timer()
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2]
......@@ -166,17 +168,15 @@ class TextRecognizer(object):
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
# rec_res = []
self.rec_times.total_time.start()
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
elapse = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
self.rec_times.preprocess_time.start()
for ino in range(beg_img_no, end_img_no):
# h, w = img_list[ino].shape[0:2]
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
......@@ -187,9 +187,8 @@ class TextRecognizer(object):
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
else:
norm_img = self.process_image_srn(img_list[indices[ino]],
self.rec_image_shape, 8,
self.max_text_length)
norm_img = self.process_image_srn(
img_list[indices[ino]], self.rec_image_shape, 8, 25)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
......@@ -203,7 +202,6 @@ class TextRecognizer(object):
norm_img_batch = norm_img_batch.copy()
if self.rec_algorithm == "SRN":
starttime = time.time()
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
......@@ -218,19 +216,23 @@ class TextRecognizer(object):
gsrm_slf_attn_bias1_list,
gsrm_slf_attn_bias2_list,
]
self.rec_times.preprocess_time.end()
self.rec_times.inference_time.start()
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[
i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
self.rec_times.inference_time.end()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = {"predict": outputs[2]}
else:
starttime = time.time()
self.rec_times.preprocess_time.end()
self.rec_times.inference_time.start()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
......@@ -239,22 +241,31 @@ class TextRecognizer(object):
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = outputs[0]
self.predictor.try_shrink_memory()
self.rec_times.inference_time.end()
self.rec_times.postprocess_time.start()
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
elapse += time.time() - starttime
return rec_res, elapse
self.rec_times.postprocess_time.end()
self.rec_times.img_num += int(norm_img_batch.shape[0])
self.rec_times.total_time.end()
return rec_res, self.rec_times.total_time.value()
def main(args):
image_file_list = get_image_file_list(args.image_dir)
text_recognizer = TextRecognizer(args)
total_run_time = 0.0
total_images_num = 0
valid_image_file_list = []
img_list = []
for idx, image_file in enumerate(image_file_list):
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
count = 0
# warmup 10 times
fake_img = np.random.uniform(-1, 1, [1, 32, 320, 3]).astype(np.float32)
for i in range(10):
dt_boxes, _ = text_recognizer(fake_img)
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
......@@ -263,29 +274,54 @@ def main(args):
continue
valid_image_file_list.append(image_file)
img_list.append(img)
if len(img_list) >= args.rec_batch_num or idx == len(
image_file_list) - 1:
try:
rec_res, predict_time = text_recognizer(img_list)
total_run_time += predict_time
except:
logger.info(traceback.format_exc())
logger.info(
"ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: "
"TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
)
exit()
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[
ino], rec_res[ino]))
total_images_num += len(valid_image_file_list)
valid_image_file_list = []
img_list = []
logger.info("Total predict time for {} images, cost: {:.3f}".format(
total_images_num, total_run_time))
try:
rec_res, _ = text_recognizer(img_list)
if args.benchmark:
cm, gm, gu = utility.get_current_memory_mb(0)
cpu_mem += cm
gpu_mem += gm
gpu_util += gu
count += 1
except Exception as E:
logger.info(traceback.format_exc())
logger.info(E)
exit()
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
rec_res[ino]))
if args.benchmark:
mems = {
'cpu_rss_mb': cpu_mem / count,
'gpu_rss_mb': gpu_mem / count,
'gpu_util': gpu_util * 100 / count
}
else:
mems = None
logger.info("The predict time about recognizer module is as follows: ")
rec_time_dict = text_recognizer.rec_times.report(average=True)
rec_model_name = args.rec_model_dir
if args.benchmark:
# construct log information
model_info = {
'model_name': args.rec_model_dir.split('/')[-1],
'precision': args.precision
}
data_info = {
'batch_size': args.rec_batch_num,
'shape': 'dynamic_shape',
'data_num': rec_time_dict['img_num']
}
perf_info = {
'preprocess_time_s': rec_time_dict['preprocess_time'],
'inference_time_s': rec_time_dict['inference_time'],
'postprocess_time_s': rec_time_dict['postprocess_time'],
'total_time_s': rec_time_dict['total_time']
}
benchmark_log = benchmark_utils.PaddleInferBenchmark(
text_recognizer.config, model_info, data_info, perf_info, mems)
benchmark_log("Rec")
if __name__ == "__main__":
......
......@@ -13,7 +13,6 @@
# limitations under the License.
import os
import sys
import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
......@@ -32,8 +31,8 @@ import tools.infer.predict_det as predict_det
import tools.infer.predict_cls as predict_cls
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from tools.infer.utility import draw_ocr_box_txt
from tools.infer.utility import draw_ocr_box_txt, get_current_memory_mb
import tools.infer.benchmark_utils as benchmark_utils
logger = get_logger()
......@@ -88,8 +87,7 @@ class TextSystem(object):
def __call__(self, img):
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
logger.info("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
if dt_boxes is None:
return None, None
img_crop_list = []
......@@ -103,13 +101,9 @@ class TextSystem(object):
if self.use_angle_cls:
img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list)
logger.info("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list)
logger.info("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
filter_boxes, filter_rec_res = [], []
for box, rec_reuslt in zip(dt_boxes, rec_res):
text, score = rec_reuslt
......@@ -142,12 +136,15 @@ def sorted_boxes(dt_boxes):
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num]
text_sys = TextSystem(args)
is_visualize = True
font_path = args.vis_font_path
drop_score = args.drop_score
for image_file in image_file_list:
total_time = 0
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
_st = time.time()
count = 0
for idx, image_file in enumerate(image_file_list):
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
......@@ -157,8 +154,16 @@ def main(args):
starttime = time.time()
dt_boxes, rec_res = text_sys(img)
elapse = time.time() - starttime
logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
total_time += elapse
if args.benchmark and idx % 20 == 0:
cm, gm, gu = get_current_memory_mb(0)
cpu_mem += cm
gpu_mem += gm
gpu_util += gu
count += 1
logger.info(
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
for text, score in rec_res:
logger.info("{}, {:.3f}".format(text, score))
......@@ -178,26 +183,74 @@ def main(args):
draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
if flag:
image_file = image_file[:-3] + "png"
cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)),
draw_img[:, :, ::-1])
logger.info("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file))))
logger.info("The predict total time is {}".format(time.time() - _st))
logger.info("\nThe predict total time is {}".format(total_time))
if __name__ == "__main__":
args = utility.parse_args()
if args.use_mp:
p_list = []
total_process_num = args.total_process_num
for process_id in range(total_process_num):
cmd = [sys.executable, "-u"] + sys.argv + [
"--process_id={}".format(process_id),
"--use_mp={}".format(False)
]
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
p_list.append(p)
for p in p_list:
p.wait()
img_num = text_sys.text_detector.det_times.img_num
if args.benchmark:
mems = {
'cpu_rss_mb': cpu_mem / count,
'gpu_rss_mb': gpu_mem / count,
'gpu_util': gpu_util * 100 / count
}
else:
main(args)
mems = None
det_time_dict = text_sys.text_detector.det_times.report(average=True)
rec_time_dict = text_sys.text_recognizer.rec_times.report(average=True)
det_model_name = args.det_model_dir
rec_model_name = args.rec_model_dir
# construct det log information
model_info = {
'model_name': args.det_model_dir.split('/')[-1],
'precision': args.precision
}
data_info = {
'batch_size': 1,
'shape': 'dynamic_shape',
'data_num': det_time_dict['img_num']
}
perf_info = {
'preprocess_time_s': det_time_dict['preprocess_time'],
'inference_time_s': det_time_dict['inference_time'],
'postprocess_time_s': det_time_dict['postprocess_time'],
'total_time_s': det_time_dict['total_time']
}
benchmark_log = benchmark_utils.PaddleInferBenchmark(
text_sys.text_detector.config, model_info, data_info, perf_info, mems,
args.save_log_path)
benchmark_log("Det")
# construct rec log information
model_info = {
'model_name': args.rec_model_dir.split('/')[-1],
'precision': args.precision
}
data_info = {
'batch_size': args.rec_batch_num,
'shape': 'dynamic_shape',
'data_num': rec_time_dict['img_num']
}
perf_info = {
'preprocess_time_s': rec_time_dict['preprocess_time'],
'inference_time_s': rec_time_dict['inference_time'],
'postprocess_time_s': rec_time_dict['postprocess_time'],
'total_time_s': rec_time_dict['total_time']
}
benchmark_log = benchmark_utils.PaddleInferBenchmark(
text_sys.text_recognizer.config, model_info, data_info, perf_info, mems,
args.save_log_path)
benchmark_log("Rec")
if __name__ == "__main__":
main(utility.parse_args())
......@@ -21,6 +21,9 @@ import json
from PIL import Image, ImageDraw, ImageFont
import math
from paddle import inference
import time
from ppocr.utils.logging import get_logger
logger = get_logger()
def parse_args():
......@@ -32,7 +35,7 @@ def parse_args():
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500)
# params for text detector
......@@ -98,15 +101,88 @@ def parse_args():
parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--cpu_threads", type=int, default=10)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0)
parser.add_argument("--benchmark", type=bool, default=False)
parser.add_argument("--save_log_path", type=str, default="./log_output/")
return parser.parse_args()
class Times(object):
def __init__(self):
self.time = 0.
self.st = 0.
self.et = 0.
def start(self):
self.st = time.time()
def end(self, accumulative=True):
self.et = time.time()
if accumulative:
self.time += self.et - self.st
else:
self.time = self.et - self.st
def reset(self):
self.time = 0.
self.st = 0.
self.et = 0.
def value(self):
return round(self.time, 4)
class Timer(Times):
def __init__(self):
super(Timer, self).__init__()
self.total_time = Times()
self.preprocess_time = Times()
self.inference_time = Times()
self.postprocess_time = Times()
self.img_num = 0
def info(self, average=False):
logger.info("----------------------- Perf info -----------------------")
logger.info("total_time: {}, img_num: {}".format(self.total_time.value(
), self.img_num))
preprocess_time = round(self.preprocess_time.value() / self.img_num,
4) if average else self.preprocess_time.value()
postprocess_time = round(
self.postprocess_time.value() / self.img_num,
4) if average else self.postprocess_time.value()
inference_time = round(self.inference_time.value() / self.img_num,
4) if average else self.inference_time.value()
average_latency = self.total_time.value() / self.img_num
logger.info("average_latency(ms): {:.2f}, QPS: {:2f}".format(
average_latency * 1000, 1 / average_latency))
logger.info(
"preprocess_latency(ms): {:.2f}, inference_latency(ms): {:.2f}, postprocess_latency(ms): {:.2f}".
format(preprocess_time * 1000, inference_time * 1000,
postprocess_time * 1000))
def report(self, average=False):
dic = {}
dic['preprocess_time'] = round(
self.preprocess_time.value() / self.img_num,
4) if average else self.preprocess_time.value()
dic['postprocess_time'] = round(
self.postprocess_time.value() / self.img_num,
4) if average else self.postprocess_time.value()
dic['inference_time'] = round(
self.inference_time.value() / self.img_num,
4) if average else self.inference_time.value()
dic['img_num'] = self.img_num
dic['total_time'] = round(self.total_time.value(), 4)
return dic
def create_predictor(args, mode, logger):
if mode == "det":
model_dir = args.det_model_dir
......@@ -131,6 +207,16 @@ def create_predictor(args, mode, logger):
config = inference.Config(model_file_path, params_file_path)
if hasattr(args, 'precision'):
if args.precision == "fp16" and args.use_tensorrt:
precision = inference.PrecisionType.Half
elif args.precision == "int8":
precision = inference.PrecisionType.Int8
else:
precision = inference.PrecisionType.Float32
else:
precision = inference.PrecisionType.Float32
if args.use_gpu:
config.enable_use_gpu(args.gpu_mem, 0)
if args.use_tensorrt:
......@@ -140,7 +226,10 @@ def create_predictor(args, mode, logger):
max_batch_size=args.max_batch_size)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(6)
if hasattr(args, "cpu_threads"):
config.set_cpu_math_library_num_threads(args.cpu_threads)
else:
config.set_cpu_math_library_num_threads(10)
if args.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
......@@ -166,7 +255,7 @@ def create_predictor(args, mode, logger):
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor)
return predictor, input_tensor, output_tensors
return predictor, input_tensor, output_tensors, config
def draw_e2e_res(dt_boxes, strs, img_path):
......@@ -417,6 +506,31 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5):
return image
def get_current_memory_mb(gpu_id=None):
"""
It is used to Obtain the memory usage of the CPU and GPU during the running of the program.
And this function Current program is time-consuming.
"""
import pynvml
import psutil
import GPUtil
pid = os.getpid()
p = psutil.Process(pid)
info = p.memory_full_info()
cpu_mem = info.uss / 1024. / 1024.
gpu_mem = 0
gpu_percent = 0
if gpu_id is not None:
GPUs = GPUtil.getGPUs()
gpu_load = GPUs[gpu_id].load
gpu_percent = gpu_load
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpu_mem = meminfo.used / 1024. / 1024.
return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4)
if __name__ == '__main__':
test_img = "./doc/test_v2"
predict_txt = "./doc/predict.txt"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册