未验证 提交 ddbba2b1 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #2228 from WenmuZhou/fix_attn_export

fix attn export
...@@ -38,7 +38,7 @@ class AttentionHead(nn.Layer): ...@@ -38,7 +38,7 @@ class AttentionHead(nn.Layer):
return input_ont_hot return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25): def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0] batch_size = paddle.shape(inputs)[0]
num_steps = batch_max_length num_steps = batch_max_length
hidden = paddle.zeros((batch_size, self.hidden_size)) hidden = paddle.zeros((batch_size, self.hidden_size))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册