From 8077b500a253c952f9f79e0aa6a0274b83f1c92e Mon Sep 17 00:00:00 2001 From: Corentin Jemine Date: Sat, 4 May 2019 19:55:39 +0200 Subject: [PATCH] Great speedup on the computation of the loss --- SV2TTS/encoder/model.py | 13 ++++++------- SV2TTS/encoder/train.py | 11 +++++++---- SV2TTS/encoder/ui/visualizations.py | 2 +- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/SV2TTS/encoder/model.py b/SV2TTS/encoder/model.py index ee64d5d..dad2911 100644 --- a/SV2TTS/encoder/model.py +++ b/SV2TTS/encoder/model.py @@ -80,18 +80,17 @@ class SpeakerEncoder(nn.Module): centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True) # Similarity matrix - sim_matrix = torch.zeros(speakers_per_batch * utterances_per_speaker, + sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker, speakers_per_batch).to(self.loss_device) for j in range(speakers_per_batch): - for i in range(utterances_per_speaker): - ji = j * utterances_per_speaker + i - for k in range(speakers_per_batch): - centroid = centroids_excl[j, i] if j == k else centroids_incl[k] - # The cosine similarity is the dot product when vectors are normalized - sim_matrix[ji, k] = torch.dot(embeds[j, i], centroid) + for k in range(speakers_per_batch): + centroid = centroids_excl[j] if j == k else centroids_incl[k] + sim_matrix[j, :, k] = (embeds[j] * centroid).sum(dim=1) sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias # Loss + sim_matrix = sim_matrix.view((speakers_per_batch * utterances_per_speaker, + speakers_per_batch)) ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker) target = torch.from_numpy(ground_truth).long().to(self.loss_device) loss = self.loss_fn(sim_matrix, target) diff --git a/SV2TTS/encoder/train.py b/SV2TTS/encoder/train.py index 09ca269..7f5efcc 100644 --- a/SV2TTS/encoder/train.py +++ b/SV2TTS/encoder/train.py @@ -61,10 +61,13 @@ def train(run_id: str, clean_data_root: Path, models_dir: Path, vis_every: int, # Forward pass inputs = torch.from_numpy(speaker_batch.data).to(device) + torch.cuda.synchronize(device) profiler.tick("Data to %s" % device) - embeds = model(inputs).to(loss_device) + embeds = model(inputs) + torch.cuda.synchronize(device) profiler.tick("Forward pass") - loss, eer = model.loss(embeds.view((speakers_per_batch, utterances_per_speaker, -1))) + embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device) + loss, eer = model.loss(embeds_loss) profiler.tick("Loss") # Backward pass @@ -84,8 +87,8 @@ def train(run_id: str, clean_data_root: Path, models_dir: Path, vis_every: int, print("Drawing and saving projections (step %d)" % step) backup_dir.mkdir(exist_ok=True) projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step)) - embeds_numpy = embeds.detach().numpy() - vis.draw_projections(embeds_numpy, utterances_per_speaker, step, projection_fpath) + embeds = embeds.detach().cpu().numpy() + vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath) vis.save() # Overwrite the latest version of the model diff --git a/SV2TTS/encoder/ui/visualizations.py b/SV2TTS/encoder/ui/visualizations.py index 7a1a906..317ab0b 100644 --- a/SV2TTS/encoder/ui/visualizations.py +++ b/SV2TTS/encoder/ui/visualizations.py @@ -45,7 +45,7 @@ class Visualizations: except ConnectionError: raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to " "start it.") - webbrowser.open("http://localhost:8097/env/" + self.env_name) + # webbrowser.open("http://localhost:8097/env/" + self.env_name) self.loss_win = None self.eer_win = None -- GitLab