From a094d2775560a6dbb6e18cd761b99edd238956c2 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Mon, 1 Feb 2021 08:08:18 +0000 Subject: [PATCH] opt rec_att_head --- ppocr/modeling/heads/rec_att_head.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py index bfe37e7a..a7cfe128 100644 --- a/ppocr/modeling/heads/rec_att_head.py +++ b/ppocr/modeling/heads/rec_att_head.py @@ -64,8 +64,10 @@ class AttentionHead(nn.Layer): (outputs, hidden), alpha = self.attention_cell(hidden, inputs, char_onehots) probs_step = self.generator(outputs) - probs = paddle.unsqueeze( - probs_step, axis=1) if probs is None else paddle.concat( + if probs is None: + probs = paddle.unsqueeze(probs_step, axis=1) + else: + probs = paddle.concat( [probs, paddle.unsqueeze( probs_step, axis=1)], axis=1) next_input = probs_step.argmax(axis=1) @@ -152,8 +154,10 @@ class AttentionLSTM(nn.Layer): char_onehots) probs_step = self.generator(hidden[0]) hidden = (hidden[1][0], hidden[1][1]) - probs = paddle.unsqueeze( - probs_step, axis=1) if probs is None else paddle.concat( + if probs is None: + probs = paddle.unsqueeze(probs_step, axis=1) + else: + probs = paddle.concat( [probs, paddle.unsqueeze( probs_step, axis=1)], axis=1) -- GitLab