提交 8077b500 编写于 作者: C Corentin Jemine

Great speedup on the computation of the loss

上级 71256622
......@@ -80,18 +80,17 @@ class SpeakerEncoder(nn.Module):
centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True)
# Similarity matrix
sim_matrix = torch.zeros(speakers_per_batch * utterances_per_speaker,
sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
speakers_per_batch).to(self.loss_device)
for j in range(speakers_per_batch):
for i in range(utterances_per_speaker):
ji = j * utterances_per_speaker + i
for k in range(speakers_per_batch):
centroid = centroids_excl[j, i] if j == k else centroids_incl[k]
# The cosine similarity is the dot product when vectors are normalized
sim_matrix[ji, k] = torch.dot(embeds[j, i], centroid)
for k in range(speakers_per_batch):
centroid = centroids_excl[j] if j == k else centroids_incl[k]
sim_matrix[j, :, k] = (embeds[j] * centroid).sum(dim=1)
sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
# Loss
sim_matrix = sim_matrix.view((speakers_per_batch * utterances_per_speaker,
speakers_per_batch))
ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
target = torch.from_numpy(ground_truth).long().to(self.loss_device)
loss = self.loss_fn(sim_matrix, target)
......
......@@ -61,10 +61,13 @@ def train(run_id: str, clean_data_root: Path, models_dir: Path, vis_every: int,
# Forward pass
inputs = torch.from_numpy(speaker_batch.data).to(device)
torch.cuda.synchronize(device)
profiler.tick("Data to %s" % device)
embeds = model(inputs).to(loss_device)
embeds = model(inputs)
torch.cuda.synchronize(device)
profiler.tick("Forward pass")
loss, eer = model.loss(embeds.view((speakers_per_batch, utterances_per_speaker, -1)))
embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
loss, eer = model.loss(embeds_loss)
profiler.tick("Loss")
# Backward pass
......@@ -84,8 +87,8 @@ def train(run_id: str, clean_data_root: Path, models_dir: Path, vis_every: int,
print("Drawing and saving projections (step %d)" % step)
backup_dir.mkdir(exist_ok=True)
projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
embeds_numpy = embeds.detach().numpy()
vis.draw_projections(embeds_numpy, utterances_per_speaker, step, projection_fpath)
embeds = embeds.detach().cpu().numpy()
vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
vis.save()
# Overwrite the latest version of the model
......
......@@ -45,7 +45,7 @@ class Visualizations:
except ConnectionError:
raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
"start it.")
webbrowser.open("http://localhost:8097/env/" + self.env_name)
# webbrowser.open("http://localhost:8097/env/" + self.env_name)
self.loss_win = None
self.eer_win = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册