diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 98e0464134f05bff46b8962a13a01c8761d7c9b3..8f725be665b38ac75c6e0e427fef6403fe8cabd0 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -419,6 +419,12 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { const int axis = CanonicalAxis(context.Attr("axis"), rank); int axis_dim = logits->dims()[axis]; + const int n = SizeToAxis(axis, logits->dims()); + const int d = SizeFromAxis(axis, logits->dims()); + + auto* softmax_data = softmax->mutable_data(context.GetPlace()); + auto* loss_data = loss->mutable_data(context.GetPlace()); + if (axis_dim == 1) { math::SetConstant set_constant; set_constant(context.cuda_device_context(), softmax, static_cast(1)); @@ -426,12 +432,6 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { return; } - const int n = SizeToAxis(axis, logits->dims()); - const int d = SizeFromAxis(axis, logits->dims()); - - auto* softmax_data = softmax->mutable_data(context.GetPlace()); - auto* loss_data = loss->mutable_data(context.GetPlace()); - auto soft_label = context.Attr("soft_label"); auto ignore_index = context.Attr("ignore_index"); diff --git a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py index ce28253491c5050602bb768c4249f6eb2f75bd3c..df2a0a523ad1ef05b462c8f8044c83e76d91f8a3 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py @@ -280,6 +280,23 @@ class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp): self.shape = [3, 5, 7, 11] +class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne( + TestSoftmaxWithCrossEntropyOp): + """ + Test softmax with cross entropy operator with discreate one-hot labels. + Given axis != -1 + """ + + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.dtype = np.float64 + self.axis = -1 + self.ignore_index = -1 + self.shape = [3, 5, 7, 1] + + class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1( TestSoftmaxWithCrossEntropyOpNoCudnnFp16): def initParams(self):