提交 cf17b649 编写于 作者: C chenfeiyu

Merge branch 'master' of upstream

......@@ -76,29 +76,61 @@ Entries to the introduction, and the launch of training and synthsis for differe
Parakeet also releases some well-trained parameters for the example models, which can be accessed in the following tables. Each column of these tables lists resources for one model, including the url link to the pre-trained model, the dataset that the model is trained on and the total training steps, and several synthesized audio samples based on the pre-trained model.
- Vocoders
#### Vocoders
We provide the model checkpoints of WaveFlow with 64 and 128 residual channels, ClariNet and WaveNet.
<div align="center">
<table>
<thead>
<tr>
<th style="width: 250px">
WaveFlow
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_res64_ljspeech_ckpt_1.0.zip">WaveFlow (res. channels 64)</a>
</th>
<th style="width: 250px">
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/clarinet_ljspeech_ckpt_1.0.zip">ClariNet</a>
WaveFlow (res. channels 128)
</th>
</tr>
</thead>
<tbody>
<tr>
<th>LJSpeech, 2M</th>
<th>LJSpeech, 500K</th>
<th>LJSpeech, 3020 K</th>
<th>LJSpeech </th>
</tr>
<tr>
<th>
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_res64_ljspeech_samples_1.0/step_3020k_sentence_0.wav">
<img src="images/audio_icon.png" width=250 /></a><br>
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_res64_ljspeech_samples_1.0/step_3020k_sentence_1.wav">
<img src="images/audio_icon.png" width=250 /></a><br>
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_res64_ljspeech_samples_1.0/step_3020k_sentence_2.wav">
<img src="images/audio_icon.png" width=250 /></a><br>
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_res64_ljspeech_samples_1.0/step_3020k_sentence_3.wav">
<img src="images/audio_icon.png" width=250 /></a><br>
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_res64_ljspeech_samples_1.0/step_3020k_sentence_4.wav">
<img src="images/audio_icon.png" width=250 /></a>
</th>
<th>
To be added soon
</th>
</tr>
</tbody>
<thead>
<tr>
<th style="width: 250px">
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/clarinet_ljspeech_ckpt_1.0.zip">ClariNet</a>
</th>
<th style="width: 250px">
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/wavenet_ljspeech_ckpt_1.0.zip">WaveNet</a>
</th>
</tr>
</thead>
<tbody>
<tr>
<th>LJSpeech, 500 K</th>
<th>LJSpeech, 2450 K</th>
</tr>
<tr>
<th>
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/clarinet_ljspeech_samples_1.0/step_500000_sentence_0.wav">
<img src="images/audio_icon.png" width=250 /></a><br>
......@@ -111,15 +143,57 @@ Parakeet also releases some well-trained parameters for the example models, whic
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/clarinet_ljspeech_samples_1.0/step_500000_sentence_4.wav">
<img src="images/audio_icon.png" width=250 /></a>
</th>
<th>
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/wavenet_ljspeech_samples_1.0/step_2450k_sentence_0.wav">
<img src="images/audio_icon.png" width=250 /></a><br>
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/wavenet_ljspeech_samples_1.0/step_2450k_sentence_1.wav">
<img src="images/audio_icon.png" width=250 /></a><br>
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/wavenet_ljspeech_samples_1.0/step_2450k_sentence_2.wav">
<img src="images/audio_icon.png" width=250 /></a><br>
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/wavenet_ljspeech_samples_1.0/step_2450k_sentence_3.wav">
<img src="images/audio_icon.png" width=250 /></a><br>
<a href="https://paddlespeech.bj.bcebos.com/Parakeet/wavenet_ljspeech_samples_1.0/step_2450k_sentence_4.wav">
<img src="images/audio_icon.png" width=250 /></a>
</th>
</tr>
</tbody>
</table>
</div>
&nbsp;&nbsp;&nbsp;&nbsp;**Note:** The input mel spectrogams are from validation dataset, which are not seen during training.
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;**Note:** The input mel spectrogams are from validation dataset, which are not seen during training.
- TTS models
#### TTS models
<div align="center">
<table>
<thead>
<tr>
<th style="width: 250px">
Deep Voice 3
</th>
<th style="width: 250px">
Transformer TTS
</th>
</tr>
</thead>
<tbody>
<tr>
<th>LJSpeech </th>
<th>LJSpeech </th>
</tr>
<tr>
<th style="height: 150px">
To be added soon
</th>
<th >
To be added soon
</th>
</tr>
</tbody>
<thead>
</table>
</div>
Click each link to download, then you can get the compressed package which contains the pre-trained model and the `yaml` config describing how to train the model.
......
......@@ -22,49 +22,71 @@ tar xjvf LJSpeech-1.1.tar.bz2
└── utils.py utility functions
```
## Train
## Saving & Loading
`train.py` and `synthesis.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`.
Train the model using train.py, follow the usage displayed by `python train.py --help`.
1. `output` is the directory for saving results.
During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`. Other possible outputs are saved in `states/` in `outuput`.
During synthesizing, audio files and other possible outputs are save in `synthesis/` in `output`.
So after training and synthesizing with the same output directory, the file structure of the output directory looks like this.
```text
usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--output OUTPUT]
[--data DATA] [--checkpoint CHECKPOINT] [--wavenet WAVENET]
├── checkpoints/ # checkpoint directory (including *.pdparams, *.pdopt and a text file `checkpoint` that records the latest checkpoint)
├── states/ # audio files generated at validation and other possible outputs
├── log/ # tensorboard log
└── synthesis/ # synthesized audio files and other possible outputs
```
train a ClariNet model with LJspeech and a trained WaveNet model.
2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule:
If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded.
If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory.
optional arguments:
-h, --help show this help message and exit
--config CONFIG path of the config file.
--device DEVICE device to use.
--output OUTPUT path to save student.
--data DATA path of LJspeech dataset.
--checkpoint CHECKPOINT checkpoint to load from.
--wavenet WAVENET wavenet checkpoint to use.
```
## Train
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
Train the model using train.py, follow the usage displayed by `python train.py --help`.
```text
├── checkpoints # checkpoint
├── states # audio files generated at validation
└── log # tensorboard log
```
usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--data DATA]
[--checkpoint CHECKPOINT | --iteration ITERATION]
[--wavenet WAVENET]
output
Train a ClariNet model with LJspeech and a trained WaveNet model.
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
positional arguments:
output path to save experiment results
optional arguments:
-h, --help show this help message and exit
--config CONFIG path of the config file
--device DEVICE device to use
--data DATA path of LJspeech dataset
--checkpoint CHECKPOINT checkpoint to resume from
--iteration ITERATION the iteration of the checkpoint to load from output directory
--wavenet WAVENET wavenet checkpoint to use
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
- `--wavenet` is the path of the wavenet checkpoint to load. If you do not specify `--resume`, then this must be provided.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains `metadata.txt`).
- `--checkpoint` is the path of the checkpoint.
- `--iteration` is the iteration of the checkpoint to load from output directory.
- `output` is the directory to save results, all result are saved in this directory.
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
Before you start training a ClariNet model, you should have trained a WaveNet model with single Gaussian output distribution. Make sure the config of the teacher model matches that of the trained model.
- `--wavenet` is the path of the wavenet checkpoint to load.
When you start training a ClariNet model without loading form a ClariNet checkpoint, you should have trained a WaveNet model with single Gaussian output distribution. Make sure the config of the teacher model matches that of the trained wavenet model.
Example script:
```bash
python train.py --config=./configs/clarinet_ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0 --conditioner=wavenet_checkpoint/conditioner --conditioner=wavenet_checkpoint/teacher
python train.py
--config=./configs/clarinet_ljspeech.yaml
--data=./LJSpeech-1.1/
--device=0
--wavenet="wavenet-step-2000000"
experiment
```
You can monitor training log via tensorboard, using the script below.
......@@ -77,29 +99,50 @@ tensorboard --logdir=.
## Synthesis
```text
usage: synthesis.py [-h] [--config CONFIG] [--device DEVICE] [--data DATA]
checkpoint output
[--checkpoint CHECKPOINT | --iteration ITERATION]
output
train a ClariNet model with LJspeech and a trained WaveNet model.
Synthesize audio files from mel spectrogram in the validation set.
positional arguments:
checkpoint checkpoint to load from.
output path to save student.
output path to save the synthesized audio
optional arguments:
-h, --help show this help message and exit
--config CONFIG path of the config file.
--device DEVICE device to use.
--data DATA path of LJspeech dataset.
-h, --help show this help message and exit
--config CONFIG path of the config file
--device DEVICE device to use.
--data DATA path of LJspeech dataset
--checkpoint CHECKPOINT checkpoint to resume from
--iteration ITERATION the iteration of the checkpoint to load from output directory
```
- `--config` is the configuration file to use. You should use the same configuration with which you train you model.
- `--data` is the path of the LJspeech dataset. A dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
- `checkpoint` is the checkpoint to load.
- `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`).
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
- `--data` is the path of the LJspeech dataset. In principle, a dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
- `--checkpoint` is the checkpoint to load.
- `--iteration` is the iteration of the checkpoint to load from output directory.
- `output` is the directory to save synthesized audio. Audio file is saved in `synthesis/` in `output` directory.
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
Example script:
```bash
python synthesis.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --device=0 experiment/checkpoints/step_500000 generated
python synthesis.py \
--config=./configs/wavenet_single_gaussian.yaml \
--data=./LJSpeech-1.1/ \
--device=0 \
--iteration=500000 \
experiment
```
or
```bash
python synthesis.py \
--config=./configs/wavenet_single_gaussian.yaml \
--data=./LJSpeech-1.1/ \
--device=0 \
--checkpoint="experiment/checkpoints/step-500000" \
experiment
```
......@@ -26,29 +26,41 @@ from tensorboardX import SummaryWriter
import paddle.fluid.dygraph as dg
from paddle import fluid
from parakeet.modules.weight_norm import WeightNormWrapper
from parakeet.models.wavenet import WaveNet, UpsampleNet
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
from parakeet.utils.layer_tools import summary, freeze
from parakeet.utils import io
from utils import valid_model, eval_model, load_model
from utils import eval_model
sys.path.append("../wavenet")
from data import LJSpeechMetaData, Transform, DataCollector
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="synthesize audio files from mel spectrogram in the validation set."
description="Synthesize audio files from mel spectrogram in the validation set."
)
parser.add_argument("--config", type=str, help="path of the config file.")
parser.add_argument("--config", type=str, help="path of the config file")
parser.add_argument(
"--device", type=int, default=-1, help="device to use.")
parser.add_argument("--data", type=str, help="path of LJspeech dataset.")
parser.add_argument(
"checkpoint", type=str, help="checkpoint to load from.")
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
g = parser.add_mutually_exclusive_group()
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
g.add_argument(
"--iteration",
type=int,
help="the iteration of the checkpoint to load from output directory")
parser.add_argument(
"output", type=str, default="experiment", help="path to save student.")
"output",
type=str,
default="experiment",
help="path to save the synthesized audio")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
......@@ -136,17 +148,32 @@ if __name__ == "__main__":
model = Clarinet(upsample_net, teacher, student, stft,
student_log_scale_min, lmd)
summary(model)
load_model(model, args.checkpoint)
# loader
train_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
train_loader.set_batch_generator(train_cargo, place)
# load parameters
if args.checkpoint is not None:
# load from args.checkpoint
iteration = io.load_parameters(
model, checkpoint_path=args.checkpoint)
else:
# load from "args.output/checkpoints"
checkpoint_dir = os.path.join(args.output, "checkpoints")
iteration = io.load_parameters(
model, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
assert iteration > 0, "A trained checkpoint is needed."
# make generation fast
for sublayer in model.sublayers():
if isinstance(sublayer, WeightNormWrapper):
sublayer.remove_weight_norm()
# data loader
valid_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
valid_loader.set_batch_generator(valid_cargo, place)
if not os.path.exists(args.output):
os.makedirs(args.output)
eval_model(model, valid_loader, args.output, sample_rate)
# the directory to save audio files
synthesis_dir = os.path.join(args.output, "synthesis")
if not os.path.exists(synthesis_dir):
os.makedirs(synthesis_dir)
eval_model(model, valid_loader, synthesis_dir, iteration, sample_rate)
......@@ -32,27 +32,36 @@ from parakeet.data import TransformDataset, SliceDataset, RandomSampler, Sequent
from parakeet.utils.layer_tools import summary, freeze
from parakeet.utils import io
from utils import make_output_tree, valid_model, load_wavenet
from utils import make_output_tree, eval_model, load_wavenet
# import dataset from wavenet
sys.path.append("../wavenet")
from data import LJSpeechMetaData, Transform, DataCollector
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="train a ClariNet model with LJspeech and a trained WaveNet model."
description="Train a ClariNet model with LJspeech and a trained WaveNet model."
)
parser.add_argument("--config", type=str, help="path of the config file.")
parser.add_argument("--config", type=str, help="path of the config file")
parser.add_argument("--device", type=int, default=-1, help="device to use")
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
g = parser.add_mutually_exclusive_group()
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
g.add_argument(
"--iteration",
type=int,
help="the iteration of the checkpoint to load from output directory")
parser.add_argument(
"--device", type=int, default=-1, help="device to use.")
"--wavenet", type=str, help="wavenet checkpoint to use")
parser.add_argument(
"--output",
"output",
type=str,
default="experiment",
help="path to save student.")
parser.add_argument("--data", type=str, help="path of LJspeech dataset.")
parser.add_argument(
"--checkpoint", type=str, help="checkpoint to load from.")
parser.add_argument(
"--wavenet", type=str, help="wavenet checkpoint to use.")
help="path to save experiment results")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
......@@ -169,30 +178,20 @@ if __name__ == "__main__":
log_dir = os.path.join(args.output, "log")
writer = SummaryWriter(log_dir)
# load wavenet/checkpoint, determine iterations done
if args.checkpoint is not None:
iteration = int(os.path.basename(args.checkpoint).split('-')[-1])
iteration = io.load_parameters(
model, optim, checkpoint_path=args.checkpoint)
else:
iteration = io.load_latest_checkpoint(checkpoint_dir)
if iteration == 0 and args.wavenet is None:
raise Exception(
"you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended."
)
if args.wavenet is not None and iteration > 0:
if args.checkpoint is None:
print("Resume training, --wavenet ignored")
else:
print("--checkpoint provided, --wavenet ignored")
if args.wavenet is not None and iteration == 0:
iteration = io.load_parameters(
model,
optim,
checkpoint_dir=checkpoint_dir,
iteration=args.iteration)
if iteration == 0:
assert args.wavenet is not None, "When training afresh, a trained wavenet model should be provided."
load_wavenet(model, args.wavenet)
# it may overwrite the wavenet loaded
io.load_parameters(
checkpoint_dir, 0, model, optim, file_path=args.checkpoint)
# loader
train_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
......@@ -205,7 +204,7 @@ if __name__ == "__main__":
# training loop
global_step = iteration + 1
iterator = iter(tqdm(train_loader))
while global_step < max_iterations:
while global_step <= max_iterations:
try:
batch = next(iterator)
except StopIteration as e:
......@@ -226,7 +225,8 @@ if __name__ == "__main__":
l = loss_dict["loss"]
step_loss = l.numpy()[0]
print("[train] loss: {:<8.6f}".format(step_loss))
print("[train] global_step: {} loss: {:<8.6f}".format(global_step,
step_loss))
l.backward()
optim.minimize(l, grad_clip=clipper)
......@@ -234,11 +234,9 @@ if __name__ == "__main__":
if global_step % eval_interval == 0:
# evaluate on valid dataset
valid_model(model, valid_loader, state_dir, global_step,
sample_rate)
eval_model(model, valid_loader, state_dir, global_step,
sample_rate)
if global_step % checkpoint_interval == 0:
io.save_latest_parameters(checkpoint_dir, global_step, model,
optim)
io.save_latest_checkpoint(checkpoint_dir, global_step)
io.save_parameters(checkpoint_dir, global_step, model, optim)
global_step += 1
......@@ -32,12 +32,12 @@ def make_output_tree(output_dir):
os.makedirs(state_dir)
def valid_model(model, valid_loader, output_dir, global_step, sample_rate):
def eval_model(model, valid_loader, output_dir, iteration, sample_rate):
model.eval()
for i, batch in enumerate(valid_loader):
# print("sentence {}".format(i))
path = os.path.join(output_dir,
"step_{}_sentence_{}.wav".format(global_step, i))
"sentence_{}_step_{}.wav".format(i, iteration))
audio_clips, mel_specs, audio_starts = batch
wav_var = model.synthesis(mel_specs)
wav_np = wav_var.numpy()[0]
......@@ -45,42 +45,6 @@ def valid_model(model, valid_loader, output_dir, global_step, sample_rate):
print("generated {}".format(path))
def eval_model(model, valid_loader, output_dir, sample_rate):
model.eval()
for i, batch in enumerate(valid_loader):
# print("sentence {}".format(i))
path = os.path.join(output_dir, "sentence_{}.wav".format(i))
audio_clips, mel_specs, audio_starts = batch
wav_var = model.synthesis(mel_specs)
wav_np = wav_var.numpy()[0]
sf.write(path, wav_np, samplerate=sample_rate)
print("generated {}".format(path))
def save_checkpoint(model, optim, checkpoint_dir, global_step):
path = os.path.join(checkpoint_dir, "step_{}".format(global_step))
dg.save_dygraph(model.state_dict(), path)
print("saving model to {}".format(path + ".pdparams"))
if optim:
dg.save_dygraph(optim.state_dict(), path)
print("saving optimizer to {}".format(path + ".pdopt"))
def load_model(model, path):
model_dict, _ = dg.load_dygraph(path)
model.set_dict(model_dict)
print("loaded model from {}.pdparams".format(path))
def load_checkpoint(model, optim, path):
model_dict, optim_dict = dg.load_dygraph(path)
model.set_dict(model_dict)
print("loaded model from {}.pdparams".format(path))
if optim_dict:
optim.set_dict(optim_dict)
print("loaded optimizer from {}.pdparams".format(path))
def load_wavenet(model, path):
wavenet_dict, _ = dg.load_dygraph(path)
encoder_dict = OrderedDict()
......
......@@ -30,29 +30,55 @@ The model consists of an encoder, a decoder and a converter (and a speaker embed
└── utils.py utility functions
```
## Saving & Loading
`train.py` and `synthesis.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`.
1. `output` is the directory for saving results.
During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`. Other possible outputs are saved in `states/` in `outuput`.
During synthesizing, audio files and other possible outputs are save in `synthesis/` in `output`.
So after training and synthesizing with the same output directory, the file structure of the output directory looks like this.
```text
├── checkpoints/ # checkpoint directory (including *.pdparams, *.pdopt and a text file `checkpoint` that records the latest checkpoint)
├── states/ # audio files generated at validation and other possible outputs
├── log/ # tensorboard log
└── synthesis/ # synthesized audio files and other possible outputs
```
2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule:
If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded.
If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory.
## Train
Train the model using train.py, follow the usage displayed by `python train.py --help`.
```text
usage: train.py [-h] [-c CONFIG] [-s DATA] [--checkpoint CHECKPOINT]
[-o OUTPUT] [-g DEVICE]
usage: train.py [-h] [--config CONFIG] [--data DATA] [--device DEVICE]
[--checkpoint CHECKPOINT | --iteration ITERATION]
output
Train a Deep Voice 3 model with LJSpeech dataset.
positional arguments:
output path to save results
optional arguments:
-h, --help show this help message and exit
-c CONFIG, --config CONFIG experimrnt config
-s DATA, --data DATA The path of the LJSpeech dataset.
--checkpoint CHECKPOINT checkpoint to load
-o OUTPUT, --output OUTPUT The directory to save result.
-g DEVICE, --device DEVICE device to use
-h, --help show this help message and exit
--config CONFIG experimrnt config
--data DATA The path of the LJSpeech dataset.
--device DEVICE device to use
--checkpoint CHECKPOINT checkpoint to resume from.
--iteration ITERATION the iteration of the checkpoint to load from output directory
```
- `--config` is the configuration file to use. The provided `ljspeech.yaml` can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
- `--output` is the directory to save results, all results are saved in this directory. The structure of the output directory is shown below.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
- `--checkpoint` is the path of the checkpoint.
- `--iteration` is the iteration of the checkpoint to load from output directory.
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
- `output` is the directory to save results, all results are saved in this directory. The structure of the output directory is shown below.
```text
├── checkpoints # checkpoint
......@@ -64,14 +90,14 @@ optional arguments:
└── waveform # waveform (.wav files)
```
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
Example script:
```bash
python train.py --config=configs/ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0
python train.py \
--config=configs/ljspeech.yaml \
--data=./LJSpeech-1.1/ \
--device=0 \
experiment
```
You can monitor training log via tensorboard, using the script below.
......@@ -83,31 +109,50 @@ tensorboard --logdir=.
## Synthesis
```text
usage: synthesis.py [-h] [-c CONFIG] [-g DEVICE] checkpoint text output_path
usage: synthesis.py [-h] [--config CONFIG] [--device DEVICE]
[--checkpoint CHECKPOINT | --iteration ITERATION]
text output
Synthsize waveform from a checkpoint.
Synthsize waveform with a checkpoint.
positional arguments:
checkpoint checkpoint to load.
text text file to synthesize
output_path path to save results
text text file to synthesize
output path to save synthesized audio
optional arguments:
-h, --help show this help message and exit
-c CONFIG, --config CONFIG
experiment config.
-g DEVICE, --device DEVICE
device to use
-h, --help show this help message and exit
--config CONFIG experiment config
--device DEVICE device to use
--checkpoint CHECKPOINT checkpoint to resume from
--iteration ITERATION the iteration of the checkpoint to load from output directory
```
- `--config` is the configuration file to use. You should use the same configuration with which you train you model.
- `checkpoint` is the checkpoint to load.
- `text`is the text file to synthesize.
- `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`) and attention plots (*.png) for each sentence.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
- `--checkpoint` is the path of the checkpoint.
- `--iteration` is the iteration of the checkpoint to load from output directory.
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
- `text`is the text file to synthesize.
- `output` is the directory to save results. The generated audio files (`*.wav`) and attention plots (*.png) for are save in `synthesis/` in ouput directory.
Example script:
```bash
python synthesis.py --config=configs/ljspeech.yaml --device=0 experiment/checkpoints/model_step_005000000 sentences.txt generated
python synthesis.py \
--config=configs/ljspeech.yaml \
--device=0 \
--checkpoint="experiment/checkpoints/model_step_005000000" \
sentences.txt experiment
```
or
```bash
python synthesis.py \
--config=configs/ljspeech.yaml \
--device=0 \
--iteration=005000000 \
sentences.txt experiment
```
......@@ -27,19 +27,26 @@ from tensorboardX import SummaryWriter
from parakeet.g2p import en
from parakeet.modules.weight_norm import WeightNormWrapper
from parakeet.utils.layer_tools import summary
from parakeet.utils.io import load_parameters
from parakeet.utils import io
from utils import make_model, eval_model, plot_alignment
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Synthsize waveform with a checkpoint.")
parser.add_argument("-c", "--config", type=str, help="experiment config.")
parser.add_argument("checkpoint", type=str, help="checkpoint to load.")
parser.add_argument("--config", type=str, help="experiment config")
parser.add_argument("--device", type=int, default=-1, help="device to use")
g = parser.add_mutually_exclusive_group()
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
g.add_argument(
"--iteration",
type=int,
help="the iteration of the checkpoint to load from output directory")
parser.add_argument("text", type=str, help="text file to synthesize")
parser.add_argument("output_path", type=str, help="path to save results")
parser.add_argument(
"-g", "--device", type=int, default=-1, help="device to use")
"output", type=str, help="path to save synthesized audio")
args = parser.parse_args()
with open(args.config, 'rt') as f:
......@@ -103,8 +110,14 @@ if __name__ == "__main__":
linear_dim, use_decoder_states, converter_channels, dropout)
summary(dv3)
state, _ = dg.load_dygraph(args.checkpoint)
dv3.set_dict(state)
checkpoint_dir = os.path.join(args.output, "checkpoints")
if args.checkpoint is not None:
iteration = io.load_parameters(
dv3, checkpoint_path=args.checkpoint)
else:
iteration = io.load_parameters(
dv3, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
# removing weight norm also speeds up computation
......@@ -112,9 +125,6 @@ if __name__ == "__main__":
if isinstance(layer, WeightNormWrapper):
layer.remove_weight_norm()
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
transform_config = config["transform"]
c = transform_config["replace_pronunciation_prob"]
sample_rate = transform_config["sample_rate"]
......@@ -128,6 +138,10 @@ if __name__ == "__main__":
power = synthesis_config["power"]
n_iter = synthesis_config["n_iter"]
synthesis_dir = os.path.join(args.output, "synthesis")
if not os.path.exists(synthesis_dir):
os.makedirs(synthesis_dir)
with open(args.text, "rt", encoding="utf-8") as f:
lines = f.readlines()
for idx, line in enumerate(lines):
......@@ -139,7 +153,9 @@ if __name__ == "__main__":
preemphasis)
plot_alignment(
attn,
os.path.join(args.output_path, "test_{}.png".format(idx)))
os.path.join(synthesis_dir,
"test_{}_step_{}.png".format(idx, iteration)))
sf.write(
os.path.join(args.output_path, "test_{}.wav".format(idx)),
os.path.join(synthesis_dir,
"test_{}_step{}.wav".format(idx, iteration)),
wav, sample_rate)
......@@ -45,22 +45,24 @@ from utils import make_model, eval_model, save_state, make_output_tree, plot_ali
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a Deep Voice 3 model with LJSpeech dataset.")
parser.add_argument("-c", "--config", type=str, help="experimrnt config")
parser.add_argument("--config", type=str, help="experimrnt config")
parser.add_argument(
"-s",
"--data",
type=str,
default="/workspace/datasets/LJSpeech-1.1/",
help="The path of the LJSpeech dataset.")
parser.add_argument("--checkpoint", type=str, help="checkpoint to load")
parser.add_argument(
"-o",
"--output",
type=str,
default="result",
help="The directory to save result.")
parser.add_argument("--device", type=int, default=-1, help="device to use")
g = parser.add_mutually_exclusive_group()
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from.")
g.add_argument(
"--iteration",
type=int,
help="the iteration of the checkpoint to load from output directory")
parser.add_argument(
"-g", "--device", type=int, default=-1, help="device to use")
"output", type=str, default="experiment", help="path to save results")
args, _ = parser.parse_known_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
......@@ -216,11 +218,12 @@ if __name__ == "__main__":
writer = SummaryWriter(logdir=log_dir)
# load parameters and optimizer, and opdate iterations done sofar
io.load_parameters(ckpt_dir, 0, dv3, optim, file_path=args.checkpoint)
if args.checkpoint is not None:
iteration = int(os.path.basename(args.checkpoint).split("-")[-1])
iteration = io.load_parameters(
dv3, optim, checkpoint_path=args.checkpoint)
else:
iteration = io.load_latest_checkpoint(ckpt_dir)
iteration = io.load_parameters(
dv3, optim, checkpoint_dir=ckpt_dir, iteration=args.iteration)
# =========================train=========================
max_iter = train_config["max_iteration"]
......@@ -325,7 +328,6 @@ if __name__ == "__main__":
# save checkpoint
if global_step % save_interval == 0:
io.save_latest_parameters(ckpt_dir, global_step, dv3, optim)
io.save_latest_checkpoint(ckpt_dir, global_step)
io.save_parameters(ckpt_dir, global_step, dv3, optim)
global_step += 1
......@@ -22,43 +22,67 @@ tar xjvf LJSpeech-1.1.tar.bz2
└── utils.py utility functions
```
## Saving & Loading
`train.py` and `synthesis.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`.
1. `output` is the directory for saving results.
During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`. Other possible outputs are saved in `states/` in `outuput`.
During synthesizing, audio files and other possible outputs are save in `synthesis/` in `output`.
So after training and synthesizing with the same output directory, the file structure of the output directory looks like this.
```text
├── checkpoints/ # checkpoint directory (including *.pdparams, *.pdopt and a text file `checkpoint` that records the latest checkpoint)
├── states/ # audio files generated at validation and other possible outputs
├── log/ # tensorboard log
└── synthesis/ # synthesized audio files and other possible outputs
```
2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule:
If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded.
If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory.
## Train
Train the model using train.py. For help on usage, try `python train.py --help`.
```text
usage: train.py [-h] [--data DATA] [--config CONFIG] [--output OUTPUT]
[--device DEVICE] [--checkpoint CHECKPOINT]
usage: train.py [-h] [--data DATA] [--config CONFIG] [--device DEVICE]
[--checkpoint CHECKPOINT | --iteration ITERATION]
output
Train a WaveNet model with LJSpeech.
positional arguments:
output path to save results
optional arguments:
-h, --help show this help message and exit
--data DATA path of the LJspeech dataset.
--config CONFIG path of the config file.
--output OUTPUT path to save results.
--device DEVICE device to use.
--checkpoint CHECKPOINT checkpoint to resume from.
-h, --help show this help message and exit
--data DATA path of the LJspeech dataset
--config CONFIG path of the config file
--device DEVICE device to use
--checkpoint CHECKPOINT checkpoint to resume from
--iteration ITERATION the iteration of the checkpoint to load from output directory
```
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before training.
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
```text
├── checkpoints # checkpoint
└── log # tensorboard log
```
- `--checkpoint` is the path of the checkpoint.
- `--iteration` is the iteration of the checkpoint to load from output directory.
- `output` is the directory to save results, all result are saved in this directory.
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
Example script:
```bash
python train.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0
python train.py \
--config=./configs/wavenet_single_gaussian.yaml \
--data=./LJSpeech-1.1/ \
--device=0 \
experiment
```
You can monitor training log via TensorBoard, using the script below.
......@@ -71,29 +95,50 @@ tensorboard --logdir=.
## Synthesis
```text
usage: synthesis.py [-h] [--data DATA] [--config CONFIG] [--device DEVICE]
checkpoint output
[--checkpoint CHECKPOINT | --iteration ITERATION]
output
Synthesize valid data from LJspeech with a WaveNet model.
Synthesize valid data from LJspeech with a wavenet model.
positional arguments:
checkpoint checkpoint to load.
output path to save results.
output path to save the synthesized audio
optional arguments:
-h, --help show this help message and exit
--data DATA path of the LJspeech dataset.
--config CONFIG path of the config file.
--device DEVICE device to use.
-h, --help show this help message and exit
--data DATA path of the LJspeech dataset
--config CONFIG path of the config file
--device DEVICE device to use
--checkpoint CHECKPOINT checkpoint to resume from
--iteration ITERATION the iteration of the checkpoint to load from output directory
```
- `--data` is the path of the LJspeech dataset. In principle, a dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
- `--config` is the configuration file to use. You should use the same configuration with which you train you model.
- `--data` is the path of the LJspeech dataset. A dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
- `checkpoint` is the checkpoint to load.
- `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`).
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
- `--checkpoint` is the checkpoint to load.
- `--iteration` is the iteration of the checkpoint to load from output directory.
- `output` is the directory to save synthesized audio. Audio file is saved in `synthesis/` in `output` directory.
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
Example script:
```bash
python synthesis.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --device=0 experiment/checkpoints/step_500000 generated
python synthesis.py \
--config=./configs/wavenet_single_gaussian.yaml \
--data=./LJSpeech-1.1/ \
--device=0 \
--checkpoint="experiment/checkpoints/step-1000000" \
experiment
```
or
```bash
python synthesis.py \
--config=./configs/wavenet_single_gaussian.yaml \
--data=./LJSpeech-1.1/ \
--device=0 \
--iteration=1000000 \
experiment
```
......@@ -25,6 +25,7 @@ from parakeet.modules.weight_norm import WeightNormWrapper
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary
from parakeet.utils import io
from data import LJSpeechMetaData, Transform, DataCollector
from utils import make_output_tree, valid_model, eval_model
......@@ -33,14 +34,22 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Synthesize valid data from LJspeech with a wavenet model.")
parser.add_argument(
"--data", type=str, help="path of the LJspeech dataset.")
parser.add_argument("--config", type=str, help="path of the config file.")
parser.add_argument(
"--device", type=int, default=-1, help="device to use.")
"--data", type=str, help="path of the LJspeech dataset")
parser.add_argument("--config", type=str, help="path of the config file")
parser.add_argument("--device", type=int, default=-1, help="device to use")
g = parser.add_mutually_exclusive_group()
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
g.add_argument(
"--iteration",
type=int,
help="the iteration of the checkpoint to load from output directory")
parser.add_argument("checkpoint", type=str, help="checkpoint to load.")
parser.add_argument(
"output", type=str, default="experiment", help="path to save results.")
"output",
type=str,
default="experiment",
help="path to save the synthesized audio")
args = parser.parse_args()
with open(args.config, 'rt') as f:
......@@ -112,9 +121,15 @@ if __name__ == "__main__":
model = ConditionalWavenet(encoder, decoder)
summary(model)
model_dict, _ = dg.load_dygraph(args.checkpoint)
print("Loading from {}.pdparams".format(args.checkpoint))
model.set_dict(model_dict)
# load model parameters
checkpoint_dir = os.path.join(args.output, "checkpoints")
if args.checkpoint:
iteration = io.load_parameters(
model, checkpoint_path=args.checkpoint)
else:
iteration = io.load_parameters(
model, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
assert iteration > 0, "A trained model is needed."
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
# removing weight norm also speeds up computation
......@@ -130,4 +145,8 @@ if __name__ == "__main__":
capacity=10, return_list=True)
valid_loader.set_batch_generator(valid_cargo, place)
eval_model(model, valid_loader, args.output, sample_rate)
synthesis_dir = os.path.join(args.output, "synthesis")
if not os.path.exists(synthesis_dir):
os.makedirs(synthesis_dir)
eval_model(model, valid_loader, synthesis_dir, iteration, sample_rate)
......@@ -33,17 +33,19 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a WaveNet model with LJSpeech.")
parser.add_argument(
"--data", type=str, help="path of the LJspeech dataset.")
parser.add_argument("--config", type=str, help="path of the config file.")
parser.add_argument(
"--output",
type=str,
default="experiment",
help="path to save results.")
parser.add_argument(
"--device", type=int, default=-1, help="device to use.")
"--data", type=str, help="path of the LJspeech dataset")
parser.add_argument("--config", type=str, help="path of the config file")
parser.add_argument("--device", type=int, default=-1, help="device to use")
g = parser.add_mutually_exclusive_group()
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
g.add_argument(
"--iteration",
type=int,
help="the iteration of the checkpoint to load from output directory")
parser.add_argument(
"--checkpoint", type=str, help="checkpoint to resume from.")
"output", type=str, default="experiment", help="path to save results")
args = parser.parse_args()
with open(args.config, 'rt') as f:
......@@ -148,17 +150,19 @@ if __name__ == "__main__":
writer = SummaryWriter(log_dir)
# load parameters and optimizer, and opdate iterations done sofar
io.load_parameters(
checkpoint_dir, 0, model, optim, file_path=args.checkpoint)
if args.checkpoint is not None:
iteration = int(os.path.basename(args.checkpoint).split("-")[-1])
iteration = io.load_parameters(
model, optim, checkpoint_path=args.checkpoint)
else:
iteration = io.load_latest_checkpoint(checkpoint_dir)
iteration = io.load_parameters(
model,
optim,
checkpoint_dir=checkpoint_dir,
iteration=args.iteration)
global_step = iteration + 1
iterator = iter(tqdm.tqdm(train_loader))
while global_step <= max_iterations:
print(global_step)
try:
batch = next(iterator)
except StopIteration as e:
......@@ -187,8 +191,6 @@ if __name__ == "__main__":
sample_rate)
if global_step % checkpoint_interval == 0:
io.save_latest_parameters(checkpoint_dir, global_step, model,
optim)
io.save_latest_checkpoint(checkpoint_dir, global_step)
io.save_parameters(checkpoint_dir, global_step, model, optim)
global_step += 1
......@@ -49,11 +49,12 @@ def valid_model(model, valid_loader, writer, global_step, sample_rate):
sample_rate)
def eval_model(model, valid_loader, output_dir, sample_rate):
def eval_model(model, valid_loader, output_dir, global_step, sample_rate):
model.eval()
for i, batch in enumerate(valid_loader):
# print("sentence {}".format(i))
path = os.path.join(output_dir, "sentence_{}.wav".format(i))
path = os.path.join(output_dir,
"sentence_{}_step_{}.wav".format(i, global_step))
audio_clips, mel_specs, audio_starts = batch
wav_var = model.synthesis(mel_specs)
wav_np = wav_var.numpy()[0]
......
......@@ -20,6 +20,11 @@ import numpy as np
import paddle.fluid.dygraph as dg
def is_main_process():
local_rank = dg.parallel.Env().local_rank
return local_rank == 0
def add_yaml_config_to_args(config):
""" Add args in yaml config to the args parsed by argparse. The argument in
yaml config will be overwritten by the same argument in argparse if they
......@@ -41,7 +46,7 @@ def add_yaml_config_to_args(config):
return config
def load_latest_checkpoint(checkpoint_dir, rank=0):
def _load_latest_checkpoint(checkpoint_dir):
"""Get the iteration number corresponding to the latest saved checkpoint
Args:
......@@ -52,26 +57,20 @@ def load_latest_checkpoint(checkpoint_dir, rank=0):
Returns:
int: the latest iteration number.
"""
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
# Create checkpoint index file if not exist.
if (not os.path.isfile(checkpoint_path)) and rank == 0:
with open(checkpoint_path, "w") as handle:
handle.write("model_checkpoint_path: step-0")
# Make sure that other process waits until checkpoint file is created
# by process 0.
while not os.path.isfile(checkpoint_path):
time.sleep(1)
if (not os.path.isfile(checkpoint_record)):
return 0
# Fetch the latest checkpoint index.
with open(checkpoint_path, "r") as handle:
with open(checkpoint_record, "r") as handle:
latest_checkpoint = handle.readline().split()[-1]
iteration = int(latest_checkpoint.split("-")[-1])
return iteration
def save_latest_checkpoint(checkpoint_dir, iteration):
def _save_checkpoint(checkpoint_dir, iteration):
"""Save the iteration number of the latest model to be checkpointed.
Args:
......@@ -81,60 +80,73 @@ def save_latest_checkpoint(checkpoint_dir, iteration):
Returns:
None
"""
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index.
with open(checkpoint_path, "w") as handle:
with open(checkpoint_record, "w") as handle:
handle.write("model_checkpoint_path: step-{}".format(iteration))
def load_parameters(checkpoint_dir,
rank,
model,
def load_parameters(model,
optimizer=None,
checkpoint_dir=None,
iteration=None,
file_path=None,
checkpoint_path=None,
dtype="float32"):
"""Load a specific model checkpoint from disk.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int): the rank of the process in multi-process setting.
model (obj): model to load parameters.
optimizer (obj, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
iteration (int, optional): if specified, load the specific checkpoint,
if not specified, load the latest one. Defaults to None.
file_path (str, optional): if specified, load the checkpoint
stored in the file_path. Defaults to None.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path. Defaults to None.
dtype (str, optional): precision of the model parameters.
Defaults to float32.
Returns:
None
iteration (int): number of iterations that the loaded checkpoint has
been trained.
"""
if file_path is None:
if iteration is not None and checkpoint_dir is None:
raise ValueError(
"When iteration is specified, checkpoint_dir should not be None")
if checkpoint_path is not None:
# checkpoint is not None
iteration = int(os.path.basename(checkpoint_path).split("-")[-1])
else:
if iteration is None:
iteration = load_latest_checkpoint(checkpoint_dir, rank)
if iteration == 0:
return
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict, optimizer_dict = dg.load_dygraph(file_path)
if dtype == "float16":
for k, v in model_dict.items():
if "conv2d_transpose" in k:
model_dict[k] = v.astype("float32")
else:
model_dict[k] = v.astype(dtype)
iteration = _load_latest_checkpoint(checkpoint_dir)
checkpoint_path = os.path.join(checkpoint_dir,
"step-{}".format(iteration))
if iteration == 0 and not os.path.exists(checkpoint_path):
# if step-0 exist, it is also loaded
return iteration
local_rank = dg.parallel.Env().local_rank
model_dict, optimizer_dict = dg.load_dygraph(checkpoint_path)
# cast to desired data type
for k, v in model_dict.items():
model_dict[k] = v.astype(dtype)
model.set_dict(model_dict)
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
print("[checkpoint] Rank {}: loaded model from {}.pdparams".format(
local_rank, checkpoint_path))
if optimizer and optimizer_dict:
optimizer.set_dict(optimizer_dict)
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
rank, file_path))
print("[checkpoint] Rank {}: loaded optimizer state from {}.pdopt".
format(local_rank, checkpoint_path))
return iteration
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
def save_parameters(checkpoint_dir, iteration, model, optimizer=None):
"""Checkpoint the latest trained model parameters.
Args:
......@@ -147,12 +159,15 @@ def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
Returns:
None
"""
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration))
model_dict = model.state_dict()
dg.save_dygraph(model_dict, file_path)
print("[checkpoint] Saved model to {}".format(file_path))
dg.save_dygraph(model_dict, checkpoint_path)
print("[checkpoint] Saved model to {}.pdparams".format(checkpoint_path))
if optimizer:
opt_dict = optimizer.state_dict()
dg.save_dygraph(opt_dict, file_path)
print("[checkpoint] Saved optimzier state to {}".format(file_path))
dg.save_dygraph(opt_dict, checkpoint_path)
print("[checkpoint] Saved optimzier state to {}.pdopt".format(
checkpoint_path))
_save_checkpoint(checkpoint_dir, iteration)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册