From 2b941736f31bfd5fc6891d329df0e01c75928fe4 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 20 Dec 2019 14:27:07 +0800 Subject: [PATCH] fix softmax_with_cross_entropy_fix bug, test=develop (#21810) --- .../operators/softmax_with_cross_entropy_op.cu | 12 ++++++------ .../test_softmax_with_cross_entropy_op.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 98e0464134..8f725be665 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -419,6 +419,12 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { const int axis = CanonicalAxis(context.Attr("axis"), rank); int axis_dim = logits->dims()[axis]; + const int n = SizeToAxis(axis, logits->dims()); + const int d = SizeFromAxis(axis, logits->dims()); + + auto* softmax_data = softmax->mutable_data(context.GetPlace()); + auto* loss_data = loss->mutable_data(context.GetPlace()); + if (axis_dim == 1) { math::SetConstant set_constant; set_constant(context.cuda_device_context(), softmax, static_cast(1)); @@ -426,12 +432,6 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { return; } - const int n = SizeToAxis(axis, logits->dims()); - const int d = SizeFromAxis(axis, logits->dims()); - - auto* softmax_data = softmax->mutable_data(context.GetPlace()); - auto* loss_data = loss->mutable_data(context.GetPlace()); - auto soft_label = context.Attr("soft_label"); auto ignore_index = context.Attr("ignore_index"); diff --git a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py index ce28253491..df2a0a523a 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py @@ -280,6 +280,23 @@ class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp): self.shape = [3, 5, 7, 11] +class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne( + TestSoftmaxWithCrossEntropyOp): + """ + Test softmax with cross entropy operator with discreate one-hot labels. + Given axis != -1 + """ + + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.dtype = np.float64 + self.axis = -1 + self.ignore_index = -1 + self.shape = [3, 5, 7, 1] + + class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1( TestSoftmaxWithCrossEntropyOpNoCudnnFp16): def initParams(self): -- GitLab