diff --git a/ppocr/data/imaug/randaugment.py b/ppocr/data/imaug/randaugment.py index 0bfac353906535464eaa6637c3edbc7f0c938502..56f114d2f665f9b326e96819ac3d606c87a6e142 100644 --- a/ppocr/data/imaug/randaugment.py +++ b/ppocr/data/imaug/randaugment.py @@ -117,13 +117,16 @@ class RawRandAugment(object): class RandAugment(RawRandAugment): """ RandAugment wrapper to auto fit different img types """ - def __init__(self, *args, **kwargs): + def __init__(self, prob=0.5, *args, **kwargs): + self.prob = prob if six.PY2: super(RandAugment, self).__init__(*args, **kwargs) else: super().__init__(*args, **kwargs) def __call__(self, data): + if np.random.rand() > self.prob: + return data img = data['image'] if not isinstance(img, Image.Image): img = np.ascontiguousarray(img) diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 074172cc947cdc03b21392cf7b109971763f796a..d2592c6c95b0f466ea3ad5b45a35781282c9a492 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -98,10 +98,10 @@ class TextClassifier(object): norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() starttime = time.time() - self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.run() prob_out = self.output_tensors[0].copy_to_cpu() + self.predictor.try_shrink_memory() cls_result = self.postprocess_op(prob_out) elapse += time.time() - starttime for rno in range(len(cls_result)): diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index b14825bdd8bad55b709d84bdf6df6575d90c7d95..f5ea0504f97f3e40853d431061f7086653f2628e 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -180,7 +180,7 @@ class TextDetector(object): preds['maps'] = outputs[0] else: raise NotImplementedError - + self.predictor.try_shrink_memory() post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] if self.det_algorithm == "SAST" and self.det_sast_polygon: diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index b24e57dd973bc0216f2875232bcec6e36ab47e29..1cb6e01b087ff98efb0a57be3cc58a79425fea57 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -237,7 +237,7 @@ class TextRecognizer(object): output = output_tensor.copy_to_cpu() outputs.append(output) preds = outputs[0] - + self.predictor.try_shrink_memory() rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): rec_res[indices[beg_img_no + rno]] = rec_result[rno] diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 9aa0afed635481859cd31d461a97c451ca72acdc..7391e9365694361c20395b4a63a101bb093f5f94 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -145,7 +145,8 @@ def create_predictor(args, mode, logger): #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) args.rec_batch_num = 1 - # config.enable_memory_optim() + # enable memory optim + config.enable_memory_optim() config.disable_glog_info() config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")