From 05d2ef83bcb067e63e891dff57120e8b1e5f0402 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Wed, 15 Jul 2020 16:21:46 +0800 Subject: [PATCH] add transpose for loss computation --- dygraph/models/hrnet.py | 4 +++- dygraph/models/unet.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/dygraph/models/hrnet.py b/dygraph/models/hrnet.py index d5270a10..2dcf2dda 100644 --- a/dygraph/models/hrnet.py +++ b/dygraph/models/hrnet.py @@ -179,6 +179,8 @@ class HRNet(fluid.dygraph.Layer): return pred, score_map def _get_loss(self, logit, label): + logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) + label = fluid.layers.transpose(label, [0, 2, 3, 1]) mask = label != self.ignore_index mask = fluid.layers.cast(mask, 'float32') loss, probs = fluid.layers.softmax_with_cross_entropy( @@ -186,7 +188,7 @@ class HRNet(fluid.dygraph.Layer): label, ignore_index=self.ignore_index, return_softmax=True, - axis=1) + axis=-1) loss = loss * mask avg_loss = fluid.layers.mean(loss) / ( diff --git a/dygraph/models/unet.py b/dygraph/models/unet.py index 9bb92302..970936d0 100644 --- a/dygraph/models/unet.py +++ b/dygraph/models/unet.py @@ -43,6 +43,8 @@ class UNet(fluid.dygraph.Layer): return pred, score_map def _get_loss(self, logit, label): + logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) + label = fluid.layers.transpose(label, [0, 2, 3, 1]) mask = label != self.ignore_index mask = fluid.layers.cast(mask, 'float32') loss, probs = fluid.layers.softmax_with_cross_entropy( -- GitLab