提交 94918305 编写于 作者: H Hui Zhang

type_as to astype

上级 82b8296f
...@@ -114,7 +114,6 @@ class ConvBn(nn.Layer): ...@@ -114,7 +114,6 @@ class ConvBn(nn.Layer):
masks = make_non_pad_mask(x_len) #[B, T] masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# TODO(Hui Zhang): not support bool multiply # TODO(Hui Zhang): not support bool multiply
# masks = masks.type_as(x)
masks = masks.astype(x.dtype) masks = masks.astype(x.dtype)
x = x.multiply(masks) x = x.multiply(masks)
......
...@@ -159,7 +159,7 @@ class BaseEncoder(nn.Layer): ...@@ -159,7 +159,7 @@ class BaseEncoder(nn.Layer):
if self.global_cmvn is not None: if self.global_cmvn is not None:
xs = self.global_cmvn(xs) xs = self.global_cmvn(xs)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0) xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool) masks = masks.astype(paddle.bool)
#TODO(Hui Zhang): mask_pad = ~masks #TODO(Hui Zhang): mask_pad = ~masks
......
...@@ -136,7 +136,7 @@ class LabelSmoothingLoss(nn.Layer): ...@@ -136,7 +136,7 @@ class LabelSmoothingLoss(nn.Layer):
#TODO(Hui Zhang): sum not support bool type #TODO(Hui Zhang): sum not support bool type
#total = len(target) - int(ignore.sum()) #total = len(target) - int(ignore.sum())
total = len(target) - int(ignore.type_as(target).sum()) total = len(target) - int(ignore.astype(target.dtype).sum())
denom = total if self.normalize_length else B denom = total if self.normalize_length else B
#numer = (kl * (1 - ignore)).sum() #numer = (kl * (1 - ignore)).sum()
numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum() numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()
......
...@@ -159,8 +159,8 @@ def th_accuracy(pad_outputs: paddle.Tensor, ...@@ -159,8 +159,8 @@ def th_accuracy(pad_outputs: paddle.Tensor,
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) # pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = ( numerator = (
pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = paddle.sum(numerator.type_as(pad_targets)) numerator = paddle.sum(numerator.astype(pad_targets.dtype))
#TODO(Hui Zhang): sum not support bool type #TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask) # denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets)) denominator = paddle.sum(mask.astype(pad_targets.dtype))
return float(numerator) / float(denominator) return float(numerator) / float(denominator)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册