提交 c459b725 编写于 作者: z37757's avatar z37757

添加RFL CNT分支infer支持

上级 3f8602c1
......@@ -103,7 +103,6 @@ class RFLHead(nn.Layer):
else:
seq_outputs = self.seq_head(seq_inputs, None,
self.batch_max_legnth)
return cnt_outputs, seq_outputs
else:
seq_outputs = None
return cnt_outputs, seq_outputs
return cnt_outputs
......@@ -287,9 +287,8 @@ class RFLLabelDecode(BaseRecLabelDecode):
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
cnt_pred, preds = preds
if preds is not None:
if len(preds) == 2:
cnt_pred, preds = preds
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
......@@ -302,9 +301,12 @@ class RFLLabelDecode(BaseRecLabelDecode):
return text, label
else:
cnt_pred = preds
if isinstance(cnt_pred, paddle.Tensor):
cnt_pred = cnt_pred.numpy()
cnt_length = []
for lens in cnt_pred:
length = round(paddle.sum(lens).item())
length = round(np.sum(lens))
cnt_length.append(length)
if label is None:
return cnt_length
......
......@@ -97,7 +97,8 @@ def main():
elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
elif config['Architecture']['algorithm'] == "RobustScanner":
op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons']
op[op_name][
'keep_keys'] = ['image', 'valid_ratio', 'word_positons']
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
......@@ -136,9 +137,10 @@ def main():
if config['Architecture']['algorithm'] == "RobustScanner":
valid_ratio = np.expand_dims(batch[1], axis=0)
word_positons = np.expand_dims(batch[2], axis=0)
img_metas = [paddle.to_tensor(valid_ratio),
paddle.to_tensor(word_positons),
]
img_metas = [
paddle.to_tensor(valid_ratio),
paddle.to_tensor(word_positons),
]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
......@@ -160,6 +162,10 @@ def main():
"score": float(post_result[key][0][1]),
}
info = json.dumps(rec_info, ensure_ascii=False)
elif isinstance(post_result, list) and isinstance(post_result[0],
int):
# for RFLearning CNT branch
info = str(post_result[0])
else:
if len(post_result[0]) >= 2:
info = post_result[0][0] + "\t" + str(post_result[0][1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册