diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 19b4698aca8a80fd4a0845dee1122cafb79c6b87..12b64052a7cd63be5bcd6be7c313111fb0727b5f 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include #include "paddle/fluid/operators/math/cross_entropy.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" #include "paddle/fluid/platform/for_range.h" @@ -309,12 +310,6 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { int ignore_idx_; }; -template -static __global__ void SetSoftmaxToOneWhenFeatureSizeIsOne(T* out, int n) { - auto idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < n) out[idx] = static_cast(1); -} - template static void HardLabelSoftmaxWithCrossEntropy( const platform::CUDADeviceContext& ctx, const T* logits_data, @@ -354,13 +349,6 @@ static void HardLabelSoftmaxWithCrossEntropy( CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8); CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4); CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2); - case 1: - SetSoftmaxToOneWhenFeatureSizeIsOne<<<(grid_dim + kMaxBlockDim - 1) / - kMaxBlockDim, - kMaxBlockDim, 0, stream>>>( - softmax_data, grid_dim); - cudaMemsetAsync(loss_data, 0, grid_dim * sizeof(T), stream); - break; default: PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op"); break; @@ -401,13 +389,6 @@ static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data, CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8); CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4); CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2); - case 1: - SetSoftmaxToOneWhenFeatureSizeIsOne<<<(grid_dim + kMaxBlockDim - 1) / - kMaxBlockDim, - kMaxBlockDim, 0, stream>>>( - softmax_data, n); - cudaMemsetAsync(loss_data, 0, grid_dim * sizeof(T), stream); - break; default: PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op"); break; @@ -431,6 +412,13 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { const int axis = CanonicalAxis(context.Attr("axis"), rank); int axis_dim = logits->dims()[axis]; + if (axis_dim == 1) { + math::SetConstant set_constant; + set_constant(context.cuda_device_context(), softmax, static_cast(1)); + set_constant(context.cuda_device_context(), loss, static_cast(0)); + return; + } + const int n = SizeToAxis(axis, logits->dims()); const int d = SizeFromAxis(axis, logits->dims());