提交 d77557b1 编写于 作者: C chenfeiyu

fix for examples/wavenet: remove weight norm after loading model

上级 0b96eeae
...@@ -21,6 +21,7 @@ from tensorboardX import SummaryWriter ...@@ -21,6 +21,7 @@ from tensorboardX import SummaryWriter
from paddle import fluid from paddle import fluid
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
from parakeet.modules.weight_norm import WeightNormWrapper
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary from parakeet.utils.layer_tools import summary
...@@ -114,6 +115,10 @@ if __name__ == "__main__": ...@@ -114,6 +115,10 @@ if __name__ == "__main__":
print("Loading from {}.pdparams".format(args.checkpoint)) print("Loading from {}.pdparams".format(args.checkpoint))
model.set_dict(model_dict) model.set_dict(model_dict)
for layer in model.sublayers():
if isinstance(layer, WeightNormWrapper):
layer.remove_weight_norm()
train_loader = fluid.io.DataLoader.from_generator( train_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True) capacity=10, return_list=True)
train_loader.set_batch_generator(train_cargo, place) train_loader.set_batch_generator(train_cargo, place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册