未验证 提交 271112ca 编写于 作者: 小湉湉's avatar 小湉湉 提交者: GitHub

fix vits reduce_sum's input/output dtype, test=tts (#3028)

上级 34f2995b
......@@ -155,12 +155,10 @@ class StochasticDurationPredictor(nn.Layer):
z_u, z1 = paddle.split(z_q, [1, 1], 1)
u = F.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += paddle.sum(
(F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask, [1, 2])
logq = (paddle.sum(-0.5 *
(math.log(2 * math.pi) +
(e_q**2)) * x_mask, [1, 2]) - logdet_tot_q)
tmp1 = (F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask
logdet_tot_q += paddle.sum(tmp1, [1, 2])
tmp2 = -0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask
logq = (paddle.sum(tmp2, [1, 2]) - logdet_tot_q)
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
......@@ -168,8 +166,8 @@ class StochasticDurationPredictor(nn.Layer):
for flow in self.flows:
z, logdet = flow(z, x_mask, g=x, inverse=inverse)
logdet_tot = logdet_tot + logdet
nll = (paddle.sum(0.5 * (math.log(2 * math.pi) +
(z**2)) * x_mask, [1, 2]) - logdet_tot)
tmp3 = 0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask
nll = (paddle.sum(tmp3, [1, 2]) - logdet_tot)
# (B,)
return nll + logq
else:
......
......@@ -371,8 +371,9 @@ class VITSGenerator(nn.Layer):
# (B, H, T_text)
s_p_sq_r = paddle.exp(-2 * logs_p)
# (B, 1, T_text)
tmp1 = -0.5 * math.log(2 * math.pi) - logs_p
neg_x_ent_1 = paddle.sum(
-0.5 * math.log(2 * math.pi) - logs_p,
tmp1,
[1],
keepdim=True, )
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
......@@ -384,8 +385,9 @@ class VITSGenerator(nn.Layer):
z_p.transpose([0, 2, 1]),
(m_p * s_p_sq_r), )
# (B, 1, T_text)
tmp2 = -0.5 * (m_p**2) * s_p_sq_r
neg_x_ent_4 = paddle.sum(
-0.5 * (m_p**2) * s_p_sq_r,
tmp2,
[1],
keepdim=True, )
# (B, T_feats, T_text)
......@@ -403,7 +405,6 @@ class VITSGenerator(nn.Layer):
w = attn.sum(2)
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
dur_nll = dur_nll / paddle.sum(x_mask)
# expand the length to match with the feature sequence
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
m_p = paddle.matmul(attn.squeeze(1),
......@@ -511,8 +512,9 @@ class VITSGenerator(nn.Layer):
# (B, H, T_text)
s_p_sq_r = paddle.exp(-2 * logs_p)
# (B, 1, T_text)
tmp3 = -0.5 * math.log(2 * math.pi) - logs_p
neg_x_ent_1 = paddle.sum(
-0.5 * math.log(2 * math.pi) - logs_p,
tmp3,
[1],
keepdim=True, )
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
......@@ -524,8 +526,9 @@ class VITSGenerator(nn.Layer):
z_p.transpose([0, 2, 1]),
(m_p * s_p_sq_r), )
# (B, 1, T_text)
tmp4 = -0.5 * (m_p**2) * s_p_sq_r
neg_x_ent_4 = paddle.sum(
-0.5 * (m_p**2) * s_p_sq_r,
tmp4,
[1],
keepdim=True, )
# (B, T_feats, T_text)
......
......@@ -61,8 +61,12 @@ def piecewise_rational_quadratic_transform(
def mask_preprocess(x, mask):
# bins.dtype = int32
B, C, T, bins = paddle.shape(x)
new_x = paddle.zeros([mask.sum(), bins])
mask_int = paddle.cast(mask, dtype='int64')
# paddle.sum 输入是 int32 或 bool 的时候,输出是 int64
# paddle.zeros (fill_constant) 的 shape 会被强制转成 int32 类型
new_x = paddle.zeros([paddle.sum(mask_int), bins])
for i in range(bins):
new_x[:, i] = x[:, :, :, i][mask]
return new_x
......@@ -240,4 +244,7 @@ def rational_quadratic_spline(
def _searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return paddle.sum(inputs[..., None] >= bin_locations, axis=-1) - 1
mask = inputs[..., None] >= bin_locations
mask_int = paddle.cast(mask, 'int64')
out = paddle.sum(mask_int, axis=-1) - 1
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册