提交 173693f4 编写于 作者: C chenfeiyu 提交者: liuyibing01

fix missing imports, fix ljspeech.yaml config key: encoder_channels

上级 5beef513
...@@ -21,6 +21,7 @@ transform: ...@@ -21,6 +21,7 @@ transform:
# db scale # db scale
min_level_db: -100 min_level_db: -100
ref_level_db: 20 ref_level_db: 20
clip_norm: true
loss: loss:
...@@ -48,20 +49,20 @@ model: ...@@ -48,20 +49,20 @@ model:
embedding_weight_std: 0.1 embedding_weight_std: 0.1
freeze_embedding: false freeze_embedding: false
padding_idx: 0 padding_idx: 0
encoder_channels: 256 encoder_channels: 512
# decoder # decoder
query_position_rate: 1.0 query_position_rate: 1.0
key_position_rate: 1.29 key_position_rate: 1.29
trainable_positional_encodings: false trainable_positional_encodings: false
kernel_size: 3 kernel_size: 3
decoder_channels: 512 decoder_channels: 256
downsample_factor: 4 downsample_factor: 4
outputs_per_step: 1 outputs_per_step: 1
# attention # attention
key_position_rate: true key_projection: true
value_position_rate: true value_projection: true
force_monotonic_attention: true force_monotonic_attention: true
window_backward: -1 window_backward: -1
window_ahead: 3 window_ahead: 3
...@@ -88,16 +89,3 @@ train: ...@@ -88,16 +89,3 @@ train:
snap_interval: 1000 snap_interval: 1000
eval_interval: 10000 eval_interval: 10000
save_interval: 10000 save_interval: 10000
import os import os
import argparse import argparse
import ruamel.yamls import ruamel.yaml
import numpy as np import numpy as np
import soundfile as sf import soundfile as sf
...@@ -22,6 +22,11 @@ if __name__ == "__main__": ...@@ -22,6 +22,11 @@ if __name__ == "__main__":
parser.add_argument("checkpoint", type=str, help="checkpoint to load.") parser.add_argument("checkpoint", type=str, help="checkpoint to load.")
parser.add_argument("text", type=str, help="text file to synthesize") parser.add_argument("text", type=str, help="text file to synthesize")
parser.add_argument("output_path", type=str, help="path to save results") parser.add_argument("output_path", type=str, help="path to save results")
parser.add_argument("-g",
"--device",
type=int,
default=-1,
help="device to use")
args = parser.parse_args() args = parser.parse_args()
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
...@@ -67,7 +72,7 @@ if __name__ == "__main__": ...@@ -67,7 +72,7 @@ if __name__ == "__main__":
use_memory_mask = model_config["use_memory_mask"] use_memory_mask = model_config["use_memory_mask"]
query_position_rate = model_config["query_position_rate"] query_position_rate = model_config["query_position_rate"]
key_position_rate = model_config["key_position_rate"] key_position_rate = model_config["key_position_rate"]
window_behind = model_config["window_behind"] window_backward = model_config["window_backward"]
window_ahead = model_config["window_ahead"] window_ahead = model_config["window_ahead"]
key_projection = model_config["key_projection"] key_projection = model_config["key_projection"]
value_projection = model_config["value_projection"] value_projection = model_config["value_projection"]
...@@ -76,11 +81,12 @@ if __name__ == "__main__": ...@@ -76,11 +81,12 @@ if __name__ == "__main__":
freeze_embedding, filter_size, encoder_channels, freeze_embedding, filter_size, encoder_channels,
n_mels, decoder_channels, r, n_mels, decoder_channels, r,
trainable_positional_encodings, use_memory_mask, trainable_positional_encodings, use_memory_mask,
query_position_rate, key_position_rate, window_behind, query_position_rate, key_position_rate,
window_ahead, key_projection, value_projection, window_backward, window_ahead, key_projection,
downsample_factor, linear_dim, use_decoder_states, value_projection, downsample_factor, linear_dim,
converter_channels, dropout) use_decoder_states, converter_channels, dropout)
summary(dv3)
state, _ = dg.load_dygraph(args.checkpoint) state, _ = dg.load_dygraph(args.checkpoint)
dv3.set_dict(state) dv3.set_dict(state)
......
import os import os
import argparse import argparse
import ruamel.yamls import ruamel.yaml
import numpy as np import numpy as np
from matplotlib import cm from matplotlib import cm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -15,10 +15,9 @@ import paddle.fluid.layers as F ...@@ -15,10 +15,9 @@ import paddle.fluid.layers as F
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
from parakeet.g2p import en from parakeet.g2p import en
from parakeet.models.deepvoice3.encoder import ConvSpec
from parakeet.data import FilterDataset, TransformDataset, FilterDataset from parakeet.data import FilterDataset, TransformDataset, FilterDataset
from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3 from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec
from parakeet.models.deepvoice3.loss import TTSLoss from parakeet.models.deepvoice3.loss import TTSLoss
from parakeet.utils.layer_tools import summary from parakeet.utils.layer_tools import summary
...@@ -128,7 +127,7 @@ if __name__ == "__main__": ...@@ -128,7 +127,7 @@ if __name__ == "__main__":
use_memory_mask = model_config["use_memory_mask"] use_memory_mask = model_config["use_memory_mask"]
query_position_rate = model_config["query_position_rate"] query_position_rate = model_config["query_position_rate"]
key_position_rate = model_config["key_position_rate"] key_position_rate = model_config["key_position_rate"]
window_behind = model_config["window_behind"] window_backward = model_config["window_backward"]
window_ahead = model_config["window_ahead"] window_ahead = model_config["window_ahead"]
key_projection = model_config["key_projection"] key_projection = model_config["key_projection"]
value_projection = model_config["value_projection"] value_projection = model_config["value_projection"]
...@@ -137,10 +136,10 @@ if __name__ == "__main__": ...@@ -137,10 +136,10 @@ if __name__ == "__main__":
freeze_embedding, filter_size, encoder_channels, freeze_embedding, filter_size, encoder_channels,
n_mels, decoder_channels, r, n_mels, decoder_channels, r,
trainable_positional_encodings, use_memory_mask, trainable_positional_encodings, use_memory_mask,
query_position_rate, key_position_rate, window_behind, query_position_rate, key_position_rate,
window_ahead, key_projection, value_projection, window_backward, window_ahead, key_projection,
downsample_factor, linear_dim, use_decoder_states, value_projection, downsample_factor, linear_dim,
converter_channels, dropout) use_decoder_states, converter_channels, dropout)
# =========================loss========================= # =========================loss=========================
loss_config = config["loss"] loss_config = config["loss"]
......
from .dataset import *
from .datacargo import *
from .sampler import *
from .batch import *
from parakeet.models.deepvoice3.encoder import Encoder from parakeet.models.deepvoice3.encoder import Encoder, ConvSpec
from parakeet.models.deepvoice3.decoder import Decoder from parakeet.models.deepvoice3.decoder import Decoder, WindowRange
from parakeet.models.deepvoice3.converter import Converter from parakeet.models.deepvoice3.converter import Converter
from parakeet.models.deepvoice3.model import DeepVoice3 from parakeet.models.deepvoice3.model import DeepVoice3
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册