未验证 提交 081e4307 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize perf of softmax_with_cross_entropy_bwd (#40643)

* Optimize perf of softmax_with_cross_entropy_bwd

* fix

* fix
上级 1904572a
......@@ -760,8 +760,9 @@ static void SoftmaxWithCrossEntropyHardLabel(
*/
template <typename T, typename LabelT>
__global__ void SoftmaxWithCrossEntropyGradHardLabel(
T* logits_grad, const T* loss_grad, const LabelT* labels, const int64_t n,
const int64_t dim, const int64_t d, const int ignore_index) {
T* logits_grad, const T* loss_grad, const T* softmax, const LabelT* labels,
const int64_t n, const int64_t dim, const int64_t d,
const int ignore_index) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx_n = idx / (d * dim);
int64_t idx_dim = (idx / d) % dim;
......@@ -773,10 +774,9 @@ __global__ void SoftmaxWithCrossEntropyGradHardLabel(
if (lbl == ignore_index) {
logits_grad[idx] = static_cast<T>(0.0);
} else if (lbl == idx_dim) {
logits_grad[idx] =
(logits_grad[idx] - static_cast<T>(1.0)) * loss_grad[ids];
logits_grad[idx] = (softmax[idx] - static_cast<T>(1.0)) * loss_grad[ids];
} else {
logits_grad[idx] *= loss_grad[ids];
logits_grad[idx] = softmax[idx] * loss_grad[ids];
}
}
}
......@@ -1395,11 +1395,20 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits"));
const Tensor* softmax = context.Input<Tensor>("Softmax");
if (logit_grad != softmax) {
auto stream = context.cuda_device_context().stream();
auto ignore_index = context.Attr<int>("ignore_index");
auto use_softmax = context.Attr<bool>("use_softmax");
T* logit_grad_data = nullptr;
bool copy_flag = (logit_grad != softmax && (!use_softmax || soft_label));
if (copy_flag) {
framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), logit_grad);
logit_grad_data = logit_grad->template data<T>();
} else {
logit_grad_data =
logit_grad->template mutable_data<T>(context.GetPlace());
}
T* logit_grad_data = logit_grad->template data<T>();
const int rank = logit_grad->dims().size();
const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
......@@ -1414,9 +1423,6 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
#else
int block = 512;
#endif
auto stream = context.cuda_device_context().stream();
auto ignore_index = context.Attr<int>("ignore_index");
auto use_softmax = context.Attr<bool>("use_softmax");
// do not with softmax op, and input is softmax
if (!use_softmax) {
......@@ -1451,11 +1457,12 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d, remain);
} else {
const T* softmax_data = softmax->template data<T>();
const auto* label_data = labels.template data<LabelT>();
int grid = (n * d + block - 1) / block;
SoftmaxWithCrossEntropyGradHardLabel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d / remain, remain,
ignore_index);
logit_grad_data, loss_grad_data, softmax_data, label_data, n,
d / remain, remain, ignore_index);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册