提交 be70b41f 编写于 作者: L liuyibing01

Merge branch 'master' into 'master'

fixes for wavenet and modules

See merge request !47
...@@ -101,6 +101,8 @@ if __name__ == "__main__": ...@@ -101,6 +101,8 @@ if __name__ == "__main__":
state, _ = dg.load_dygraph(args.checkpoint) state, _ = dg.load_dygraph(args.checkpoint)
dv3.set_dict(state) dv3.set_dict(state)
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
# removing weight norm also speeds up computation
for layer in dv3.sublayers(): for layer in dv3.sublayers():
if isinstance(layer, WeightNormWrapper): if isinstance(layer, WeightNormWrapper):
layer.remove_weight_norm() layer.remove_weight_norm()
......
...@@ -21,6 +21,7 @@ from tensorboardX import SummaryWriter ...@@ -21,6 +21,7 @@ from tensorboardX import SummaryWriter
from paddle import fluid from paddle import fluid
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
from parakeet.modules.weight_norm import WeightNormWrapper
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary from parakeet.utils.layer_tools import summary
...@@ -114,6 +115,12 @@ if __name__ == "__main__": ...@@ -114,6 +115,12 @@ if __name__ == "__main__":
print("Loading from {}.pdparams".format(args.checkpoint)) print("Loading from {}.pdparams".format(args.checkpoint))
model.set_dict(model_dict) model.set_dict(model_dict)
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
# removing weight norm also speeds up computation
for layer in model.sublayers():
if isinstance(layer, WeightNormWrapper):
layer.remove_weight_norm()
train_loader = fluid.io.DataLoader.from_generator( train_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True) capacity=10, return_list=True)
train_loader.set_batch_generator(train_cargo, place) train_loader.set_batch_generator(train_cargo, place)
......
...@@ -313,6 +313,7 @@ class WaveNet(dg.Layer): ...@@ -313,6 +313,7 @@ class WaveNet(dg.Layer):
""" """
# Causal Conv # Causal Conv
if self.loss_type == "softmax": if self.loss_type == "softmax":
x = F.clip(x, min=-1., max=0.99999)
x = quantize(x, self.output_dim) x = quantize(x, self.output_dim)
x = self.embed(x) # (B, T, C), T=1 x = self.embed(x) # (B, T, C), T=1
else: else:
......
...@@ -86,7 +86,7 @@ class Conv1D(dg.Conv2D): ...@@ -86,7 +86,7 @@ class Conv1D(dg.Conv2D):
stride=1, stride=1,
padding=0, padding=0,
dilation=1, dilation=1,
groups=None, groups=1,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
use_cudnn=True, use_cudnn=True,
...@@ -128,7 +128,7 @@ class Conv1DTranspose(dg.Conv2DTranspose): ...@@ -128,7 +128,7 @@ class Conv1DTranspose(dg.Conv2DTranspose):
padding=0, padding=0,
stride=1, stride=1,
dilation=1, dilation=1,
groups=None, groups=1,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
use_cudnn=True, use_cudnn=True,
...@@ -179,7 +179,7 @@ class Conv1DCell(Conv1D): ...@@ -179,7 +179,7 @@ class Conv1DCell(Conv1D):
filter_size, filter_size,
dilation=1, dilation=1,
causal=False, causal=False,
groups=None, groups=1,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
use_cudnn=True, use_cudnn=True,
...@@ -225,6 +225,12 @@ class Conv1DCell(Conv1D): ...@@ -225,6 +225,12 @@ class Conv1DCell(Conv1D):
def start_sequence(self): def start_sequence(self):
"""Prepare the Conv1DCell to generate a new sequence, this method should be called before calling add_input multiple times. """Prepare the Conv1DCell to generate a new sequence, this method should be called before calling add_input multiple times.
WARNING:
This method accesses `self.weight` directly. If a `Conv1DCell` object is wrapped in a `WeightNormWrapper`, make sure this method is called only after the `WeightNormWrapper`'s hook is called.
`WeightNormWrapper` removes the wrapped layer's `weight`, add has a `weight_v` and `weight_g` to re-compute the wrapped layer's weight as $weight = weight_g * weight_v / ||weight_v||$. (Recomputing the `weight` is a hook before calling the wrapped layer's `forward` method.)
Whenever a `WeightNormWrapper`'s `forward` method is called, the wrapped layer's weight is updated. But when loading from a checkpoint, `weight_v` and `weight_g` are updated but the wrapped layer's weight is not, since it is no longer a `Parameter`. You should manually call `remove_weight_norm` or `hook` to re-compute the wrapped layer's weight before calling this method if you don't call `forward` first.
So when loading a model which uses `Conv1DCell` objects wrapped in `WeightNormWrapper`s, remember to call `remove_weight_norm` for all `WeightNormWrapper`s before synthesizing. Also, removing weight norm speeds up computation.
""" """
if not self.causal: if not self.causal:
raise ValueError( raise ValueError(
......
...@@ -151,7 +151,7 @@ def Conv1D(num_channels, ...@@ -151,7 +151,7 @@ def Conv1D(num_channels,
stride=1, stride=1,
padding=0, padding=0,
dilation=1, dilation=1,
groups=None, groups=1,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
use_cudnn=True, use_cudnn=True,
...@@ -170,7 +170,7 @@ def Conv1DTranspose(num_channels, ...@@ -170,7 +170,7 @@ def Conv1DTranspose(num_channels,
padding=0, padding=0,
stride=1, stride=1,
dilation=1, dilation=1,
groups=None, groups=1,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
use_cudnn=True, use_cudnn=True,
...@@ -188,7 +188,7 @@ def Conv1DCell(num_channels, ...@@ -188,7 +188,7 @@ def Conv1DCell(num_channels,
filter_size, filter_size,
dilation=1, dilation=1,
causal=False, causal=False,
groups=None, groups=1,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
use_cudnn=True, use_cudnn=True,
...@@ -207,7 +207,7 @@ def Conv2D(num_channels, ...@@ -207,7 +207,7 @@ def Conv2D(num_channels,
stride=1, stride=1,
padding=0, padding=0,
dilation=1, dilation=1,
groups=None, groups=1,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
use_cudnn=True, use_cudnn=True,
...@@ -228,7 +228,7 @@ def Conv2DTranspose(num_channels, ...@@ -228,7 +228,7 @@ def Conv2DTranspose(num_channels,
padding=0, padding=0,
stride=1, stride=1,
dilation=1, dilation=1,
groups=None, groups=1,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
use_cudnn=True, use_cudnn=True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册