diff --git a/SV2TTS/encoder/model.py b/SV2TTS/encoder/model.py index ee64d5d99872eb8b0f6ee9af92647615f3ed71d4..dad29115f9b8fc271d304bc2199f4c5c3ca65f3e 100644 --- a/SV2TTS/encoder/model.py +++ b/SV2TTS/encoder/model.py @@ -80,18 +80,17 @@ class SpeakerEncoder(nn.Module): centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True) # Similarity matrix - sim_matrix = torch.zeros(speakers_per_batch * utterances_per_speaker, + sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker, speakers_per_batch).to(self.loss_device) for j in range(speakers_per_batch): - for i in range(utterances_per_speaker): - ji = j * utterances_per_speaker + i - for k in range(speakers_per_batch): - centroid = centroids_excl[j, i] if j == k else centroids_incl[k] - # The cosine similarity is the dot product when vectors are normalized - sim_matrix[ji, k] = torch.dot(embeds[j, i], centroid) + for k in range(speakers_per_batch): + centroid = centroids_excl[j] if j == k else centroids_incl[k] + sim_matrix[j, :, k] = (embeds[j] * centroid).sum(dim=1) sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias # Loss + sim_matrix = sim_matrix.view((speakers_per_batch * utterances_per_speaker, + speakers_per_batch)) ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker) target = torch.from_numpy(ground_truth).long().to(self.loss_device) loss = self.loss_fn(sim_matrix, target) diff --git a/SV2TTS/encoder/train.py b/SV2TTS/encoder/train.py index 09ca269e6f1eb0d5b3fec15eeea811ad6377de7c..7f5efcc7c416d8111d58e93b3a6db52d7b178b6b 100644 --- a/SV2TTS/encoder/train.py +++ b/SV2TTS/encoder/train.py @@ -61,10 +61,13 @@ def train(run_id: str, clean_data_root: Path, models_dir: Path, vis_every: int, # Forward pass inputs = torch.from_numpy(speaker_batch.data).to(device) + torch.cuda.synchronize(device) profiler.tick("Data to %s" % device) - embeds = model(inputs).to(loss_device) + embeds = model(inputs) + torch.cuda.synchronize(device) profiler.tick("Forward pass") - loss, eer = model.loss(embeds.view((speakers_per_batch, utterances_per_speaker, -1))) + embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device) + loss, eer = model.loss(embeds_loss) profiler.tick("Loss") # Backward pass @@ -84,8 +87,8 @@ def train(run_id: str, clean_data_root: Path, models_dir: Path, vis_every: int, print("Drawing and saving projections (step %d)" % step) backup_dir.mkdir(exist_ok=True) projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step)) - embeds_numpy = embeds.detach().numpy() - vis.draw_projections(embeds_numpy, utterances_per_speaker, step, projection_fpath) + embeds = embeds.detach().cpu().numpy() + vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath) vis.save() # Overwrite the latest version of the model diff --git a/SV2TTS/encoder/ui/visualizations.py b/SV2TTS/encoder/ui/visualizations.py index 7a1a906573dbe04aece3b2742f6befeeb3d4934b..317ab0be1d217ccd448201003f47391c5d06cdf5 100644 --- a/SV2TTS/encoder/ui/visualizations.py +++ b/SV2TTS/encoder/ui/visualizations.py @@ -45,7 +45,7 @@ class Visualizations: except ConnectionError: raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to " "start it.") - webbrowser.open("http://localhost:8097/env/" + self.env_name) + # webbrowser.open("http://localhost:8097/env/" + self.env_name) self.loss_win = None self.eer_win = None