提交 94918305 编写于 作者: H Hui Zhang

type_as to astype

上级 82b8296f
......@@ -114,7 +114,6 @@ class ConvBn(nn.Layer):
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# TODO(Hui Zhang): not support bool multiply
# masks = masks.type_as(x)
masks = masks.astype(x.dtype)
x = x.multiply(masks)
......
......@@ -159,7 +159,7 @@ class BaseEncoder(nn.Layer):
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
#TODO(Hui Zhang): mask_pad = ~masks
......
......@@ -136,7 +136,7 @@ class LabelSmoothingLoss(nn.Layer):
#TODO(Hui Zhang): sum not support bool type
#total = len(target) - int(ignore.sum())
total = len(target) - int(ignore.type_as(target).sum())
total = len(target) - int(ignore.astype(target.dtype).sum())
denom = total if self.normalize_length else B
#numer = (kl * (1 - ignore)).sum()
numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()
......
......@@ -159,8 +159,8 @@ def th_accuracy(pad_outputs: paddle.Tensor,
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = (
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = paddle.sum(numerator.type_as(pad_targets))
numerator = paddle.sum(numerator.astype(pad_targets.dtype))
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets))
denominator = paddle.sum(mask.astype(pad_targets.dtype))
return float(numerator) / float(denominator)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册