未验证 提交 180877e9 编写于 作者: G GaoWei8 提交者: GitHub

Softmax backward optimize (#30249)

* softmax backward optimize
上级 342d62de
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle {
namespace platform {
......@@ -39,6 +40,13 @@ using Tensor = framework::Tensor;
out_data, x->data<T>(), N, dim, dim); \
break;
#define LAUNCH_SOFTMAX_WARP_BACKWARD(Log2Elements) \
case Log2Elements: \
softmax_warp_backward<T, float, Log2Elements><<< \
blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
dx_data, mul_grad.data<T>(), out->data<T>(), N, dim, dim); \
break;
static inline int SizeOutAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis + 1; i < dims.size(); i++) {
......@@ -199,6 +207,83 @@ __global__ void WarpSoftmaxForward(T* dst, const T* src, const int batch_size,
}
}
template <typename T, typename AccT, int Log2Elements>
__global__ void softmax_warp_backward(T* gradInput, const T* grad,
const T* output, int batch_size,
int stride, int element_count) {
constexpr int next_power_of_two = 1 << Log2Elements;
constexpr int warp_size_softmax =
(next_power_of_two < 32) ? next_power_of_two : 32;
constexpr int WARP_ITERATIONS = next_power_of_two / warp_size_softmax;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH) {
local_batches = WARP_BATCH;
}
int local_idx = threadIdx.x % warp_size_softmax;
int thread_offset = first_batch * stride + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
AccT grad_reg[WARP_BATCH][WARP_ITERATIONS];
AccT output_reg[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * warp_size_softmax;
if (element_index < batch_element_count) {
grad_reg[i][it] =
static_cast<AccT>(grad[i * element_count + it * warp_size_softmax]);
output_reg[i][it] = static_cast<AccT>(
output[i * element_count + it * warp_size_softmax]);
} else {
grad_reg[i][it] = AccT(0);
output_reg[i][it] = AccT(0);
}
}
}
AccT sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce_sum<AccT, WARP_BATCH, warp_size_softmax>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * warp_size_softmax;
if (element_index < element_count) {
// compute gradients
gradInput[i * element_count + it * warp_size_softmax] =
(grad_reg[i][it] - output_reg[i][it] * sum[i]);
}
}
}
}
template <typename T>
__global__ void MultiplyCUDAKernel(T* C, const T* A, const T* B, int N) {
CUDA_KERNEL_LOOP(i, N) {
C[i] = static_cast<T>(static_cast<float>(A[i]) * static_cast<float>(B[i]));
}
}
template <typename T, int VPT, int WARP_PER_BLOCK>
__global__ void VecSoftmaxBackward(T* dst, const T* grad, const T* src,
const int batch_size,
......@@ -340,28 +425,74 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
constexpr bool warp_softmax_available =
std::is_same<T, float>::value ||
std::is_same<T, platform::float16>::value;
if (D == 1 && dim == 128 && N % warps_per_block == 0 &&
warp_softmax_available) {
if (std::is_same<T, float>::value) {
VecSoftmaxBackward<
float, 4,
warps_per_block><<<N / warps_per_block, warps_per_block * WARP_SIZE,
0, ctx.cuda_device_context().stream()>>>(
dx->data<float>(), dout->data<float>(), out->data<float>(), N, dim);
} else if (std::is_same<T, platform::float16>::value) {
VecSoftmaxBackward<
platform::float16, 4,
warps_per_block><<<N / warps_per_block, warps_per_block * WARP_SIZE,
0, ctx.cuda_device_context().stream()>>>(
dx->data<platform::float16>(), dout->data<platform::float16>(),
out->data<platform::float16>(), N, dim);
} else {
PADDLE_ENFORCE_EQ(
warp_softmax_available, true,
platform::errors::Unimplemented(
"Warp softmax backward is only available for fp32 and fp16"));
bool optimize = false;
if (D == 1 && warp_softmax_available) {
if (dim == 128 && N % warps_per_block == 0) {
optimize = true;
if (std::is_same<T, float>::value) {
VecSoftmaxBackward<float, 4, warps_per_block><<<
N / warps_per_block, warps_per_block * WARP_SIZE, 0,
ctx.cuda_device_context().stream()>>>(dx->data<float>(),
dout->data<float>(),
out->data<float>(), N, dim);
} else if (std::is_same<T, platform::float16>::value) {
VecSoftmaxBackward<platform::float16, 4, warps_per_block><<<
N / warps_per_block, warps_per_block * WARP_SIZE, 0,
ctx.cuda_device_context().stream()>>>(
dx->data<platform::float16>(), dout->data<platform::float16>(),
out->data<platform::float16>(), N, dim);
} else {
PADDLE_ENFORCE_EQ(
warp_softmax_available, true,
platform::errors::Unimplemented(
"Warp softmax backward is only available for fp32 and fp16"));
}
} else if (dim < 40 && dim % 32 != 0) {
optimize = true;
Tensor mul_grad;
int numel = N * dim;
mul_grad.mutable_data<T>({numel}, ctx.GetPlace());
auto stream = ctx.cuda_device_context().stream();
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto config = GetGpuLaunchConfig1D(dev_ctx, numel);
MultiplyCUDAKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x, 0, stream>>>(
mul_grad.data<T>(), dout->data<T>(), out->data<T>(), numel);
int log2_elements = log2_ceil(dim);
const int next_power_of_two = 1 << log2_elements;
int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
switch (log2_elements) {
LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1
LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2
LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4
LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8
LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16
LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32
LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64
LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128
LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256
LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512
default:
break;
}
}
} else {
}
if (!optimize) {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册