未验证 提交 bc385a29 编写于 作者: C Chen Weihang 提交者: GitHub

fix softmax_with_cross_entropy_fix bug, test=develop (#21810) (#22183)

上级 515b206d
...@@ -415,6 +415,12 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -415,6 +415,12 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logits->dims()[axis]; int axis_dim = logits->dims()[axis];
const int n = SizeToAxis(axis, logits->dims());
const int d = SizeFromAxis(axis, logits->dims());
auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
auto* loss_data = loss->mutable_data<T>(context.GetPlace());
if (axis_dim == 1) { if (axis_dim == 1) {
math::SetConstant<platform::CUDADeviceContext, T> set_constant; math::SetConstant<platform::CUDADeviceContext, T> set_constant;
set_constant(context.cuda_device_context(), softmax, static_cast<T>(1)); set_constant(context.cuda_device_context(), softmax, static_cast<T>(1));
...@@ -422,12 +428,6 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -422,12 +428,6 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
return; return;
} }
const int n = SizeToAxis(axis, logits->dims());
const int d = SizeFromAxis(axis, logits->dims());
auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
auto* loss_data = loss->mutable_data<T>(context.GetPlace());
auto soft_label = context.Attr<bool>("soft_label"); auto soft_label = context.Attr<bool>("soft_label");
auto ignore_index = context.Attr<int>("ignore_index"); auto ignore_index = context.Attr<int>("ignore_index");
......
...@@ -280,6 +280,23 @@ class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp): ...@@ -280,6 +280,23 @@ class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp):
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne(
TestSoftmaxWithCrossEntropyOp):
"""
Test softmax with cross entropy operator with discreate one-hot labels.
Given axis != -1
"""
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.axis = -1
self.ignore_index = -1
self.shape = [3, 5, 7, 1]
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1( class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1(
TestSoftmaxWithCrossEntropyOpNoCudnnFp16): TestSoftmaxWithCrossEntropyOpNoCudnnFp16):
def initParams(self): def initParams(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册