提交 59af7359 编写于 作者: W WenmuZhou

inference adaptation 2.0

上级 7efa3975
...@@ -39,7 +39,6 @@ class TextClassifier(object): ...@@ -39,7 +39,6 @@ class TextClassifier(object):
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
self.cls_batch_num = args.cls_batch_num self.cls_batch_num = args.cls_batch_num
self.cls_thresh = args.cls_thresh self.cls_thresh = args.cls_thresh
self.use_zero_copy_run = args.use_zero_copy_run
postprocess_params = { postprocess_params = {
'name': 'ClsPostProcess', 'name': 'ClsPostProcess',
"label_list": args.label_list, "label_list": args.label_list,
...@@ -99,12 +98,8 @@ class TextClassifier(object): ...@@ -99,12 +98,8 @@ class TextClassifier(object):
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
starttime = time.time() starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run() self.predictor.run()
else:
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch])
prob_out = self.output_tensors[0].copy_to_cpu() prob_out = self.output_tensors[0].copy_to_cpu()
cls_result = self.postprocess_op(prob_out) cls_result = self.postprocess_op(prob_out)
elapse += time.time() - starttime elapse += time.time() - starttime
...@@ -143,10 +138,11 @@ def main(args): ...@@ -143,10 +138,11 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
ino])) cls_res[ino]))
logger.info("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) len(img_list), predict_time))
if __name__ == "__main__": if __name__ == "__main__":
main(utility.parse_args()) main(utility.parse_args())
...@@ -37,7 +37,6 @@ class TextDetector(object): ...@@ -37,7 +37,6 @@ class TextDetector(object):
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
self.det_algorithm = args.det_algorithm self.det_algorithm = args.det_algorithm
self.use_zero_copy_run = args.use_zero_copy_run
pre_process_list = [{ pre_process_list = [{
'DetResizeForTest': { 'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len, 'limit_side_len': args.det_limit_side_len,
...@@ -72,7 +71,9 @@ class TextDetector(object): ...@@ -72,7 +71,9 @@ class TextDetector(object):
postprocess_params["nms_thresh"] = args.det_east_nms_thresh postprocess_params["nms_thresh"] = args.det_east_nms_thresh
elif self.det_algorithm == "SAST": elif self.det_algorithm == "SAST":
pre_process_list[0] = { pre_process_list[0] = {
'DetResizeForTest': {'resize_long': args.det_limit_side_len} 'DetResizeForTest': {
'resize_long': args.det_limit_side_len
}
} }
postprocess_params['name'] = 'SASTPostProcess' postprocess_params['name'] = 'SASTPostProcess'
postprocess_params["score_thresh"] = args.det_sast_score_thresh postprocess_params["score_thresh"] = args.det_sast_score_thresh
...@@ -161,12 +162,8 @@ class TextDetector(object): ...@@ -161,12 +162,8 @@ class TextDetector(object):
img = img.copy() img = img.copy()
starttime = time.time() starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(img) self.input_tensor.copy_from_cpu(img)
self.predictor.zero_copy_run() self.predictor.run()
else:
im = paddle.fluid.core.PaddleTensor(img)
self.predictor.run([im])
outputs = [] outputs = []
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
......
...@@ -39,7 +39,6 @@ class TextRecognizer(object): ...@@ -39,7 +39,6 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm self.rec_algorithm = args.rec_algorithm
self.use_zero_copy_run = args.use_zero_copy_run
postprocess_params = { postprocess_params = {
'name': 'CTCLabelDecode', 'name': 'CTCLabelDecode',
"character_type": args.rec_char_type, "character_type": args.rec_char_type,
...@@ -101,12 +100,8 @@ class TextRecognizer(object): ...@@ -101,12 +100,8 @@ class TextRecognizer(object):
norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
starttime = time.time() starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run() self.predictor.run()
else:
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch])
outputs = [] outputs = []
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
...@@ -145,8 +140,8 @@ def main(args): ...@@ -145,8 +140,8 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
ino])) rec_res[ino]))
logger.info("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) len(img_list), predict_time))
......
...@@ -20,8 +20,7 @@ import numpy as np ...@@ -20,8 +20,7 @@ import numpy as np
import json import json
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import math import math
from paddle.fluid.core import AnalysisConfig from paddle import inference
from paddle.fluid.core import create_paddle_predictor
def parse_args(): def parse_args():
...@@ -83,8 +82,6 @@ def parse_args(): ...@@ -83,8 +82,6 @@ def parse_args():
parser.add_argument("--cls_thresh", type=float, default=0.9) parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False) parser.add_argument("--use_pdserving", type=str2bool, default=False)
return parser.parse_args() return parser.parse_args()
...@@ -110,14 +107,14 @@ def create_predictor(args, mode, logger): ...@@ -110,14 +107,14 @@ def create_predictor(args, mode, logger):
logger.info("not find params file path {}".format(params_file_path)) logger.info("not find params file path {}".format(params_file_path))
sys.exit(0) sys.exit(0)
config = AnalysisConfig(model_file_path, params_file_path) config = inference.Config(model_file_path, params_file_path)
if args.use_gpu: if args.use_gpu:
config.enable_use_gpu(args.gpu_mem, 0) config.enable_use_gpu(args.gpu_mem, 0)
if args.use_tensorrt: if args.use_tensorrt:
config.enable_tensorrt_engine( config.enable_tensorrt_engine(
precision_mode=AnalysisConfig.Precision.Half precision_mode=inference.PrecisionType.Half
if args.use_fp16 else AnalysisConfig.Precision.Float32, if args.use_fp16 else inference.PrecisionType.Float32,
max_batch_size=args.max_batch_size) max_batch_size=args.max_batch_size)
else: else:
config.disable_gpu() config.disable_gpu()
...@@ -130,20 +127,18 @@ def create_predictor(args, mode, logger): ...@@ -130,20 +127,18 @@ def create_predictor(args, mode, logger):
# config.enable_memory_optim() # config.enable_memory_optim()
config.disable_glog_info() config.disable_glog_info()
if args.use_zero_copy_run:
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
else:
config.switch_use_feed_fetch_ops(True)
predictor = create_paddle_predictor(config) # create predictor
predictor = inference.create_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
for name in input_names: for name in input_names:
input_tensor = predictor.get_input_tensor(name) input_tensor = predictor.get_input_handle(name)
output_names = predictor.get_output_names() output_names = predictor.get_output_names()
output_tensors = [] output_tensors = []
for output_name in output_names: for output_name in output_names:
output_tensor = predictor.get_output_tensor(output_name) output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
return predictor, input_tensor, output_tensors return predictor, input_tensor, output_tensors
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册