提交 484465ca 编写于 作者: K Kexin Zhao 提交者: lifuchen

add docstring

上级 e82115bc
...@@ -109,6 +109,16 @@ def add_yaml_config(config): ...@@ -109,6 +109,16 @@ def add_yaml_config(config):
def load_latest_checkpoint(checkpoint_dir, rank=0): def load_latest_checkpoint(checkpoint_dir, rank=0):
"""Get the iteration number corresponding to the latest saved checkpoint
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int, optional): the rank of the process in multi-process setting.
Defaults to 0.
Returns:
int: the latest iteration number.
"""
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
# Create checkpoint index file if not exist. # Create checkpoint index file if not exist.
if (not os.path.isfile(checkpoint_path)) and rank == 0: if (not os.path.isfile(checkpoint_path)) and rank == 0:
...@@ -129,6 +139,15 @@ def load_latest_checkpoint(checkpoint_dir, rank=0): ...@@ -129,6 +139,15 @@ def load_latest_checkpoint(checkpoint_dir, rank=0):
def save_latest_checkpoint(checkpoint_dir, iteration): def save_latest_checkpoint(checkpoint_dir, iteration):
"""Save the iteration number of the latest model to be checkpointed.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
"""
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index. # Update the latest checkpoint index.
with open(checkpoint_path, "w") as handle: with open(checkpoint_path, "w") as handle:
...@@ -142,6 +161,24 @@ def load_parameters(checkpoint_dir, ...@@ -142,6 +161,24 @@ def load_parameters(checkpoint_dir,
iteration=None, iteration=None,
file_path=None, file_path=None,
dtype="float32"): dtype="float32"):
"""Load a specific model checkpoint from disk.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int): the rank of the process in multi-process setting.
model (obj): model to load parameters.
optimizer (obj, optional): optimizer to load states if needed.
Defaults to None.
iteration (int, optional): if specified, load the specific checkpoint,
if not specified, load the latest one. Defaults to None.
file_path (str, optional): if specified, load the checkpoint
stored in the file_path. Defaults to None.
dtype (str, optional): precision of the model parameters.
Defaults to float32.
Returns:
None
"""
if file_path is None: if file_path is None:
if iteration is None: if iteration is None:
iteration = load_latest_checkpoint(checkpoint_dir, rank) iteration = load_latest_checkpoint(checkpoint_dir, rank)
...@@ -165,6 +202,18 @@ def load_parameters(checkpoint_dir, ...@@ -165,6 +202,18 @@ def load_parameters(checkpoint_dir,
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None): def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
model (obj): model to be checkpointed.
optimizer (obj, optional): optimizer to be checkpointed.
Defaults to None.
Returns:
None
"""
file_path = "{}/step-{}".format(checkpoint_dir, iteration) file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict = model.state_dict() model_dict = model.state_dict()
dg.save_dygraph(model_dict, file_path) dg.save_dygraph(model_dict, file_path)
......
...@@ -80,6 +80,7 @@ class Subset(DatasetMixin): ...@@ -80,6 +80,7 @@ class Subset(DatasetMixin):
# whole audio for valid set # whole audio for valid set
pass pass
else: else:
# Randomly crop segment_length from audios in the training set.
# audio shape: [len] # audio shape: [len]
if audio.shape[0] >= segment_length: if audio.shape[0] >= segment_length:
max_audio_start = audio.shape[0] - segment_length max_audio_start = audio.shape[0] - segment_length
......
...@@ -28,6 +28,25 @@ from .waveflow_modules import WaveFlowLoss, WaveFlowModule ...@@ -28,6 +28,25 @@ from .waveflow_modules import WaveFlowLoss, WaveFlowModule
class WaveFlow(): class WaveFlow():
"""Wrapper class of WaveFlow model that supports multiple APIs.
This module provides APIs for model building, training, validation,
inference, benchmarking, and saving.
Args:
config (obj): config info.
checkpoint_dir (str): path for checkpointing.
parallel (bool, optional): whether use multiple GPUs for training.
Defaults to False.
rank (int, optional): the rank of the process in a multi-process
scenario. Defaults to 0.
nranks (int, optional): the total number of processes. Defaults to 1.
tb_logger (obj, optional): logger to visualize metrics.
Defaults to None.
Returns:
WaveFlow
"""
def __init__(self, def __init__(self,
config, config,
checkpoint_dir, checkpoint_dir,
...@@ -44,6 +63,15 @@ class WaveFlow(): ...@@ -44,6 +63,15 @@ class WaveFlow():
self.dtype = "float16" if config.use_fp16 else "float32" self.dtype = "float16" if config.use_fp16 else "float32"
def build(self, training=True): def build(self, training=True):
"""Initialize the model.
Args:
training (bool, optional): Whether the model is built for training or inference.
Defaults to True.
Returns:
None
"""
config = self.config config = self.config
dataset = LJSpeech(config, self.nranks, self.rank) dataset = LJSpeech(config, self.nranks, self.rank)
self.trainloader = dataset.trainloader self.trainloader = dataset.trainloader
...@@ -99,6 +127,14 @@ class WaveFlow(): ...@@ -99,6 +127,14 @@ class WaveFlow():
self.waveflow = waveflow self.waveflow = waveflow
def train_step(self, iteration): def train_step(self, iteration):
"""Train the model for one step.
Args:
iteration (int): current iteration number.
Returns:
None
"""
self.waveflow.train() self.waveflow.train()
start_time = time.time() start_time = time.time()
...@@ -135,6 +171,14 @@ class WaveFlow(): ...@@ -135,6 +171,14 @@ class WaveFlow():
@dg.no_grad @dg.no_grad
def valid_step(self, iteration): def valid_step(self, iteration):
"""Run the model on the validation dataset.
Args:
iteration (int): current iteration number.
Returns:
None
"""
self.waveflow.eval() self.waveflow.eval()
tb = self.tb_logger tb = self.tb_logger
...@@ -167,6 +211,14 @@ class WaveFlow(): ...@@ -167,6 +211,14 @@ class WaveFlow():
@dg.no_grad @dg.no_grad
def infer(self, iteration): def infer(self, iteration):
"""Run the model to synthesize audios.
Args:
iteration (int): iteration number of the loaded checkpoint.
Returns:
None
"""
self.waveflow.eval() self.waveflow.eval()
config = self.config config = self.config
...@@ -203,6 +255,14 @@ class WaveFlow(): ...@@ -203,6 +255,14 @@ class WaveFlow():
@dg.no_grad @dg.no_grad
def benchmark(self): def benchmark(self):
"""Run the model to benchmark synthesis speed.
Args:
None
Returns:
None
"""
self.waveflow.eval() self.waveflow.eval()
mels_list = [mels for _, mels in self.validloader()] mels_list = [mels for _, mels in self.validloader()]
...@@ -223,6 +283,14 @@ class WaveFlow(): ...@@ -223,6 +283,14 @@ class WaveFlow():
print("{} X real-time".format(audio_time / syn_time)) print("{} X real-time".format(audio_time / syn_time))
def save(self, iteration): def save(self, iteration):
"""Save model checkpoint.
Args:
iteration (int): iteration number of the model to be saved.
Returns:
None
"""
utils.save_latest_parameters(self.checkpoint_dir, iteration, utils.save_latest_parameters(self.checkpoint_dir, iteration,
self.waveflow, self.optimizer) self.waveflow, self.optimizer)
utils.save_latest_checkpoint(self.checkpoint_dir, iteration) utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
...@@ -293,6 +293,14 @@ class Flow(dg.Layer): ...@@ -293,6 +293,14 @@ class Flow(dg.Layer):
class WaveFlowModule(dg.Layer): class WaveFlowModule(dg.Layer):
"""WaveFlow model implementation.
Args:
config (obj): model configuration parameters.
Returns:
WaveFlowModule
"""
def __init__(self, config): def __init__(self, config):
super(WaveFlowModule, self).__init__() super(WaveFlowModule, self).__init__()
self.n_flows = config.n_flows self.n_flows = config.n_flows
...@@ -321,6 +329,22 @@ class WaveFlowModule(dg.Layer): ...@@ -321,6 +329,22 @@ class WaveFlowModule(dg.Layer):
self.perms.append(perm) self.perms.append(perm)
def forward(self, audio, mel): def forward(self, audio, mel):
"""Training forward pass.
Use a conditioner to upsample mel spectrograms into hidden states.
These hidden states along with the audio are passed to a stack of Flow
modules to obtain the final latent variable z and a list of log scaling
variables, which are then passed to the WaveFlowLoss module to calculate
the negative log likelihood.
Args:
audio (obj): audio samples.
mel (obj): mel spectrograms.
Returns:
z (obj): latent variable.
log_s_list(list): list of log scaling variables.
"""
mel = self.conditioner(mel) mel = self.conditioner(mel)
assert mel.shape[2] >= audio.shape[1] assert mel.shape[2] >= audio.shape[1]
# Prune out the tail of audio/mel so that time/n_group == 0. # Prune out the tail of audio/mel so that time/n_group == 0.
...@@ -361,6 +385,20 @@ class WaveFlowModule(dg.Layer): ...@@ -361,6 +385,20 @@ class WaveFlowModule(dg.Layer):
return z, log_s_list return z, log_s_list
def synthesize(self, mel, sigma=1.0): def synthesize(self, mel, sigma=1.0):
"""Use model to synthesize waveform.
Use a conditioner to upsample mel spectrograms into hidden states.
These hidden states along with initial random gaussian latent variable
are passed to a stack of Flow modules to obtain the audio output.
Args:
mel (obj): mel spectrograms.
sigma (float, optional): standard deviation of the guassian latent
variable. Defaults to 1.0.
Returns:
audio (obj): synthesized audio.
"""
if self.dtype == "float16": if self.dtype == "float16":
mel = fluid.layers.cast(mel, self.dtype) mel = fluid.layers.cast(mel, self.dtype)
mel = self.conditioner.infer(mel) mel = self.conditioner.infer(mel)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册