提交 483e5038 编写于 作者: z37757's avatar z37757

通过变量类型判断是否是visual

上级 c25eec88
......@@ -36,22 +36,26 @@ class RFLLoss(nn.Layer):
self.total_loss = {}
total_loss = 0.0
if isinstance(predicts, tuple) or isinstance(predicts, list):
cnt_outputs, seq_outputs = predicts
else:
cnt_outputs, seq_outputs = predicts, None
# batch [image, label, length, cnt_label]
if predicts[0] is not None:
cnt_loss = self.cnt_loss(predicts[0],
if cnt_outputs is not None:
cnt_loss = self.cnt_loss(cnt_outputs,
paddle.cast(batch[3], paddle.float32))
self.total_loss['cnt_loss'] = cnt_loss
total_loss += cnt_loss
if predicts[1] is not None:
if seq_outputs is not None:
targets = batch[1].astype("int64")
label_lengths = batch[2].astype('int64')
batch_size, num_steps, num_classes = predicts[1].shape[0], predicts[
1].shape[1], predicts[1].shape[2]
assert len(targets.shape) == len(list(predicts[1].shape)) - 1, \
batch_size, num_steps, num_classes = seq_outputs.shape[
0], seq_outputs.shape[1], seq_outputs.shape[2]
assert len(targets.shape) == len(list(seq_outputs.shape)) - 1, \
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs = predicts[1][:, :-1, :]
inputs = seq_outputs[:, :-1, :]
targets = targets[:, 1:]
inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]])
......
......@@ -287,12 +287,13 @@ class RFLLabelDecode(BaseRecLabelDecode):
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
if len(preds) == 2:
cnt_pred, preds = preds
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
# if seq_outputs is not None:
if isinstance(preds, tuple) or isinstance(preds, list):
cnt_outputs, seq_outputs = preds
if isinstance(seq_outputs, paddle.Tensor):
seq_outputs = seq_outputs.numpy()
preds_idx = seq_outputs.argmax(axis=2)
preds_prob = seq_outputs.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
......@@ -301,11 +302,11 @@ class RFLLabelDecode(BaseRecLabelDecode):
return text, label
else:
cnt_pred = preds
if isinstance(cnt_pred, paddle.Tensor):
cnt_pred = cnt_pred.numpy()
cnt_outputs = preds
if isinstance(cnt_outputs, paddle.Tensor):
cnt_outputs = cnt_outputs.numpy()
cnt_length = []
for lens in cnt_pred:
for lens in cnt_outputs:
length = round(np.sum(lens))
cnt_length.append(length)
if label is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册