未验证 提交 16c36465 编写于 作者: L lzydev 提交者: GitHub

Fix bug in cross_entropy in static mode (#52771)

* fix bug in cross_entropy in static mode

* fix ci-coverage
上级 84d34ddd
......@@ -886,7 +886,6 @@ void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits,
auto logits_dims = logits.dims();
auto labels_dims = label.dims();
auto logits_rank = logits_dims.size();
auto labels_rank = labels_dims.size();
PADDLE_ENFORCE_GE(axis,
-logits_rank,
phi::errors::InvalidArgument(
......@@ -919,12 +918,6 @@ void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits,
"when not in numeric_stable_mode."));
}
PADDLE_ENFORCE_EQ(
(logits_rank - 1 != labels_rank) && (logits_rank != labels_rank),
false,
phi::errors::InvalidArgument("Expected input_dims - 1 == label_dims "
"or input_dims == label_dims."));
if (soft_label) {
if (config.is_runtime || (logits_dims[axis] > 0 && labels_dims[axis] > 0)) {
PADDLE_ENFORCE_EQ(logits_dims[axis],
......
......@@ -441,6 +441,21 @@ class TestCrossEntropyOpError(unittest.TestCase):
self.assertRaises(TypeError, test_dtype)
def test_input_dims():
with paddle_static_guard():
# "input_dims - 1 != label_dims and input_dims != label_dims" must be false.
x3 = paddle.static.data(
name='x3', shape=[-1, 3, 4, 5], dtype="int32"
)
lab3 = paddle.static.data(
name='lab3', shape=[-1, 3, 4, 5, 6], dtype="int32"
)
paddle.nn.functional.cross_entropy(
x3, lab3, reduction='none', use_softmax=False
)
self.assertRaises(ValueError, test_input_dims)
if __name__ == "__main__":
unittest.main()
......@@ -15,11 +15,11 @@
import unittest
import numpy as np
from eager_op_test import OpTest, convert_float_to_uint16
from eager_op_test import OpTest, convert_float_to_uint16, paddle_static_guard
from test_softmax_op import stable_softmax
import paddle
from paddle.fluid import core
from paddle.fluid import Program, core, program_guard
def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1):
......@@ -974,6 +974,35 @@ class TestSoftmaxWithCrossEntropyOpBF16(TestSoftmaxWithCrossEntropyOp):
)
class TestSoftmaxWithCrossEntropyOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_input_dims1():
with paddle_static_guard():
# the input dims of cross_entropy can't be 0,
x1 = paddle.static.data(name='x1', shape=[], dtype="int32")
lab1 = paddle.static.data(
name='lab1', shape=[-1, 3, 4, 5, 6], dtype="int32"
)
paddle.nn.functional.softmax_with_cross_entropy(x1, lab1)
self.assertRaises(ValueError, test_input_dims1)
def test_input_dims2():
with paddle_static_guard():
# "input_dims - 1 != label_dims and input_dims != label_dims" must be false.
x2 = paddle.static.data(
name='x2', shape=[-1, 3, 4, 5], dtype="int32"
)
lab2 = paddle.static.data(
name='lab2', shape=[-1, 3, 4, 5, 6], dtype="int32"
)
paddle.nn.functional.softmax_with_cross_entropy(x2, lab2)
self.assertRaises(ValueError, test_input_dims2)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -253,6 +253,20 @@ def fluid_softmax_with_cross_entropy(
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1.15328646])
"""
input_dims = len(list(logits.shape))
if input_dims == 0:
raise ValueError('The dimention of input should be larger than zero!')
label_dims = len(list(label.shape))
if input_dims - 1 != label_dims and input_dims != label_dims:
raise ValueError(
'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
(got nput_dims{}, label_dims{})'.format(
input_dims, label_dims
)
)
if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=axis)
if in_dygraph_mode():
if core.is_compiled_with_custom_device("npu"):
if not soft_label:
......@@ -2700,6 +2714,14 @@ def cross_entropy(
if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=axis)
if input_dims - 1 != label_dims and input_dims != label_dims:
raise ValueError(
'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
(got nput_dims{}, label_dims{})'.format(
input_dims, label_dims
)
)
if in_dygraph_mode():
if not soft_label:
valid_label = (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册