提交 d77557b1 编写于 作者: C chenfeiyu

fix for examples/wavenet: remove weight norm after loading model

上级 0b96eeae
......@@ -21,6 +21,7 @@ from tensorboardX import SummaryWriter
from paddle import fluid
import paddle.fluid.dygraph as dg
from parakeet.modules.weight_norm import WeightNormWrapper
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary
......@@ -114,6 +115,10 @@ if __name__ == "__main__":
print("Loading from {}.pdparams".format(args.checkpoint))
model.set_dict(model_dict)
for layer in model.sublayers():
if isinstance(layer, WeightNormWrapper):
layer.remove_weight_norm()
train_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
train_loader.set_batch_generator(train_cargo, place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册