提交 e6e8fa74 编写于 作者: H Hui Zhang

paddle.broadcast_shape; log_softmax; equal(zeros); register_buffer

上级 3c2dcfad
...@@ -616,7 +616,7 @@ class U2Tester(U2Trainer): ...@@ -616,7 +616,7 @@ class U2Tester(U2Trainer):
shape=[1, encoder_max_time, encoder_model_size], shape=[1, encoder_max_time, encoder_model_size],
dtype='float32'), # encoder_out dtype='float32'), # encoder_out
]) ])
logger.info(f"Export code: {static_model}") logger.info(f"Export code: {static_model.main_program}")
paddle.jit.save(static_model, self.args.export_path) paddle.jit.save(static_model, self.args.export_path)
......
...@@ -954,5 +954,5 @@ class U2InferModel(U2Model): ...@@ -954,5 +954,5 @@ class U2InferModel(U2Model):
# (num_hyps, max_hyps_len, vocab_size) # (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder.export(encoder_out, encoder_mask, hyps, decoder_out, _ = self.decoder.export(encoder_out, encoder_mask, hyps,
hyps_masks) hyps_masks)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1) decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
return decoder_out return decoder_out
...@@ -99,11 +99,11 @@ class MultiHeadedAttention(nn.Layer): ...@@ -99,11 +99,11 @@ class MultiHeadedAttention(nn.Layer):
""" """
n_batch = value.shape[0] n_batch = value.shape[0]
if mask is not None: if mask is not None:
# TODO(Hui Zhang): slice not support `int`; paddle not has `scalar` tensor.
mask = mask.unsqueeze(1).equal( mask = mask.unsqueeze(1).equal(
paddle.to_tensor(0, dtype=mask.dtype)) # (batch, 1, *, time2) paddle.zeros([1], dtype=mask.dtype)) # (batch, 1, *, time2)
scores = masked_fill(scores, mask, -float('inf')) scores = masked_fill(scores, mask, -float('inf'))
attn = paddle.softmax( attn = paddle.softmax(scores, axis=-1)
scores, axis=-1)
attn = masked_fill(attn, mask, 0.0) # (batch, head, time1, time2) attn = masked_fill(attn, mask, 0.0) # (batch, head, time1, time2)
else: else:
attn = paddle.softmax( attn = paddle.softmax(
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
"""Unility functions for Transformer.""" """Unility functions for Transformer."""
from typing import List from typing import List
from typing import Tuple
from typing import Optional from typing import Optional
from typing import Tuple
from typing import Union from typing import Union
import paddle import paddle
...@@ -25,6 +25,7 @@ __all__ = ["masked_fill", "pad_sequence", "add_sos_eos", "th_accuracy"] ...@@ -25,6 +25,7 @@ __all__ = ["masked_fill", "pad_sequence", "add_sos_eos", "th_accuracy"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
def is_broadcastable(shp1, shp2): def is_broadcastable(shp1, shp2):
for a, b in zip(shp1[::-1], shp2[::-1]): for a, b in zip(shp1[::-1], shp2[::-1]):
if a == 1 or b == 1 or a == b: if a == 1 or b == 1 or a == b:
...@@ -33,17 +34,22 @@ def is_broadcastable(shp1, shp2): ...@@ -33,17 +34,22 @@ def is_broadcastable(shp1, shp2):
return False return False
return True return True
def masked_fill(xs: paddle.Tensor, def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
value: Union[float, int]): value: Union[float, int]):
if paddle.in_dynamic_mode(): if paddle.in_dynamic_mode():
assert is_broadcastable(xs.shape, mask.shape) is True assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape) bshape = paddle.broadcast_shape(xs.shape, mask.shape)
else:
# TODO(Hui Zhang): support broadcast_shape in static graph
bshape = xs.shape
mask = mask.broadcast_to(bshape) mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs) xs = paddle.where(mask, trues, xs)
return xs return xs
def pad_sequence(sequences: List[paddle.Tensor], def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False, batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor: padding_value: float=0.0) -> paddle.Tensor:
......
...@@ -109,8 +109,7 @@ class STFT(nn.Layer): ...@@ -109,8 +109,7 @@ class STFT(nn.Layer):
# (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size) # (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size)
w = np.expand_dims(w, 1) w = np.expand_dims(w, 1)
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype()) self.weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
self.register_buffer("weight", weight)
def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]: def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]:
"""Compute the stft transform. """Compute the stft transform.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册