From fbf784dd3015ec1195152bbebd375a35de46ab29 Mon Sep 17 00:00:00 2001 From: XYZ916829 <1290573099@qq.com> Date: Tue, 31 Aug 2021 02:26:13 +0000 Subject: [PATCH] 1. optimize the error message of softmax_with_cross_entropy_op;2. add input value check for cross_entropy, if the dimention of input is zero, raise error. test = develop --- .../operators/softmax_with_cross_entropy_op.h | 6 +++--- python/paddle/nn/functional/loss.py | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.h b/paddle/fluid/operators/softmax_with_cross_entropy_op.h index b9eaa9bece8..e513b99b876 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.h +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.h @@ -61,7 +61,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { PADDLE_ENFORCE_GT( n, 0, platform::errors::InvalidArgument( "The size of axis should be larger than 0, but received " - "axis size is %d.", + "SizeToAxis of softmax is %d.", n)); const int d = SizeFromAxis(axis, softmax->dims()); @@ -110,7 +110,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { PADDLE_ENFORCE_GT( n, 0, platform::errors::InvalidArgument( "The size of axis should be larger than 0, but received " - "axis size is %d.", + "SizeToAxis of logits is %d.", n)); const int d = SizeFromAxis(axis, logits->dims()); @@ -162,7 +162,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { PADDLE_ENFORCE_GT( n, 0, platform::errors::InvalidArgument( "The size of axis should be larger than 0, but received " - "axis size is %d.", + "SizeToAxis of logit_grad is %d.", n)); const int d = SizeFromAxis(axis, logit_grad->dims()); diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index fdbc5b9c476..6fcd60ad113 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1644,6 +1644,9 @@ def cross_entropy(input, ignore_index) input_dims = len(list(input.shape)) + if input_dims == 0: + raise ValueError('The dimention of input should be larger than zero!') + label_dims = len(list(label.shape)) if input_dims - 1 != label_dims and input_dims != label_dims: raise ValueError( @@ -2016,3 +2019,20 @@ def sigmoid_focal_loss(logit, loss = paddle.sum(loss, name=name) return loss + + +if __name__ == "__main__": + input_arr = np.array([], dtype=np.float32) + input = paddle.to_tensor(np.reshape(input_arr, (0, 0)), dtype='float32') + + label = paddle.to_tensor([], dtype='float32') + + weight = paddle.to_tensor([], dtype='float32') + + result = cross_entropy( + input, + label, + weight=weight, + ignore_index=-100, + soft_label=False, + axis=-1) -- GitLab