未验证 提交 c8094e65 编写于 作者: T topduke 提交者: GitHub

Update rec_nrtr_optim_head.py

上级 c6359258
......@@ -216,7 +216,7 @@ class TransformerOptim(nn.Layer):
new_shape = (n_curr_active_inst * n_bm, *d_hs)
beamed_tensor = beamed_tensor.reshape(
[n_prev_active_inst, -1]) #contiguous()
[n_prev_active_inst, -1])
beamed_tensor = beamed_tensor.index_select(
paddle.to_tensor(curr_active_inst_idx), axis=0)
beamed_tensor = beamed_tensor.reshape([*new_shape])
......@@ -337,7 +337,7 @@ class TransformerOptim(nn.Layer):
n_inst, len_s, d_h = src_enc.shape
src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1)
src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose(
[1, 0, 2]) #repeat(1, n_bm, 1)
[1, 0, 2])
#-- Prepare beams
inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册