diff --git a/README.md b/README.md index b5f61fd67bf75ac8699073e6932c8a43b851da5d..aacf6c31d86d6f8d1b7561cea68bda827b54af04 100644 --- a/README.md +++ b/README.md @@ -76,29 +76,61 @@ Entries to the introduction, and the launch of training and synthsis for differe Parakeet also releases some well-trained parameters for the example models, which can be accessed in the following tables. Each column of these tables lists resources for one model, including the url link to the pre-trained model, the dataset that the model is trained on and the total training steps, and several synthesized audio samples based on the pre-trained model. -- Vocoders +#### Vocoders + +We provide the model checkpoints of WaveFlow with 64 and 128 residual channels, ClariNet and WaveNet.
- - + + + + + + + + + + + + + + + + + +
- WaveFlow + WaveFlow (res. channels 64) - ClariNet + WaveFlow (res. channels 128)
LJSpeech, 2MLJSpeech, 500KLJSpeech, 3020 KLJSpeech
+ +
+ +
+ +
+ +
+ + +
To be added soon
+ ClariNet + + WaveNet +
LJSpeech, 500 KLJSpeech, 2450 K

@@ -111,15 +143,57 @@ Parakeet also releases some well-trained parameters for the example models, whic
+ +
+ +
+ +
+ +
+ + +
-    **Note:** The input mel spectrogams are from validation dataset, which are not seen during training. +      **Note:** The input mel spectrogams are from validation dataset, which are not seen during training. -- TTS models +#### TTS models + +
+ + + + + + + + + + + + + + + + + + +
+ Deep Voice 3 + + Transformer TTS +
LJSpeech LJSpeech
+ To be added soon + + To be added soon +
+
Click each link to download, then you can get the compressed package which contains the pre-trained model and the `yaml` config describing how to train the model. diff --git a/examples/clarinet/README.md b/examples/clarinet/README.md index 9b79897d49237c7f57b70ec0ade908ab65f8441c..ca74b2d90a6b784ae3d99e96f7b992ea919fb433 100644 --- a/examples/clarinet/README.md +++ b/examples/clarinet/README.md @@ -22,49 +22,71 @@ tar xjvf LJSpeech-1.1.tar.bz2 └── utils.py utility functions ``` -## Train +## Saving & Loading +`train.py` and `synthesis.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`. -Train the model using train.py, follow the usage displayed by `python train.py --help`. +1. `output` is the directory for saving results. +During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`. Other possible outputs are saved in `states/` in `outuput`. +During synthesizing, audio files and other possible outputs are save in `synthesis/` in `output`. +So after training and synthesizing with the same output directory, the file structure of the output directory looks like this. ```text -usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--output OUTPUT] - [--data DATA] [--checkpoint CHECKPOINT] [--wavenet WAVENET] +├── checkpoints/ # checkpoint directory (including *.pdparams, *.pdopt and a text file `checkpoint` that records the latest checkpoint) +├── states/ # audio files generated at validation and other possible outputs +├── log/ # tensorboard log +└── synthesis/ # synthesized audio files and other possible outputs +``` -train a ClariNet model with LJspeech and a trained WaveNet model. +2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule: +If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded. +If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory. -optional arguments: - -h, --help show this help message and exit - --config CONFIG path of the config file. - --device DEVICE device to use. - --output OUTPUT path to save student. - --data DATA path of LJspeech dataset. - --checkpoint CHECKPOINT checkpoint to load from. - --wavenet WAVENET wavenet checkpoint to use. -``` +## Train -- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config. -- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt). -- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig. -- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below. +Train the model using train.py, follow the usage displayed by `python train.py --help`. ```text -├── checkpoints # checkpoint -├── states # audio files generated at validation -└── log # tensorboard log -``` +usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--data DATA] + [--checkpoint CHECKPOINT | --iteration ITERATION] + [--wavenet WAVENET] + output + +Train a ClariNet model with LJspeech and a trained WaveNet model. -If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training. +positional arguments: + output path to save experiment results + +optional arguments: + -h, --help show this help message and exit + --config CONFIG path of the config file + --device DEVICE device to use + --data DATA path of LJspeech dataset + --checkpoint CHECKPOINT checkpoint to resume from + --iteration ITERATION the iteration of the checkpoint to load from output directory + --wavenet WAVENET wavenet checkpoint to use +- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config. - `--device` is the device (gpu id) to use for training. `-1` means CPU. -- `--wavenet` is the path of the wavenet checkpoint to load. If you do not specify `--resume`, then this must be provided. +- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains `metadata.txt`). + +- `--checkpoint` is the path of the checkpoint. +- `--iteration` is the iteration of the checkpoint to load from output directory. +- `output` is the directory to save results, all result are saved in this directory. +See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading. -Before you start training a ClariNet model, you should have trained a WaveNet model with single Gaussian output distribution. Make sure the config of the teacher model matches that of the trained model. +- `--wavenet` is the path of the wavenet checkpoint to load. +When you start training a ClariNet model without loading form a ClariNet checkpoint, you should have trained a WaveNet model with single Gaussian output distribution. Make sure the config of the teacher model matches that of the trained wavenet model. Example script: ```bash -python train.py --config=./configs/clarinet_ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0 --conditioner=wavenet_checkpoint/conditioner --conditioner=wavenet_checkpoint/teacher +python train.py + --config=./configs/clarinet_ljspeech.yaml + --data=./LJSpeech-1.1/ + --device=0 + --wavenet="wavenet-step-2000000" + experiment ``` You can monitor training log via tensorboard, using the script below. @@ -77,29 +99,50 @@ tensorboard --logdir=. ## Synthesis ```text usage: synthesis.py [-h] [--config CONFIG] [--device DEVICE] [--data DATA] - checkpoint output + [--checkpoint CHECKPOINT | --iteration ITERATION] + output -train a ClariNet model with LJspeech and a trained WaveNet model. +Synthesize audio files from mel spectrogram in the validation set. positional arguments: - checkpoint checkpoint to load from. - output path to save student. + output path to save the synthesized audio optional arguments: - -h, --help show this help message and exit - --config CONFIG path of the config file. - --device DEVICE device to use. - --data DATA path of LJspeech dataset. + -h, --help show this help message and exit + --config CONFIG path of the config file + --device DEVICE device to use. + --data DATA path of LJspeech dataset + --checkpoint CHECKPOINT checkpoint to resume from + --iteration ITERATION the iteration of the checkpoint to load from output directory ``` - `--config` is the configuration file to use. You should use the same configuration with which you train you model. -- `--data` is the path of the LJspeech dataset. A dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files. -- `checkpoint` is the checkpoint to load. -- `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`). - `--device` is the device (gpu id) to use for training. `-1` means CPU. +- `--data` is the path of the LJspeech dataset. In principle, a dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files. +- `--checkpoint` is the checkpoint to load. +- `--iteration` is the iteration of the checkpoint to load from output directory. +- `output` is the directory to save synthesized audio. Audio file is saved in `synthesis/` in `output` directory. +See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading. + Example script: ```bash -python synthesis.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --device=0 experiment/checkpoints/step_500000 generated +python synthesis.py \ + --config=./configs/wavenet_single_gaussian.yaml \ + --data=./LJSpeech-1.1/ \ + --device=0 \ + --iteration=500000 \ + experiment +``` + +or + +```bash +python synthesis.py \ + --config=./configs/wavenet_single_gaussian.yaml \ + --data=./LJSpeech-1.1/ \ + --device=0 \ + --checkpoint="experiment/checkpoints/step-500000" \ + experiment ``` diff --git a/examples/clarinet/synthesis.py b/examples/clarinet/synthesis.py index ce16fc1cdd14e18b0299193c6aa9b439a8f524ea..ff086bb5eafbef11c4a29aed7a47e802405a01e1 100644 --- a/examples/clarinet/synthesis.py +++ b/examples/clarinet/synthesis.py @@ -26,29 +26,41 @@ from tensorboardX import SummaryWriter import paddle.fluid.dygraph as dg from paddle import fluid +from parakeet.modules.weight_norm import WeightNormWrapper from parakeet.models.wavenet import WaveNet, UpsampleNet from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo from parakeet.utils.layer_tools import summary, freeze +from parakeet.utils import io -from utils import valid_model, eval_model, load_model +from utils import eval_model sys.path.append("../wavenet") from data import LJSpeechMetaData, Transform, DataCollector if __name__ == "__main__": parser = argparse.ArgumentParser( - description="synthesize audio files from mel spectrogram in the validation set." + description="Synthesize audio files from mel spectrogram in the validation set." ) - parser.add_argument("--config", type=str, help="path of the config file.") + parser.add_argument("--config", type=str, help="path of the config file") parser.add_argument( "--device", type=int, default=-1, help="device to use.") - parser.add_argument("--data", type=str, help="path of LJspeech dataset.") - parser.add_argument( - "checkpoint", type=str, help="checkpoint to load from.") + parser.add_argument("--data", type=str, help="path of LJspeech dataset") + + g = parser.add_mutually_exclusive_group() + g.add_argument("--checkpoint", type=str, help="checkpoint to resume from") + g.add_argument( + "--iteration", + type=int, + help="the iteration of the checkpoint to load from output directory") + parser.add_argument( - "output", type=str, default="experiment", help="path to save student.") + "output", + type=str, + default="experiment", + help="path to save the synthesized audio") args = parser.parse_args() + with open(args.config, 'rt') as f: config = ruamel.yaml.safe_load(f) @@ -136,17 +148,32 @@ if __name__ == "__main__": model = Clarinet(upsample_net, teacher, student, stft, student_log_scale_min, lmd) summary(model) - load_model(model, args.checkpoint) - - # loader - train_loader = fluid.io.DataLoader.from_generator( - capacity=10, return_list=True) - train_loader.set_batch_generator(train_cargo, place) + # load parameters + if args.checkpoint is not None: + # load from args.checkpoint + iteration = io.load_parameters( + model, checkpoint_path=args.checkpoint) + else: + # load from "args.output/checkpoints" + checkpoint_dir = os.path.join(args.output, "checkpoints") + iteration = io.load_parameters( + model, checkpoint_dir=checkpoint_dir, iteration=args.iteration) + assert iteration > 0, "A trained checkpoint is needed." + + # make generation fast + for sublayer in model.sublayers(): + if isinstance(sublayer, WeightNormWrapper): + sublayer.remove_weight_norm() + + # data loader valid_loader = fluid.io.DataLoader.from_generator( capacity=10, return_list=True) valid_loader.set_batch_generator(valid_cargo, place) - if not os.path.exists(args.output): - os.makedirs(args.output) - eval_model(model, valid_loader, args.output, sample_rate) + # the directory to save audio files + synthesis_dir = os.path.join(args.output, "synthesis") + if not os.path.exists(synthesis_dir): + os.makedirs(synthesis_dir) + + eval_model(model, valid_loader, synthesis_dir, iteration, sample_rate) diff --git a/examples/clarinet/train.py b/examples/clarinet/train.py index dcfff9bafa339f01b8eb050270edcb46b7a456c2..82d9aa1557bbc8889ec46ff0b8eb8a3d4e60d429 100644 --- a/examples/clarinet/train.py +++ b/examples/clarinet/train.py @@ -32,27 +32,36 @@ from parakeet.data import TransformDataset, SliceDataset, RandomSampler, Sequent from parakeet.utils.layer_tools import summary, freeze from parakeet.utils import io -from utils import make_output_tree, valid_model, load_wavenet +from utils import make_output_tree, eval_model, load_wavenet + +# import dataset from wavenet sys.path.append("../wavenet") from data import LJSpeechMetaData, Transform, DataCollector if __name__ == "__main__": parser = argparse.ArgumentParser( - description="train a ClariNet model with LJspeech and a trained WaveNet model." + description="Train a ClariNet model with LJspeech and a trained WaveNet model." ) - parser.add_argument("--config", type=str, help="path of the config file.") + parser.add_argument("--config", type=str, help="path of the config file") + parser.add_argument("--device", type=int, default=-1, help="device to use") + parser.add_argument("--data", type=str, help="path of LJspeech dataset") + + g = parser.add_mutually_exclusive_group() + g.add_argument("--checkpoint", type=str, help="checkpoint to resume from") + g.add_argument( + "--iteration", + type=int, + help="the iteration of the checkpoint to load from output directory") + parser.add_argument( - "--device", type=int, default=-1, help="device to use.") + "--wavenet", type=str, help="wavenet checkpoint to use") + parser.add_argument( - "--output", + "output", type=str, default="experiment", - help="path to save student.") - parser.add_argument("--data", type=str, help="path of LJspeech dataset.") - parser.add_argument( - "--checkpoint", type=str, help="checkpoint to load from.") - parser.add_argument( - "--wavenet", type=str, help="wavenet checkpoint to use.") + help="path to save experiment results") + args = parser.parse_args() with open(args.config, 'rt') as f: config = ruamel.yaml.safe_load(f) @@ -169,30 +178,20 @@ if __name__ == "__main__": log_dir = os.path.join(args.output, "log") writer = SummaryWriter(log_dir) - # load wavenet/checkpoint, determine iterations done if args.checkpoint is not None: - iteration = int(os.path.basename(args.checkpoint).split('-')[-1]) + iteration = io.load_parameters( + model, optim, checkpoint_path=args.checkpoint) else: - iteration = io.load_latest_checkpoint(checkpoint_dir) - - if iteration == 0 and args.wavenet is None: - raise Exception( - "you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended." - ) - - if args.wavenet is not None and iteration > 0: - if args.checkpoint is None: - print("Resume training, --wavenet ignored") - else: - print("--checkpoint provided, --wavenet ignored") - - if args.wavenet is not None and iteration == 0: + iteration = io.load_parameters( + model, + optim, + checkpoint_dir=checkpoint_dir, + iteration=args.iteration) + + if iteration == 0: + assert args.wavenet is not None, "When training afresh, a trained wavenet model should be provided." load_wavenet(model, args.wavenet) - # it may overwrite the wavenet loaded - io.load_parameters( - checkpoint_dir, 0, model, optim, file_path=args.checkpoint) - # loader train_loader = fluid.io.DataLoader.from_generator( capacity=10, return_list=True) @@ -205,7 +204,7 @@ if __name__ == "__main__": # training loop global_step = iteration + 1 iterator = iter(tqdm(train_loader)) - while global_step < max_iterations: + while global_step <= max_iterations: try: batch = next(iterator) except StopIteration as e: @@ -226,7 +225,8 @@ if __name__ == "__main__": l = loss_dict["loss"] step_loss = l.numpy()[0] - print("[train] loss: {:<8.6f}".format(step_loss)) + print("[train] global_step: {} loss: {:<8.6f}".format(global_step, + step_loss)) l.backward() optim.minimize(l, grad_clip=clipper) @@ -234,11 +234,9 @@ if __name__ == "__main__": if global_step % eval_interval == 0: # evaluate on valid dataset - valid_model(model, valid_loader, state_dir, global_step, - sample_rate) + eval_model(model, valid_loader, state_dir, global_step, + sample_rate) if global_step % checkpoint_interval == 0: - io.save_latest_parameters(checkpoint_dir, global_step, model, - optim) - io.save_latest_checkpoint(checkpoint_dir, global_step) + io.save_parameters(checkpoint_dir, global_step, model, optim) global_step += 1 diff --git a/examples/clarinet/utils.py b/examples/clarinet/utils.py index 2c3e18428672ccb29110aba3e3bca5d1a377690f..1cbc1b6b4ebee20d241d24a0a4821be49f5320bc 100644 --- a/examples/clarinet/utils.py +++ b/examples/clarinet/utils.py @@ -32,12 +32,12 @@ def make_output_tree(output_dir): os.makedirs(state_dir) -def valid_model(model, valid_loader, output_dir, global_step, sample_rate): +def eval_model(model, valid_loader, output_dir, iteration, sample_rate): model.eval() for i, batch in enumerate(valid_loader): # print("sentence {}".format(i)) path = os.path.join(output_dir, - "step_{}_sentence_{}.wav".format(global_step, i)) + "sentence_{}_step_{}.wav".format(i, iteration)) audio_clips, mel_specs, audio_starts = batch wav_var = model.synthesis(mel_specs) wav_np = wav_var.numpy()[0] @@ -45,42 +45,6 @@ def valid_model(model, valid_loader, output_dir, global_step, sample_rate): print("generated {}".format(path)) -def eval_model(model, valid_loader, output_dir, sample_rate): - model.eval() - for i, batch in enumerate(valid_loader): - # print("sentence {}".format(i)) - path = os.path.join(output_dir, "sentence_{}.wav".format(i)) - audio_clips, mel_specs, audio_starts = batch - wav_var = model.synthesis(mel_specs) - wav_np = wav_var.numpy()[0] - sf.write(path, wav_np, samplerate=sample_rate) - print("generated {}".format(path)) - - -def save_checkpoint(model, optim, checkpoint_dir, global_step): - path = os.path.join(checkpoint_dir, "step_{}".format(global_step)) - dg.save_dygraph(model.state_dict(), path) - print("saving model to {}".format(path + ".pdparams")) - if optim: - dg.save_dygraph(optim.state_dict(), path) - print("saving optimizer to {}".format(path + ".pdopt")) - - -def load_model(model, path): - model_dict, _ = dg.load_dygraph(path) - model.set_dict(model_dict) - print("loaded model from {}.pdparams".format(path)) - - -def load_checkpoint(model, optim, path): - model_dict, optim_dict = dg.load_dygraph(path) - model.set_dict(model_dict) - print("loaded model from {}.pdparams".format(path)) - if optim_dict: - optim.set_dict(optim_dict) - print("loaded optimizer from {}.pdparams".format(path)) - - def load_wavenet(model, path): wavenet_dict, _ = dg.load_dygraph(path) encoder_dict = OrderedDict() diff --git a/examples/deepvoice3/README.md b/examples/deepvoice3/README.md index 7c2ad77cab3677945253f4c69cc0d10f246569de..f1a55df01edc5984a7801276157485a055422a5a 100644 --- a/examples/deepvoice3/README.md +++ b/examples/deepvoice3/README.md @@ -30,29 +30,55 @@ The model consists of an encoder, a decoder and a converter (and a speaker embed └── utils.py utility functions ``` +## Saving & Loading +`train.py` and `synthesis.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`. + +1. `output` is the directory for saving results. +During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`. Other possible outputs are saved in `states/` in `outuput`. +During synthesizing, audio files and other possible outputs are save in `synthesis/` in `output`. +So after training and synthesizing with the same output directory, the file structure of the output directory looks like this. + +```text +├── checkpoints/ # checkpoint directory (including *.pdparams, *.pdopt and a text file `checkpoint` that records the latest checkpoint) +├── states/ # audio files generated at validation and other possible outputs +├── log/ # tensorboard log +└── synthesis/ # synthesized audio files and other possible outputs +``` + +2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule: +If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded. +If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory. + ## Train Train the model using train.py, follow the usage displayed by `python train.py --help`. ```text -usage: train.py [-h] [-c CONFIG] [-s DATA] [--checkpoint CHECKPOINT] - [-o OUTPUT] [-g DEVICE] +usage: train.py [-h] [--config CONFIG] [--data DATA] [--device DEVICE] + [--checkpoint CHECKPOINT | --iteration ITERATION] + output Train a Deep Voice 3 model with LJSpeech dataset. +positional arguments: + output path to save results + optional arguments: - -h, --help show this help message and exit - -c CONFIG, --config CONFIG experimrnt config - -s DATA, --data DATA The path of the LJSpeech dataset. - --checkpoint CHECKPOINT checkpoint to load - -o OUTPUT, --output OUTPUT The directory to save result. - -g DEVICE, --device DEVICE device to use + -h, --help show this help message and exit + --config CONFIG experimrnt config + --data DATA The path of the LJSpeech dataset. + --device DEVICE device to use + --checkpoint CHECKPOINT checkpoint to resume from. + --iteration ITERATION the iteration of the checkpoint to load from output directory ``` - `--config` is the configuration file to use. The provided `ljspeech.yaml` can be used directly. And you can change some values in the configuration file and train the model with a different config. - `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt). -- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig. -- `--output` is the directory to save results, all results are saved in this directory. The structure of the output directory is shown below. +- `--device` is the device (gpu id) to use for training. `-1` means CPU. +- `--checkpoint` is the path of the checkpoint. +- `--iteration` is the iteration of the checkpoint to load from output directory. +See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading. +- `output` is the directory to save results, all results are saved in this directory. The structure of the output directory is shown below. ```text ├── checkpoints # checkpoint @@ -64,14 +90,14 @@ optional arguments: └── waveform # waveform (.wav files) ``` -If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training. - -- `--device` is the device (gpu id) to use for training. `-1` means CPU. - Example script: ```bash -python train.py --config=configs/ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0 +python train.py \ + --config=configs/ljspeech.yaml \ + --data=./LJSpeech-1.1/ \ + --device=0 \ + experiment ``` You can monitor training log via tensorboard, using the script below. @@ -83,31 +109,50 @@ tensorboard --logdir=. ## Synthesis ```text -usage: synthesis.py [-h] [-c CONFIG] [-g DEVICE] checkpoint text output_path +usage: synthesis.py [-h] [--config CONFIG] [--device DEVICE] + [--checkpoint CHECKPOINT | --iteration ITERATION] + text output -Synthsize waveform from a checkpoint. +Synthsize waveform with a checkpoint. positional arguments: - checkpoint checkpoint to load. - text text file to synthesize - output_path path to save results + text text file to synthesize + output path to save synthesized audio optional arguments: - -h, --help show this help message and exit - -c CONFIG, --config CONFIG - experiment config. - -g DEVICE, --device DEVICE - device to use + -h, --help show this help message and exit + --config CONFIG experiment config + --device DEVICE device to use + --checkpoint CHECKPOINT checkpoint to resume from + --iteration ITERATION the iteration of the checkpoint to load from output directory ``` - `--config` is the configuration file to use. You should use the same configuration with which you train you model. -- `checkpoint` is the checkpoint to load. -- `text`is the text file to synthesize. -- `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`) and attention plots (*.png) for each sentence. - `--device` is the device (gpu id) to use for training. `-1` means CPU. +- `--checkpoint` is the path of the checkpoint. +- `--iteration` is the iteration of the checkpoint to load from output directory. +See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading. + +- `text`is the text file to synthesize. +- `output` is the directory to save results. The generated audio files (`*.wav`) and attention plots (*.png) for are save in `synthesis/` in ouput directory. + Example script: ```bash -python synthesis.py --config=configs/ljspeech.yaml --device=0 experiment/checkpoints/model_step_005000000 sentences.txt generated +python synthesis.py \ + --config=configs/ljspeech.yaml \ + --device=0 \ + --checkpoint="experiment/checkpoints/model_step_005000000" \ + sentences.txt experiment +``` + +or + +```bash +python synthesis.py \ + --config=configs/ljspeech.yaml \ + --device=0 \ + --iteration=005000000 \ + sentences.txt experiment ``` diff --git a/examples/deepvoice3/synthesis.py b/examples/deepvoice3/synthesis.py index d3cd9f06cdded50d02c4cfcfad5e1d76d7d32777..b8fb6267fae8d29dc66d34d24517afd6a02febad 100644 --- a/examples/deepvoice3/synthesis.py +++ b/examples/deepvoice3/synthesis.py @@ -27,19 +27,26 @@ from tensorboardX import SummaryWriter from parakeet.g2p import en from parakeet.modules.weight_norm import WeightNormWrapper from parakeet.utils.layer_tools import summary -from parakeet.utils.io import load_parameters +from parakeet.utils import io from utils import make_model, eval_model, plot_alignment if __name__ == "__main__": parser = argparse.ArgumentParser( description="Synthsize waveform with a checkpoint.") - parser.add_argument("-c", "--config", type=str, help="experiment config.") - parser.add_argument("checkpoint", type=str, help="checkpoint to load.") + parser.add_argument("--config", type=str, help="experiment config") + parser.add_argument("--device", type=int, default=-1, help="device to use") + + g = parser.add_mutually_exclusive_group() + g.add_argument("--checkpoint", type=str, help="checkpoint to resume from") + g.add_argument( + "--iteration", + type=int, + help="the iteration of the checkpoint to load from output directory") + parser.add_argument("text", type=str, help="text file to synthesize") - parser.add_argument("output_path", type=str, help="path to save results") parser.add_argument( - "-g", "--device", type=int, default=-1, help="device to use") + "output", type=str, help="path to save synthesized audio") args = parser.parse_args() with open(args.config, 'rt') as f: @@ -103,8 +110,14 @@ if __name__ == "__main__": linear_dim, use_decoder_states, converter_channels, dropout) summary(dv3) - state, _ = dg.load_dygraph(args.checkpoint) - dv3.set_dict(state) + + checkpoint_dir = os.path.join(args.output, "checkpoints") + if args.checkpoint is not None: + iteration = io.load_parameters( + dv3, checkpoint_path=args.checkpoint) + else: + iteration = io.load_parameters( + dv3, checkpoint_dir=checkpoint_dir, iteration=args.iteration) # WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight # removing weight norm also speeds up computation @@ -112,9 +125,6 @@ if __name__ == "__main__": if isinstance(layer, WeightNormWrapper): layer.remove_weight_norm() - if not os.path.exists(args.output_path): - os.makedirs(args.output_path) - transform_config = config["transform"] c = transform_config["replace_pronunciation_prob"] sample_rate = transform_config["sample_rate"] @@ -128,6 +138,10 @@ if __name__ == "__main__": power = synthesis_config["power"] n_iter = synthesis_config["n_iter"] + synthesis_dir = os.path.join(args.output, "synthesis") + if not os.path.exists(synthesis_dir): + os.makedirs(synthesis_dir) + with open(args.text, "rt", encoding="utf-8") as f: lines = f.readlines() for idx, line in enumerate(lines): @@ -139,7 +153,9 @@ if __name__ == "__main__": preemphasis) plot_alignment( attn, - os.path.join(args.output_path, "test_{}.png".format(idx))) + os.path.join(synthesis_dir, + "test_{}_step_{}.png".format(idx, iteration))) sf.write( - os.path.join(args.output_path, "test_{}.wav".format(idx)), + os.path.join(synthesis_dir, + "test_{}_step{}.wav".format(idx, iteration)), wav, sample_rate) diff --git a/examples/deepvoice3/train.py b/examples/deepvoice3/train.py index 6e0a9ba9c1d8c084ffb351fb53d3395fbbf4dd93..d363e6f7a58fd7c059099df81a9137c90302c717 100644 --- a/examples/deepvoice3/train.py +++ b/examples/deepvoice3/train.py @@ -45,22 +45,24 @@ from utils import make_model, eval_model, save_state, make_output_tree, plot_ali if __name__ == "__main__": parser = argparse.ArgumentParser( description="Train a Deep Voice 3 model with LJSpeech dataset.") - parser.add_argument("-c", "--config", type=str, help="experimrnt config") + parser.add_argument("--config", type=str, help="experimrnt config") parser.add_argument( - "-s", "--data", type=str, default="/workspace/datasets/LJSpeech-1.1/", help="The path of the LJSpeech dataset.") - parser.add_argument("--checkpoint", type=str, help="checkpoint to load") - parser.add_argument( - "-o", - "--output", - type=str, - default="result", - help="The directory to save result.") + parser.add_argument("--device", type=int, default=-1, help="device to use") + + g = parser.add_mutually_exclusive_group() + g.add_argument("--checkpoint", type=str, help="checkpoint to resume from.") + g.add_argument( + "--iteration", + type=int, + help="the iteration of the checkpoint to load from output directory") + parser.add_argument( - "-g", "--device", type=int, default=-1, help="device to use") + "output", type=str, default="experiment", help="path to save results") + args, _ = parser.parse_known_args() with open(args.config, 'rt') as f: config = ruamel.yaml.safe_load(f) @@ -216,11 +218,12 @@ if __name__ == "__main__": writer = SummaryWriter(logdir=log_dir) # load parameters and optimizer, and opdate iterations done sofar - io.load_parameters(ckpt_dir, 0, dv3, optim, file_path=args.checkpoint) if args.checkpoint is not None: - iteration = int(os.path.basename(args.checkpoint).split("-")[-1]) + iteration = io.load_parameters( + dv3, optim, checkpoint_path=args.checkpoint) else: - iteration = io.load_latest_checkpoint(ckpt_dir) + iteration = io.load_parameters( + dv3, optim, checkpoint_dir=ckpt_dir, iteration=args.iteration) # =========================train========================= max_iter = train_config["max_iteration"] @@ -325,7 +328,6 @@ if __name__ == "__main__": # save checkpoint if global_step % save_interval == 0: - io.save_latest_parameters(ckpt_dir, global_step, dv3, optim) - io.save_latest_checkpoint(ckpt_dir, global_step) + io.save_parameters(ckpt_dir, global_step, dv3, optim) global_step += 1 diff --git a/examples/wavenet/README.md b/examples/wavenet/README.md index af34457e69461edc1e32d38e4497bf5f51776851..42defe7c2b6e625995d3511c7c5a0f0655055225 100644 --- a/examples/wavenet/README.md +++ b/examples/wavenet/README.md @@ -22,43 +22,67 @@ tar xjvf LJSpeech-1.1.tar.bz2 └── utils.py utility functions ``` +## Saving & Loading +`train.py` and `synthesis.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`. + +1. `output` is the directory for saving results. +During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`. Other possible outputs are saved in `states/` in `outuput`. +During synthesizing, audio files and other possible outputs are save in `synthesis/` in `output`. +So after training and synthesizing with the same output directory, the file structure of the output directory looks like this. + +```text +├── checkpoints/ # checkpoint directory (including *.pdparams, *.pdopt and a text file `checkpoint` that records the latest checkpoint) +├── states/ # audio files generated at validation and other possible outputs +├── log/ # tensorboard log +└── synthesis/ # synthesized audio files and other possible outputs +``` + +2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule: +If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded. +If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory. + ## Train Train the model using train.py. For help on usage, try `python train.py --help`. ```text -usage: train.py [-h] [--data DATA] [--config CONFIG] [--output OUTPUT] - [--device DEVICE] [--checkpoint CHECKPOINT] +usage: train.py [-h] [--data DATA] [--config CONFIG] [--device DEVICE] + [--checkpoint CHECKPOINT | --iteration ITERATION] + output Train a WaveNet model with LJSpeech. +positional arguments: + output path to save results + optional arguments: - -h, --help show this help message and exit - --data DATA path of the LJspeech dataset. - --config CONFIG path of the config file. - --output OUTPUT path to save results. - --device DEVICE device to use. - --checkpoint CHECKPOINT checkpoint to resume from. + -h, --help show this help message and exit + --data DATA path of the LJspeech dataset + --config CONFIG path of the config file + --device DEVICE device to use + --checkpoint CHECKPOINT checkpoint to resume from + --iteration ITERATION the iteration of the checkpoint to load from output directory ``` -- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config. - `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt). -- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before training. -- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below. +- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config. +- `--device` is the device (gpu id) to use for training. `-1` means CPU. -```text -├── checkpoints # checkpoint -└── log # tensorboard log -``` +- `--checkpoint` is the path of the checkpoint. +- `--iteration` is the iteration of the checkpoint to load from output directory. +- `output` is the directory to save results, all result are saved in this directory. -If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training. +See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading. -- `--device` is the device (gpu id) to use for training. `-1` means CPU. Example script: ```bash -python train.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0 +python train.py \ + --config=./configs/wavenet_single_gaussian.yaml \ + --data=./LJSpeech-1.1/ \ + --device=0 \ + experiment ``` You can monitor training log via TensorBoard, using the script below. @@ -71,29 +95,50 @@ tensorboard --logdir=. ## Synthesis ```text usage: synthesis.py [-h] [--data DATA] [--config CONFIG] [--device DEVICE] - checkpoint output + [--checkpoint CHECKPOINT | --iteration ITERATION] + output -Synthesize valid data from LJspeech with a WaveNet model. +Synthesize valid data from LJspeech with a wavenet model. positional arguments: - checkpoint checkpoint to load. - output path to save results. + output path to save the synthesized audio optional arguments: - -h, --help show this help message and exit - --data DATA path of the LJspeech dataset. - --config CONFIG path of the config file. - --device DEVICE device to use. + -h, --help show this help message and exit + --data DATA path of the LJspeech dataset + --config CONFIG path of the config file + --device DEVICE device to use + --checkpoint CHECKPOINT checkpoint to resume from + --iteration ITERATION the iteration of the checkpoint to load from output directory ``` +- `--data` is the path of the LJspeech dataset. In principle, a dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files. - `--config` is the configuration file to use. You should use the same configuration with which you train you model. -- `--data` is the path of the LJspeech dataset. A dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files. -- `checkpoint` is the checkpoint to load. -- `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`). - `--device` is the device (gpu id) to use for training. `-1` means CPU. +- `--checkpoint` is the checkpoint to load. +- `--iteration` is the iteration of the checkpoint to load from output directory. +- `output` is the directory to save synthesized audio. Audio file is saved in `synthesis/` in `output` directory. +See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading. + Example script: ```bash -python synthesis.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --device=0 experiment/checkpoints/step_500000 generated +python synthesis.py \ + --config=./configs/wavenet_single_gaussian.yaml \ + --data=./LJSpeech-1.1/ \ + --device=0 \ + --checkpoint="experiment/checkpoints/step-1000000" \ + experiment +``` + +or + +```bash +python synthesis.py \ + --config=./configs/wavenet_single_gaussian.yaml \ + --data=./LJSpeech-1.1/ \ + --device=0 \ + --iteration=1000000 \ + experiment ``` diff --git a/examples/wavenet/synthesis.py b/examples/wavenet/synthesis.py index 5edb1edbe8c83cf665dff60e712fdd4100ea2ad0..65c81dddc89c1abf0a8648b52b37f9da6475438e 100644 --- a/examples/wavenet/synthesis.py +++ b/examples/wavenet/synthesis.py @@ -25,6 +25,7 @@ from parakeet.modules.weight_norm import WeightNormWrapper from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet from parakeet.utils.layer_tools import summary +from parakeet.utils import io from data import LJSpeechMetaData, Transform, DataCollector from utils import make_output_tree, valid_model, eval_model @@ -33,14 +34,22 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( description="Synthesize valid data from LJspeech with a wavenet model.") parser.add_argument( - "--data", type=str, help="path of the LJspeech dataset.") - parser.add_argument("--config", type=str, help="path of the config file.") - parser.add_argument( - "--device", type=int, default=-1, help="device to use.") + "--data", type=str, help="path of the LJspeech dataset") + parser.add_argument("--config", type=str, help="path of the config file") + parser.add_argument("--device", type=int, default=-1, help="device to use") + + g = parser.add_mutually_exclusive_group() + g.add_argument("--checkpoint", type=str, help="checkpoint to resume from") + g.add_argument( + "--iteration", + type=int, + help="the iteration of the checkpoint to load from output directory") - parser.add_argument("checkpoint", type=str, help="checkpoint to load.") parser.add_argument( - "output", type=str, default="experiment", help="path to save results.") + "output", + type=str, + default="experiment", + help="path to save the synthesized audio") args = parser.parse_args() with open(args.config, 'rt') as f: @@ -112,9 +121,15 @@ if __name__ == "__main__": model = ConditionalWavenet(encoder, decoder) summary(model) - model_dict, _ = dg.load_dygraph(args.checkpoint) - print("Loading from {}.pdparams".format(args.checkpoint)) - model.set_dict(model_dict) + # load model parameters + checkpoint_dir = os.path.join(args.output, "checkpoints") + if args.checkpoint: + iteration = io.load_parameters( + model, checkpoint_path=args.checkpoint) + else: + iteration = io.load_parameters( + model, checkpoint_dir=checkpoint_dir, iteration=args.iteration) + assert iteration > 0, "A trained model is needed." # WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight # removing weight norm also speeds up computation @@ -130,4 +145,8 @@ if __name__ == "__main__": capacity=10, return_list=True) valid_loader.set_batch_generator(valid_cargo, place) - eval_model(model, valid_loader, args.output, sample_rate) + synthesis_dir = os.path.join(args.output, "synthesis") + if not os.path.exists(synthesis_dir): + os.makedirs(synthesis_dir) + + eval_model(model, valid_loader, synthesis_dir, iteration, sample_rate) diff --git a/examples/wavenet/train.py b/examples/wavenet/train.py index 3fdfaeb7db12518980e5bbd6aee4f5fc4e0f230b..14b861bc86dd7ea8add599e20e6ca4fe75194fca 100644 --- a/examples/wavenet/train.py +++ b/examples/wavenet/train.py @@ -33,17 +33,19 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( description="Train a WaveNet model with LJSpeech.") parser.add_argument( - "--data", type=str, help="path of the LJspeech dataset.") - parser.add_argument("--config", type=str, help="path of the config file.") - parser.add_argument( - "--output", - type=str, - default="experiment", - help="path to save results.") - parser.add_argument( - "--device", type=int, default=-1, help="device to use.") + "--data", type=str, help="path of the LJspeech dataset") + parser.add_argument("--config", type=str, help="path of the config file") + parser.add_argument("--device", type=int, default=-1, help="device to use") + + g = parser.add_mutually_exclusive_group() + g.add_argument("--checkpoint", type=str, help="checkpoint to resume from") + g.add_argument( + "--iteration", + type=int, + help="the iteration of the checkpoint to load from output directory") + parser.add_argument( - "--checkpoint", type=str, help="checkpoint to resume from.") + "output", type=str, default="experiment", help="path to save results") args = parser.parse_args() with open(args.config, 'rt') as f: @@ -148,17 +150,19 @@ if __name__ == "__main__": writer = SummaryWriter(log_dir) # load parameters and optimizer, and opdate iterations done sofar - io.load_parameters( - checkpoint_dir, 0, model, optim, file_path=args.checkpoint) if args.checkpoint is not None: - iteration = int(os.path.basename(args.checkpoint).split("-")[-1]) + iteration = io.load_parameters( + model, optim, checkpoint_path=args.checkpoint) else: - iteration = io.load_latest_checkpoint(checkpoint_dir) + iteration = io.load_parameters( + model, + optim, + checkpoint_dir=checkpoint_dir, + iteration=args.iteration) global_step = iteration + 1 iterator = iter(tqdm.tqdm(train_loader)) while global_step <= max_iterations: - print(global_step) try: batch = next(iterator) except StopIteration as e: @@ -187,8 +191,6 @@ if __name__ == "__main__": sample_rate) if global_step % checkpoint_interval == 0: - io.save_latest_parameters(checkpoint_dir, global_step, model, - optim) - io.save_latest_checkpoint(checkpoint_dir, global_step) + io.save_parameters(checkpoint_dir, global_step, model, optim) global_step += 1 diff --git a/examples/wavenet/utils.py b/examples/wavenet/utils.py index cb71acd9e062d3026ea622fc419801ded39d9902..b6037706f5c718559554b0b9af8d34cc33ac119f 100644 --- a/examples/wavenet/utils.py +++ b/examples/wavenet/utils.py @@ -49,11 +49,12 @@ def valid_model(model, valid_loader, writer, global_step, sample_rate): sample_rate) -def eval_model(model, valid_loader, output_dir, sample_rate): +def eval_model(model, valid_loader, output_dir, global_step, sample_rate): model.eval() for i, batch in enumerate(valid_loader): # print("sentence {}".format(i)) - path = os.path.join(output_dir, "sentence_{}.wav".format(i)) + path = os.path.join(output_dir, + "sentence_{}_step_{}.wav".format(i, global_step)) audio_clips, mel_specs, audio_starts = batch wav_var = model.synthesis(mel_specs) wav_np = wav_var.numpy()[0] diff --git a/parakeet/utils/io.py b/parakeet/utils/io.py index e6124008d9e5964b8cce91ae6bd66a69cf061d06..959dbfb9e1f9724fb2349c03b570c1b117969eeb 100644 --- a/parakeet/utils/io.py +++ b/parakeet/utils/io.py @@ -20,6 +20,11 @@ import numpy as np import paddle.fluid.dygraph as dg +def is_main_process(): + local_rank = dg.parallel.Env().local_rank + return local_rank == 0 + + def add_yaml_config_to_args(config): """ Add args in yaml config to the args parsed by argparse. The argument in yaml config will be overwritten by the same argument in argparse if they @@ -41,7 +46,7 @@ def add_yaml_config_to_args(config): return config -def load_latest_checkpoint(checkpoint_dir, rank=0): +def _load_latest_checkpoint(checkpoint_dir): """Get the iteration number corresponding to the latest saved checkpoint Args: @@ -52,26 +57,20 @@ def load_latest_checkpoint(checkpoint_dir, rank=0): Returns: int: the latest iteration number. """ - checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") # Create checkpoint index file if not exist. - if (not os.path.isfile(checkpoint_path)) and rank == 0: - with open(checkpoint_path, "w") as handle: - handle.write("model_checkpoint_path: step-0") - - # Make sure that other process waits until checkpoint file is created - # by process 0. - while not os.path.isfile(checkpoint_path): - time.sleep(1) + if (not os.path.isfile(checkpoint_record)): + return 0 # Fetch the latest checkpoint index. - with open(checkpoint_path, "r") as handle: + with open(checkpoint_record, "r") as handle: latest_checkpoint = handle.readline().split()[-1] iteration = int(latest_checkpoint.split("-")[-1]) return iteration -def save_latest_checkpoint(checkpoint_dir, iteration): +def _save_checkpoint(checkpoint_dir, iteration): """Save the iteration number of the latest model to be checkpointed. Args: @@ -81,60 +80,73 @@ def save_latest_checkpoint(checkpoint_dir, iteration): Returns: None """ - checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") # Update the latest checkpoint index. - with open(checkpoint_path, "w") as handle: + with open(checkpoint_record, "w") as handle: handle.write("model_checkpoint_path: step-{}".format(iteration)) -def load_parameters(checkpoint_dir, - rank, - model, +def load_parameters(model, optimizer=None, + checkpoint_dir=None, iteration=None, - file_path=None, + checkpoint_path=None, dtype="float32"): """Load a specific model checkpoint from disk. Args: - checkpoint_dir (str): the directory where checkpoint is saved. - rank (int): the rank of the process in multi-process setting. model (obj): model to load parameters. optimizer (obj, optional): optimizer to load states if needed. Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. iteration (int, optional): if specified, load the specific checkpoint, if not specified, load the latest one. Defaults to None. - file_path (str, optional): if specified, load the checkpoint - stored in the file_path. Defaults to None. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path. Defaults to None. dtype (str, optional): precision of the model parameters. Defaults to float32. Returns: - None + iteration (int): number of iterations that the loaded checkpoint has + been trained. """ - if file_path is None: + + if iteration is not None and checkpoint_dir is None: + raise ValueError( + "When iteration is specified, checkpoint_dir should not be None") + + if checkpoint_path is not None: + # checkpoint is not None + iteration = int(os.path.basename(checkpoint_path).split("-")[-1]) + else: if iteration is None: - iteration = load_latest_checkpoint(checkpoint_dir, rank) - if iteration == 0: - return - file_path = "{}/step-{}".format(checkpoint_dir, iteration) - - model_dict, optimizer_dict = dg.load_dygraph(file_path) - if dtype == "float16": - for k, v in model_dict.items(): - if "conv2d_transpose" in k: - model_dict[k] = v.astype("float32") - else: - model_dict[k] = v.astype(dtype) + iteration = _load_latest_checkpoint(checkpoint_dir) + checkpoint_path = os.path.join(checkpoint_dir, + "step-{}".format(iteration)) + if iteration == 0 and not os.path.exists(checkpoint_path): + # if step-0 exist, it is also loaded + return iteration + + local_rank = dg.parallel.Env().local_rank + model_dict, optimizer_dict = dg.load_dygraph(checkpoint_path) + + # cast to desired data type + for k, v in model_dict.items(): + model_dict[k] = v.astype(dtype) + model.set_dict(model_dict) - print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path)) + print("[checkpoint] Rank {}: loaded model from {}.pdparams".format( + local_rank, checkpoint_path)) + if optimizer and optimizer_dict: optimizer.set_dict(optimizer_dict) - print("[checkpoint] Rank {}: loaded optimizer state from {}".format( - rank, file_path)) + print("[checkpoint] Rank {}: loaded optimizer state from {}.pdopt". + format(local_rank, checkpoint_path)) + + return iteration -def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None): +def save_parameters(checkpoint_dir, iteration, model, optimizer=None): """Checkpoint the latest trained model parameters. Args: @@ -147,12 +159,15 @@ def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None): Returns: None """ - file_path = "{}/step-{}".format(checkpoint_dir, iteration) + checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration)) model_dict = model.state_dict() - dg.save_dygraph(model_dict, file_path) - print("[checkpoint] Saved model to {}".format(file_path)) + dg.save_dygraph(model_dict, checkpoint_path) + print("[checkpoint] Saved model to {}.pdparams".format(checkpoint_path)) if optimizer: opt_dict = optimizer.state_dict() - dg.save_dygraph(opt_dict, file_path) - print("[checkpoint] Saved optimzier state to {}".format(file_path)) + dg.save_dygraph(opt_dict, checkpoint_path) + print("[checkpoint] Saved optimzier state to {}.pdopt".format( + checkpoint_path)) + + _save_checkpoint(checkpoint_dir, iteration)