diff --git a/examples/deepvoice3/README.md b/examples/deepvoice3/README.md index 4f939e12715010304f22a99a7f966adc4e73a12e..3e4b0b385a201548dbfc5917cffda77f3c3b2cd0 100644 --- a/examples/deepvoice3/README.md +++ b/examples/deepvoice3/README.md @@ -87,7 +87,7 @@ runs/Jul07_09-39-34_instance-mqcyj27y-4/ ... ``` -Since e use waveflow to synthesize audio while training, so download the trained waveflow model and extract it in current directory before training. +Since we use waveflow to synthesize audio while training, so download the trained waveflow model and extract it in current directory before training. ```bash wget https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_res128_ljspeech_ckpt_1.0.zip diff --git a/examples/deepvoice3/configs/ljspeech.yaml b/examples/deepvoice3/configs/ljspeech.yaml index cbcaa9c5c834561fa4bae41c1f7fb5ed14f67738..1e8ec7b2d77bcf2d994b8ad976b305fc22bcf70a 100644 --- a/examples/deepvoice3/configs/ljspeech.yaml +++ b/examples/deepvoice3/configs/ljspeech.yaml @@ -39,6 +39,7 @@ clip_value: 5.0 clip_norm: 100.0 # training: +max_iteration: 1000000 batch_size: 16 report_interval: 10000 save_interval: 10000 diff --git a/examples/deepvoice3/data.py b/examples/deepvoice3/data.py index 3e30c95e71207be4ca6b775e0552d10ac933b2ed..984f9639a05196c5fb6a9c6234d038e0ed5e380d 100644 --- a/examples/deepvoice3/data.py +++ b/examples/deepvoice3/data.py @@ -62,10 +62,8 @@ class DataCollector(object): for example in examples: text, spec, mel, _ = example text_seqs.append(en.text_to_sequence(text, self.p_pronunciation)) - # if max_frames - mel.shape[0] < 0: - # import pdb; pdb.set_trace() - specs.append(np.pad(spec, [(0, max_frames - spec.shape[0]), (0, 0)])) - mels.append(np.pad(mel, [(0, max_frames - mel.shape[0]), (0, 0)])) + specs.append(np.pad(spec, [(0, max_frames - spec.shape[0]), (0, 0)], mode="constant")) + mels.append(np.pad(mel, [(0, max_frames - mel.shape[0]), (0, 0)], mode="constant")) specs = np.stack(specs) mels = np.stack(mels) diff --git a/examples/deepvoice3/train.py b/examples/deepvoice3/train.py index 07f5c9435cab854bfe1eb386371641d166d021e5..8e629c360d4af0f20611b2ab17890cbd5e8c5fbe 100644 --- a/examples/deepvoice3/train.py +++ b/examples/deepvoice3/train.py @@ -81,7 +81,7 @@ def train(args, config): optim = create_optimizer(model, config) global global_step - max_iteration = 1000000 + max_iteration = config["max_iteration"] iterator = iter(tqdm.tqdm(train_loader)) while global_step <= max_iteration: