From c459b7256538b9e788c81c540b615f4fe7911b81 Mon Sep 17 00:00:00 2001 From: zhiminzhang0830 <452516515@qq.com> Date: Sat, 8 Oct 2022 11:20:36 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0RFL=20CNT=E5=88=86=E6=94=AFin?= =?UTF-8?q?fer=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppocr/modeling/heads/rec_rfl_head.py | 5 ++--- ppocr/postprocess/rec_postprocess.py | 10 ++++++---- tools/infer_rec.py | 14 ++++++++++---- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/ppocr/modeling/heads/rec_rfl_head.py b/ppocr/modeling/heads/rec_rfl_head.py index b5452ec1..1ded8cde 100644 --- a/ppocr/modeling/heads/rec_rfl_head.py +++ b/ppocr/modeling/heads/rec_rfl_head.py @@ -103,7 +103,6 @@ class RFLHead(nn.Layer): else: seq_outputs = self.seq_head(seq_inputs, None, self.batch_max_legnth) + return cnt_outputs, seq_outputs else: - seq_outputs = None - - return cnt_outputs, seq_outputs + return cnt_outputs diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index e754c950..40ba5c20 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -287,9 +287,8 @@ class RFLLabelDecode(BaseRecLabelDecode): return result_list def __call__(self, preds, label=None, *args, **kwargs): - cnt_pred, preds = preds - if preds is not None: - + if len(preds) == 2: + cnt_pred, preds = preds if isinstance(preds, paddle.Tensor): preds = preds.numpy() preds_idx = preds.argmax(axis=2) @@ -302,9 +301,12 @@ class RFLLabelDecode(BaseRecLabelDecode): return text, label else: + cnt_pred = preds + if isinstance(cnt_pred, paddle.Tensor): + cnt_pred = cnt_pred.numpy() cnt_length = [] for lens in cnt_pred: - length = round(paddle.sum(lens).item()) + length = round(np.sum(lens)) cnt_length.append(length) if label is None: return cnt_length diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 14b14544..cb8a6ec3 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -97,7 +97,8 @@ def main(): elif config['Architecture']['algorithm'] == "SAR": op[op_name]['keep_keys'] = ['image', 'valid_ratio'] elif config['Architecture']['algorithm'] == "RobustScanner": - op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons'] + op[op_name][ + 'keep_keys'] = ['image', 'valid_ratio', 'word_positons'] else: op[op_name]['keep_keys'] = ['image'] transforms.append(op) @@ -136,9 +137,10 @@ def main(): if config['Architecture']['algorithm'] == "RobustScanner": valid_ratio = np.expand_dims(batch[1], axis=0) word_positons = np.expand_dims(batch[2], axis=0) - img_metas = [paddle.to_tensor(valid_ratio), - paddle.to_tensor(word_positons), - ] + img_metas = [ + paddle.to_tensor(valid_ratio), + paddle.to_tensor(word_positons), + ] images = np.expand_dims(batch[0], axis=0) images = paddle.to_tensor(images) if config['Architecture']['algorithm'] == "SRN": @@ -160,6 +162,10 @@ def main(): "score": float(post_result[key][0][1]), } info = json.dumps(rec_info, ensure_ascii=False) + elif isinstance(post_result, list) and isinstance(post_result[0], + int): + # for RFLearning CNT branch + info = str(post_result[0]) else: if len(post_result[0]) >= 2: info = post_result[0][0] + "\t" + str(post_result[0][1]) -- GitLab