提交 c459b725 编写于 作者: z37757's avatar z37757

添加RFL CNT分支infer支持

上级 3f8602c1
...@@ -103,7 +103,6 @@ class RFLHead(nn.Layer): ...@@ -103,7 +103,6 @@ class RFLHead(nn.Layer):
else: else:
seq_outputs = self.seq_head(seq_inputs, None, seq_outputs = self.seq_head(seq_inputs, None,
self.batch_max_legnth) self.batch_max_legnth)
return cnt_outputs, seq_outputs
else: else:
seq_outputs = None return cnt_outputs
return cnt_outputs, seq_outputs
...@@ -287,9 +287,8 @@ class RFLLabelDecode(BaseRecLabelDecode): ...@@ -287,9 +287,8 @@ class RFLLabelDecode(BaseRecLabelDecode):
return result_list return result_list
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
cnt_pred, preds = preds if len(preds) == 2:
if preds is not None: cnt_pred, preds = preds
if isinstance(preds, paddle.Tensor): if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
preds_idx = preds.argmax(axis=2) preds_idx = preds.argmax(axis=2)
...@@ -302,9 +301,12 @@ class RFLLabelDecode(BaseRecLabelDecode): ...@@ -302,9 +301,12 @@ class RFLLabelDecode(BaseRecLabelDecode):
return text, label return text, label
else: else:
cnt_pred = preds
if isinstance(cnt_pred, paddle.Tensor):
cnt_pred = cnt_pred.numpy()
cnt_length = [] cnt_length = []
for lens in cnt_pred: for lens in cnt_pred:
length = round(paddle.sum(lens).item()) length = round(np.sum(lens))
cnt_length.append(length) cnt_length.append(length)
if label is None: if label is None:
return cnt_length return cnt_length
......
...@@ -97,7 +97,8 @@ def main(): ...@@ -97,7 +97,8 @@ def main():
elif config['Architecture']['algorithm'] == "SAR": elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = ['image', 'valid_ratio'] op[op_name]['keep_keys'] = ['image', 'valid_ratio']
elif config['Architecture']['algorithm'] == "RobustScanner": elif config['Architecture']['algorithm'] == "RobustScanner":
op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons'] op[op_name][
'keep_keys'] = ['image', 'valid_ratio', 'word_positons']
else: else:
op[op_name]['keep_keys'] = ['image'] op[op_name]['keep_keys'] = ['image']
transforms.append(op) transforms.append(op)
...@@ -136,9 +137,10 @@ def main(): ...@@ -136,9 +137,10 @@ def main():
if config['Architecture']['algorithm'] == "RobustScanner": if config['Architecture']['algorithm'] == "RobustScanner":
valid_ratio = np.expand_dims(batch[1], axis=0) valid_ratio = np.expand_dims(batch[1], axis=0)
word_positons = np.expand_dims(batch[2], axis=0) word_positons = np.expand_dims(batch[2], axis=0)
img_metas = [paddle.to_tensor(valid_ratio), img_metas = [
paddle.to_tensor(word_positons), paddle.to_tensor(valid_ratio),
] paddle.to_tensor(word_positons),
]
images = np.expand_dims(batch[0], axis=0) images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN": if config['Architecture']['algorithm'] == "SRN":
...@@ -160,6 +162,10 @@ def main(): ...@@ -160,6 +162,10 @@ def main():
"score": float(post_result[key][0][1]), "score": float(post_result[key][0][1]),
} }
info = json.dumps(rec_info, ensure_ascii=False) info = json.dumps(rec_info, ensure_ascii=False)
elif isinstance(post_result, list) and isinstance(post_result[0],
int):
# for RFLearning CNT branch
info = str(post_result[0])
else: else:
if len(post_result[0]) >= 2: if len(post_result[0]) >= 2:
info = post_result[0][0] + "\t" + str(post_result[0][1]) info = post_result[0][0] + "\t" + str(post_result[0][1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册