未验证 提交 da10c5cf 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] fix softmax_with_cross_entropy_op, test=develop (#31629)

上级 75433126
...@@ -452,12 +452,7 @@ struct HardLabelCrossEntropyFunctorWithIgnoreIdx { ...@@ -452,12 +452,7 @@ struct HardLabelCrossEntropyFunctorWithIgnoreIdx {
// labels, loss view as [n, remain] // labels, loss view as [n, remain]
int idx_lbl = idx_n * remain + idx_remain; int idx_lbl = idx_n * remain + idx_remain;
if (idx_axis == ignore_idx_) { if (idx_axis == labels_[idx_lbl] && idx_axis != ignore_idx_) {
loss_[idx_lbl] = 0;
return;
}
if (idx_axis == labels_[idx_lbl]) {
loss_[idx_lbl] = -log_on_device(logits_data_[idx]); loss_[idx_lbl] = -log_on_device(logits_data_[idx]);
} }
} }
...@@ -732,7 +727,7 @@ static void SoftmaxWithCrossEntropyFusedKernel( ...@@ -732,7 +727,7 @@ static void SoftmaxWithCrossEntropyFusedKernel(
template <typename T> template <typename T>
static void CrossEntropyFusedKernel(const T* logits_data, const T* labels_data, static void CrossEntropyFusedKernel(const T* logits_data, const T* labels_data,
T* loss_data, int n, int d, int axis_dim, T* loss_data, int n, int d, int axis_dim,
cudaStream_t stream) { gpuStream_t stream) {
constexpr int kMaxBlockDim = 512; constexpr int kMaxBlockDim = 512;
int block_dim = axis_dim >= kMaxBlockDim int block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim ? kMaxBlockDim
...@@ -792,11 +787,11 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -792,11 +787,11 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
auto* softmax_out_data = softmax_out->mutable_data<T>(context.GetPlace()); auto* softmax_out_data = softmax_out->mutable_data<T>(context.GetPlace());
auto* loss_data = loss->mutable_data<T>(context.GetPlace()); auto* loss_data = loss->mutable_data<T>(context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> set_constant;
set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
if (axis_dim == 1) { if (axis_dim == 1) {
math::SetConstant<platform::CUDADeviceContext, T> set_constant;
set_constant(context.cuda_device_context(), softmax_out, set_constant(context.cuda_device_context(), softmax_out,
static_cast<T>(1)); static_cast<T>(1));
set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
return; return;
} }
......
...@@ -116,7 +116,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D( ...@@ -116,7 +116,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D(
self.shape = [13, 8] self.shape = [13, 8]
self.axis = -1 self.axis = -1
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax" self.softmax_switch = False #default is true, means "with softmax"
...@@ -129,7 +129,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_1D( ...@@ -129,7 +129,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_1D(
self.shape = [13, 8] self.shape = [13, 8]
self.axis = -1 self.axis = -1
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax" self.softmax_switch = False #default is true, means "with softmax"
...@@ -145,7 +145,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D( ...@@ -145,7 +145,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D(
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
self.axis = -1 self.axis = -1
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax" self.softmax_switch = False #default is true, means "with softmax"
...@@ -155,7 +155,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis2( ...@@ -155,7 +155,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis2(
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True self.numeric_stable_mode = True
self.soft_label = True self.soft_label = True
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 1 self.axis = 1
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
...@@ -168,7 +168,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis3( ...@@ -168,7 +168,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis3(
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True self.numeric_stable_mode = True
self.soft_label = True self.soft_label = True
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 2 self.axis = 2
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
...@@ -181,7 +181,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis4( ...@@ -181,7 +181,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis4(
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True self.numeric_stable_mode = True
self.soft_label = True self.soft_label = True
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 3 self.axis = 3
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
...@@ -206,7 +206,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D( ...@@ -206,7 +206,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D(
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
self.axis = -1 self.axis = -1
self.ignore_index = -1 self.ignore_index = -1
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax" self.softmax_switch = False #default is true, means "with softmax"
...@@ -216,7 +216,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis2( ...@@ -216,7 +216,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis2(
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True self.numeric_stable_mode = True
self.soft_label = False self.soft_label = False
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 1 self.axis = 1
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
...@@ -229,7 +229,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis3( ...@@ -229,7 +229,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis3(
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True self.numeric_stable_mode = True
self.soft_label = False self.soft_label = False
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 2 self.axis = 2
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
...@@ -242,7 +242,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis4( ...@@ -242,7 +242,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis4(
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True self.numeric_stable_mode = True
self.soft_label = False self.soft_label = False
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 3 self.axis = 3
self.ignore_index = -1 self.ignore_index = -1
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
...@@ -267,7 +267,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore( ...@@ -267,7 +267,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore(
self.shape = [13, 8] self.shape = [13, 8]
self.axis = -1 self.axis = -1
self.ignore_index = 2 self.ignore_index = 2
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax" self.softmax_switch = False #default is true, means "with softmax"
...@@ -280,7 +280,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore_Axis( ...@@ -280,7 +280,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore_Axis(
self.shape = [13, 8] self.shape = [13, 8]
self.axis = 1 self.axis = 1
self.ignore_index = 2 self.ignore_index = 2
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax" self.softmax_switch = False #default is true, means "with softmax"
...@@ -293,7 +293,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore( ...@@ -293,7 +293,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore(
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
self.axis = -1 self.axis = -1
self.ignore_index = 2 self.ignore_index = 2
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax" self.softmax_switch = False #default is true, means "with softmax"
...@@ -303,7 +303,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore_Axis3( ...@@ -303,7 +303,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore_Axis3(
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True self.numeric_stable_mode = True
self.soft_label = False self.soft_label = False
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 2 self.axis = 2
self.ignore_index = 2 self.ignore_index = 2
self.shape = [3, 5, 7, 11] self.shape = [3, 5, 7, 11]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册