提交 e6e8fa74 编写于 作者: H Hui Zhang

paddle.broadcast_shape; log_softmax; equal(zeros); register_buffer

上级 3c2dcfad
......@@ -616,7 +616,7 @@ class U2Tester(U2Trainer):
shape=[1, encoder_max_time, encoder_model_size],
dtype='float32'), # encoder_out
])
logger.info(f"Export code: {static_model}")
logger.info(f"Export code: {static_model.main_program}")
paddle.jit.save(static_model, self.args.export_path)
......
......@@ -954,5 +954,5 @@ class U2InferModel(U2Model):
# (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder.export(encoder_out, encoder_mask, hyps,
hyps_masks)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
return decoder_out
......@@ -99,11 +99,11 @@ class MultiHeadedAttention(nn.Layer):
"""
n_batch = value.shape[0]
if mask is not None:
# TODO(Hui Zhang): slice not support `int`; paddle not has `scalar` tensor.
mask = mask.unsqueeze(1).equal(
paddle.to_tensor(0, dtype=mask.dtype)) # (batch, 1, *, time2)
paddle.zeros([1], dtype=mask.dtype)) # (batch, 1, *, time2)
scores = masked_fill(scores, mask, -float('inf'))
attn = paddle.softmax(
scores, axis=-1)
attn = paddle.softmax(scores, axis=-1)
attn = masked_fill(attn, mask, 0.0) # (batch, head, time1, time2)
else:
attn = paddle.softmax(
......
......@@ -13,8 +13,8 @@
# limitations under the License.
"""Unility functions for Transformer."""
from typing import List
from typing import Tuple
from typing import Optional
from typing import Tuple
from typing import Union
import paddle
......@@ -25,6 +25,7 @@ __all__ = ["masked_fill", "pad_sequence", "add_sos_eos", "th_accuracy"]
logger = Log(__name__).getlog()
def is_broadcastable(shp1, shp2):
for a, b in zip(shp1[::-1], shp2[::-1]):
if a == 1 or b == 1 or a == b:
......@@ -33,17 +34,22 @@ def is_broadcastable(shp1, shp2):
return False
return True
def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
if paddle.in_dynamic_mode():
assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
else:
# TODO(Hui Zhang): support broadcast_shape in static graph
bshape = xs.shape
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs)
return xs
def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor:
......
......@@ -109,8 +109,7 @@ class STFT(nn.Layer):
# (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size)
w = np.expand_dims(w, 1)
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
self.register_buffer("weight", weight)
self.weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
def forward(self, x: Tensor, num_samples: Tensor) -> Tuple[Tensor, Tensor]:
"""Compute the stft transform.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册