未验证 提交 e2341b21 编写于 作者: L liu zhengxi 提交者: GitHub

Transfer dynamic_decode and BeamSearchDecoder to v2 op (#35656)

上级 91cf918f
...@@ -958,7 +958,7 @@ class BeamSearchDecoder(Decoder): ...@@ -958,7 +958,7 @@ class BeamSearchDecoder(Decoder):
x = nn.unsqueeze(x, [1]) # [batch_size, 1, ...] x = nn.unsqueeze(x, [1]) # [batch_size, 1, ...]
expand_times = [1] * len(x.shape) expand_times = [1] * len(x.shape)
expand_times[1] = beam_size expand_times[1] = beam_size
x = nn.expand(x, expand_times) # [batch_size, beam_size, ...] x = paddle.tile(x, expand_times) # [batch_size, beam_size, ...]
x = nn.transpose(x, list(range(2, len(x.shape))) + x = nn.transpose(x, list(range(2, len(x.shape))) +
[0, 1]) # [..., batch_size, beam_size] [0, 1]) # [..., batch_size, beam_size]
# use 0 to copy to avoid wrong shape # use 0 to copy to avoid wrong shape
...@@ -1024,7 +1024,7 @@ class BeamSearchDecoder(Decoder): ...@@ -1024,7 +1024,7 @@ class BeamSearchDecoder(Decoder):
x = nn.unsqueeze(x, [1]) x = nn.unsqueeze(x, [1])
expand_times = [1] * len(x.shape) expand_times = [1] * len(x.shape)
expand_times[1] = self.beam_size expand_times[1] = self.beam_size
x = nn.expand(x, expand_times) x = paddle.tile(x, expand_times)
return x return x
def _mask_probs(self, probs, finished): def _mask_probs(self, probs, finished):
...@@ -1050,7 +1050,7 @@ class BeamSearchDecoder(Decoder): ...@@ -1050,7 +1050,7 @@ class BeamSearchDecoder(Decoder):
# TODO: use where_op # TODO: use where_op
finished = tensor.cast(finished, dtype=probs.dtype) finished = tensor.cast(finished, dtype=probs.dtype)
probs = nn.elementwise_mul( probs = nn.elementwise_mul(
nn.expand(nn.unsqueeze(finished, [2]), [1, 1, self.vocab_size]), paddle.tile(nn.unsqueeze(finished, [2]), [1, 1, self.vocab_size]),
self.noend_mask_tensor, self.noend_mask_tensor,
axis=-1) - nn.elementwise_mul( axis=-1) - nn.elementwise_mul(
probs, (finished - 1), axis=0) probs, (finished - 1), axis=0)
...@@ -1080,7 +1080,7 @@ class BeamSearchDecoder(Decoder): ...@@ -1080,7 +1080,7 @@ class BeamSearchDecoder(Decoder):
batch_size, batch_size,
indices.dtype) if batch_size.dtype != indices.dtype else batch_size indices.dtype) if batch_size.dtype != indices.dtype else batch_size
batch_size.stop_gradient = True # TODO: remove this batch_size.stop_gradient = True # TODO: remove this
batch_pos = nn.expand( batch_pos = paddle.tile(
nn.unsqueeze( nn.unsqueeze(
tensor.range( tensor.range(
0, batch_size, 1, dtype=indices.dtype), [1]), 0, batch_size, 1, dtype=indices.dtype), [1]),
...@@ -1140,12 +1140,11 @@ class BeamSearchDecoder(Decoder): ...@@ -1140,12 +1140,11 @@ class BeamSearchDecoder(Decoder):
init_cell_states = map_structure(self._expand_to_beam_size, init_cell_states = map_structure(self._expand_to_beam_size,
initial_cell_states) initial_cell_states)
# TODO: use fill_constant when support variable shape init_inputs = paddle.full(
init_inputs = nn.expand( shape=[self.batch_size, self.beam_size],
nn.unsqueeze( fill_value=self.start_token_tensor,
nn.expand(self.start_token_tensor, [self.batch_size]), [1]), dtype=self.start_token_tensor.dtype)
[1, self.beam_size]) log_probs = paddle.tile(
log_probs = nn.expand(
tensor.assign( tensor.assign(
np.array( np.array(
[[0.] + [-self.kinf] * (self.beam_size - 1)], [[0.] + [-self.kinf] * (self.beam_size - 1)],
...@@ -1213,7 +1212,7 @@ class BeamSearchDecoder(Decoder): ...@@ -1213,7 +1212,7 @@ class BeamSearchDecoder(Decoder):
scores = log_probs scores = log_probs
scores = nn.reshape(scores, [-1, self.beam_size * self.vocab_size]) scores = nn.reshape(scores, [-1, self.beam_size * self.vocab_size])
# TODO: add grad for topk then this beam search can be used to train # TODO: add grad for topk then this beam search can be used to train
topk_scores, topk_indices = nn.topk(input=scores, k=self.beam_size) topk_scores, topk_indices = paddle.topk(x=scores, k=self.beam_size)
beam_indices = nn.elementwise_floordiv(topk_indices, beam_indices = nn.elementwise_floordiv(topk_indices,
self.vocab_size_tensor) self.vocab_size_tensor)
token_indices = nn.elementwise_mod(topk_indices, self.vocab_size_tensor) token_indices = nn.elementwise_mod(topk_indices, self.vocab_size_tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册