提交 797ca389 编写于 作者: H Hui Zhang

paddle support some bool op

上级 a463f940
...@@ -162,10 +162,7 @@ class U2BaseModel(nn.Layer): ...@@ -162,10 +162,7 @@ class U2BaseModel(nn.Layer):
encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_time = time.time() - start encoder_time = time.time() - start
#logger.debug(f"encoder time: {encoder_time}") #logger.debug(f"encoder time: {encoder_time}")
#TODO(Hui Zhang): sum not support bool type encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
1) #[B, 1, T] -> [B]
# 2a. Attention-decoder branch # 2a. Attention-decoder branch
loss_att = None loss_att = None
...@@ -320,8 +317,7 @@ class U2BaseModel(nn.Layer): ...@@ -320,8 +317,7 @@ class U2BaseModel(nn.Layer):
# 2. Decoder forward step by step # 2. Decoder forward step by step
for i in range(1, maxlen + 1): for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos # Stop if all batch and all beam produce eos
# TODO(Hui Zhang): if end_flag.sum() == running_size: if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break break
# 2.1 Forward decoder step # 2.1 Forward decoder step
...@@ -407,9 +403,7 @@ class U2BaseModel(nn.Layer): ...@@ -407,9 +403,7 @@ class U2BaseModel(nn.Layer):
speech, speech_lengths, decoding_chunk_size, speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming) num_decoding_left_chunks, simulate_streaming)
maxlen = encoder_out.size(1) maxlen = encoder_out.size(1)
# (TODO Hui Zhang): bool no support reduce_sum encoder_out_lens = encoder_mask.squeeze(1).sum(1)
# encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1)
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
......
...@@ -163,10 +163,7 @@ class U2STBaseModel(nn.Layer): ...@@ -163,10 +163,7 @@ class U2STBaseModel(nn.Layer):
encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_time = time.time() - start encoder_time = time.time() - start
#logger.debug(f"encoder time: {encoder_time}") #logger.debug(f"encoder time: {encoder_time}")
#TODO(Hui Zhang): sum not support bool type encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
1) #[B, 1, T] -> [B]
# 2a. ST-decoder branch # 2a. ST-decoder branch
start = time.time() start = time.time()
...@@ -363,8 +360,7 @@ class U2STBaseModel(nn.Layer): ...@@ -363,8 +360,7 @@ class U2STBaseModel(nn.Layer):
# 2. Decoder forward step by step # 2. Decoder forward step by step
for i in range(1, maxlen + 1): for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos # Stop if all batch and all beam produce eos
# TODO(Hui Zhang): if end_flag.sum() == running_size: if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break break
# 2.1 Forward decoder step # 2.1 Forward decoder step
......
...@@ -109,8 +109,8 @@ class MultiHeadedAttention(nn.Layer): ...@@ -109,8 +109,8 @@ class MultiHeadedAttention(nn.Layer):
p_attn = self.dropout(attn) p_attn = self.dropout(attn)
x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k)
x = x.transpose([0, 2, 1, 3]).contiguous().view( x = x.transpose([0, 2, 1, 3]).view(n_batch, -1, self.h *
n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) self.d_k) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model)
......
...@@ -124,9 +124,7 @@ class TransformerDecoder(nn.Layer): ...@@ -124,9 +124,7 @@ class TransformerDecoder(nn.Layer):
# m: (1, L, L) # m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1)).unsqueeze(0) m = subsequent_mask(tgt_mask.size(-1)).unsqueeze(0)
# tgt_mask: (B, L, L) # tgt_mask: (B, L, L)
# TODO(Hui Zhang): not support & for tensor tgt_mask = tgt_mask & m
# tgt_mask = tgt_mask & m
tgt_mask = tgt_mask.logical_and(m)
x, _ = self.embed(tgt) x, _ = self.embed(tgt)
for layer in self.decoders: for layer in self.decoders:
...@@ -137,9 +135,7 @@ class TransformerDecoder(nn.Layer): ...@@ -137,9 +135,7 @@ class TransformerDecoder(nn.Layer):
if self.use_output_layer: if self.use_output_layer:
x = self.output_layer(x) x = self.output_layer(x)
# TODO(Hui Zhang): reduce_sum not support bool type olens = tgt_mask.sum(1)
# olens = tgt_mask.sum(1)
olens = tgt_mask.astype(paddle.int).sum(1)
return x, olens return x, olens
def forward_one_step( def forward_one_step(
......
...@@ -162,8 +162,7 @@ class BaseEncoder(nn.Layer): ...@@ -162,8 +162,7 @@ class BaseEncoder(nn.Layer):
xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0) xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool) masks = masks.astype(paddle.bool)
#TODO(Hui Zhang): mask_pad = ~masks mask_pad = ~masks
mask_pad = masks.logical_not()
chunk_masks = add_optional_chunk_mask( chunk_masks = add_optional_chunk_mask(
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
decoding_chunk_size, self.static_chunk_size, decoding_chunk_size, self.static_chunk_size,
......
...@@ -124,9 +124,9 @@ class LabelSmoothingLoss(nn.Layer): ...@@ -124,9 +124,9 @@ class LabelSmoothingLoss(nn.Layer):
# use zeros_like instead of torch.no_grad() for true_dist, # use zeros_like instead of torch.no_grad() for true_dist,
# since no_grad() can not be exported by JIT # since no_grad() can not be exported by JIT
true_dist = paddle.full_like(x, self.smoothing / (self.size - 1)) true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,) ignore = (target == self.padding_idx) # (B,)
# target = target * (1 - ignore) # avoid -1 index #TODO(Hui Zhang): target = target * (1 - ignore) # avoid -1 index
target = target.masked_fill(ignore, 0) # avoid -1 index target = target.masked_fill(ignore, 0) # avoid -1 index
# true_dist.scatter_(1, target.unsqueeze(1), self.confidence) # true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
target_mask = F.one_hot(target, self.size) target_mask = F.one_hot(target, self.size)
...@@ -135,10 +135,8 @@ class LabelSmoothingLoss(nn.Layer): ...@@ -135,10 +135,8 @@ class LabelSmoothingLoss(nn.Layer):
kl = self.criterion(F.log_softmax(x, axis=1), true_dist) kl = self.criterion(F.log_softmax(x, axis=1), true_dist)
#TODO(Hui Zhang): sum not support bool type total = len(target) - int(ignore.sum())
#total = len(target) - int(ignore.sum())
total = len(target) - int(ignore.type_as(target).sum())
denom = total if self.normalize_length else B denom = total if self.normalize_length else B
#numer = (kl * (1 - ignore)).sum() #TODO(Hui Zhang): numer = (kl * (1 - ignore)).sum()
numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum() numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()
return numer / denom return numer / denom
...@@ -69,8 +69,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: ...@@ -69,8 +69,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
[1, 1, 1, 0, 0], [1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]] [1, 1, 0, 0, 0]]
""" """
#TODO(Hui Zhang): return ~make_pad_mask(lengths), not support ~ return ~make_pad_mask(lengths)
return make_pad_mask(lengths).logical_not()
def subsequent_mask(size: int) -> paddle.Tensor: def subsequent_mask(size: int) -> paddle.Tensor:
...@@ -92,12 +91,7 @@ def subsequent_mask(size: int) -> paddle.Tensor: ...@@ -92,12 +91,7 @@ def subsequent_mask(size: int) -> paddle.Tensor:
[1, 1, 1]] [1, 1, 1]]
""" """
ret = paddle.ones([size, size], dtype=paddle.bool) ret = paddle.ones([size, size], dtype=paddle.bool)
#TODO(Hui Zhang): tril not support bool return paddle.tril(ret)
#return paddle.tril(ret)
ret = ret.astype(paddle.float)
ret = paddle.tril(ret)
ret = ret.astype(paddle.bool)
return ret
def subsequent_chunk_mask( def subsequent_chunk_mask(
...@@ -186,15 +180,13 @@ def add_optional_chunk_mask(xs: paddle.Tensor, ...@@ -186,15 +180,13 @@ def add_optional_chunk_mask(xs: paddle.Tensor,
chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size, chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size,
num_left_chunks) # (L, L) num_left_chunks) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
# chunk_masks = masks & chunk_masks # (B, L, L) chunk_masks = masks & chunk_masks # (B, L, L)
chunk_masks = masks.logical_and(chunk_masks) # (B, L, L)
elif static_chunk_size > 0: elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size, chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size,
num_left_chunks) # (L, L) num_left_chunks) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
# chunk_masks = masks & chunk_masks # (B, L, L) chunk_masks = masks & chunk_masks # (B, L, L)
chunk_masks = masks.logical_and(chunk_masks) # (B, L, L)
else: else:
chunk_masks = masks chunk_masks = masks
return chunk_masks return chunk_masks
......
...@@ -168,13 +168,7 @@ def th_accuracy(pad_outputs: paddle.Tensor, ...@@ -168,13 +168,7 @@ def th_accuracy(pad_outputs: paddle.Tensor,
pad_pred = pad_outputs.view( pad_pred = pad_outputs.view(
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2) pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label mask = pad_targets != ignore_label
#TODO(Hui Zhang): sum not support bool type numerator = paddle.sum(
# numerator = paddle.sum(
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = (
pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = paddle.sum(numerator.type_as(pad_targets)) denominator = paddle.sum(mask)
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets))
return float(numerator) / float(denominator) return float(numerator) / float(denominator)
# Reference # Reference
* [delta](https://github.com/Delta-ML/delta.git)
* [espnet](https://github.com/espnet/espnet.git)
* [kaldi](https://github.com/kaldi-asr/kaldi.git)
* [wenet](https://github.com/mobvoi/wenet) * [wenet](https://github.com/mobvoi/wenet)
...@@ -37,13 +37,13 @@ class TestU2Model(unittest.TestCase): ...@@ -37,13 +37,13 @@ class TestU2Model(unittest.TestCase):
def test_make_non_pad_mask(self): def test_make_non_pad_mask(self):
res = make_non_pad_mask(self.lengths) res = make_non_pad_mask(self.lengths)
res2 = make_pad_mask(self.lengths).logical_not() res2 = ~make_pad_mask(self.lengths)
self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist()) self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist())
self.assertSequenceEqual(res.numpy().tolist(), res2.numpy().tolist()) self.assertSequenceEqual(res.numpy().tolist(), res2.numpy().tolist())
def test_make_pad_mask(self): def test_make_pad_mask(self):
res = make_pad_mask(self.lengths) res = make_pad_mask(self.lengths)
res1 = make_non_pad_mask(self.lengths).logical_not() res1 = ~make_non_pad_mask(self.lengths)
self.assertSequenceEqual(res.numpy().tolist(), self.pad_masks.tolist()) self.assertSequenceEqual(res.numpy().tolist(), self.pad_masks.tolist())
self.assertSequenceEqual(res.numpy().tolist(), res1.tolist()) self.assertSequenceEqual(res.numpy().tolist(), res1.tolist())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册