提交 e51e2b2b 编写于 作者: T tink2123

add autolog for infer_rec

上级 38801c7f
...@@ -64,6 +64,24 @@ class TextRecognizer(object): ...@@ -64,6 +64,24 @@ class TextRecognizer(object):
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger) utility.create_predictor(args, 'rec', logger)
self.benchmark = args.benchmark
if args.benchmark:
import auto_log
pid = os.getpid()
self.autolog = auto_log.AutoLogger(
model_name="rec",
model_precision=args.precision,
batch_size=6,
data_shape="dynamic",
save_path="./output/auto_log.lpg",
inference_config=self.config,
pids=pid,
process_name=None,
gpu_ids=0 if args.use_gpu else None,
time_keys=[
'preprocess_time', 'inference_time', 'postprocess_time'
],
warmup=10)
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
...@@ -168,6 +186,8 @@ class TextRecognizer(object): ...@@ -168,6 +186,8 @@ class TextRecognizer(object):
rec_res = [['', 0.0]] * img_num rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num batch_num = self.rec_batch_num
st = time.time() st = time.time()
if self.benchmark:
self.autolog.times.start()
for beg_img_no in range(0, img_num, batch_num): for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num) end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = [] norm_img_batch = []
...@@ -196,6 +216,8 @@ class TextRecognizer(object): ...@@ -196,6 +216,8 @@ class TextRecognizer(object):
norm_img_batch.append(norm_img[0]) norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
if self.benchmark:
self.autolog.times.stamp()
if self.rec_algorithm == "SRN": if self.rec_algorithm == "SRN":
encoder_word_pos_list = np.concatenate(encoder_word_pos_list) encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
...@@ -222,6 +244,8 @@ class TextRecognizer(object): ...@@ -222,6 +244,8 @@ class TextRecognizer(object):
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
if self.benchmark:
self.autolog.times.stamp()
preds = {"predict": outputs[2]} preds = {"predict": outputs[2]}
else: else:
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
...@@ -231,11 +255,14 @@ class TextRecognizer(object): ...@@ -231,11 +255,14 @@ class TextRecognizer(object):
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
if self.benchmark:
self.autolog.times.stamp()
preds = outputs[0] preds = outputs[0]
rec_result = self.postprocess_op(preds) rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)): for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno] rec_res[indices[beg_img_no + rno]] = rec_result[rno]
if self.benchmark:
self.autolog.times.end(stamp=True)
return rec_res, time.time() - st return rec_res, time.time() - st
...@@ -251,9 +278,6 @@ def main(args): ...@@ -251,9 +278,6 @@ def main(args):
for i in range(10): for i in range(10):
res = text_recognizer([img]) res = text_recognizer([img])
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
count = 0
for image_file in image_file_list: for image_file in image_file_list:
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
...@@ -273,6 +297,8 @@ def main(args): ...@@ -273,6 +297,8 @@ def main(args):
for ino in range(len(img_list)): for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
rec_res[ino])) rec_res[ino]))
if args.benchmark:
text_recognizer.autolog.report()
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册