From a7322245cd778b246c2a0b535f1d1dd973b5c199 Mon Sep 17 00:00:00 2001 From: zhoujun Date: Thu, 18 Mar 2021 15:26:34 +0800 Subject: [PATCH] =?UTF-8?q?Revert=20"=E4=BF=AE=E5=A4=8D=E6=8E=A8=E7=90=86?= =?UTF-8?q?=E8=BF=87=E7=A8=8B=E4=B8=AD=E7=9A=84=E5=86=85=E5=AD=98=E6=B3=84?= =?UTF-8?q?=E9=9C=B2=E9=97=AE=E9=A2=98"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/infer/predict_cls.py | 2 +- tools/infer/predict_det.py | 2 +- tools/infer/predict_rec.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index d2592c6c..074172cc 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -98,10 +98,10 @@ class TextClassifier(object): norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() starttime = time.time() + self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.run() prob_out = self.output_tensors[0].copy_to_cpu() - self.predictor.try_shrink_memory() cls_result = self.postprocess_op(prob_out) elapse += time.time() - starttime for rno in range(len(cls_result)): diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index f5ea0504..b14825bd 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -180,7 +180,7 @@ class TextDetector(object): preds['maps'] = outputs[0] else: raise NotImplementedError - self.predictor.try_shrink_memory() + post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] if self.det_algorithm == "SAST" and self.det_sast_polygon: diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 1cb6e01b..b24e57dd 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -237,7 +237,7 @@ class TextRecognizer(object): output = output_tensor.copy_to_cpu() outputs.append(output) preds = outputs[0] - self.predictor.try_shrink_memory() + rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): rec_res[indices[beg_img_no + rno]] = rec_result[rno] -- GitLab