未验证 提交 a55fd2e5 编写于 作者: 艾梦 提交者: GitHub

[TTS]Fix diffusion wavenet denoiser final conv init param (#2868)

* add diffusion module for training diffsinger

* add wavenet denoiser final conv initializer
上级 896da6dc
...@@ -40,7 +40,7 @@ class WaveNetDenoiser(nn.Layer): ...@@ -40,7 +40,7 @@ class WaveNetDenoiser(nn.Layer):
layers (int, optional): layers (int, optional):
Number of residual blocks inside, by default 20 Number of residual blocks inside, by default 20
stacks (int, optional): stacks (int, optional):
The number of groups to split the residual blocks into, by default 4 The number of groups to split the residual blocks into, by default 5
Within each group, the dilation of the residual block grows exponentially. Within each group, the dilation of the residual block grows exponentially.
residual_channels (int, optional): residual_channels (int, optional):
Residual channel of the residual blocks, by default 256 Residual channel of the residual blocks, by default 256
...@@ -64,7 +64,7 @@ class WaveNetDenoiser(nn.Layer): ...@@ -64,7 +64,7 @@ class WaveNetDenoiser(nn.Layer):
out_channels: int=80, out_channels: int=80,
kernel_size: int=3, kernel_size: int=3,
layers: int=20, layers: int=20,
stacks: int=4, stacks: int=5,
residual_channels: int=256, residual_channels: int=256,
gate_channels: int=512, gate_channels: int=512,
skip_channels: int=256, skip_channels: int=256,
...@@ -72,7 +72,7 @@ class WaveNetDenoiser(nn.Layer): ...@@ -72,7 +72,7 @@ class WaveNetDenoiser(nn.Layer):
dropout: float=0., dropout: float=0.,
bias: bool=True, bias: bool=True,
use_weight_norm: bool=False, use_weight_norm: bool=False,
init_type: str="kaiming_uniform", ): init_type: str="kaiming_normal", ):
super().__init__() super().__init__()
# initialize parameters # initialize parameters
...@@ -118,18 +118,15 @@ class WaveNetDenoiser(nn.Layer): ...@@ -118,18 +118,15 @@ class WaveNetDenoiser(nn.Layer):
bias=bias) bias=bias)
self.conv_layers.append(conv) self.conv_layers.append(conv)
final_conv = nn.Conv1D(skip_channels, out_channels, 1, bias_attr=True)
nn.initializer.Constant(0.0)(final_conv.weight)
self.last_conv_layers = nn.Sequential(nn.ReLU(), self.last_conv_layers = nn.Sequential(nn.ReLU(),
nn.Conv1D( nn.Conv1D(
skip_channels, skip_channels,
skip_channels, skip_channels,
1, 1,
bias_attr=True), bias_attr=True),
nn.ReLU(), nn.ReLU(), final_conv)
nn.Conv1D(
skip_channels,
out_channels,
1,
bias_attr=True))
if use_weight_norm: if use_weight_norm:
self.apply_weight_norm() self.apply_weight_norm()
...@@ -200,10 +197,6 @@ class GaussianDiffusion(nn.Layer): ...@@ -200,10 +197,6 @@ class GaussianDiffusion(nn.Layer):
Args: Args:
denoiser (Layer, optional): denoiser (Layer, optional):
The model used for denoising noises. The model used for denoising noises.
In fact, the denoiser model performs the operation
of producing a output with more noises from the noisy input.
Then we use the diffusion algorithm to calculate
the input with the output to get the denoised result.
num_train_timesteps (int, optional): num_train_timesteps (int, optional):
The number of timesteps between the noise and the real during training, by default 1000. The number of timesteps between the noise and the real during training, by default 1000.
beta_start (float, optional): beta_start (float, optional):
...@@ -233,7 +226,8 @@ class GaussianDiffusion(nn.Layer): ...@@ -233,7 +226,8 @@ class GaussianDiffusion(nn.Layer):
>>> def callback(index, timestep, num_timesteps, sample): >>> def callback(index, timestep, num_timesteps, sample):
>>> nonlocal pbar >>> nonlocal pbar
>>> if pbar is None: >>> if pbar is None:
>>> pbar = tqdm(total=num_timesteps-index) >>> pbar = tqdm(total=num_timesteps)
>>> pbar.update(index)
>>> pbar.update() >>> pbar.update()
>>> >>>
>>> return callback >>> return callback
...@@ -247,7 +241,7 @@ class GaussianDiffusion(nn.Layer): ...@@ -247,7 +241,7 @@ class GaussianDiffusion(nn.Layer):
>>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step) >>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step)
>>> with paddle.no_grad(): >>> with paddle.no_grad():
>>> sample = diffusion.inference( >>> sample = diffusion.inference(
>>> paddle.randn(x.shape), c, x, >>> paddle.randn(x.shape), c, ref_x=x_in,
>>> num_inference_steps=infer_steps, >>> num_inference_steps=infer_steps,
>>> scheduler_type=scheduler_type, >>> scheduler_type=scheduler_type,
>>> callback=create_progress_callback()) >>> callback=create_progress_callback())
...@@ -262,7 +256,7 @@ class GaussianDiffusion(nn.Layer): ...@@ -262,7 +256,7 @@ class GaussianDiffusion(nn.Layer):
>>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step) >>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step)
>>> with paddle.no_grad(): >>> with paddle.no_grad():
>>> sample = diffusion.inference( >>> sample = diffusion.inference(
>>> paddle.randn(x.shape), c, x_in, >>> paddle.randn(x.shape), c, ref_x=x_in,
>>> num_inference_steps=infer_steps, >>> num_inference_steps=infer_steps,
>>> scheduler_type=scheduler_type, >>> scheduler_type=scheduler_type,
>>> callback=create_progress_callback()) >>> callback=create_progress_callback())
...@@ -277,11 +271,11 @@ class GaussianDiffusion(nn.Layer): ...@@ -277,11 +271,11 @@ class GaussianDiffusion(nn.Layer):
>>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step) >>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step)
>>> with paddle.no_grad(): >>> with paddle.no_grad():
>>> sample = diffusion.inference( >>> sample = diffusion.inference(
>>> paddle.randn(x.shape), c, None, >>> paddle.randn(x.shape), c, ref_x=x_in,
>>> num_inference_steps=infer_steps, >>> num_inference_steps=infer_steps,
>>> scheduler_type=scheduler_type, >>> scheduler_type=scheduler_type,
>>> callback=create_progress_callback()) >>> callback=create_progress_callback())
100%|█████| 25/25 [00:01<00:00, 19.75it/s] 100%|█████| 34/34 [00:01<00:00, 19.75it/s]
>>> >>>
>>> # ds=1000, K_step=100, scheduler=pndm, infer_step=50, from aux fs2 mel output >>> # ds=1000, K_step=100, scheduler=pndm, infer_step=50, from aux fs2 mel output
>>> ds = 1000 >>> ds = 1000
...@@ -292,11 +286,11 @@ class GaussianDiffusion(nn.Layer): ...@@ -292,11 +286,11 @@ class GaussianDiffusion(nn.Layer):
>>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step) >>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step)
>>> with paddle.no_grad(): >>> with paddle.no_grad():
>>> sample = diffusion.inference( >>> sample = diffusion.inference(
>>> paddle.randn(x.shape), c, x, >>> paddle.randn(x.shape), c, ref_x=x_in,
>>> num_inference_steps=infer_steps, >>> num_inference_steps=infer_steps,
>>> scheduler_type=scheduler_type, >>> scheduler_type=scheduler_type,
>>> callback=create_progress_callback()) >>> callback=create_progress_callback())
100%|█████| 5/5 [00:00<00:00, 23.80it/s] 100%|█████| 14/14 [00:00<00:00, 23.80it/s]
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册