From 081e4307b47d2c9e119b470d274886188a14391a Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Fri, 18 Mar 2022 09:40:54 +0800 Subject: [PATCH] Optimize perf of softmax_with_cross_entropy_bwd (#40643) * Optimize perf of softmax_with_cross_entropy_bwd * fix * fix --- .../softmax_with_cross_entropy_op.cu | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 19a395e723..41545a1ca2 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -760,8 +760,9 @@ static void SoftmaxWithCrossEntropyHardLabel( */ template __global__ void SoftmaxWithCrossEntropyGradHardLabel( - T* logits_grad, const T* loss_grad, const LabelT* labels, const int64_t n, - const int64_t dim, const int64_t d, const int ignore_index) { + T* logits_grad, const T* loss_grad, const T* softmax, const LabelT* labels, + const int64_t n, const int64_t dim, const int64_t d, + const int ignore_index) { int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; int64_t idx_n = idx / (d * dim); int64_t idx_dim = (idx / d) % dim; @@ -773,10 +774,9 @@ __global__ void SoftmaxWithCrossEntropyGradHardLabel( if (lbl == ignore_index) { logits_grad[idx] = static_cast(0.0); } else if (lbl == idx_dim) { - logits_grad[idx] = - (logits_grad[idx] - static_cast(1.0)) * loss_grad[ids]; + logits_grad[idx] = (softmax[idx] - static_cast(1.0)) * loss_grad[ids]; } else { - logits_grad[idx] *= loss_grad[ids]; + logits_grad[idx] = softmax[idx] * loss_grad[ids]; } } } @@ -1395,11 +1395,20 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { Tensor* logit_grad = context.Output(framework::GradVarName("Logits")); const Tensor* softmax = context.Input("Softmax"); - if (logit_grad != softmax) { + auto stream = context.cuda_device_context().stream(); + auto ignore_index = context.Attr("ignore_index"); + auto use_softmax = context.Attr("use_softmax"); + + T* logit_grad_data = nullptr; + bool copy_flag = (logit_grad != softmax && (!use_softmax || soft_label)); + if (copy_flag) { framework::TensorCopy(*softmax, context.GetPlace(), context.device_context(), logit_grad); + logit_grad_data = logit_grad->template data(); + } else { + logit_grad_data = + logit_grad->template mutable_data(context.GetPlace()); } - T* logit_grad_data = logit_grad->template data(); const int rank = logit_grad->dims().size(); const int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); @@ -1414,9 +1423,6 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { #else int block = 512; #endif - auto stream = context.cuda_device_context().stream(); - auto ignore_index = context.Attr("ignore_index"); - auto use_softmax = context.Attr("use_softmax"); // do not with softmax op, and input is softmax if (!use_softmax) { @@ -1451,11 +1457,12 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { SoftCrossEntropyGradientKernel<<>>( logit_grad_data, loss_grad_data, label_data, n, d, remain); } else { + const T* softmax_data = softmax->template data(); const auto* label_data = labels.template data(); int grid = (n * d + block - 1) / block; SoftmaxWithCrossEntropyGradHardLabel<<>>( - logit_grad_data, loss_grad_data, label_data, n, d / remain, remain, - ignore_index); + logit_grad_data, loss_grad_data, softmax_data, label_data, n, + d / remain, remain, ignore_index); } } }; -- GitLab