未验证 提交 da10c5cf 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] fix softmax_with_cross_entropy_op, test=develop (#31629)

上级 75433126
......@@ -452,12 +452,7 @@ struct HardLabelCrossEntropyFunctorWithIgnoreIdx {
// labels, loss view as [n, remain]
int idx_lbl = idx_n * remain + idx_remain;
if (idx_axis == ignore_idx_) {
loss_[idx_lbl] = 0;
return;
}
if (idx_axis == labels_[idx_lbl]) {
if (idx_axis == labels_[idx_lbl] && idx_axis != ignore_idx_) {
loss_[idx_lbl] = -log_on_device(logits_data_[idx]);
}
}
......@@ -732,7 +727,7 @@ static void SoftmaxWithCrossEntropyFusedKernel(
template <typename T>
static void CrossEntropyFusedKernel(const T* logits_data, const T* labels_data,
T* loss_data, int n, int d, int axis_dim,
cudaStream_t stream) {
gpuStream_t stream) {
constexpr int kMaxBlockDim = 512;
int block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
......@@ -792,11 +787,11 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
auto* softmax_out_data = softmax_out->mutable_data<T>(context.GetPlace());
auto* loss_data = loss->mutable_data<T>(context.GetPlace());
if (axis_dim == 1) {
math::SetConstant<platform::CUDADeviceContext, T> set_constant;
set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
if (axis_dim == 1) {
set_constant(context.cuda_device_context(), softmax_out,
static_cast<T>(1));
set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
return;
}
......
......@@ -116,7 +116,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D(
self.shape = [13, 8]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax"
......@@ -129,7 +129,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_1D(
self.shape = [13, 8]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax"
......@@ -145,7 +145,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D(
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax"
......@@ -155,7 +155,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis2(
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 1
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
......@@ -168,7 +168,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis3(
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 2
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
......@@ -181,7 +181,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis4(
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 3
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
......@@ -206,7 +206,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D(
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax"
......@@ -216,7 +216,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis2(
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 1
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
......@@ -229,7 +229,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis3(
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 2
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
......@@ -242,7 +242,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis4(
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 3
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
......@@ -267,7 +267,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore(
self.shape = [13, 8]
self.axis = -1
self.ignore_index = 2
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax"
......@@ -280,7 +280,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore_Axis(
self.shape = [13, 8]
self.axis = 1
self.ignore_index = 2
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax"
......@@ -293,7 +293,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore(
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = 2
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax"
......@@ -303,7 +303,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore_Axis3(
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 2
self.ignore_index = 2
self.shape = [3, 5, 7, 11]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册