提交 4c3e57a2 编写于 作者: 小湉湉's avatar 小湉湉

align preprocess of wavernn, test=tts

上级 fb0acd40
......@@ -6,10 +6,50 @@ stop_stage=100
config_path=$1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
python3 ${BIN_DIR}/preprocess.py \
--input=~/datasets/BZNSYP/ \
--output=dump \
--dataset=csmsc \
# get durations from MFA's result
echo "Generate durations.txt from MFA results ..."
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
--inputdir=./baker_alignment_tone \
--output=durations.txt \
--config=${config_path}
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# extract features
echo "Extract features ..."
python3 ${BIN_DIR}/../gan_vocoder/preprocess.py \
--rootdir=~/datasets/BZNSYP/ \
--dataset=baker \
--dumpdir=dump \
--dur-file=durations.txt \
--config=${config_path} \
--cut-sil=True \
--num-cpu=20
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# get features' stats(mean and std)
echo "Get features' stats ..."
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="feats"
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# normalize, dev and test should use train's stats
echo "Normalize ..."
python3 ${BIN_DIR}/../gan_vocoder/normalize.py \
--metadata=dump/train/raw/metadata.jsonl \
--dumpdir=dump/train/norm \
--stats=dump/train/feats_stats.npy
python3 ${BIN_DIR}/../gan_vocoder/normalize.py \
--metadata=dump/dev/raw/metadata.jsonl \
--dumpdir=dump/dev/norm \
--stats=dump/train/feats_stats.npy
python3 ${BIN_DIR}/../gan_vocoder/normalize.py \
--metadata=dump/test/raw/metadata.jsonl \
--dumpdir=dump/test/norm \
--stats=dump/train/feats_stats.npy
fi
......@@ -3,12 +3,11 @@
config_path=$1
train_output_path=$2
ckpt_name=$3
test_input=$4
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize.py \
--config=${config_path} \
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
--input=${test_input} \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=${train_output_path}/test
......@@ -2,8 +2,12 @@
config_path=$1
train_output_path=$2
FLAGS_cudnn_exhaustive_search=true \
FLAGS_conv_workspace_size_limit=4000 \
python ${BIN_DIR}/train.py \
--train-metadata=dump/train/norm/metadata.jsonl \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--data=dump/ \
--output-dir=${train_output_path} \
--ngpu=1
......@@ -9,7 +9,7 @@ stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
test_input=dump/mel_test
test_input=dump/dump_gta_test
ckpt_name=snapshot_iter_100000.pdz
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
......@@ -25,9 +25,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# copy some test mels from dump
mkdir -p ${test_input}
cp -r dump/mel/00995*.npy ${test_input}
# synthesize
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} ${test_input}|| exit -1
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
......@@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from pathlib import Path
import numpy as np
import paddle
from paddle.io import Dataset
def label_2_float(x, bits):
......@@ -44,102 +42,6 @@ def decode_mu_law(y, mu, from_labels=True):
return x
class WaveRNNDataset(Dataset):
"""A simple dataset adaptor for the processed ljspeech dataset."""
def __init__(self, root):
self.root = Path(root).expanduser()
records = []
with open(self.root / "metadata.csv", 'r') as rf:
for line in rf:
name = line.split("\t")[0]
mel_path = str(self.root / "mel" / (str(name) + ".npy"))
wav_path = str(self.root / "wav" / (str(name) + ".npy"))
records.append((mel_path, wav_path))
self.records = records
def __getitem__(self, i):
mel_name, wav_name = self.records[i]
mel = np.load(mel_name)
wav = np.load(wav_name)
return mel, wav
def __len__(self):
return len(self.records)
class WaveRNNClip(object):
def __init__(self,
mode: str='RAW',
batch_max_steps: int=4500,
hop_size: int=300,
aux_context_window: int=2,
bits: int=9):
self.mode = mode
self.mel_win = batch_max_steps // hop_size + 2 * aux_context_window
self.batch_max_steps = batch_max_steps
self.hop_size = hop_size
self.aux_context_window = aux_context_window
if self.mode == 'MOL':
self.bits = 16
else:
self.bits = bits
def __call__(self, batch):
# batch: [mel, quant]
# voc_pad = 2 this will pad the input so that the resnet can 'see' wider than input length
# max_offsets = n_frames - 2 - (mel_win + 2 * hp.voc_pad) = n_frames - 15
max_offsets = [
x[0].shape[-1] - 2 - (self.mel_win + 2 * self.aux_context_window)
for x in batch
]
# the slice point of mel selecting randomly
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
# the slice point of wav selecting randomly, which is behind 2(=pad) frames
sig_offsets = [(offset + self.aux_context_window) * self.hop_size
for offset in mel_offsets]
# mels.sape[1] = voc_seq_len // hop_length + 2 * voc_pad
mels = [
x[0][:, mel_offsets[i]:mel_offsets[i] + self.mel_win]
for i, x in enumerate(batch)
]
# label.shape[1] = voc_seq_len + 1
labels = [
x[1][sig_offsets[i]:sig_offsets[i] + self.batch_max_steps + 1]
for i, x in enumerate(batch)
]
mels = np.stack(mels).astype(np.float32)
labels = np.stack(labels).astype(np.int64)
mels = paddle.to_tensor(mels)
labels = paddle.to_tensor(labels, dtype='int64')
# x is input, y is label
x = labels[:, :self.batch_max_steps]
y = labels[:, 1:]
'''
mode = RAW:
mu_law = True:
quant: bits = 9 0, 1, 2, ..., 509, 510, 511 int
mu_law = False
quant bits = 9 [0, 511] float
mode = MOL:
quant: bits = 16 [0. 65536] float
'''
# x should be normalizes in.[0, 1] in RAW mode
x = label_2_float(paddle.cast(x, dtype='float32'), self.bits)
# y should be normalizes in.[0, 1] in MOL mode
if self.mode == 'MOL':
y = label_2_float(paddle.cast(y, dtype='float32'), self.bits)
return x, y, mels
class Clip(object):
"""Collate functor for training vocoders.
"""
......@@ -174,7 +76,7 @@ class Clip(object):
self.end_offset = -(self.batch_max_frames + aux_context_window)
self.mel_threshold = self.batch_max_frames + 2 * aux_context_window
def __call__(self, examples):
def __call__(self, batch):
"""Convert into batch tensors.
Parameters
......@@ -192,11 +94,11 @@ class Clip(object):
"""
# check length
examples = [
self._adjust_length(b['wave'], b['feats']) for b in examples
batch = [
self._adjust_length(b['wave'], b['feats']) for b in batch
if b['feats'].shape[0] > self.mel_threshold
]
xs, cs = [b[0] for b in examples], [b[1] for b in examples]
xs, cs = [b[0] for b in batch], [b[1] for b in batch]
# make batch with random cut
c_lengths = [c.shape[0] for c in cs]
......@@ -214,7 +116,7 @@ class Clip(object):
c_batch = np.stack(
[c[start:end] for c, start, end in zip(cs, c_starts, c_ends)])
# convert each batch to tensor, asuume that each item in batch has the same length
# convert each batch to tensor, assume that each item in batch has the same length
y_batch = paddle.to_tensor(
y_batch, dtype=paddle.float32).unsqueeze(1) # (B, 1, T)
c_batch = paddle.to_tensor(
......@@ -245,3 +147,111 @@ class Clip(object):
0] * self.hop_size, f"wave length: ({len(x)}), mel length: ({c.shape[0]})"
return x, c
class WaveRNNClip(Clip):
def __init__(self,
mode: str='RAW',
batch_max_steps: int=4500,
hop_size: int=300,
aux_context_window: int=2,
bits: int=9,
mu_law: bool=True):
self.mode = mode
self.mel_win = batch_max_steps // hop_size + 2 * aux_context_window
self.batch_max_steps = batch_max_steps
self.hop_size = hop_size
self.aux_context_window = aux_context_window
self.mu_law = mu_law
self.batch_max_frames = batch_max_steps // hop_size
self.mel_threshold = self.batch_max_frames + 2 * aux_context_window
if self.mode == 'MOL':
self.bits = 16
else:
self.bits = bits
def to_quant(self, wav):
if self.mode == 'RAW':
if self.mu_law:
quant = encode_mu_law(wav, mu=2**self.bits)
else:
quant = float_2_label(wav, bits=self.bits)
elif self.mode == 'MOL':
quant = float_2_label(wav, bits=16)
quant = quant.astype(np.int64)
return quant
def __call__(self, batch):
# voc_pad = 2 this will pad the input so that the resnet can 'see' wider than input length
# max_offsets = n_frames - 2 - (mel_win + 2 * hp.voc_pad) = n_frames - 15
"""Convert into batch tensors.
Parameters
----------
batch : list
list of tuple of the pair of audio and features.
Audio shape (T, ), features shape(T', C).
Returns
----------
Tensor
Auxiliary feature batch (B, C, T'), where
T = (T' - 2 * aux_context_window) * hop_size.
Tensor
Target signal batch (B, 1, T).
"""
# check length
batch = [
self._adjust_length(b['wave'], b['feats']) for b in batch
if b['feats'].shape[0] > self.mel_threshold
]
wav, mel = [b[0] for b in batch], [b[1] for b in batch]
# mel 此处需要转置
mel = [x.T for x in mel]
max_offsets = [
x.shape[-1] - 2 - (self.mel_win + 2 * self.aux_context_window)
for x in mel
]
# the slice point of mel selecting randomly
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
# the slice point of wav selecting randomly, which is behind 2(=pad) frames
sig_offsets = [(offset + self.aux_context_window) * self.hop_size
for offset in mel_offsets]
# mels.shape[1] = voc_seq_len // hop_length + 2 * voc_pad
mels = [
x[:, mel_offsets[i]:mel_offsets[i] + self.mel_win]
for i, x in enumerate(mel)
]
# label.shape[1] = voc_seq_len + 1
wav = [self.to_quant(x) for x in wav]
labels = [
x[sig_offsets[i]:sig_offsets[i] + self.batch_max_steps + 1]
for i, x in enumerate(wav)
]
mels = np.stack(mels).astype(np.float32)
labels = np.stack(labels).astype(np.int64)
mels = paddle.to_tensor(mels)
labels = paddle.to_tensor(labels, dtype='int64')
# x is input, y is label
x = labels[:, :self.batch_max_steps]
y = labels[:, 1:]
'''
mode = RAW:
mu_law = True:
quant: bits = 9 0, 1, 2, ..., 509, 510, 511 int
mu_law = False
quant bits = 9 [0, 511] float
mode = MOL:
quant: bits = 16 [0. 65536] float
'''
# x should be normalizes in.[0, 1] in RAW mode
x = label_2_float(paddle.cast(x, dtype='float32'), self.bits)
# y should be normalizes in.[0, 1] in MOL mode
if self.mode == 'MOL':
y = label_2_float(paddle.cast(y, dtype='float32'), self.bits)
return x, y, mels
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from multiprocessing import cpu_count
from multiprocessing import Pool
from pathlib import Path
import librosa
import numpy as np
import pandas as pd
import tqdm
import yaml
from yacs.config import CfgNode
from paddlespeech.t2s.data.get_feats import LogMelFBank
from paddlespeech.t2s.datasets import CSMSCMetaData
from paddlespeech.t2s.datasets import LJSpeechMetaData
from paddlespeech.t2s.datasets.vocoder_batch_fn import encode_mu_law
from paddlespeech.t2s.datasets.vocoder_batch_fn import float_2_label
class Transform(object):
def __init__(self, output_dir: Path, config):
self.fs = config.fs
self.peak_norm = config.peak_norm
self.bits = config.model.bits
self.mode = config.model.mode
self.mu_law = config.mu_law
self.wav_dir = output_dir / "wav"
self.mel_dir = output_dir / "mel"
self.wav_dir.mkdir(exist_ok=True)
self.mel_dir.mkdir(exist_ok=True)
self.mel_extractor = LogMelFBank(
sr=config.fs,
n_fft=config.n_fft,
hop_length=config.n_shift,
win_length=config.win_length,
window=config.window,
n_mels=config.n_mels,
fmin=config.fmin,
fmax=config.fmax)
if self.mode != 'RAW' and self.mode != 'MOL':
raise RuntimeError('Unknown mode value - ', self.mode)
def __call__(self, example):
wav_path, _, _ = example
base_name = os.path.splitext(os.path.basename(wav_path))[0]
# print("self.sample_rate:",self.sample_rate)
wav, _ = librosa.load(wav_path, sr=self.fs)
peak = np.abs(wav).max()
if self.peak_norm or peak > 1.0:
wav /= peak
mel = self.mel_extractor.get_log_mel_fbank(wav).T
if self.mode == 'RAW':
if self.mu_law:
quant = encode_mu_law(wav, mu=2**self.bits)
else:
quant = float_2_label(wav, bits=self.bits)
elif self.mode == 'MOL':
quant = float_2_label(wav, bits=16)
mel = mel.astype(np.float32)
audio = quant.astype(np.int64)
np.save(str(self.wav_dir / base_name), audio)
np.save(str(self.mel_dir / base_name), mel)
return base_name, mel.shape[-1], audio.shape[-1]
def create_dataset(config,
input_dir,
output_dir,
nprocs: int=1,
dataset_type: str="ljspeech"):
input_dir = Path(input_dir).expanduser()
'''
LJSpeechMetaData.records: [filename, normalized text, speaker name(ljspeech)]
CSMSCMetaData.records: [filename, normalized text, pinyin]
'''
if dataset_type == 'ljspeech':
dataset = LJSpeechMetaData(input_dir)
else:
dataset = CSMSCMetaData(input_dir)
output_dir = Path(output_dir).expanduser()
output_dir.mkdir(exist_ok=True)
transform = Transform(output_dir, config)
file_names = []
pool = Pool(processes=nprocs)
for info in tqdm.tqdm(pool.imap(transform, dataset), total=len(dataset)):
base_name, mel_len, audio_len = info
file_names.append((base_name, mel_len, audio_len))
meta_data = pd.DataFrame.from_records(file_names)
meta_data.to_csv(
str(output_dir / "metadata.csv"), sep="\t", index=None, header=None)
print("saved meta data in to {}".format(
os.path.join(output_dir, "metadata.csv")))
print("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="create dataset")
parser.add_argument(
"--config", type=str, help="config file to overwrite default config.")
parser.add_argument(
"--input", type=str, help="path of the ljspeech dataset")
parser.add_argument(
"--output", type=str, help="path to save output dataset")
parser.add_argument(
"--num-cpu",
type=int,
default=cpu_count() // 2,
help="number of process.")
parser.add_argument(
"--dataset",
type=str,
default="ljspeech",
help="The dataset to preprocess, ljspeech or csmsc")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = CfgNode(yaml.safe_load(f))
if args.dataset != "ljspeech" and args.dataset != "csmsc":
raise RuntimeError('Unknown dataset - ', args.dataset)
create_dataset(
config,
input_dir=args.input,
output_dir=args.output,
nprocs=args.num_cpu,
dataset_type=args.dataset)
......@@ -15,13 +15,16 @@ import argparse
import os
from pathlib import Path
import jsonlines
import numpy as np
import paddle
import soundfile as sf
import yaml
from paddle import distributed as dist
from timer import timer
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.wavernn import WaveRNN
......@@ -30,10 +33,7 @@ def main():
parser.add_argument("--config", type=str, help="GANVocoder config file.")
parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
parser.add_argument(
"--input",
type=str,
help="path of directory containing mel spectrogram (in .npy format)")
parser.add_argument("--test-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
......@@ -65,24 +65,43 @@ def main():
model.eval()
mel_dir = Path(args.input).expanduser()
output_dir = Path(args.output_dir).expanduser()
with jsonlines.open(args.test_metadata, 'r') as reader:
metadata = list(reader)
test_dataset = DataTable(
metadata,
fields=['utt_id', 'feats'],
converters={
'utt_id': None,
'feats': np.load,
})
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for file_path in sorted(mel_dir.iterdir()):
mel = np.load(str(file_path))
mel = paddle.to_tensor(mel)
mel = mel.transpose([1, 0])
# input shape is (T', C_aux)
audio = model.generate(
c=mel,
batched=config.inference.gen_batched,
target=config.inference.target,
overlap=config.inference.overlap,
mu_law=config.mu_law,
gen_display=True)
audio_path = output_dir / (os.path.splitext(file_path.name)[0] + ".wav")
sf.write(audio_path, audio.numpy(), samplerate=config.fs)
print("[synthesize] {} -> {}".format(file_path, audio_path))
N = 0
T = 0
for example in test_dataset:
utt_id = example['utt_id']
mel = example['feats']
mel = paddle.to_tensor(mel) # (T, C)
with timer() as t:
with paddle.no_grad():
wav = model.generate(
c=mel,
batched=config.inference.gen_batched,
target=config.inference.target,
overlap=config.inference.overlap,
mu_law=config.mu_law,
gen_display=True)
wav = wav.numpy()
N += wav.size
T += t.elapse
speed = wav.size / t.elapse
rtf = config.fs / speed
print(
f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
)
sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs)
print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }")
if __name__ == "__main__":
......
......@@ -16,6 +16,8 @@ import os
import shutil
from pathlib import Path
import jsonlines
import numpy as np
import paddle
import yaml
from paddle import DataParallel
......@@ -25,9 +27,8 @@ from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam
from yacs.config import CfgNode
from paddlespeech.t2s.data import dataset
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.datasets.vocoder_batch_fn import WaveRNNClip
from paddlespeech.t2s.datasets.vocoder_batch_fn import WaveRNNDataset
from paddlespeech.t2s.models.wavernn import WaveRNN
from paddlespeech.t2s.models.wavernn import WaveRNNEvaluator
from paddlespeech.t2s.models.wavernn import WaveRNNUpdater
......@@ -56,10 +57,26 @@ def train_sp(args, config):
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
)
wavernn_dataset = WaveRNNDataset(args.data)
train_dataset, dev_dataset = dataset.split(
wavernn_dataset, len(wavernn_dataset) - config.valid_size)
# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
train_dataset = DataTable(
data=train_metadata,
fields=["wave", "feats"],
converters={
"wave": np.load,
"feats": np.load,
}, )
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=["wave", "feats"],
converters={
"wave": np.load,
"feats": np.load,
}, )
batch_fn = WaveRNNClip(
mode=config.model.mode,
......@@ -92,7 +109,9 @@ def train_sp(args, config):
collate_fn=batch_fn,
batch_sampler=dev_sampler,
num_workers=config.num_workers)
valid_generate_loader = DataLoader(dev_dataset, batch_size=1)
print("dataloaders done!")
model = WaveRNN(
......@@ -160,10 +179,11 @@ def train_sp(args, config):
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="Train a WaveRNN model.")
parser = argparse.ArgumentParser(description="Train a HiFiGAN model.")
parser.add_argument(
"--config", type=str, help="config file to overwrite default config.")
parser.add_argument("--data", type=str, help="input")
parser.add_argument("--train-metadata", type=str, help="training data.")
parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
......
......@@ -21,8 +21,6 @@ from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddlespeech.t2s.datasets.vocoder_batch_fn import decode_mu_law
from paddlespeech.t2s.datasets.vocoder_batch_fn import label_2_float
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
......@@ -156,31 +154,22 @@ class WaveRNNEvaluator(StandardEvaluator):
losses_dict["loss"] = float(loss)
self.iteration = ITERATION
if self.iteration % self.config.gen_eval_samples_interval_steps == 0:
self.gen_valid_samples()
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)
def gen_valid_samples(self):
for i, (mel, wav) in enumerate(self.valid_generate_loader):
for i, item in enumerate(self.valid_generate_loader):
if i >= self.config.generate_num:
print("before break")
break
print(
'\n| Generating: {}/{}'.format(i + 1, self.config.generate_num))
wav = wav[0]
if self.mode == 'MOL':
bits = 16
else:
bits = self.config.model.bits
if self.config.mu_law and self.mode != 'MOL':
wav = decode_mu_law(wav, 2**bits, from_labels=True)
else:
wav = label_2_float(wav, bits)
mel = item['feats']
wav = item['wave']
wav = wav.squeeze(0)
origin_save_path = self.valid_samples_dir / '{}_steps_{}_target.wav'.format(
self.iteration, i)
sf.write(origin_save_path, wav.numpy(), samplerate=self.config.fs)
......@@ -193,11 +182,20 @@ class WaveRNNEvaluator(StandardEvaluator):
gen_save_path = str(self.valid_samples_dir /
'{}_steps_{}_{}.wav'.format(self.iteration, i,
batch_str))
# (1, C_aux, T) -> (T, C_aux)
mel = mel.squeeze(0).transpose([1, 0])
# (1, T, C_aux) -> (T, C_aux)
mel = mel.squeeze(0)
gen_sample = self.model.generate(
mel, self.config.inference.gen_batched,
self.config.inference.target, self.config.inference.overlap,
self.config.mu_law)
sf.write(
gen_save_path, gen_sample.numpy(), samplerate=self.config.fs)
def __call__(self, trainer=None):
summary = self.evaluate()
for k, v in summary.items():
report(k, v)
# gen samples at then end of evaluate
self.iteration = ITERATION
if self.iteration % self.config.gen_eval_samples_interval_steps == 0:
self.gen_valid_samples()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册