提交 7a50b68f 编写于 作者: Y Youwei Song 提交者: hong

fix input shape for new Embedding (#4037)

test=develop
上级 182d5092
...@@ -4,3 +4,4 @@ ...@@ -4,3 +4,4 @@
*.pyc *.pyc
*~ *~
*.vscode *.vscode
*.idea
\ No newline at end of file
...@@ -54,12 +54,12 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): ...@@ -54,12 +54,12 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
""" """
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len, 1) src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len, 1) src_pos = src_pos.reshape(-1, src_max_len)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len, 1) trg_word = trg_word.reshape(-1, trg_max_len)
trg_pos = trg_pos.reshape(-1, trg_max_len, 1) trg_pos = trg_pos.reshape(-1, trg_max_len)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32") [1, 1, trg_max_len, 1]).astype("float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册