提交 fbf784dd 编写于 作者: XYZ_916's avatar XYZ_916 提交者: chajchaj

1. optimize the error message of softmax_with_cross_entropy_op;2. add input...

1. optimize the error message of softmax_with_cross_entropy_op;2. add input value check for cross_entropy, if the dimention of input is zero, raise error. test = develop
上级 23170e21
......@@ -61,7 +61,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"axis size is %d.",
"SizeToAxis of softmax is %d.",
n));
const int d = SizeFromAxis(axis, softmax->dims());
......@@ -110,7 +110,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"axis size is %d.",
"SizeToAxis of logits is %d.",
n));
const int d = SizeFromAxis(axis, logits->dims());
......@@ -162,7 +162,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument(
"The size of axis should be larger than 0, but received "
"axis size is %d.",
"SizeToAxis of logit_grad is %d.",
n));
const int d = SizeFromAxis(axis, logit_grad->dims());
......
......@@ -1644,6 +1644,9 @@ def cross_entropy(input,
ignore_index)
input_dims = len(list(input.shape))
if input_dims == 0:
raise ValueError('The dimention of input should be larger than zero!')
label_dims = len(list(label.shape))
if input_dims - 1 != label_dims and input_dims != label_dims:
raise ValueError(
......@@ -2016,3 +2019,20 @@ def sigmoid_focal_loss(logit,
loss = paddle.sum(loss, name=name)
return loss
if __name__ == "__main__":
input_arr = np.array([], dtype=np.float32)
input = paddle.to_tensor(np.reshape(input_arr, (0, 0)), dtype='float32')
label = paddle.to_tensor([], dtype='float32')
weight = paddle.to_tensor([], dtype='float32')
result = cross_entropy(
input,
label,
weight=weight,
ignore_index=-100,
soft_label=False,
axis=-1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册