未验证 提交 e2341b21 编写于 作者: L liu zhengxi 提交者: GitHub

Transfer dynamic_decode and BeamSearchDecoder to v2 op (#35656)

上级 91cf918f
......@@ -958,7 +958,7 @@ class BeamSearchDecoder(Decoder):
x = nn.unsqueeze(x, [1]) # [batch_size, 1, ...]
expand_times = [1] * len(x.shape)
expand_times[1] = beam_size
x = nn.expand(x, expand_times) # [batch_size, beam_size, ...]
x = paddle.tile(x, expand_times) # [batch_size, beam_size, ...]
x = nn.transpose(x, list(range(2, len(x.shape))) +
[0, 1]) # [..., batch_size, beam_size]
# use 0 to copy to avoid wrong shape
......@@ -1024,7 +1024,7 @@ class BeamSearchDecoder(Decoder):
x = nn.unsqueeze(x, [1])
expand_times = [1] * len(x.shape)
expand_times[1] = self.beam_size
x = nn.expand(x, expand_times)
x = paddle.tile(x, expand_times)
return x
def _mask_probs(self, probs, finished):
......@@ -1050,7 +1050,7 @@ class BeamSearchDecoder(Decoder):
# TODO: use where_op
finished = tensor.cast(finished, dtype=probs.dtype)
probs = nn.elementwise_mul(
nn.expand(nn.unsqueeze(finished, [2]), [1, 1, self.vocab_size]),
paddle.tile(nn.unsqueeze(finished, [2]), [1, 1, self.vocab_size]),
self.noend_mask_tensor,
axis=-1) - nn.elementwise_mul(
probs, (finished - 1), axis=0)
......@@ -1080,7 +1080,7 @@ class BeamSearchDecoder(Decoder):
batch_size,
indices.dtype) if batch_size.dtype != indices.dtype else batch_size
batch_size.stop_gradient = True # TODO: remove this
batch_pos = nn.expand(
batch_pos = paddle.tile(
nn.unsqueeze(
tensor.range(
0, batch_size, 1, dtype=indices.dtype), [1]),
......@@ -1140,12 +1140,11 @@ class BeamSearchDecoder(Decoder):
init_cell_states = map_structure(self._expand_to_beam_size,
initial_cell_states)
# TODO: use fill_constant when support variable shape
init_inputs = nn.expand(
nn.unsqueeze(
nn.expand(self.start_token_tensor, [self.batch_size]), [1]),
[1, self.beam_size])
log_probs = nn.expand(
init_inputs = paddle.full(
shape=[self.batch_size, self.beam_size],
fill_value=self.start_token_tensor,
dtype=self.start_token_tensor.dtype)
log_probs = paddle.tile(
tensor.assign(
np.array(
[[0.] + [-self.kinf] * (self.beam_size - 1)],
......@@ -1213,7 +1212,7 @@ class BeamSearchDecoder(Decoder):
scores = log_probs
scores = nn.reshape(scores, [-1, self.beam_size * self.vocab_size])
# TODO: add grad for topk then this beam search can be used to train
topk_scores, topk_indices = nn.topk(input=scores, k=self.beam_size)
topk_scores, topk_indices = paddle.topk(x=scores, k=self.beam_size)
beam_indices = nn.elementwise_floordiv(topk_indices,
self.vocab_size_tensor)
token_indices = nn.elementwise_mod(topk_indices, self.vocab_size_tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册