diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index e7682d78a14a1ecef09ba97d64125cd268fa86e5..7282c0695086a0a1f85a48004b40be9153ebf6a5 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -14,6 +14,8 @@ limitations under the License. */ #include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" + namespace phi { void BilinearTensorProductGradInferMeta(const MetaTensor& x, @@ -103,6 +105,69 @@ void Conv2dTransposeDoubleGradInferMeta(const MetaTensor& x, } } +void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, + const MetaTensor& softmax, + const MetaTensor& loss_grad, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + MetaTensor* logits_grad, + MetaConfig config) { + auto softmax_dims = softmax.dims(); + auto labels_dims = label.dims(); + auto softmax_rank = softmax_dims.size(); + PADDLE_ENFORCE_GE(axis, + -softmax_rank, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(Logits).")); + PADDLE_ENFORCE_LT(axis, + softmax_rank, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(Logits).")); + + axis = phi::funcs::CanonicalAxis(axis, softmax_rank); + for (int i = 0; i < softmax_rank; i++) { + if (i != axis) { + if (config.is_runtime || (softmax_dims[i] > 0 && labels_dims[i] > 0)) { + PADDLE_ENFORCE_EQ( + softmax_dims[i], + labels_dims[i], + phi::errors::InvalidArgument( + "Input(Logits) and Input(Label) should in same shape in " + "dimensions except axis.")); + } + } + } + + if (soft_label) { + if (config.is_runtime || + (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) { + PADDLE_ENFORCE_EQ(softmax_dims[axis], + labels_dims[axis], + phi::errors::InvalidArgument( + "If Attr(soft_label) == true, " + "the axis dimension of " + "Input(X) and Input(Label) should be equal.")); + } + } else { + if (config.is_runtime || labels_dims[axis] > 0) { + PADDLE_ENFORCE_EQ( + labels_dims[axis], + 1UL, + phi::errors::InvalidArgument("If Attr(soft_label) == false, " + "the axis dimension of " + "Input(Label) should be 1.")); + } + } + + logits_grad->set_dims(softmax.dims()); + logits_grad->set_dtype(softmax.dtype()); +} + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 4cdc048b24964796bfd2ebeace0474e8e10f31b5..92266811de0576a420e476a6752fa2771c8e7823 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -68,6 +68,17 @@ void Conv2dTransposeDoubleGradInferMeta(const MetaTensor& x, MetaTensor* dfilter, MetaTensor* ddout); +void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, + const MetaTensor& softmax, + const MetaTensor& loss_grad, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + MetaTensor* logits_grad, + MetaConfig config = MetaConfig()); + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 60db5d342b8b371653317a1cb0736f56f57af293..298ad14f9e04b66147932d4e1960f6e3bb58c45c 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/common_shape.h" namespace phi { @@ -753,6 +754,82 @@ void CrossInferMeta(const MetaTensor& x, out->share_lod(x); } +void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits, + const MetaTensor& label, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + MetaTensor* softmax, + MetaTensor* loss, + MetaConfig config) { + auto logits_dims = logits.dims(); + auto labels_dims = label.dims(); + auto logits_rank = logits_dims.size(); + PADDLE_ENFORCE_GE(axis, + -logits_rank, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(Logits).")); + PADDLE_ENFORCE_LT(axis, + logits_rank, + phi::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(Logits).")); + + axis = phi::funcs::CanonicalAxis(axis, logits_rank); + for (int i = 0; i < logits_rank; i++) { + if (i != axis) { + if (config.is_runtime || (logits_dims[i] > 0 && labels_dims[i] > 0)) { + PADDLE_ENFORCE_EQ(logits_dims[i], + labels_dims[i], + phi::errors::InvalidArgument( + "Input(Logits) and Input(Label) should in " + "same shape in dimensions except axis.")); + } + } + } + + if (axis != logits_rank - 1) { + PADDLE_ENFORCE_EQ( + numeric_stable_mode, + true, + phi::errors::InvalidArgument("Attr(axis) can only be -1 " + "when not in numeric_stable_mode.")); + } + + if (soft_label) { + if (config.is_runtime || (logits_dims[axis] > 0 && labels_dims[axis] > 0)) { + PADDLE_ENFORCE_EQ(logits_dims[axis], + labels_dims[axis], + phi::errors::InvalidArgument( + "If Attr(soft_label) == true, " + "the axis dimension of " + "Input(X) and Input(Label) should be equal.")); + } + } else { + if (config.is_runtime || labels_dims[axis] > 0) { + PADDLE_ENFORCE_EQ( + labels_dims[axis], + 1UL, + phi::errors::InvalidArgument("If Attr(soft_label) == false, " + "the axis dimension of " + "Input(Label) should be 1.")); + } + } + + softmax->set_dims(logits_dims); + softmax->set_dtype(logits.dtype()); + + logits_dims[axis] = 1; + loss->set_dims(logits_dims); + loss->set_dtype(logits.dtype()); + + softmax->share_lod(logits); + loss->share_lod(logits); +} + void DistInferMeta(const MetaTensor& x, const MetaTensor& y, float p, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 296c05756f29138e2c2209e389faf9a705e55e98..70c3c9dfe849dee15242674d70af95d1932f9e02 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -117,6 +117,17 @@ void CrossInferMeta(const MetaTensor& x, int axis, MetaTensor* out); +void CrossEntropyWithSoftmaxInferMeta(const MetaTensor& logits, + const MetaTensor& label, + bool soft_label, + bool use_softmax, + bool numeric_stable_mode, + int ignore_index, + int axis, + MetaTensor* softmax, + MetaTensor* loss, + MetaConfig config = MetaConfig()); + void DistInferMeta(const MetaTensor& x, const MetaTensor& y, float p, diff --git a/python/paddle/fluid/layers/loss.py b/python/paddle/fluid/layers/loss.py index a1cebc2f369bdab23e17bd0a06f16836d04acf1c..1efcbe4ee88712462c2feb725bcd00c7e648376c 100644 --- a/python/paddle/fluid/layers/loss.py +++ b/python/paddle/fluid/layers/loss.py @@ -21,7 +21,7 @@ from paddle.utils import deprecated from . import nn from .layer_function_generator import templatedoc from ..layer_helper import LayerHelper -from ..framework import Variable, _non_static_mode, static_only, _in_legacy_dygraph +from ..framework import Variable, _non_static_mode, static_only, _in_legacy_dygraph, in_dygraph_mode from .. import core from ..data_feeder import check_variable_and_dtype, check_type from ..param_attr import ParamAttr @@ -1267,10 +1267,15 @@ def softmax_with_cross_entropy(logits, ignore_index, 'numeric_stable_mode', numeric_stable_mode, 'axis', axis) else: - softmax, loss = _C_ops.softmax_with_cross_entropy( - logits, label, 'soft_label', soft_label, 'ignore_index', - ignore_index, 'numeric_stable_mode', numeric_stable_mode, - 'axis', axis) + if in_dygraph_mode(): + softmax, loss = _C_ops.final_state_cross_entropy_with_softmax( + logits, label, soft_label, True, numeric_stable_mode, + ignore_index, axis) + if _in_legacy_dygraph(): + softmax, loss = _C_ops.softmax_with_cross_entropy( + logits, label, 'soft_label', soft_label, 'ignore_index', + ignore_index, 'numeric_stable_mode', numeric_stable_mode, + 'axis', axis) if not return_softmax: return loss else: diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 4a771990d91e10f6b7013aa201c85fd6e4a9f3ef..81849606370d60a5cbead53d87498db61f90d763 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -969,6 +969,7 @@ set_tests_properties(test_nearest_interp_op PROPERTIES TIMEOUT 120) set_tests_properties(test_profiler PROPERTIES TIMEOUT 120) set_tests_properties(test_inplace_softmax_with_cross_entropy PROPERTIES TIMEOUT 120) set_tests_properties(test_cross_entropy2_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_cross_entropy_loss PROPERTIES TIMEOUT 150) set_tests_properties(test_fetch_unmerged PROPERTIES TIMEOUT 120) set_tests_properties(test_gru_unit_op PROPERTIES TIMEOUT 120) set_tests_properties(test_activation_nn_grad PROPERTIES TIMEOUT 200) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index d3ed76e34a614d6ea3b9d9454a330cf619635224..4402d875a41f67cf098dc6ba8758658f20668add 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -21,6 +21,7 @@ import unittest from test_softmax_op import stable_softmax from test_softmax_with_cross_entropy_op import cross_entropy from paddle.fluid import Program, program_guard +from paddle.fluid.framework import _test_eager_guard def log_softmax(x, axis=-1): @@ -1447,6 +1448,43 @@ class CrossEntropyLoss(unittest.TestCase): self.assertTrue(np.allclose(static_ret, expected)) self.assertTrue(np.allclose(dy_ret_value, expected)) + def test_soft_1d_dygraph_final_state_api(self): + with _test_eager_guard(): + self.test_cross_entropy_loss_soft_1d() + self.test_cross_entropy_loss_soft_1d_weight() + self.test_cross_entropy_loss_soft_1d_mean() + self.test_cross_entropy_loss_soft_1d_weight_mean() + + # put all testcases in one test will be failed + def test_soft_2d_dygraph_final_state_api(self): + with _test_eager_guard(): + self.test_cross_entropy_loss_soft_2d() + self.test_cross_entropy_loss_soft_2d_weight_mean() + + def test_other_dygraph_final_state_api(self): + with _test_eager_guard(): + self.test_cross_entropy_loss_1d_with_mean_ignore() + self.test_cross_entropy_loss_1d_with_mean_ignore_negative() + self.test_cross_entropy_loss_1d_with_weight_mean_ignore() + self.test_cross_entropy_loss_1d_with_weight_mean_ignore_exceedlabel( + ) + self.test_cross_entropy_loss_1d_with_weight_mean() + self.test_cross_entropy_loss_1d_with_weight_sum() + self.test_cross_entropy_loss_1d_with_weight_none() + self.test_cross_entropy_loss_1d_with_weight_none_func() + self.test_cross_entropy_loss_1d_mean() + self.test_cross_entropy_loss_1d_sum() + self.test_cross_entropy_loss_1d_none() + self.test_cross_entropy_loss_2d_with_weight_none() + self.test_cross_entropy_loss_2d_with_weight_axis_change_mean() + self.test_cross_entropy_loss_2d_with_weight_mean_ignore_exceedlabel( + ) + self.test_cross_entropy_loss_2d_with_weight_mean() + self.test_cross_entropy_loss_2d_with_weight_sum() + self.test_cross_entropy_loss_2d_none() + self.test_cross_entropy_loss_2d_mean() + self.test_cross_entropy_loss_2d_sum() + class TestCrossEntropyFAPIError(unittest.TestCase): def test_errors(self): diff --git a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py index 69f6a87dd9ed18909741fa6da0bd9acb6b6dd831..75d09e3df0c30361f53c7645ec6448afac00c99a 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py @@ -26,7 +26,6 @@ from test_softmax_op import stable_softmax def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1): if soft_label: return (-label * np.log(softmax)).sum(axis=axis, keepdims=True) - shape = softmax.shape axis %= len(shape) n = int(np.prod(shape[:axis])) @@ -43,6 +42,41 @@ def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1): return result.reshape(label.shape) +def python_api(logits, + label, + soft_label=False, + use_softmax=True, + numeric_stable_mode=True, + ignore_index=-100, + axis=-1): + # here only can test paddle.nn.functional.softmax_with_cross_entropy, + # the paddle.nn.functional.cross_entropy contains other math ops + return paddle.nn.functional.softmax_with_cross_entropy( + logits, + label, + soft_label=soft_label, + ignore_index=ignore_index, + numeric_stable_mode=numeric_stable_mode, + return_softmax=use_softmax, + axis=axis) + + +def python_core_api_without_softmax(logits, + label, + soft_label=False, + use_softmax=False, + numeric_stable_mode=True, + ignore_index=-100, + axis=-1): + # the API paddle.nn.functional.softmax_with_cross_entropy cannot + # set use_softmax=False, so add a core api manually + assert use_softmax is False + _, loss = paddle._C_ops.final_state_cross_entropy_with_softmax( + logits, label, soft_label, use_softmax, numeric_stable_mode, + ignore_index, axis) + return loss + + class TestSoftmaxWithCrossEntropyOp(OpTest): """ Test softmax with cross entropy operator with discreate one-hot labels. @@ -50,6 +84,8 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = False self.soft_label = False # explicilty use float32 for ROCm, as MIOpen does not yet support float64 @@ -102,13 +138,27 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): self.attrs['axis'] = self.axis def test_check_output(self): + if self.python_api is not None: + self.check_output(check_eager=True) self.check_output() def test_check_grad(self): if core.is_compiled_with_rocm(): + if self.python_api is not None: + self.check_grad( + ["Logits"], + "Loss", + max_relative_error=5e-1, + check_eager=True) # HIP will have accuracy fail when using float32 in CPU place self.check_grad(["Logits"], "Loss", max_relative_error=5e-1) else: + if self.python_api is not None: + self.check_grad( + ["Logits"], + "Loss", + numeric_grad_delta=0.001, + check_eager=True) self.check_grad(["Logits"], "Loss", numeric_grad_delta=0.001) @@ -136,6 +186,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = True self.shape = [13, 8] @@ -149,6 +201,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_1D( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = False self.shape = [13, 8] @@ -165,6 +219,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = True self.shape = [3, 5, 7, 11] @@ -178,6 +234,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis2( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = True self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -191,6 +249,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis3( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = True self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -204,6 +264,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis4( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = True self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -226,6 +288,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -239,6 +303,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis2( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -252,6 +318,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis3( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -265,6 +333,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis4( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -287,6 +357,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = False self.soft_label = False self.shape = [13, 8] @@ -300,6 +372,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore_Axis( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = False self.soft_label = False self.shape = [13, 8] @@ -313,6 +387,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -326,6 +402,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore_Axis3( TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_core_api_without_softmax + self.python_out_sig = ["Loss"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -343,6 +421,8 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore_Axis3( class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -357,6 +437,8 @@ class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp): class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = False self.soft_label = False self.shape = [3, 5, 7, 11] @@ -394,9 +476,14 @@ class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp): self.attrs['axis'] = self.axis def test_check_output(self): + if self.python_api is not None: + self.check_output(atol=1e-2, check_eager=True) self.check_output(atol=1e-2) def test_check_grad(self): + if self.python_api is not None: + self.check_grad( + ["Logits"], "Loss", max_relative_error=0.1, check_eager=True) self.check_grad(["Logits"], "Loss", max_relative_error=0.1) @@ -404,6 +491,8 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16( TestSoftmaxWithCrossEntropyOpFp16): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -412,6 +501,9 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16( self.dtype = np.float16 def test_check_grad(self): + if self.python_api is not None: + self.check_grad( + ["Logits"], "Loss", max_relative_error=0.1, check_eager=True) self.check_grad(["Logits"], "Loss", max_relative_error=0.1) @@ -422,6 +514,8 @@ class TestSoftmaxWithCrossEntropyOp2(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = True self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -431,13 +525,23 @@ class TestSoftmaxWithCrossEntropyOp2(TestSoftmaxWithCrossEntropyOp): self.use_softmax = True def test_check_output(self): + if self.python_api is not None: + self.check_output(check_eager=True) self.check_output() def test_check_grad(self): if core.is_compiled_with_rocm(): # HIP will have accuracy fail when using float32 in CPU place + if self.python_api is not None: + self.check_grad( + ["Logits"], + "Loss", + max_relative_error=0.1, + check_eager=True) self.check_grad(["Logits"], "Loss", max_relative_error=0.1) else: + if self.python_api is not None: + self.check_grad(["Logits"], "Loss", check_eager=True) self.check_grad(["Logits"], "Loss") @@ -448,6 +552,8 @@ class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = False self.soft_label = False self.shape = [41, 37] @@ -460,6 +566,8 @@ class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp): class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -477,6 +585,8 @@ class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -494,6 +604,8 @@ class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -511,6 +623,8 @@ class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -528,6 +642,8 @@ class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -546,6 +662,8 @@ class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne( def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64 @@ -559,6 +677,8 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1( TestSoftmaxWithCrossEntropyOpNoCudnnFp16): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -572,6 +692,8 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis2( TestSoftmaxWithCrossEntropyOpNoCudnnFp16): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -585,6 +707,8 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis3( TestSoftmaxWithCrossEntropyOpNoCudnnFp16): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -598,6 +722,8 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1( TestSoftmaxWithCrossEntropyOp2): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = True self.shape = [3, 5, 7, 11] @@ -611,6 +737,8 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2( TestSoftmaxWithCrossEntropyOp2): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = True self.shape = [3, 5, 7, 11] @@ -624,6 +752,8 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3( TestSoftmaxWithCrossEntropyOp2): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = True self.shape = [3, 5, 7, 11] @@ -637,6 +767,8 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4( TestSoftmaxWithCrossEntropyOp2): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = True self.shape = [3, 5, 7, 11] @@ -650,6 +782,8 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1( TestSoftmaxWithCrossEntropyOp3): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -663,6 +797,8 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2( TestSoftmaxWithCrossEntropyOp3): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -676,6 +812,8 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3( TestSoftmaxWithCrossEntropyOp3): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -689,6 +827,8 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4( TestSoftmaxWithCrossEntropyOp3): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -706,6 +846,8 @@ class TestSoftmaxWithCrossEntropyOpBoundary0(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] @@ -724,6 +866,8 @@ class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp): def initParams(self): self.op_type = "softmax_with_cross_entropy" + self.python_api = python_api + self.python_out_sig = ["Loss", "Softmax"] self.numeric_stable_mode = True self.soft_label = False self.shape = [3, 5, 7, 11] diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 3748a5904ba9614cafafcddc557175ac07f43018..8a2b5cbb8b334590ad05140bfd40f5c54d752697 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1700,7 +1700,8 @@ def cross_entropy(input, (got nput_dims{}, label_dims{})'.format(input_dims, label_dims)) if input_dims - 1 == label_dims: label = paddle.unsqueeze(label, axis=axis) - if in_dynamic_mode(): + + if _non_static_mode(): if soft_label == False: valid_label = paddle.cast( label != ignore_index, dtype=label.dtype) * label @@ -1718,10 +1719,15 @@ def cross_entropy(input, ignore_index, 'numeric_stable_mode', True, 'axis', axis, 'use_softmax', use_softmax) else: - _, out = _C_ops.softmax_with_cross_entropy( - input, label, 'soft_label', soft_label, 'ignore_index', - ignore_index, 'numeric_stable_mode', True, 'axis', axis, - 'use_softmax', use_softmax) + if in_dygraph_mode(): + _, out = _C_ops.final_state_cross_entropy_with_softmax( + input, label, soft_label, use_softmax, True, ignore_index, + axis) + if _in_legacy_dygraph(): + _, out = _C_ops.softmax_with_cross_entropy( + input, label, 'soft_label', soft_label, 'ignore_index', + ignore_index, 'numeric_stable_mode', True, 'axis', axis, + 'use_softmax', use_softmax) if weight is not None: diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index b137399b71c88054b9c426c8d73bc178707adb3b..af4e7a5b3bb32175f7e504d7318c9a483ab91a97 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -382,6 +382,17 @@ func : cross backward : cross_grad +# Part of python API paddle.nn.functional.cross_entropy +- api : cross_entropy_with_softmax + args : (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) + output : Tensor(softmax), Tensor(loss) + infer_meta : + func : CrossEntropyWithSoftmaxInferMeta + kernel : + func : cross_entropy_with_softmax + data_type : input + backward : cross_entropy_with_softmax_grad + - api : cumprod args : (Tensor x, int dim) output : Tensor(out) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index c6951fa8fc1d4b9f25d48a28adba07e3c08fca25..f94d0a9e50523b1e35179a151fdb15d1e98110e9 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -223,6 +223,16 @@ kernel : func : cosh_grad +- backward_api : cross_entropy_with_softmax_grad + forward : cross_entropy_with_softmax (Tensor input, Tensor label, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) -> Tensor(softmax), Tensor(loss) + args : (Tensor label, Tensor softmax, Tensor loss_grad, bool soft_label, bool use_softmax, bool numeric_stable_mode, int ignore_index, int axis) + output : Tensor(input_grad) + infer_meta : + func : CrossEntropyWithSoftmaxGradInferMeta + kernel : + func : cross_entropy_with_softmax_grad + data_type : softmax + - backward_api : cross_grad forward : cross (Tensor x, Tensor y, int axis = 9) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad, int axis)