From 271112ca69a8a73500c4fec0f83cda53672f620b Mon Sep 17 00:00:00 2001 From: TianYuan Date: Mon, 13 Mar 2023 21:12:45 +0800 Subject: [PATCH] fix vits reduce_sum's input/output dtype, test=tts (#3028) --- paddlespeech/t2s/models/vits/duration_predictor.py | 14 ++++++-------- paddlespeech/t2s/models/vits/generator.py | 13 ++++++++----- paddlespeech/t2s/models/vits/transform.py | 11 +++++++++-- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/paddlespeech/t2s/models/vits/duration_predictor.py b/paddlespeech/t2s/models/vits/duration_predictor.py index b0bb68d0..12177fbc 100644 --- a/paddlespeech/t2s/models/vits/duration_predictor.py +++ b/paddlespeech/t2s/models/vits/duration_predictor.py @@ -155,12 +155,10 @@ class StochasticDurationPredictor(nn.Layer): z_u, z1 = paddle.split(z_q, [1, 1], 1) u = F.sigmoid(z_u) * x_mask z0 = (w - u) * x_mask - logdet_tot_q += paddle.sum( - (F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask, [1, 2]) - logq = (paddle.sum(-0.5 * - (math.log(2 * math.pi) + - (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q) - + tmp1 = (F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask + logdet_tot_q += paddle.sum(tmp1, [1, 2]) + tmp2 = -0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask + logq = (paddle.sum(tmp2, [1, 2]) - logdet_tot_q) logdet_tot = 0 z0, logdet = self.log_flow(z0, x_mask) logdet_tot += logdet @@ -168,8 +166,8 @@ class StochasticDurationPredictor(nn.Layer): for flow in self.flows: z, logdet = flow(z, x_mask, g=x, inverse=inverse) logdet_tot = logdet_tot + logdet - nll = (paddle.sum(0.5 * (math.log(2 * math.pi) + - (z**2)) * x_mask, [1, 2]) - logdet_tot) + tmp3 = 0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask + nll = (paddle.sum(tmp3, [1, 2]) - logdet_tot) # (B,) return nll + logq else: diff --git a/paddlespeech/t2s/models/vits/generator.py b/paddlespeech/t2s/models/vits/generator.py index fbd2d665..44bd7898 100644 --- a/paddlespeech/t2s/models/vits/generator.py +++ b/paddlespeech/t2s/models/vits/generator.py @@ -371,8 +371,9 @@ class VITSGenerator(nn.Layer): # (B, H, T_text) s_p_sq_r = paddle.exp(-2 * logs_p) # (B, 1, T_text) + tmp1 = -0.5 * math.log(2 * math.pi) - logs_p neg_x_ent_1 = paddle.sum( - -0.5 * math.log(2 * math.pi) - logs_p, + tmp1, [1], keepdim=True, ) # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) @@ -384,8 +385,9 @@ class VITSGenerator(nn.Layer): z_p.transpose([0, 2, 1]), (m_p * s_p_sq_r), ) # (B, 1, T_text) + tmp2 = -0.5 * (m_p**2) * s_p_sq_r neg_x_ent_4 = paddle.sum( - -0.5 * (m_p**2) * s_p_sq_r, + tmp2, [1], keepdim=True, ) # (B, T_feats, T_text) @@ -403,7 +405,6 @@ class VITSGenerator(nn.Layer): w = attn.sum(2) dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) dur_nll = dur_nll / paddle.sum(x_mask) - # expand the length to match with the feature sequence # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) m_p = paddle.matmul(attn.squeeze(1), @@ -511,8 +512,9 @@ class VITSGenerator(nn.Layer): # (B, H, T_text) s_p_sq_r = paddle.exp(-2 * logs_p) # (B, 1, T_text) + tmp3 = -0.5 * math.log(2 * math.pi) - logs_p neg_x_ent_1 = paddle.sum( - -0.5 * math.log(2 * math.pi) - logs_p, + tmp3, [1], keepdim=True, ) # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) @@ -524,8 +526,9 @@ class VITSGenerator(nn.Layer): z_p.transpose([0, 2, 1]), (m_p * s_p_sq_r), ) # (B, 1, T_text) + tmp4 = -0.5 * (m_p**2) * s_p_sq_r neg_x_ent_4 = paddle.sum( - -0.5 * (m_p**2) * s_p_sq_r, + tmp4, [1], keepdim=True, ) # (B, T_feats, T_text) diff --git a/paddlespeech/t2s/models/vits/transform.py b/paddlespeech/t2s/models/vits/transform.py index 61bd5ee2..0edc1d09 100644 --- a/paddlespeech/t2s/models/vits/transform.py +++ b/paddlespeech/t2s/models/vits/transform.py @@ -61,8 +61,12 @@ def piecewise_rational_quadratic_transform( def mask_preprocess(x, mask): + # bins.dtype = int32 B, C, T, bins = paddle.shape(x) - new_x = paddle.zeros([mask.sum(), bins]) + mask_int = paddle.cast(mask, dtype='int64') + # paddle.sum 输入是 int32 或 bool 的时候,输出是 int64 + # paddle.zeros (fill_constant) 的 shape 会被强制转成 int32 类型 + new_x = paddle.zeros([paddle.sum(mask_int), bins]) for i in range(bins): new_x[:, i] = x[:, :, :, i][mask] return new_x @@ -240,4 +244,7 @@ def rational_quadratic_spline( def _searchsorted(bin_locations, inputs, eps=1e-6): bin_locations[..., -1] += eps - return paddle.sum(inputs[..., None] >= bin_locations, axis=-1) - 1 + mask = inputs[..., None] >= bin_locations + mask_int = paddle.cast(mask, 'int64') + out = paddle.sum(mask_int, axis=-1) - 1 + return out -- GitLab