diff --git a/contrib/HumanSeg/nets/seg_modules.py b/contrib/HumanSeg/nets/seg_modules.py index fb59dce486420585edd47559c6fdd3cf88e59350..902c903fcc426e68de483ee249af16d21788a86c 100644 --- a/contrib/HumanSeg/nets/seg_modules.py +++ b/contrib/HumanSeg/nets/seg_modules.py @@ -34,6 +34,7 @@ def softmax_with_loss(logit, loss, probs = fluid.layers.softmax_with_cross_entropy( logit, label, ignore_index=ignore_index, return_softmax=True) else: + label = fluid.layers.squeeze(label, axes=[-1]) label_one_hot = fluid.one_hot(input=label, depth=num_classes) if isinstance(weight, list): assert len( diff --git a/contrib/RemoteSensing/nets/loss.py b/contrib/RemoteSensing/nets/loss.py index fb59dce486420585edd47559c6fdd3cf88e59350..902c903fcc426e68de483ee249af16d21788a86c 100644 --- a/contrib/RemoteSensing/nets/loss.py +++ b/contrib/RemoteSensing/nets/loss.py @@ -34,6 +34,7 @@ def softmax_with_loss(logit, loss, probs = fluid.layers.softmax_with_cross_entropy( logit, label, ignore_index=ignore_index, return_softmax=True) else: + label = fluid.layers.squeeze(label, axes=[-1]) label_one_hot = fluid.one_hot(input=label, depth=num_classes) if isinstance(weight, list): assert len( diff --git a/pdseg/loss.py b/pdseg/loss.py index 92638a9caaa15d749a8dfd3abc7cc8dec550b7fe..bfd1c83f0a4673cc627c5296d586d9a6f2f31c40 100644 --- a/pdseg/loss.py +++ b/pdseg/loss.py @@ -40,6 +40,7 @@ def softmax_with_loss(logit, ignore_index=cfg.DATASET.IGNORE_INDEX, return_softmax=True) else: + label = fluid.layers.squeeze(label, axes=[-1]) label_one_hot = fluid.one_hot(input=label, depth=num_classes) if isinstance(weight, list): assert len(