diff --git a/paddlehub/module/cv_module.py b/paddlehub/module/cv_module.py index 3dacfd3c9ad9cd9b30e01cbc0f86e4d1b041a5d3..e49f173c82d25328e86aad894ea939a444dcdfd7 100644 --- a/paddlehub/module/cv_module.py +++ b/paddlehub/module/cv_module.py @@ -207,8 +207,8 @@ class ImageColorizeModule(RunModule, ImageServing): out_class, out_reg = self(img['A'], img['hint_B'], img['mask_B']) # loss - criterionCE = nn.loss.CrossEntropyLoss() - loss_ce = criterionCE(out_class, img['real_B_enc'][:, 0, :, :]) + loss_ce = F.cross_entropy(out_class, img['real_B_enc'][:, :1, :, :], axis=1) + loss_ce = paddle.mean(loss_ce) loss_G_L1_reg = paddle.sum(paddle.abs(img['B'] - out_reg), axis=1, keepdim=True) loss_G_L1_reg = paddle.mean(loss_G_L1_reg) loss = loss_ce + loss_G_L1_reg