未验证 提交 081e4307 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize perf of softmax_with_cross_entropy_bwd (#40643)

* Optimize perf of softmax_with_cross_entropy_bwd

* fix

* fix
上级 1904572a
...@@ -760,8 +760,9 @@ static void SoftmaxWithCrossEntropyHardLabel( ...@@ -760,8 +760,9 @@ static void SoftmaxWithCrossEntropyHardLabel(
*/ */
template <typename T, typename LabelT> template <typename T, typename LabelT>
__global__ void SoftmaxWithCrossEntropyGradHardLabel( __global__ void SoftmaxWithCrossEntropyGradHardLabel(
T* logits_grad, const T* loss_grad, const LabelT* labels, const int64_t n, T* logits_grad, const T* loss_grad, const T* softmax, const LabelT* labels,
const int64_t dim, const int64_t d, const int ignore_index) { const int64_t n, const int64_t dim, const int64_t d,
const int ignore_index) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx_n = idx / (d * dim); int64_t idx_n = idx / (d * dim);
int64_t idx_dim = (idx / d) % dim; int64_t idx_dim = (idx / d) % dim;
...@@ -773,10 +774,9 @@ __global__ void SoftmaxWithCrossEntropyGradHardLabel( ...@@ -773,10 +774,9 @@ __global__ void SoftmaxWithCrossEntropyGradHardLabel(
if (lbl == ignore_index) { if (lbl == ignore_index) {
logits_grad[idx] = static_cast<T>(0.0); logits_grad[idx] = static_cast<T>(0.0);
} else if (lbl == idx_dim) { } else if (lbl == idx_dim) {
logits_grad[idx] = logits_grad[idx] = (softmax[idx] - static_cast<T>(1.0)) * loss_grad[ids];
(logits_grad[idx] - static_cast<T>(1.0)) * loss_grad[ids];
} else { } else {
logits_grad[idx] *= loss_grad[ids]; logits_grad[idx] = softmax[idx] * loss_grad[ids];
} }
} }
} }
...@@ -1395,11 +1395,20 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -1395,11 +1395,20 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
Tensor* logit_grad = Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits")); context.Output<Tensor>(framework::GradVarName("Logits"));
const Tensor* softmax = context.Input<Tensor>("Softmax"); const Tensor* softmax = context.Input<Tensor>("Softmax");
if (logit_grad != softmax) { auto stream = context.cuda_device_context().stream();
auto ignore_index = context.Attr<int>("ignore_index");
auto use_softmax = context.Attr<bool>("use_softmax");
T* logit_grad_data = nullptr;
bool copy_flag = (logit_grad != softmax && (!use_softmax || soft_label));
if (copy_flag) {
framework::TensorCopy(*softmax, context.GetPlace(), framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), logit_grad); context.device_context(), logit_grad);
logit_grad_data = logit_grad->template data<T>();
} else {
logit_grad_data =
logit_grad->template mutable_data<T>(context.GetPlace());
} }
T* logit_grad_data = logit_grad->template data<T>();
const int rank = logit_grad->dims().size(); const int rank = logit_grad->dims().size();
const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
...@@ -1414,9 +1423,6 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -1414,9 +1423,6 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
#else #else
int block = 512; int block = 512;
#endif #endif
auto stream = context.cuda_device_context().stream();
auto ignore_index = context.Attr<int>("ignore_index");
auto use_softmax = context.Attr<bool>("use_softmax");
// do not with softmax op, and input is softmax // do not with softmax op, and input is softmax
if (!use_softmax) { if (!use_softmax) {
...@@ -1451,11 +1457,12 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -1451,11 +1457,12 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>( SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d, remain); logit_grad_data, loss_grad_data, label_data, n, d, remain);
} else { } else {
const T* softmax_data = softmax->template data<T>();
const auto* label_data = labels.template data<LabelT>(); const auto* label_data = labels.template data<LabelT>();
int grid = (n * d + block - 1) / block; int grid = (n * d + block - 1) / block;
SoftmaxWithCrossEntropyGradHardLabel<T><<<grid, block, 0, stream>>>( SoftmaxWithCrossEntropyGradHardLabel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d / remain, remain, logit_grad_data, loss_grad_data, softmax_data, label_data, n,
ignore_index); d / remain, remain, ignore_index);
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册