utils.py 5.5 KB
Newer Older
1 2 3 4
import itertools
import os
import time

5 6
import argparse
import ruamel.yaml
7 8 9 10
import numpy as np
import paddle.fluid.dygraph as dg


11 12 13 14
def str2bool(v):
    return v.lower() in ("true", "t", "1")


15
def add_config_options_to_parser(parser):
16 17 18 19 20
    parser.add_argument(
        '--valid_size', type=int, help="size of the valid dataset")
    parser.add_argument(
        '--segment_length',
        type=int,
21
        help="the length of audio clip for training")
22 23 24 25 26
    parser.add_argument(
        '--sample_rate', type=int, help="sampling rate of audio data file")
    parser.add_argument(
        '--fft_window_shift',
        type=int,
27
        help="the shift of fft window for each frame")
28 29 30
    parser.add_argument(
        '--fft_window_size',
        type=int,
31
        help="the size of fft window for each frame")
32 33 34 35 36
    parser.add_argument(
        '--fft_size', type=int, help="the size of fft filter on each frame")
    parser.add_argument(
        '--mel_bands',
        type=int,
37
        help="the number of mel bands when calculating mel spectrograms")
38 39 40
    parser.add_argument(
        '--mel_fmin',
        type=float,
41
        help="lowest frequency in calculating mel spectrograms")
42 43 44
    parser.add_argument(
        '--mel_fmax',
        type=float,
45 46
        help="highest frequency in calculating mel spectrograms")

47 48
    parser.add_argument(
        '--seed', type=int, help="seed of random initialization for the model")
49
    parser.add_argument('--learning_rate', type=float)
50 51 52 53 54 55 56
    parser.add_argument(
        '--batch_size', type=int, help="batch size for training")
    parser.add_argument(
        '--test_every', type=int, help="test interval during training")
    parser.add_argument(
        '--save_every',
        type=int,
57
        help="checkpointing interval during training")
58 59
    parser.add_argument(
        '--max_iterations', type=int, help="maximum training iterations")
60

61 62 63
    parser.add_argument(
        '--sigma',
        type=float,
64
        help="standard deviation of the latent Gaussian variable")
65 66 67 68
    parser.add_argument('--n_flows', type=int, help="number of flows")
    parser.add_argument(
        '--n_group',
        type=int,
69
        help="number of adjacent audio samples to squeeze into one column")
70 71 72
    parser.add_argument(
        '--n_layers',
        type=int,
73
        help="number of conv2d layer in one wavenet-like flow architecture")
74 75 76 77 78
    parser.add_argument(
        '--n_channels', type=int, help="number of residual channels in flow")
    parser.add_argument(
        '--kernel_h',
        type=int,
79
        help="height of the kernel in the conv2d layer")
80 81 82 83 84
    parser.add_argument(
        '--kernel_w', type=int, help="width of the kernel in the conv2d layer")

    parser.add_argument('--config', type=str, help="Path to the config file.")

85

86 87 88 89 90 91 92 93 94
def add_yaml_config(config):
    with open(config.config, 'rt') as f:
        yaml_cfg = ruamel.yaml.safe_load(f)
    cfg_vars = vars(config)
    for k, v in yaml_cfg.items():
        if k in cfg_vars and cfg_vars[k] is not None:
            continue
        cfg_vars[k] = v
    return config
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123


def load_latest_checkpoint(checkpoint_dir, rank=0):
    checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
    # Create checkpoint index file if not exist.
    if (not os.path.isfile(checkpoint_path)) and rank == 0:
        with open(checkpoint_path, "w") as handle:
            handle.write("model_checkpoint_path: step-0")

    # Make sure that other process waits until checkpoint file is created
    # by process 0.
    while not os.path.isfile(checkpoint_path):
        time.sleep(1)

    # Fetch the latest checkpoint index.
    with open(checkpoint_path, "r") as handle:
        latest_checkpoint = handle.readline().split()[-1]
        iteration = int(latest_checkpoint.split("-")[-1])

    return iteration


def save_latest_checkpoint(checkpoint_dir, iteration):
    checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
    # Update the latest checkpoint index.
    with open(checkpoint_path, "w") as handle:
        handle.write("model_checkpoint_path: step-{}".format(iteration))


124 125 126 127 128
def load_parameters(checkpoint_dir,
                    rank,
                    model,
                    optimizer=None,
                    iteration=None,
129 130
                    file_path=None,
                    dtype="float32"):
131 132 133 134 135 136 137 138
    if file_path is None:
        if iteration is None:
            iteration = load_latest_checkpoint(checkpoint_dir, rank)
        if iteration == 0:
            return
        file_path = "{}/step-{}".format(checkpoint_dir, iteration)

    model_dict, optimizer_dict = dg.load_dygraph(file_path)
139 140 141 142 143 144
    if dtype == "float16":
        for k, v in model_dict.items():
            if "conv2d_transpose" in k:
                model_dict[k] = v.astype("float32")
            else:
                model_dict[k] = v.astype(dtype)
145 146 147 148 149
    model.set_dict(model_dict)
    print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
    if optimizer and optimizer_dict:
        optimizer.set_dict(optimizer_dict)
        print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
150
            rank, file_path))
151 152 153 154 155 156 157 158 159 160 161 162


def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
    file_path = "{}/step-{}".format(checkpoint_dir, iteration)
    model_dict = model.state_dict()
    dg.save_dygraph(model_dict, file_path)
    print("[checkpoint] Saved model to {}".format(file_path))

    if optimizer:
        opt_dict = optimizer.state_dict()
        dg.save_dygraph(opt_dict, file_path)
        print("[checkpoint] Saved optimzier state to {}".format(file_path))