From 790e05ddd3d88152a93a2f5f70bd42487c7df0c3 Mon Sep 17 00:00:00 2001 From: haoyuying <35907364+haoyuying@users.noreply.github.com> Date: Wed, 6 Jan 2021 19:49:42 +0800 Subject: [PATCH] fix cross entropy bug --- paddlehub/module/cv_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlehub/module/cv_module.py b/paddlehub/module/cv_module.py index 3dacfd3c..e49f173c 100644 --- a/paddlehub/module/cv_module.py +++ b/paddlehub/module/cv_module.py @@ -207,8 +207,8 @@ class ImageColorizeModule(RunModule, ImageServing): out_class, out_reg = self(img['A'], img['hint_B'], img['mask_B']) # loss - criterionCE = nn.loss.CrossEntropyLoss() - loss_ce = criterionCE(out_class, img['real_B_enc'][:, 0, :, :]) + loss_ce = F.cross_entropy(out_class, img['real_B_enc'][:, :1, :, :], axis=1) + loss_ce = paddle.mean(loss_ce) loss_G_L1_reg = paddle.sum(paddle.abs(img['B'] - out_reg), axis=1, keepdim=True) loss_G_L1_reg = paddle.mean(loss_G_L1_reg) loss = loss_ce + loss_G_L1_reg -- GitLab