提交 79e7a4d4 编写于 作者: 小湉湉's avatar 小湉湉

align ouput of dygraph and static graph

上级 f652ba3a
......@@ -215,6 +215,7 @@ python3 ${BIN_DIR}/synthesize_e2e.py \
--pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--text=${BIN_DIR}/../sentences.txt \
--output-dir=exp/default/test_e2e \
--inference-dir=exp/default/inference \
--device="gpu" \
--phones-dict=fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt
```
#!/bin/bash
train_output_path=$1
python3 ${BIN_DIR}/inference.py \
--inference-dir=${train_output_path}/inference \
--text=${BIN_DIR}/../sentences.txt \
--output-dir=${train_output_path}/pd_infer_out \
--phones-dict=dump/phone_id_map.txt
......@@ -15,5 +15,6 @@ python3 ${BIN_DIR}/synthesize_e2e.py \
--pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
--text=${BIN_DIR}/../sentences.txt \
--output-dir=${train_output_path}/test_e2e \
--inference-dir=${train_output_path}/inference \
--device="gpu" \
--phones-dict=dump/phone_id_map.txt
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
import soundfile as sf
from paddle import inference
from parakeet.frontend.zh_frontend import Frontend
def main():
parser = argparse.ArgumentParser(
description="Paddle Infernce with speedyspeech & parallel wavegan.")
parser.add_argument(
"--inference-dir", type=str, help="dir to save inference models")
parser.add_argument(
"--text",
type=str,
help="text to synthesize, a 'utt_id sentence' pair per line")
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--enable-auto-log", action="store_true", help="use auto log")
parser.add_argument(
"--phones-dict",
type=str,
default="phones.txt",
help="phone vocabulary file.")
args, _ = parser.parse_known_args()
frontend = Frontend(phone_vocab_path=args.phones_dict)
print("frontend done!")
fastspeech2_config = inference.Config(
str(Path(args.inference_dir) / "fastspeech2.pdmodel"),
str(Path(args.inference_dir) / "fastspeech2.pdiparams"))
fastspeech2_config.enable_use_gpu(50, 0)
fastspeech2_config.enable_memory_optim()
fastspeech2_predictor = inference.create_predictor(fastspeech2_config)
pwg_config = inference.Config(
str(Path(args.inference_dir) / "pwg.pdmodel"),
str(Path(args.inference_dir) / "pwg.pdiparams"))
pwg_config.enable_use_gpu(100, 0)
pwg_config.enable_memory_optim()
pwg_predictor = inference.create_predictor(pwg_config)
if args.enable_auto_log:
import auto_log
os.makedirs("output", exist_ok=True)
pid = os.getpid()
logger = auto_log.AutoLogger(
model_name="fastspeech2",
model_precision='float32',
batch_size=1,
data_shape="dynamic",
save_path="./output/auto_log.log",
inference_config=fastspeech2_config,
pids=pid,
process_name=None,
gpu_ids=0,
time_keys=['preprocess_time', 'inference_time', 'postprocess_time'],
warmup=0)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = []
with open(args.text, 'rt') as f:
for line in f:
utt_id, sentence = line.strip().split()
sentences.append((utt_id, sentence))
for utt_id, sentence in sentences:
if args.enable_auto_log:
logger.times.start()
input_ids = frontend.get_input_ids(sentence, merge_sentences=True)
phone_ids = input_ids["phone_ids"]
phones = phone_ids[0].numpy()
if args.enable_auto_log:
logger.times.stamp()
input_names = fastspeech2_predictor.get_input_names()
phones_handle = fastspeech2_predictor.get_input_handle(input_names[0])
phones_handle.reshape(phones.shape)
phones_handle.copy_from_cpu(phones)
fastspeech2_predictor.run()
output_names = fastspeech2_predictor.get_output_names()
output_handle = fastspeech2_predictor.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu()
input_names = pwg_predictor.get_input_names()
mel_handle = pwg_predictor.get_input_handle(input_names[0])
mel_handle.reshape(output_data.shape)
mel_handle.copy_from_cpu(output_data)
pwg_predictor.run()
output_names = pwg_predictor.get_output_names()
output_handle = pwg_predictor.get_output_handle(output_names[0])
wav = output_data = output_handle.copy_to_cpu()
if args.enable_auto_log:
logger.times.stamp()
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000)
if args.enable_auto_log:
logger.times.end(stamp=True)
print(f"{utt_id} done!")
if args.enable_auto_log:
logger.report()
if __name__ == "__main__":
main()
......@@ -13,12 +13,15 @@
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
import numpy as np
import paddle
import soundfile as sf
import yaml
from paddle import jit
from paddle.static import InputSpec
from yacs.config import CfgNode
from parakeet.frontend.zh_frontend import Frontend
......@@ -74,7 +77,21 @@ def evaluate(args, fastspeech2_config, pwg_config):
pwg_normalizer = ZScore(mu, std)
fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model)
fastspeech2_inference.eval()
fastspeech2_inference = jit.to_static(
fastspeech2_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
paddle.jit.save(fastspeech2_inference,
os.path.join(args.inference_dir, "fastspeech2"))
fastspeech2_inference = paddle.jit.load(
os.path.join(args.inference_dir, "fastspeech2"))
pwg_inference = PWGInference(pwg_normalizer, vocoder)
pwg_inference.eval()
pwg_inference = jit.to_static(
pwg_inference, input_spec=[
InputSpec([-1, 80], dtype=paddle.float32),
])
paddle.jit.save(pwg_inference, os.path.join(args.inference_dir, "pwg"))
pwg_inference = paddle.jit.load(os.path.join(args.inference_dir, "pwg"))
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
......@@ -135,6 +152,8 @@ def main():
type=str,
help="text to synthesize, a 'utt_id sentence' pair per line.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--inference-dir", type=str, help="dir to save inference models")
parser.add_argument(
"--device", type=str, default="gpu", help="device type to use.")
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
......
......@@ -388,7 +388,6 @@ class FastSpeech2(nn.Layer):
spk_id=None,
tone_id=None) -> Sequence[paddle.Tensor]:
# forward encoder
bs = xs.shape[0]
x_masks = self._source_mask(ilens)
# (B, Tmax, adim)
hs, _ = self.encoder(xs, x_masks)
......@@ -428,6 +427,7 @@ class FastSpeech2(nn.Layer):
e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs = hs + e_embs + p_embs
# (B, Lmax, adim)
hs = self.length_regulator(hs, d_outs, alpha)
else:
......@@ -438,6 +438,7 @@ class FastSpeech2(nn.Layer):
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs = hs + e_embs + p_embs
# (B, Lmax, adim)
hs = self.length_regulator(hs, ds)
......@@ -455,7 +456,8 @@ class FastSpeech2(nn.Layer):
zs, _ = self.decoder(hs, h_masks)
# (B, Lmax, odim)
before_outs = self.feat_out(zs).reshape((bs, -1, self.odim))
before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim))
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
......@@ -463,6 +465,7 @@ class FastSpeech2(nn.Layer):
else:
after_outs = before_outs + self.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
return before_outs, after_outs, d_outs, p_outs, e_outs
def inference(
......
......@@ -48,10 +48,9 @@ class LengthRegulator(nn.Layer):
encodings: (B, T, C)
durations: (B, T)
"""
batch_size, t_enc = durations.shape
# durations = durations.numpy()
slens = paddle.sum(durations, -1)
t_dec = paddle.max(slens)
batch_size, t_enc = paddle.shape(durations)
slens = durations.sum(-1)
t_dec = slens.max()
M = paddle.zeros([batch_size, t_dec, t_enc])
for i in range(batch_size):
k = 0
......@@ -60,7 +59,6 @@ class LengthRegulator(nn.Layer):
if d >= 1:
M[i, k:k + d, j] = 1
k += d
M = paddle.to_tensor(M, dtype=encodings.dtype)
encodings = paddle.matmul(M, encodings)
return encodings
......
......@@ -37,7 +37,7 @@ class MultiHeadedAttention(nn.Layer):
def __init__(self, n_head, n_feat, dropout_rate):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__()
# assert n_feat % n_head == 0
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
......@@ -108,7 +108,9 @@ class MultiHeadedAttention(nn.Layer):
if mask is not None:
mask = mask.unsqueeze(1)
mask = paddle.logical_not(mask)
min_value = float(numpy.finfo("float32").min)
# assume scores.dtype==paddle.float32, we only use "float32" here
dtype = str(scores.dtype).split(".")[-1]
min_value = numpy.finfo(dtype).min
scores = masked_fill(scores, mask, min_value)
# (batch, head, time1, time2)
self.attn = softmax(scores)
......
......@@ -31,9 +31,16 @@ class PositionalEncoding(nn.Layer):
Maximum input length.
reverse : bool
Whether to reverse the input position.
type : str
dtype of param
"""
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
def __init__(self,
d_model,
dropout_rate,
max_len=5000,
dtype="float32",
reverse=False):
"""Construct an PositionalEncoding object."""
super(PositionalEncoding, self).__init__()
self.d_model = d_model
......@@ -41,21 +48,21 @@ class PositionalEncoding(nn.Layer):
self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(paddle.expand(paddle.to_tensor(0.0), (1, max_len)))
self.dtype = dtype
self.extend_pe(paddle.expand(paddle.zeros([1]), (1, max_len)))
def extend_pe(self, x):
"""Reset the positional encodings."""
pe = paddle.zeros([paddle.shape(x)[1], self.d_model])
x_shape = paddle.shape(x)
pe = paddle.zeros([x_shape[1], self.d_model])
if self.reverse:
position = paddle.arange(
paddle.shape(x)[1] - 1, -1, -1.0,
dtype=paddle.float32).unsqueeze(1)
x_shape[1] - 1, -1, -1.0, dtype=self.dtype).unsqueeze(1)
else:
position = paddle.arange(
0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1)
0, x_shape[1], dtype=self.dtype).unsqueeze(1)
div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
paddle.arange(0, self.d_model, 2, dtype=self.dtype) *
-(math.log(10000.0) / self.d_model))
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
......@@ -76,8 +83,8 @@ class PositionalEncoding(nn.Layer):
Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, :paddle.shape(x)[1]]
T = paddle.shape(x)[1]
x = x * self.xscale + self.pe[:, :T]
return self.dropout(x)
......@@ -94,21 +101,26 @@ class ScaledPositionalEncoding(PositionalEncoding):
Dropout rate.
max_len : int
Maximum input length.
dtype : str
dtype of param
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
def __init__(self, d_model, dropout_rate, max_len=5000, dtype="float32"):
"""Initialize class."""
super().__init__(
d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
x = paddle.ones([1], dtype="float32")
d_model=d_model,
dropout_rate=dropout_rate,
max_len=max_len,
dtype=dtype)
x = paddle.ones([1], dtype=self.dtype)
self.alpha = paddle.create_parameter(
shape=x.shape,
dtype="float32",
dtype=self.dtype,
default_initializer=paddle.nn.initializer.Assign(x))
def reset_parameters(self):
"""Reset parameters."""
self.alpha = paddle.to_tensor(1.0)
self.alpha = paddle.ones([1])
def forward(self, x):
"""Add positional encoding.
......@@ -123,5 +135,6 @@ class ScaledPositionalEncoding(PositionalEncoding):
Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
x = x + self.alpha * self.pe[:, :paddle.shape(x)[1]]
T = paddle.shape(x)[1]
x = x + self.alpha * self.pe[:, :T]
return self.dropout(x)
......@@ -87,7 +87,7 @@ class EncoderLayer(nn.Layer):
if cache is None:
x_q = x
else:
# assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
......
......@@ -55,10 +55,12 @@ class LayerNorm(paddle.nn.LayerNorm):
orig_perm = list(range(len_dim))
new_perm = orig_perm[:]
# Python style item change is not able when converting dygraph to static graph.
# new_perm[self.dim], new_perm[len_dim -1] = new_perm[len_dim -1], new_perm[self.dim]
# use C++ style item change here
temp = new_perm[self.dim]
new_perm[self.dim] = new_perm[len_dim - 1]
new_perm[len_dim - 1] = temp
# new_perm[self.dim], new_perm[len_dim -1] = new_perm[len_dim -1], new_perm[self.dim]
return paddle.transpose(
super(LayerNorm, self).forward(paddle.transpose(x, new_perm)),
......
......@@ -25,6 +25,7 @@ def is_broadcastable(shp1, shp2):
return True
# assume that len(shp1) == len(shp2)
def broadcast_shape(shp1, shp2):
result = []
for a, b in zip(shp1[::-1], shp2[::-1]):
......@@ -35,6 +36,7 @@ def broadcast_shape(shp1, shp2):
def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
# comment following line for converting dygraph to static graph.
# assert is_broadcastable(xs.shape, mask.shape) is True
# bshape = paddle.broadcast_shape(xs.shape, mask.shape)
bshape = broadcast_shape(xs.shape, mask.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册