提交 e6e8fa74 编写于 作者: H Hui Zhang

paddle.broadcast_shape; log_softmax; equal(zeros); register_buffer

上级 3c2dcfad
...@@ -616,7 +616,7 @@ class U2Tester(U2Trainer): ...@@ -616,7 +616,7 @@ class U2Tester(U2Trainer):
shape=[1, encoder_max_time, encoder_model_size], shape=[1, encoder_max_time, encoder_model_size],
dtype='float32'), # encoder_out dtype='float32'), # encoder_out
]) ])
logger.info(f"Export code: {static_model}") logger.info(f"Export code: {static_model.main_program}")
paddle.jit.save(static_model, self.args.export_path) paddle.jit.save(static_model, self.args.export_path)
......
...@@ -954,5 +954,5 @@ class U2InferModel(U2Model): ...@@ -954,5 +954,5 @@ class U2InferModel(U2Model):
# (num_hyps, max_hyps_len, vocab_size) # (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder.export(encoder_out, encoder_mask, hyps, decoder_out, _ = self.decoder.export(encoder_out, encoder_mask, hyps,
hyps_masks) hyps_masks)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1) decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
return decoder_out return decoder_out
...@@ -99,11 +99,11 @@ class MultiHeadedAttention(nn.Layer): ...@@ -99,11 +99,11 @@ class MultiHeadedAttention(nn.Layer):
""" """
n_batch = value.shape[0] n_batch = value.shape[0]
if mask is not None: if mask is not None:
# TODO(Hui Zhang): slice not support `int`; paddle not has `scalar` tensor.
mask = mask.unsqueeze(1).equal( mask = mask.unsqueeze(1).equal(
paddle.to_tensor(0, dtype=mask.dtype)) # (batch, 1, *, time2) paddle.zeros([1], dtype=mask.dtype)) # (batch, 1, *, time2)
scores = masked_fill(scores, mask, -float('inf')) scores = masked_fill(scores, mask, -float('inf'))
attn = paddle.softmax( attn = paddle.softmax(scores, axis=-1)
scores, axis=-1)
attn = masked_fill(attn, mask, 0.0) # (batch, head, time1, time2) attn = masked_fill(attn, mask, 0.0) # (batch, head, time1, time2)
else: else:
attn = paddle.softmax( attn = paddle.softmax(
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
"""Unility functions for Transformer.""" """Unility functions for Transformer."""
from typing import List from typing import List
from typing import Tuple
from typing import Optional from typing import Optional
from typing import Tuple
from typing import Union from typing import Union
import paddle import paddle
...@@ -25,6 +25,7 @@ __all__ = ["masked_fill", "pad_sequence", "add_sos_eos", "th_accuracy"] ...@@ -25,6 +25,7 @@ __all__ = ["masked_fill", "pad_sequence", "add_sos_eos", "th_accuracy"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
def is_broadcastable(shp1, shp2): def is_broadcastable(shp1, shp2):
for a, b in zip(shp1[::-1], shp2[::-1]): for a, b in zip(shp1[::-1], shp2[::-1]):
if a == 1 or b == 1 or a == b: if a == 1 or b == 1 or a == b:
...@@ -33,17 +34,22 @@ def is_broadcastable(shp1, shp2): ...@@ -33,17 +34,22 @@ def is_broadcastable(shp1, shp2):
return False return False
return True return True
def masked_fill(xs: paddle.Tensor, def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
value: Union[float, int]): value: Union[float, int]):
if paddle.in_dynamic_mode(): if paddle.in_dynamic_mode():
assert is_broadcastable(xs.shape, mask.shape) is True assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape) bshape = paddle.broadcast_shape(xs.shape, mask.shape)
else:
# TODO(Hui Zhang): support broadcast_shape in static graph
bshape = xs.shape
mask = mask.broadcast_to(bshape) mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs) xs = paddle.where(mask, trues, xs)
return xs return xs
def pad_sequence(sequences: List[paddle.Tensor], def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False, batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor: padding_value: float=0.0) -> paddle.Tensor:
...@@ -184,4 +190,4 @@ def th_accuracy(pad_outputs: paddle.Tensor, ...@@ -184,4 +190,4 @@ def th_accuracy(pad_outputs: paddle.Tensor,
#TODO(Hui Zhang): sum not support bool type #TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask) # denominator = paddle.sum(mask)
denominator = paddle.sum(mask.astype(pad_targets.dtype)) denominator = paddle.sum(mask.astype(pad_targets.dtype))
return float(numerator) / float(denominator) return float(numerator) / float(denominator)
\ No newline at end of file
...@@ -24,7 +24,7 @@ def frame(x: Tensor, ...@@ -24,7 +24,7 @@ def frame(x: Tensor,
hop_length : int hop_length : int
Number of samples shifted between ajancent frames. Number of samples shifted between ajancent frames.
clip : bool, optional clip : bool, optional
Whether to clip audio that does not fit into the last frame, by Whether to clip audio that does not fit into the last frame, by
default True default True
Returns Returns
...@@ -53,16 +53,16 @@ def frame(x: Tensor, ...@@ -53,16 +53,16 @@ def frame(x: Tensor,
class STFT(nn.Layer): class STFT(nn.Layer):
"""A module for computing stft transformation in a differentiable way. """A module for computing stft transformation in a differentiable way.
Parameters Parameters
------------ ------------
n_fft : int n_fft : int
Number of samples in a frame. Number of samples in a frame.
hop_length : int hop_length : int
Number of samples shifted between adjacent frames. Number of samples shifted between adjacent frames.
win_length : int win_length : int
Length of the window. Length of the window.
...@@ -109,8 +109,7 @@ class STFT(nn.Layer): ...@@ -109,8 +109,7 @@ class STFT(nn.Layer):
# (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size) # (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size)
w = np.expand_dims(w, 1) w = np.expand_dims(w, 1)
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype()) self.weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
self.register_buffer("weight", weight)
def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]: def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]:
"""Compute the stft transform. """Compute the stft transform.
...@@ -118,7 +117,7 @@ class STFT(nn.Layer): ...@@ -118,7 +117,7 @@ class STFT(nn.Layer):
------------ ------------
x : Tensor [shape=(B, T)] x : Tensor [shape=(B, T)]
The input waveform. The input waveform.
num_samples : Tensor num_samples : Tensor
Number of samples of each waveform. Number of samples of each waveform.
Returns Returns
------------ ------------
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册