__init__.py 11.9 KB
Newer Older
1
from toolbox.ui import UI
2
from encoder import inference as encoder
3
from synthesizer.inference import Synthesizer
4
from vocoder import inference as vocoder
U
unknown 已提交
5
from pathlib import Path
6
from time import perf_counter as timer
C
Corentin Jemine 已提交
7
from toolbox.utterance import Utterance
8
import numpy as np
9 10
import traceback
import sys
11

12

13
# Use this directory structure for your datasets, or modify it to fit your needs
14
recognized_datasets = [
C
Corentin Jemine 已提交
15 16 17 18 19 20 21
    "LibriSpeech/dev-clean",
    "LibriSpeech/dev-other",
    "LibriSpeech/test-clean",
    "LibriSpeech/test-other",
    "LibriSpeech/train-clean-100",
    "LibriSpeech/train-clean-360",
    "LibriSpeech/train-other-500",
22 23 24 25 26 27 28
    "LibriTTS/dev-clean",
    "LibriTTS/dev-other",
    "LibriTTS/test-clean",
    "LibriTTS/test-other",
    "LibriTTS/train-clean-100",
    "LibriTTS/train-clean-360",
    "LibriTTS/train-other-500",
29 30 31 32
    "LJSpeech-1.1",
    "VoxCeleb1/wav",
    "VoxCeleb1/test_wav",
    "VoxCeleb2/dev/aac",
33
    "VoxCeleb2/test/aac",
34 35 36
    "VCTK-Corpus/wav48",
]

37 38 39
#Maximum of generated wavs to keep on memory
MAX_WAVES = 15

40
class Toolbox:
41
    def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, low_mem):
42
        sys.excepthook = self.excepthook
43
        self.datasets_root = datasets_root
44
        self.low_mem = low_mem
C
Corentin Jemine 已提交
45
        self.utterances = set()
46
        self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
47
        
48
        self.synthesizer = None # type: Synthesizer
49 50 51 52
        self.current_wav = None
        self.waves_list = []
        self.waves_count = 0
        self.waves_namelist = []
53
        
54 55
        # Initialize the events and the interface
        self.ui = UI()
C
Corentin Jemine 已提交
56
        self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir)
57 58 59
        self.setup_events()
        self.ui.start()
        
60 61 62 63
    def excepthook(self, exc_type, exc_value, exc_tb):
        traceback.print_exception(exc_type, exc_value, exc_tb)
        self.ui.log("Exception: %s" % exc_value)
        
64
    def setup_events(self):
C
Corentin Jemine 已提交
65
        # Dataset, speaker and utterance selection
66
        self.ui.browser_load_button.clicked.connect(lambda: self.load_from_browser())
67 68 69 70 71 72 73 74
        random_func = lambda level: lambda: self.ui.populate_browser(self.datasets_root,
                                                                     recognized_datasets,
                                                                     level)
        self.ui.random_dataset_button.clicked.connect(random_func(0))
        self.ui.random_speaker_button.clicked.connect(random_func(1))
        self.ui.random_utterance_button.clicked.connect(random_func(2))
        self.ui.dataset_box.currentIndexChanged.connect(random_func(1))
        self.ui.speaker_box.currentIndexChanged.connect(random_func(2))
C
Corentin Jemine 已提交
75 76 77
        
        # Model selection
        self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
78 79 80
        def func(): 
            self.synthesizer = None
        self.ui.synthesizer_box.currentIndexChanged.connect(func)
C
Corentin Jemine 已提交
81 82
        self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
        
C
Corentin Jemine 已提交
83
        # Utterance selection
84 85
        func = lambda: self.load_from_browser(self.ui.browse_file())
        self.ui.browser_browse_button.clicked.connect(func)
C
Corentin Jemine 已提交
86 87
        func = lambda: self.ui.draw_utterance(self.ui.selected_utterance, "current")
        self.ui.utterance_history.currentIndexChanged.connect(func)
88
        func = lambda: self.ui.play(self.ui.selected_utterance.wav, Synthesizer.sample_rate)
C
Corentin Jemine 已提交
89
        self.ui.play_button.clicked.connect(func)
C
Corentin Jemine 已提交
90
        self.ui.stop_button.clicked.connect(self.ui.stop)
91
        self.ui.record_button.clicked.connect(self.record)
92 93

        #Audio
94 95
        self.ui.setup_audio_devices(Synthesizer.sample_rate)

96 97 98 99 100 101 102
        #Wav playback & save
        func = lambda: self.replay_last_wav()
        self.ui.replay_wav_button.clicked.connect(func)
        func = lambda: self.export_current_wave()
        self.ui.export_wav_button.clicked.connect(func)
        self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)

C
Corentin Jemine 已提交
103
        # Generation
104 105 106 107 108 109 110
        func = lambda: self.synthesize() or self.vocode()
        self.ui.generate_button.clicked.connect(func)
        self.ui.synthesize_button.clicked.connect(self.synthesize)
        self.ui.vocode_button.clicked.connect(self.vocode)
        
        # UMAP legend
        self.ui.clear_button.clicked.connect(self.clear_utterances)
U
unknown 已提交
111

112 113 114 115 116 117 118 119 120
    def set_current_wav(self, index):
        self.current_wav = self.waves_list[index]

    def export_current_wave(self):
        self.ui.save_audio_file(self.current_wav, Synthesizer.sample_rate)

    def replay_last_wav(self):
        self.ui.play(self.current_wav, Synthesizer.sample_rate)

C
Corentin Jemine 已提交
121
    def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir):
C
Corentin Jemine 已提交
122
        self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
C
Corentin Jemine 已提交
123 124
        self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir)
        
125 126 127 128 129 130 131 132 133 134 135 136
    def load_from_browser(self, fpath=None):
        if fpath is None:
            fpath = Path(self.datasets_root,
                         self.ui.current_dataset_name,
                         self.ui.current_speaker_name,
                         self.ui.current_utterance_name)
            name = str(fpath.relative_to(self.datasets_root))
            speaker_name = self.ui.current_dataset_name + '_' + self.ui.current_speaker_name
            
            # Select the next utterance
            if self.ui.auto_next_checkbox.isChecked():
                self.ui.browser_select_next()
137 138
        elif fpath == "":
            return 
139 140 141
        else:
            name = fpath.name
            speaker_name = fpath.parent.name
U
unknown 已提交
142
        
C
Corentin Jemine 已提交
143 144
        # Get the wav from the disk. We take the wav with the vocoder/synthesizer format for
        # playback, so as to have a fair comparison with the generated audio
145
        wav = Synthesizer.load_preprocess_wav(fpath)
C
Corentin Jemine 已提交
146 147
        self.ui.log("Loaded %s" % name)

148 149 150
        self.add_real_utterance(wav, name, speaker_name)
        
    def record(self):
151
        wav = self.ui.record_one(encoder.sampling_rate, 5)
152 153
        if wav is None:
            return 
154
        self.ui.play(wav, encoder.sampling_rate)
155 156 157 158 159 160

        speaker_name = "user01"
        name = speaker_name + "_rec_%05d" % np.random.randint(100000)
        self.add_real_utterance(wav, name, speaker_name)
        
    def add_real_utterance(self, wav, name, speaker_name):
C
Corentin Jemine 已提交
161
        # Compute the mel spectrogram
162
        spec = Synthesizer.make_spectrogram(wav)
163 164
        self.ui.draw_spec(spec, "current")

C
Corentin Jemine 已提交
165
        # Compute the embedding
166 167
        if not encoder.is_loaded():
            self.init_encoder()
168
        encoder_wav = encoder.preprocess_wav(wav)
C
Corentin Jemine 已提交
169
        embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
170

C
Corentin Jemine 已提交
171
        # Add the utterance
172
        utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
C
Corentin Jemine 已提交
173 174
        self.utterances.add(utterance)
        self.ui.register_utterance(utterance)
175

C
Corentin Jemine 已提交
176
        # Plot it
177 178
        self.ui.draw_embed(embed, name, "current")
        self.ui.draw_umap_projections(self.utterances)
C
Corentin Jemine 已提交
179
        
180 181 182 183 184
    def clear_utterances(self):
        self.utterances.clear()
        self.ui.draw_umap_projections(self.utterances)
        
    def synthesize(self):
185
        self.ui.log("Generating the mel spectrogram...")
186 187 188 189 190 191 192 193 194
        self.ui.set_loading(1)
        
        # Synthesize the spectrogram
        if self.synthesizer is None:
            model_dir = self.ui.current_synthesizer_model_dir
            checkpoints_dir = model_dir.joinpath("taco_pretrained")
            self.synthesizer = Synthesizer(checkpoints_dir, low_mem=self.low_mem)
        if not self.synthesizer.is_loaded():
            self.ui.log("Loading the synthesizer %s" % self.synthesizer.checkpoint_fpath)
C
Corentin Jemine 已提交
195 196
        
        texts = self.ui.text_prompt.toPlainText().split("\n")
C
Corentin Jemine 已提交
197
        embed = self.ui.selected_utterance.embed
C
Corentin Jemine 已提交
198
        embeds = np.stack([embed] * len(texts))
199
        specs = self.synthesizer.synthesize_spectrograms(texts, embeds)
200
        breaks = [spec.shape[1] for spec in specs]
C
Corentin Jemine 已提交
201
        spec = np.concatenate(specs, axis=1)
C
Corentin Jemine 已提交
202
        
203
        self.ui.draw_spec(spec, "generated")
204
        self.current_generated = (self.ui.selected_utterance.speaker_name, spec, breaks, None)
205 206
        self.ui.set_loading(0)

207
    def vocode(self):
208
        speaker_name, spec, breaks, _ = self.current_generated
209 210
        assert spec is not None

211 212 213
        # Synthesize the waveform
        if not vocoder.is_loaded():
            self.init_vocoder()
C
Corentin Jemine 已提交
214
        def vocoder_progress(i, seq_len, b_size, gen_rate):
215
            real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000
216 217
            line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \
                   % (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor)
C
Corentin Jemine 已提交
218 219
            self.ui.log(line, "overwrite")
            self.ui.set_loading(i, seq_len)
220
        if self.ui.current_vocoder_fpath is not None:
221
            self.ui.log("")
C
Corentin Jemine 已提交
222
            wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
223
        else:
224
            self.ui.log("Waveform generation with Griffin-Lim... ")
225
            wav = Synthesizer.griffin_lim(spec)
C
Corentin Jemine 已提交
226 227
        self.ui.set_loading(0)
        self.ui.log(" Done!", "append")
228
        
229
        # Add breaks
230
        b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
231 232
        b_starts = np.concatenate(([0], b_ends[:-1]))
        wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
233
        breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
234 235
        wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])

236
        # Play it
C
Corentin Jemine 已提交
237
        wav = wav / np.abs(wav).max() * 0.97
238
        self.ui.play(wav, Synthesizer.sample_rate)
239

240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
        # Name it (history displayed in combobox)
        # TODO better naming for the combobox items?
        wav_name = str(self.waves_count + 1)

        #Update waves combobox
        self.waves_count += 1
        if self.waves_count > MAX_WAVES:
          self.waves_list.pop()
          self.waves_namelist.pop()
        self.waves_list.insert(0, wav)
        self.waves_namelist.insert(0, wav_name)

        self.ui.waves_cb.disconnect()
        self.ui.waves_cb_model.setStringList(self.waves_namelist)
        self.ui.waves_cb.setCurrentIndex(0)
        self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)

        # Update current wav
        self.set_current_wav(0)
        
        #Enable replay and save buttons:
        self.ui.replay_wav_button.setDisabled(False)
        self.ui.export_wav_button.setDisabled(False)

264 265 266 267
        # Compute the embedding
        # TODO: this is problematic with different sampling rates, gotta fix it
        if not encoder.is_loaded():
            self.init_encoder()
268
        encoder_wav = encoder.preprocess_wav(wav)
269
        embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
C
Corentin Jemine 已提交
270 271
        
        # Add the utterance
272 273
        name = speaker_name + "_gen_%05d" % np.random.randint(100000)
        utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, True)
C
Corentin Jemine 已提交
274 275 276
        self.utterances.add(utterance)
        
        # Plot it
277 278
        self.ui.draw_embed(embed, name, "generated")
        self.ui.draw_umap_projections(self.utterances)
279
        
280
    def init_encoder(self):
C
Corentin Jemine 已提交
281 282
        model_fpath = self.ui.current_encoder_fpath
        
C
Corentin Jemine 已提交
283
        self.ui.log("Loading the encoder %s... " % model_fpath)
284
        self.ui.set_loading(1)
285
        start = timer()
C
Corentin Jemine 已提交
286
        encoder.load_model(model_fpath)
C
Corentin Jemine 已提交
287
        self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
288
        self.ui.set_loading(0)
289
           
C
Corentin Jemine 已提交
290
    def init_vocoder(self):
291 292 293 294 295
        model_fpath = self.ui.current_vocoder_fpath
        # Case of Griffin-lim
        if model_fpath is None:
            return 
    
C
Corentin Jemine 已提交
296
        self.ui.log("Loading the vocoder %s... " % model_fpath)
297
        self.ui.set_loading(1)
298 299
        start = timer()
        vocoder.load_model(model_fpath)
C
Corentin Jemine 已提交
300
        self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
301
        self.ui.set_loading(0)