未验证 提交 2ea15fc9 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of softmax_bwd when axis!=-1 (#38609)

* Optimize performance of softmax_bwd when axis!=-1

* fix

* fix

* fix

* fix
上级 a1174973
......@@ -584,6 +584,43 @@ __global__ void NormalSoftmaxForward(T* output, const T* input, int high_dim,
}
}
template <typename T, typename AccT,
template <typename, typename> class Functor>
__global__ void NormalSoftmaxBackward(T* input_grad, const T* output_grad,
const T* output, int high_dim,
int mid_dim, int low_dim) {
using kMode = kps::details::ReduceMode;
const int high_stride = mid_dim * low_dim;
const int mid_stride = low_dim;
for (int high_id = blockIdx.y; high_id < high_dim; high_id += gridDim.y) {
for (int low_id = blockIdx.x * blockDim.x + threadIdx.x; low_id < low_dim;
low_id += blockDim.x * gridDim.x) {
const int grad_offset = high_id * high_stride + low_id;
// 1. reduce sum
AccT sum = 0;
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
int data_offset = grad_offset + mid_id * mid_stride;
sum += static_cast<AccT>(output_grad[data_offset]) *
static_cast<AccT>(output[data_offset]);
}
if (blockDim.y > 1) {
kps::Reduce<AccT, 1, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>(
&sum, &sum, kps::AddFunctor<AccT>(), false);
}
// 2. (log)softmax backward
Functor<AccT, T> functor(sum);
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
int data_offset = grad_offset + mid_id * mid_stride;
input_grad[data_offset] =
functor(static_cast<AccT>(output_grad[data_offset]),
static_cast<AccT>(output[data_offset]));
}
}
}
}
template <typename T, bool LogMode = false>
void LaunchNormalSoftmaxForward(const platform::CUDADeviceContext& dev_ctx,
T* output_data, const T* input_data,
......@@ -603,6 +640,28 @@ void LaunchNormalSoftmaxForward(const platform::CUDADeviceContext& dev_ctx,
}
}
template <typename T, bool LogMode = false>
void LaunchNormalSoftmaxBackward(const platform::CUDADeviceContext& dev_ctx,
T* input_grad_data, const T* output_grad_data,
const T* output_data, int high_dim,
int mid_dim, int low_dim) {
using AccT = typename details::MPTypeTrait<T>::Type;
dim3 grid, block;
GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
if (LogMode) {
NormalSoftmaxBackward<
T, AccT,
LogSoftmaxBackwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
input_grad_data, output_grad_data, output_data, high_dim, mid_dim,
low_dim);
} else {
NormalSoftmaxBackward<
T, AccT, SoftmaxBackwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
input_grad_data, output_grad_data, output_data, high_dim, mid_dim,
low_dim);
}
}
template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const Tensor& x, const int input_axis,
......@@ -741,6 +800,9 @@ void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
blocks, threads, dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N,
dim, dim, kDimLog2);
}
} else if (D > 1) {
LaunchNormalSoftmaxBackward<T, LogMode>(dev_ctx, dx_data, dout.data<T>(),
out.data<T>(), N, dim, D);
} else {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册