未验证 提交 c764710a 编写于 作者: 章宏彬 提交者: GitHub

[TTS]Avoid using variable "attn_loss" before assignment (#2860)

* Avoid using variable "attn_loss" before assignment

* Update tacotron2_updater.py

---------
Co-authored-by: 小湉湉's avatarTianYuan <white-sky@qq.com>
上级 a283f8a5
......@@ -113,16 +113,18 @@ class Tacotron2Updater(StandardUpdater):
loss.backward()
optimizer.step()
if self.use_guided_attn_loss:
report("train/attn_loss", float(attn_loss))
losses_dict["attn_loss"] = float(attn_loss)
report("train/l1_loss", float(l1_loss))
report("train/mse_loss", float(mse_loss))
report("train/bce_loss", float(bce_loss))
report("train/attn_loss", float(attn_loss))
report("train/loss", float(loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["mse_loss"] = float(mse_loss)
losses_dict["bce_loss"] = float(bce_loss)
losses_dict["attn_loss"] = float(attn_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
......@@ -202,17 +204,19 @@ class Tacotron2Evaluator(StandardEvaluator):
attn_loss = self.attn_loss(
att_ws=att_ws, ilens=batch["text_lengths"] + 1, olens=olens_in)
loss = loss + attn_loss
if self.use_guided_attn_loss:
report("eval/attn_loss", float(attn_loss))
losses_dict["attn_loss"] = float(attn_loss)
report("eval/l1_loss", float(l1_loss))
report("eval/mse_loss", float(mse_loss))
report("eval/bce_loss", float(bce_loss))
report("eval/attn_loss", float(attn_loss))
report("eval/loss", float(loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["mse_loss"] = float(mse_loss)
losses_dict["bce_loss"] = float(bce_loss)
losses_dict["attn_loss"] = float(attn_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册