未验证 提交 790e05dd 编写于 作者: H haoyuying 提交者: GitHub

fix cross entropy bug

上级 6bb3483f
......@@ -207,8 +207,8 @@ class ImageColorizeModule(RunModule, ImageServing):
out_class, out_reg = self(img['A'], img['hint_B'], img['mask_B'])
# loss
criterionCE = nn.loss.CrossEntropyLoss()
loss_ce = criterionCE(out_class, img['real_B_enc'][:, 0, :, :])
loss_ce = F.cross_entropy(out_class, img['real_B_enc'][:, :1, :, :], axis=1)
loss_ce = paddle.mean(loss_ce)
loss_G_L1_reg = paddle.sum(paddle.abs(img['B'] - out_reg), axis=1, keepdim=True)
loss_G_L1_reg = paddle.mean(loss_G_L1_reg)
loss = loss_ce + loss_G_L1_reg
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册