未验证 提交 93c919e6 编写于 作者: D Double_V 提交者: GitHub

Merge branch 'dygraph' into pgnet-v1

...@@ -117,13 +117,16 @@ class RawRandAugment(object): ...@@ -117,13 +117,16 @@ class RawRandAugment(object):
class RandAugment(RawRandAugment): class RandAugment(RawRandAugment):
""" RandAugment wrapper to auto fit different img types """ """ RandAugment wrapper to auto fit different img types """
def __init__(self, *args, **kwargs): def __init__(self, prob=0.5, *args, **kwargs):
self.prob = prob
if six.PY2: if six.PY2:
super(RandAugment, self).__init__(*args, **kwargs) super(RandAugment, self).__init__(*args, **kwargs)
else: else:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def __call__(self, data): def __call__(self, data):
if np.random.rand() > self.prob:
return data
img = data['image'] img = data['image']
if not isinstance(img, Image.Image): if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img) img = np.ascontiguousarray(img)
......
...@@ -98,10 +98,10 @@ class TextClassifier(object): ...@@ -98,10 +98,10 @@ class TextClassifier(object):
norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
starttime = time.time() starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run() self.predictor.run()
prob_out = self.output_tensors[0].copy_to_cpu() prob_out = self.output_tensors[0].copy_to_cpu()
self.predictor.try_shrink_memory()
cls_result = self.postprocess_op(prob_out) cls_result = self.postprocess_op(prob_out)
elapse += time.time() - starttime elapse += time.time() - starttime
for rno in range(len(cls_result)): for rno in range(len(cls_result)):
......
...@@ -180,7 +180,7 @@ class TextDetector(object): ...@@ -180,7 +180,7 @@ class TextDetector(object):
preds['maps'] = outputs[0] preds['maps'] = outputs[0]
else: else:
raise NotImplementedError raise NotImplementedError
self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']
if self.det_algorithm == "SAST" and self.det_sast_polygon: if self.det_algorithm == "SAST" and self.det_sast_polygon:
......
...@@ -237,7 +237,7 @@ class TextRecognizer(object): ...@@ -237,7 +237,7 @@ class TextRecognizer(object):
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
preds = outputs[0] preds = outputs[0]
self.predictor.try_shrink_memory()
rec_result = self.postprocess_op(preds) rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)): for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno] rec_res[indices[beg_img_no + rno]] = rec_result[rno]
......
...@@ -145,7 +145,8 @@ def create_predictor(args, mode, logger): ...@@ -145,7 +145,8 @@ def create_predictor(args, mode, logger):
#config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'})
args.rec_batch_num = 1 args.rec_batch_num = 1
# config.enable_memory_optim() # enable memory optim
config.enable_memory_optim()
config.disable_glog_info() config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册