提交 95d07675 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix kldiv input

上级 e45ff483
...@@ -79,7 +79,7 @@ class DMLLoss(nn.Layer): ...@@ -79,7 +79,7 @@ class DMLLoss(nn.Layer):
log_out2 = paddle.log(out2) log_out2 = paddle.log(out2)
loss = (F.kl_div( loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div( log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, log_out1, reduction='batchmean')) / 2.0 log_out2, out1, reduction='batchmean')) / 2.0
return loss return loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册