提交 737d142a 编写于 作者: L liuyibing01

Enable the fp16 inference for waveflow

上级 1c6cd10a
...@@ -36,14 +36,16 @@ nltk.download("cmudict") ...@@ -36,14 +36,16 @@ nltk.download("cmudict")
``` ```
## Supported models ## Related Research
- [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654) - [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654)
- [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895) - [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895)
- [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263). - [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263).
- [WaveFlow: A Compact Flow-based Model for Raw Audio](https://arxiv.org/abs/1912.01219)
## Examples ## Examples
- [Train a deepvoice 3 model with ljspeech dataset](./parakeet/examples/deepvoice3) - [Train a DeepVoice3 model with ljspeech dataset](./parakeet/examples/deepvoice3)
- [Train a transformer_tts model with ljspeech dataset](./parakeet/examples/transformer_tts) - [Train a TransformerTTS model with ljspeech dataset](./parakeet/examples/transformer_tts)
- [Train a fastspeech model with ljspeech dataset](./parakeet/examples/fastspeech) - [Train a FastSpeech model with ljspeech dataset](./parakeet/examples/fastspeech)
- [Train a WaveFlow model with ljspeech dataset](./parakeet/examples/waveflow)
...@@ -109,3 +109,13 @@ python -u benchmark.py \ ...@@ -109,3 +109,13 @@ python -u benchmark.py \
--root=./data/LJSpeech-1.1 \ --root=./data/LJSpeech-1.1 \
--name=${ModelName} --use_gpu=true --name=${ModelName} --use_gpu=true
``` ```
### Low-precision inference
This model supports the float16 low-precsion inference. By appending the argument
```bash
--use_fp16=true
```
to the command of synthesis and benchmarking, one can experience the fast speed of low-precision inference.
...@@ -24,9 +24,14 @@ def add_options_to_parser(parser): ...@@ -24,9 +24,14 @@ def add_options_to_parser(parser):
parser.add_argument( parser.add_argument(
'--use_gpu', '--use_gpu',
type=bool, type=utils.str2bool,
default=True, default=True,
help="option to use gpu training") help="option to use gpu training")
parser.add_argument(
'--use_fp16',
type=utils.str2bool,
default=True,
help="option to use fp16 for inference")
parser.add_argument( parser.add_argument(
'--iteration', '--iteration',
......
...@@ -24,9 +24,14 @@ def add_options_to_parser(parser): ...@@ -24,9 +24,14 @@ def add_options_to_parser(parser):
parser.add_argument( parser.add_argument(
'--use_gpu', '--use_gpu',
type=bool, type=utils.str2bool,
default=True, default=True,
help="option to use gpu training") help="option to use gpu training")
parser.add_argument(
'--use_fp16',
type=utils.str2bool,
default=True,
help="option to use fp16 for inference")
parser.add_argument( parser.add_argument(
'--iteration', '--iteration',
...@@ -74,7 +79,6 @@ def synthesize(config): ...@@ -74,7 +79,6 @@ def synthesize(config):
# Build model. # Build model.
model = WaveFlow(config, checkpoint_dir) model = WaveFlow(config, checkpoint_dir)
model.build(training=False) model.build(training=False)
# Obtain the current iteration. # Obtain the current iteration.
if config.checkpoint is None: if config.checkpoint is None:
if config.iteration is None: if config.iteration is None:
......
...@@ -127,4 +127,6 @@ if __name__ == "__main__": ...@@ -127,4 +127,6 @@ if __name__ == "__main__":
# the preceding update will be overwritten by the following one. # the preceding update will be overwritten by the following one.
config = parser.parse_args() config = parser.parse_args()
config = utils.add_yaml_config(config) config = utils.add_yaml_config(config)
# Force to use fp32 in model training
vars(config)["use_fp16"] = False
train(config) train(config)
...@@ -126,7 +126,8 @@ def load_parameters(checkpoint_dir, ...@@ -126,7 +126,8 @@ def load_parameters(checkpoint_dir,
model, model,
optimizer=None, optimizer=None,
iteration=None, iteration=None,
file_path=None): file_path=None,
dtype="float32"):
if file_path is None: if file_path is None:
if iteration is None: if iteration is None:
iteration = load_latest_checkpoint(checkpoint_dir, rank) iteration = load_latest_checkpoint(checkpoint_dir, rank)
...@@ -135,6 +136,12 @@ def load_parameters(checkpoint_dir, ...@@ -135,6 +136,12 @@ def load_parameters(checkpoint_dir,
file_path = "{}/step-{}".format(checkpoint_dir, iteration) file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict, optimizer_dict = dg.load_dygraph(file_path) model_dict, optimizer_dict = dg.load_dygraph(file_path)
if dtype == "float16":
for k, v in model_dict.items():
if "conv2d_transpose" in k:
model_dict[k] = v.astype("float32")
else:
model_dict[k] = v.astype(dtype)
model.set_dict(model_dict) model.set_dict(model_dict)
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path)) print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
if optimizer and optimizer_dict: if optimizer and optimizer_dict:
......
...@@ -8,6 +8,7 @@ from paddle import fluid ...@@ -8,6 +8,7 @@ from paddle import fluid
from scipy.io.wavfile import write from scipy.io.wavfile import write
import utils import utils
from parakeet.modules import weight_norm
from .data import LJSpeech from .data import LJSpeech
from .waveflow_modules import WaveFlowLoss, WaveFlowModule from .waveflow_modules import WaveFlowLoss, WaveFlowModule
...@@ -26,6 +27,7 @@ class WaveFlow(): ...@@ -26,6 +27,7 @@ class WaveFlow():
self.rank = rank self.rank = rank
self.nranks = nranks self.nranks = nranks
self.tb_logger = tb_logger self.tb_logger = tb_logger
self.dtype = "float16" if config.use_fp16 else "float32"
def build(self, training=True): def build(self, training=True):
config = self.config config = self.config
...@@ -36,9 +38,9 @@ class WaveFlow(): ...@@ -36,9 +38,9 @@ class WaveFlow():
waveflow = WaveFlowModule(config) waveflow = WaveFlowModule(config)
# Dry run once to create and initalize all necessary parameters. # Dry run once to create and initalize all necessary parameters.
audio = dg.to_variable(np.random.randn(1, 16000).astype(np.float32)) audio = dg.to_variable(np.random.randn(1, 16000).astype(self.dtype))
mel = dg.to_variable( mel = dg.to_variable(
np.random.randn(1, config.mel_bands, 63).astype(np.float32)) np.random.randn(1, config.mel_bands, 63).astype(self.dtype))
waveflow(audio, mel) waveflow(audio, mel)
if training: if training:
...@@ -72,9 +74,14 @@ class WaveFlow(): ...@@ -72,9 +74,14 @@ class WaveFlow():
self.rank, self.rank,
waveflow, waveflow,
iteration=config.iteration, iteration=config.iteration,
file_path=config.checkpoint) file_path=config.checkpoint,
dtype=self.dtype)
print("Rank {}: checkpoint loaded.".format(self.rank)) print("Rank {}: checkpoint loaded.".format(self.rank))
for layer in waveflow.sublayers():
if isinstance(layer, weight_norm.WeightNormWrapper):
layer.remove_weight_norm()
self.waveflow = waveflow self.waveflow = waveflow
def train_step(self, iteration): def train_step(self, iteration):
...@@ -173,7 +180,7 @@ class WaveFlow(): ...@@ -173,7 +180,7 @@ class WaveFlow():
syn_time)) syn_time))
# Denormalize audio from [-1, 1] to [-32768, 32768] int16 range. # Denormalize audio from [-1, 1] to [-32768, 32768] int16 range.
audio = audio.numpy() * 32768.0 audio = audio.numpy().astype("float32") * 32768.0
audio = audio.astype('int16') audio = audio.astype('int16')
write(filename, config.sample_rate, audio) write(filename, config.sample_rate, audio)
......
import itertools import itertools
import numpy as np import numpy as np
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
from paddle import fluid from paddle import fluid
...@@ -49,7 +48,7 @@ class WaveFlowLoss: ...@@ -49,7 +48,7 @@ class WaveFlowLoss:
class Conditioner(dg.Layer): class Conditioner(dg.Layer):
def __init__(self): def __init__(self, dtype):
super(Conditioner, self).__init__() super(Conditioner, self).__init__()
upsample_factors = [16, 16] upsample_factors = [16, 16]
...@@ -65,7 +64,8 @@ class Conditioner(dg.Layer): ...@@ -65,7 +64,8 @@ class Conditioner(dg.Layer):
padding=(1, s // 2), padding=(1, s // 2),
stride=(1, s), stride=(1, s),
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr) bias_attr=bias_attr,
dtype="float32")
self.upsample_conv2d.append(conv_trans2d) self.upsample_conv2d.append(conv_trans2d)
for i, layer in enumerate(self.upsample_conv2d): for i, layer in enumerate(self.upsample_conv2d):
...@@ -74,19 +74,30 @@ class Conditioner(dg.Layer): ...@@ -74,19 +74,30 @@ class Conditioner(dg.Layer):
def forward(self, x): def forward(self, x):
x = fluid.layers.unsqueeze(x, 1) x = fluid.layers.unsqueeze(x, 1)
for layer in self.upsample_conv2d: for layer in self.upsample_conv2d:
x = fluid.layers.leaky_relu(layer(x), alpha=0.4) in_dtype = x.dtype
if in_dtype == fluid.core.VarDesc.VarType.FP16:
x = fluid.layers.cast(x, "float32")
x = layer(x)
if in_dtype == fluid.core.VarDesc.VarType.FP16:
x = fluid.layers.cast(x, "float16")
x = fluid.layers.leaky_relu(x, alpha=0.4)
return fluid.layers.squeeze(x, [1]) return fluid.layers.reshape(x, [x.shape[0], x.shape[2], x.shape[3]])
def infer(self, x): def infer(self, x):
x = fluid.layers.unsqueeze(x, 1) x = fluid.layers.unsqueeze(x, 1)
for layer in self.upsample_conv2d: for layer in self.upsample_conv2d:
in_dtype = x.dtype
if in_dtype == fluid.core.VarDesc.VarType.FP16:
x = fluid.layers.cast(x, "float32")
x = layer(x) x = layer(x)
if in_dtype == fluid.core.VarDesc.VarType.FP16:
x = fluid.layers.cast(x, "float16")
# Trim conv artifacts. # Trim conv artifacts.
time_cutoff = layer._filter_size[1] - layer._stride[1] time_cutoff = layer._filter_size[1] - layer._stride[1]
x = fluid.layers.leaky_relu(x[:, :, :, :-time_cutoff], alpha=0.4) x = fluid.layers.leaky_relu(x[:, :, :, :-time_cutoff], alpha=0.4)
return fluid.layers.squeeze(x, [1]) return fluid.layers.reshape(x, [x.shape[0], x.shape[2], x.shape[3]])
class Flow(dg.Layer): class Flow(dg.Layer):
...@@ -96,6 +107,7 @@ class Flow(dg.Layer): ...@@ -96,6 +107,7 @@ class Flow(dg.Layer):
self.n_channels = config.n_channels self.n_channels = config.n_channels
self.kernel_h = config.kernel_h self.kernel_h = config.kernel_h
self.kernel_w = config.kernel_w self.kernel_w = config.kernel_w
self.dtype = "float16" if config.use_fp16 else "float32"
# Transform audio: [batch, 1, n_group, time/n_group] # Transform audio: [batch, 1, n_group, time/n_group]
# => [batch, n_channels, n_group, time/n_group] # => [batch, n_channels, n_group, time/n_group]
...@@ -105,7 +117,8 @@ class Flow(dg.Layer): ...@@ -105,7 +117,8 @@ class Flow(dg.Layer):
num_filters=self.n_channels, num_filters=self.n_channels,
filter_size=(1, 1), filter_size=(1, 1),
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr) bias_attr=bias_attr,
dtype=self.dtype)
# Initializing last layer to 0 makes the affine coupling layers # Initializing last layer to 0 makes the affine coupling layers
# do nothing at first. This helps with training stability # do nothing at first. This helps with training stability
...@@ -117,7 +130,8 @@ class Flow(dg.Layer): ...@@ -117,7 +130,8 @@ class Flow(dg.Layer):
num_filters=2, num_filters=2,
filter_size=(1, 1), filter_size=(1, 1),
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr) bias_attr=bias_attr,
dtype=self.dtype)
# receiptive fileds: (kernel - 1) * sum(dilations) + 1 >= squeeze # receiptive fileds: (kernel - 1) * sum(dilations) + 1 >= squeeze
dilation_dict = { dilation_dict = {
...@@ -145,7 +159,8 @@ class Flow(dg.Layer): ...@@ -145,7 +159,8 @@ class Flow(dg.Layer):
filter_size=(self.kernel_h, self.kernel_w), filter_size=(self.kernel_h, self.kernel_w),
dilation=(dilation_h, dilation_w), dilation=(dilation_h, dilation_w),
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr) bias_attr=bias_attr,
dtype=self.dtype)
self.in_layers.append(in_layer) self.in_layers.append(in_layer)
param_attr, bias_attr = get_param_attr( param_attr, bias_attr = get_param_attr(
...@@ -155,7 +170,8 @@ class Flow(dg.Layer): ...@@ -155,7 +170,8 @@ class Flow(dg.Layer):
num_filters=2 * self.n_channels, num_filters=2 * self.n_channels,
filter_size=(1, 1), filter_size=(1, 1),
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr) bias_attr=bias_attr,
dtype=self.dtype)
self.cond_layers.append(cond_layer) self.cond_layers.append(cond_layer)
if i < self.n_layers - 1: if i < self.n_layers - 1:
...@@ -169,7 +185,8 @@ class Flow(dg.Layer): ...@@ -169,7 +185,8 @@ class Flow(dg.Layer):
num_filters=res_skip_channels, num_filters=res_skip_channels,
filter_size=(1, 1), filter_size=(1, 1),
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr) bias_attr=bias_attr,
dtype=self.dtype)
self.res_skip_layers.append(res_skip_layer) self.res_skip_layers.append(res_skip_layer)
self.add_sublayer("in_layer_{}".format(i), in_layer) self.add_sublayer("in_layer_{}".format(i), in_layer)
...@@ -191,7 +208,6 @@ class Flow(dg.Layer): ...@@ -191,7 +208,6 @@ class Flow(dg.Layer):
pad_left = pad_right = int((self.kernel_w - 1) * dilation_w / 2) pad_left = pad_right = int((self.kernel_w - 1) * dilation_w / 2)
audio_pad = fluid.layers.pad2d( audio_pad = fluid.layers.pad2d(
audio, paddings=[pad_top, pad_bottom, pad_left, pad_right]) audio, paddings=[pad_top, pad_bottom, pad_left, pad_right])
hidden = self.in_layers[i](audio_pad) hidden = self.in_layers[i](audio_pad)
cond_hidden = self.cond_layers[i](mel) cond_hidden = self.cond_layers[i](mel)
in_acts = hidden + cond_hidden in_acts = hidden + cond_hidden
...@@ -239,12 +255,11 @@ class Flow(dg.Layer): ...@@ -239,12 +255,11 @@ class Flow(dg.Layer):
pad_right = int((self.kernel_w - 1) * dilation_w / 2) pad_right = int((self.kernel_w - 1) * dilation_w / 2)
state = fluid.layers.pad2d( state = fluid.layers.pad2d(
state, paddings=[pad_top, pad_bottom, pad_left, pad_right]) state, paddings=[pad_top, pad_bottom, pad_left, pad_right])
hidden = self.in_layers[i](state) hidden = self.in_layers[i](state)
cond_hidden = self.cond_layers[i](mel) cond_hidden = self.cond_layers[i](mel)
in_acts = hidden + cond_hidden in_acts = hidden + cond_hidden
out_acts = fluid.layers.tanh(in_acts[:, :self.n_channels, :]) * \ out_acts = fluid.layers.tanh(in_acts[:, :self.n_channels, :]) * \
fluid.layers.sigmoid(in_acts[:, self.n_channels:, :]) fluid.layers.sigmoid(in_acts[:, self.n_channels:, :])
res_skip_acts = self.res_skip_layers[i](out_acts) res_skip_acts = self.res_skip_layers[i](out_acts)
if i < self.n_layers - 1: if i < self.n_layers - 1:
...@@ -270,7 +285,8 @@ class WaveFlowModule(dg.Layer): ...@@ -270,7 +285,8 @@ class WaveFlowModule(dg.Layer):
assert self.n_group % 2 == 0 assert self.n_group % 2 == 0
assert self.n_flows % 2 == 0 assert self.n_flows % 2 == 0
self.conditioner = Conditioner() self.dtype = "float16" if config.use_fp16 else "float32"
self.conditioner = Conditioner(self.dtype)
self.flows = [] self.flows = []
for i in range(self.n_flows): for i in range(self.n_flows):
flow = Flow(config) flow = Flow(config)
...@@ -324,17 +340,21 @@ class WaveFlowModule(dg.Layer): ...@@ -324,17 +340,21 @@ class WaveFlowModule(dg.Layer):
mel_slices = [mel[:, :, j, :] for j in self.perms[i]] mel_slices = [mel[:, :, j, :] for j in self.perms[i]]
mel = fluid.layers.stack(mel_slices, axis=2) mel = fluid.layers.stack(mel_slices, axis=2)
z = fluid.layers.squeeze(audio, [1]) z = fluid.layers.reshape(
audio, [audio.shape[0], audio.shape[2], audio.shape[3]])
return z, log_s_list return z, log_s_list
def synthesize(self, mel, sigma=1.0): def synthesize(self, mel, sigma=1.0):
if self.dtype == "float16":
mel = fluid.layers.cast(mel, self.dtype)
mel = self.conditioner.infer(mel) mel = self.conditioner.infer(mel)
# From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group] # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group]
mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2]) mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2])
audio = fluid.layers.gaussian_random( audio = fluid.layers.gaussian_random(
shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma) shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma)
if self.dtype == "float16":
audio = fluid.layers.cast(audio, self.dtype)
for i in reversed(range(self.n_flows)): for i in reversed(range(self.n_flows)):
# Permute over the height dimension. # Permute over the height dimension.
audio_slices = [audio[:, :, j, :] for j in self.perms[i]] audio_slices = [audio[:, :, j, :] for j in self.perms[i]]
...@@ -362,9 +382,9 @@ class WaveFlowModule(dg.Layer): ...@@ -362,9 +382,9 @@ class WaveFlowModule(dg.Layer):
audio = fluid.layers.concat(audio_list, axis=2) audio = fluid.layers.concat(audio_list, axis=2)
# audio: [bs, n_group, time/n_group] # audio: [bs, n_group, time/n_group]
audio = fluid.layers.squeeze(audio, [1]) audio = fluid.layers.reshape(
audio, [audio.shape[0], audio.shape[2], audio.shape[3]])
# audio: [bs, time] # audio: [bs, time]
audio = fluid.layers.reshape( audio = fluid.layers.reshape(
fluid.layers.transpose(audio, [0, 2, 1]), [audio.shape[0], -1]) fluid.layers.transpose(audio, [0, 2, 1]), [audio.shape[0], -1])
return audio return audio
...@@ -8,8 +8,13 @@ from parakeet.modules import customized as L ...@@ -8,8 +8,13 @@ from parakeet.modules import customized as L
def norm(param, dim, power): def norm(param, dim, power):
powered = F.pow(param, power) powered = F.pow(param, power)
in_dtype = powered.dtype
if in_dtype == fluid.core.VarDesc.VarType.FP16:
powered = F.cast(powered, "float32")
powered_norm = F.reduce_sum(powered, dim=dim, keep_dim=False) powered_norm = F.reduce_sum(powered, dim=dim, keep_dim=False)
norm_ = F.pow(powered_norm, 1. / power) norm_ = F.pow(powered_norm, 1. / power)
if in_dtype == fluid.core.VarDesc.VarType.FP16:
norm_ = F.cast(norm_, "float16")
return norm_ return norm_
...@@ -46,6 +51,15 @@ def compute_weight(v, g, dim, power): ...@@ -46,6 +51,15 @@ def compute_weight(v, g, dim, power):
return weight return weight
def assign_by_cast(i, o):
fluid.default_main_program().current_block().append_op(
type="cast",
inputs={"X": i},
outputs={"Out": o},
attrs={"in_dtype": i.dtype,
"out_dtype": o.dtype})
class WeightNormWrapper(dg.Layer): class WeightNormWrapper(dg.Layer):
def __init__(self, layer, param_name="weight", dim=0, power=2): def __init__(self, layer, param_name="weight", dim=0, power=2):
super(WeightNormWrapper, self).__init__() super(WeightNormWrapper, self).__init__()
...@@ -65,13 +79,13 @@ class WeightNormWrapper(dg.Layer): ...@@ -65,13 +79,13 @@ class WeightNormWrapper(dg.Layer):
w_v, w_v,
self.create_parameter( self.create_parameter(
shape=original_weight.shape, dtype=original_weight.dtype)) shape=original_weight.shape, dtype=original_weight.dtype))
F.assign(original_weight, getattr(self, w_v)) assign_by_cast(original_weight, getattr(self, w_v))
delattr(layer, param_name) delattr(layer, param_name)
temp = norm_except(getattr(self, w_v), self.dim, self.power) temp = norm_except(getattr(self, w_v), self.dim, self.power)
self.add_parameter( self.add_parameter(
w_g, self.create_parameter( w_g, self.create_parameter(
shape=temp.shape, dtype=temp.dtype)) shape=temp.shape, dtype=temp.dtype))
F.assign(temp, getattr(self, w_g)) assign_by_cast(temp, getattr(self, w_g))
# also set this when setting up # also set this when setting up
setattr(self.layer, self.param_name, setattr(self.layer, self.param_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册