未验证 提交 eaf5ec44 编写于 作者: B blue-fish 提交者: GitHub

Add option to make toolbox deterministic (#432)

上级 5e400d47
......@@ -32,12 +32,13 @@ if __name__ == '__main__':
"overhead but allows to save some GPU memory for lower-end GPUs.")
parser.add_argument("--no_sound", action="store_true", help=\
"If True, audio won't be played.")
parser.add_argument("--seed", type=int, default=None, help=\
"Optional random number seed value to make toolbox deterministic.")
args = parser.parse_args()
print_args(args, parser)
if not args.no_sound:
import sounddevice as sd
print("Running a test of your configuration...\n")
if torch.cuda.is_available():
device_id = torch.cuda.current_device()
......@@ -61,7 +62,7 @@ if __name__ == '__main__':
## Load the models one by one.
print("Preparing the encoder, the synthesizer and the vocoder...")
encoder.load_model(args.enc_model_fpath)
synthesizer = Synthesizer(args.syn_model_dir.joinpath("taco_pretrained"), low_mem=args.low_mem)
synthesizer = Synthesizer(args.syn_model_dir.joinpath("taco_pretrained"), low_mem=args.low_mem, seed=args.seed)
vocoder.load_model(args.voc_model_fpath)
......@@ -158,6 +159,12 @@ if __name__ == '__main__':
## Generating the waveform
print("Synthesizing the waveform:")
# If seed is specified, reset torch seed and reload vocoder
if args.seed is not None:
torch.manual_seed(args.seed)
vocoder.load_model(args.voc_model_fpath)
# Synthesizing the waveform is fairly straightforward. Remember that the longer the
# spectrogram, the more time-efficient the vocoder.
generated_wav = vocoder.infer_waveform(spec)
......@@ -167,6 +174,9 @@ if __name__ == '__main__':
# There's a bug with sounddevice that makes the audio cut one second earlier, so we
# pad it.
generated_wav = np.pad(generated_wav, (0, synthesizer.sample_rate), mode="constant")
# Trim excess silences to compensate for gaps in spectrograms (issue #53)
generated_wav = encoder.preprocess_wav(generated_wav)
# Play the audio (non-blocking)
if not args.no_sound:
......
......@@ -26,6 +26,8 @@ if __name__ == '__main__':
parser.add_argument("--low_mem", action="store_true", help=\
"If True, the memory used by the synthesizer will be freed after each use. Adds large "
"overhead but allows to save some GPU memory for lower-end GPUs.")
parser.add_argument("--seed", type=int, default=None, help=\
"Optional random number seed value to make toolbox deterministic.")
args = parser.parse_args()
print_args(args, parser)
......
......@@ -15,7 +15,7 @@ class Synthesizer:
sample_rate = hparams.sample_rate
hparams = hparams
def __init__(self, checkpoints_dir: Path, verbose=True, low_mem=False):
def __init__(self, checkpoints_dir: Path, verbose=True, low_mem=False, seed=None):
"""
Creates a synthesizer ready for inference. The actual model isn't loaded in memory until
needed or until load() is called.
......@@ -26,10 +26,14 @@ class Synthesizer:
:param low_mem: if True, the model will be loaded in a separate process and its resources
will be released after each usage. Adds a large overhead, only recommended if your GPU
memory is low (<= 2gb)
:param seed: optional integer for seeding random number generators when initializing model
This makes the synthesizer output consistent for a given embedding and input text.
However, it requires the model to be reloaded whenever a text is synthesized.
"""
self.verbose = verbose
self._low_mem = low_mem
self._seed = seed
# Prepare the model
self._model = None # type: Tacotron2
checkpoint_state = tf.train.get_checkpoint_state(checkpoints_dir)
......@@ -40,7 +44,19 @@ class Synthesizer:
model_name = checkpoints_dir.parent.name.replace("logs-", "")
step = int(self.checkpoint_fpath[self.checkpoint_fpath.rfind('-') + 1:])
print("Found synthesizer \"%s\" trained to step %d" % (model_name, step))
def set_seed(self, new_seed):
"""
Updates the seed that initializes random number generators associated with Tacotron2.
Returns the new seed state as confirmation.
"""
try:
self._seed = int(new_seed)
except:
self._seed = None
return self._seed
def is_loaded(self):
"""
Whether the model is loaded in GPU memory.
......@@ -55,7 +71,7 @@ class Synthesizer:
if self._low_mem:
raise Exception("Cannot load the synthesizer permanently in low mem mode")
tf.compat.v1.reset_default_graph()
self._model = Tacotron2(self.checkpoint_fpath, hparams)
self._model = Tacotron2(self.checkpoint_fpath, hparams, seed=self._seed)
def synthesize_spectrograms(self, texts: List[str],
embeddings: Union[np.ndarray, List[np.ndarray]],
......@@ -73,7 +89,8 @@ class Synthesizer:
"""
if not self._low_mem:
# Usual inference mode: load the model on the first request and keep it loaded.
if not self.is_loaded():
# Reload it every time for deterministic operation if seed specified.
if not self.is_loaded() or self._seed is not None:
self.load()
specs, alignments = self._model.my_synthesize(embeddings, texts)
else:
......@@ -89,7 +106,7 @@ class Synthesizer:
def _one_shot_synthesize_spectrograms(checkpoint_fpath, embeddings, texts):
# Load the model and forward the inputs
tf.compat.v1.reset_default_graph()
model = Tacotron2(checkpoint_fpath, hparams)
model = Tacotron2(checkpoint_fpath, hparams, seed=self._seed)
specs, alignments = model.my_synthesize(embeddings, texts)
# Detach the outputs (not doing so will cause the process to hang)
......
......@@ -9,8 +9,13 @@ import os
class Tacotron2:
def __init__(self, checkpoint_path, hparams, gta=False, model_name="Tacotron"):
def __init__(self, checkpoint_path, hparams, gta=False, model_name="Tacotron", seed=None):
log("Constructing model: %s" % model_name)
# Initialize tensorflow random number seed for deterministic operation if provided
if seed is not None:
tf.compat.v1.set_random_seed(seed)
#Force the batch size to be known in order to use attention masking in batch synthesis
inputs = tf.compat.v1.placeholder(tf.int32, (None, None), name="inputs")
input_lengths = tf.compat.v1.placeholder(tf.int32, (None,), name="input_lengths")
......
......@@ -8,6 +8,7 @@ from toolbox.utterance import Utterance
import numpy as np
import traceback
import sys
import torch
# Use this directory structure for your datasets, or modify it to fit your needs
......@@ -38,7 +39,7 @@ recognized_datasets = [
MAX_WAVES = 15
class Toolbox:
def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, low_mem):
def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, low_mem, seed):
sys.excepthook = self.excepthook
self.datasets_root = datasets_root
self.low_mem = low_mem
......@@ -50,10 +51,17 @@ class Toolbox:
self.waves_list = []
self.waves_count = 0
self.waves_namelist = []
# Check for webrtcvad (enables removal of silences in vocoder output)
try:
import webrtcvad
self.trim_silences = True
except:
self.trim_silences = False
# Initialize the events and the interface
self.ui = UI()
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir)
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, seed)
self.setup_events()
self.ui.start()
......@@ -105,7 +113,8 @@ class Toolbox:
self.ui.generate_button.clicked.connect(func)
self.ui.synthesize_button.clicked.connect(self.synthesize)
self.ui.vocode_button.clicked.connect(self.vocode)
self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
# UMAP legend
self.ui.clear_button.clicked.connect(self.clear_utterances)
......@@ -118,9 +127,10 @@ class Toolbox:
def replay_last_wav(self):
self.ui.play(self.current_wav, Synthesizer.sample_rate)
def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir):
def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, seed):
self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir)
self.ui.populate_gen_options(seed, self.trim_silences)
def load_from_browser(self, fpath=None):
if fpath is None:
......@@ -192,6 +202,13 @@ class Toolbox:
self.synthesizer = Synthesizer(checkpoints_dir, low_mem=self.low_mem)
if not self.synthesizer.is_loaded():
self.ui.log("Loading the synthesizer %s" % self.synthesizer.checkpoint_fpath)
# Update the synthesizer random seed
if self.ui.random_seed_checkbox.isChecked():
seed = self.synthesizer.set_seed(int(self.ui.seed_textbox.text()))
self.ui.populate_gen_options(seed, self.trim_silences)
else:
seed = self.synthesizer.set_seed(None)
texts = self.ui.text_prompt.toPlainText().split("\n")
embed = self.ui.selected_utterance.embed
......@@ -208,9 +225,20 @@ class Toolbox:
speaker_name, spec, breaks, _ = self.current_generated
assert spec is not None
# Initialize the vocoder model and make it determinstic, if user provides a seed
if self.ui.random_seed_checkbox.isChecked():
seed = self.synthesizer.set_seed(int(self.ui.seed_textbox.text()))
self.ui.populate_gen_options(seed, self.trim_silences)
else:
seed = None
if seed is not None:
torch.manual_seed(seed)
# Synthesize the waveform
if not vocoder.is_loaded():
if not vocoder.is_loaded() or seed is not None:
self.init_vocoder()
def vocoder_progress(i, seq_len, b_size, gen_rate):
real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000
line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \
......@@ -233,6 +261,10 @@ class Toolbox:
breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
# Trim excessive silences
if self.ui.trim_silences_checkbox.isChecked():
wav = encoder.preprocess_wav(wav)
# Play it
wav = wav / np.abs(wav).max() * 0.97
self.ui.play(wav, Synthesizer.sample_rate)
......@@ -299,3 +331,6 @@ class Toolbox:
vocoder.load_model(model_fpath)
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
self.ui.set_loading(0)
def update_seed_textbox(self):
self.ui.update_seed_textbox()
......@@ -389,6 +389,26 @@ class UI(QDialog):
self.loading_bar.setTextVisible(value != 0)
self.app.processEvents()
def populate_gen_options(self, seed, trim_silences):
if seed is not None:
self.random_seed_checkbox.setChecked(True)
self.seed_textbox.setText(str(seed))
self.seed_textbox.setEnabled(True)
else:
self.random_seed_checkbox.setChecked(False)
self.seed_textbox.setText(str(0))
self.seed_textbox.setEnabled(False)
if not trim_silences:
self.trim_silences_checkbox.setChecked(False)
self.trim_silences_checkbox.setDisabled(True)
def update_seed_textbox(self):
if self.random_seed_checkbox.isChecked():
self.seed_textbox.setEnabled(True)
else:
self.seed_textbox.setEnabled(False)
def reset_interface(self):
self.draw_embed(None, None, "current")
self.draw_embed(None, None, "generated")
......@@ -555,6 +575,19 @@ class UI(QDialog):
layout.addWidget(self.vocode_button)
gen_layout.addLayout(layout)
layout_seed = QGridLayout()
self.random_seed_checkbox = QCheckBox("Random seed:")
self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.")
layout_seed.addWidget(self.random_seed_checkbox, 0, 0)
self.seed_textbox = QLineEdit()
self.seed_textbox.setMaximumWidth(80)
layout_seed.addWidget(self.seed_textbox, 0, 1)
self.trim_silences_checkbox = QCheckBox("Enhance vocoder output")
self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
" This feature requires `webrtcvad` to be installed.")
layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2)
gen_layout.addLayout(layout_seed)
self.loading_bar = QProgressBar()
gen_layout.addWidget(self.loading_bar)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册