提交 e6e8fa74 编写于 作者: H Hui Zhang

paddle.broadcast_shape; log_softmax; equal(zeros); register_buffer

上级 3c2dcfad
......@@ -616,7 +616,7 @@ class U2Tester(U2Trainer):
shape=[1, encoder_max_time, encoder_model_size],
dtype='float32'), # encoder_out
])
logger.info(f"Export code: {static_model}")
logger.info(f"Export code: {static_model.main_program}")
paddle.jit.save(static_model, self.args.export_path)
......
......@@ -954,5 +954,5 @@ class U2InferModel(U2Model):
# (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder.export(encoder_out, encoder_mask, hyps,
hyps_masks)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
return decoder_out
......@@ -99,11 +99,11 @@ class MultiHeadedAttention(nn.Layer):
"""
n_batch = value.shape[0]
if mask is not None:
# TODO(Hui Zhang): slice not support `int`; paddle not has `scalar` tensor.
mask = mask.unsqueeze(1).equal(
paddle.to_tensor(0, dtype=mask.dtype)) # (batch, 1, *, time2)
paddle.zeros([1], dtype=mask.dtype)) # (batch, 1, *, time2)
scores = masked_fill(scores, mask, -float('inf'))
attn = paddle.softmax(
scores, axis=-1)
attn = paddle.softmax(scores, axis=-1)
attn = masked_fill(attn, mask, 0.0) # (batch, head, time1, time2)
else:
attn = paddle.softmax(
......
......@@ -13,8 +13,8 @@
# limitations under the License.
"""Unility functions for Transformer."""
from typing import List
from typing import Tuple
from typing import Optional
from typing import Tuple
from typing import Union
import paddle
......@@ -25,6 +25,7 @@ __all__ = ["masked_fill", "pad_sequence", "add_sos_eos", "th_accuracy"]
logger = Log(__name__).getlog()
def is_broadcastable(shp1, shp2):
for a, b in zip(shp1[::-1], shp2[::-1]):
if a == 1 or b == 1 or a == b:
......@@ -33,17 +34,22 @@ def is_broadcastable(shp1, shp2):
return False
return True
def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
if paddle.in_dynamic_mode():
assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
else:
# TODO(Hui Zhang): support broadcast_shape in static graph
bshape = xs.shape
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs)
return xs
def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor:
......@@ -184,4 +190,4 @@ def th_accuracy(pad_outputs: paddle.Tensor,
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator = paddle.sum(mask.astype(pad_targets.dtype))
return float(numerator) / float(denominator)
\ No newline at end of file
return float(numerator) / float(denominator)
......@@ -24,7 +24,7 @@ def frame(x: Tensor,
hop_length : int
Number of samples shifted between ajancent frames.
clip : bool, optional
Whether to clip audio that does not fit into the last frame, by
Whether to clip audio that does not fit into the last frame, by
default True
Returns
......@@ -53,16 +53,16 @@ def frame(x: Tensor,
class STFT(nn.Layer):
"""A module for computing stft transformation in a differentiable way.
"""A module for computing stft transformation in a differentiable way.
Parameters
------------
n_fft : int
Number of samples in a frame.
hop_length : int
Number of samples shifted between adjacent frames.
win_length : int
Length of the window.
......@@ -109,8 +109,7 @@ class STFT(nn.Layer):
# (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size)
w = np.expand_dims(w, 1)
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
self.register_buffer("weight", weight)
self.weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]:
"""Compute the stft transform.
......@@ -118,7 +117,7 @@ class STFT(nn.Layer):
------------
x : Tensor [shape=(B, T)]
The input waveform.
num_samples : Tensor
num_samples : Tensor
Number of samples of each waveform.
Returns
------------
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册