提交 0a011e56 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add python interface

上级 3ebebae3
......@@ -42,6 +42,7 @@ class TextDetector(object):
def __init__(self, args):
max_side_len = args.det_max_side_len
self.det_algorithm = args.det_algorithm
self.use_zero_copy_run = args.use_zero_copy_run
preprocess_params = {'max_side_len': max_side_len}
postprocess_params = {}
if self.det_algorithm == "DB":
......@@ -138,8 +139,12 @@ class TextDetector(object):
return None, 0
im = im.copy()
starttime = time.time()
im = fluid.core.PaddleTensor(im)
self.predictor.run([im])
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(im)
self.predictor.zero_copy_run()
else:
im = fluid.core.PaddleTensor(im)
self.predictor.run([im])
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
......
......@@ -40,6 +40,7 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
self.use_zero_copy_run = args.use_zero_copy_run
char_ops_params = {
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
......@@ -105,8 +106,12 @@ class TextRecognizer(object):
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
starttime = time.time()
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch])
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
else:
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch])
if self.loss_type == "ctc":
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
......
......@@ -71,6 +71,7 @@ def parse_args():
default="./ppocr/utils/ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=bool, default=True)
parser.add_argument("--enable_mkldnn", type=bool, default=False)
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
return parser.parse_args()
......@@ -105,8 +106,12 @@ def create_predictor(args, mode):
#config.enable_memory_optim()
config.disable_glog_info()
# use zero copy
config.switch_use_feed_fetch_ops(True)
if args.use_zero_copy_run:
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False)
else:
config.switch_use_feed_fetch_ops(True)
predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_tensor(input_names[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册