提交 550022ea 编写于 作者: L LDOUBLEV

fix comment

上级 e7d24ac8
......@@ -211,7 +211,7 @@ class AttnLabelEncode(BaseRecLabelEncode):
text = self.encode(text)
if text is None:
return None
if len(text) > self.max_text_len:
if len(text) >= self.max_text_len:
return None
data['length'] = np.array(len(text))
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
......
......@@ -194,18 +194,3 @@ class AttentionLSTMCell(nn.Layer):
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
if __name__ == '__main__':
paddle.disable_static()
model = Attention(100, 200, 10)
x = np.random.uniform(-1, 1, [2, 10, 100]).astype(np.float32)
y = np.random.randint(0, 10, [2, 21]).astype(np.int32)
xp = paddle.to_tensor(x)
yp = paddle.to_tensor(y)
res = model(inputs=xp, targets=yp, is_train=True, batch_max_length=20)
print("res: ", res.shape)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册