未验证 提交 0a21924a 编写于 作者: G GaoWei8 提交者: GitHub

optimize softmax forward (#30217)

* optimize softmax forward
上级 af80859d
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h" #include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle { namespace paddle {
...@@ -31,6 +32,13 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; ...@@ -31,6 +32,13 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
#define LAUNCH_SOFTMAX_WARP_FORWARD(Log2Elements) \
case Log2Elements: \
WarpSoftmaxForward<T, float, Log2Elements><<< \
blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
out_data, x->data<T>(), N, dim, dim); \
break;
static inline int SizeOutAxis(const int axis, DDim dims) { static inline int SizeOutAxis(const int axis, DDim dims) {
int size = 1; int size = 1;
for (int i = axis + 1; i < dims.size(); i++) { for (int i = axis + 1; i < dims.size(); i++) {
...@@ -39,6 +47,12 @@ static inline int SizeOutAxis(const int axis, DDim dims) { ...@@ -39,6 +47,12 @@ static inline int SizeOutAxis(const int axis, DDim dims) {
return size; return size;
} }
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template <typename T, int VLEN> template <typename T, int VLEN>
union vec_t { union vec_t {
static_assert(sizeof(T) == -1, "vec_t is only available by specialization."); static_assert(sizeof(T) == -1, "vec_t is only available by specialization.");
...@@ -84,6 +98,107 @@ __global__ void VecSoftmaxForward(T* dst, const T* src, const int batch_size, ...@@ -84,6 +98,107 @@ __global__ void VecSoftmaxForward(T* dst, const T* src, const int batch_size,
reinterpret_cast<VECT*>(&dst[offset + idx])[0] = buf; reinterpret_cast<VECT*>(&dst[offset + idx])[0] = buf;
} }
template <typename T, int WARP_BATCH, int WARP_SIZE_SOFTMAX>
__device__ __forceinline__ void warp_reduce_sum(T* sum) {
#pragma unroll
for (int offset = WARP_SIZE_SOFTMAX / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
sum[i] = sum[i] + sum_val;
}
}
}
template <typename T, int WARP_BATCH, int WARP_SIZE_SOFTMAX>
__device__ __forceinline__ void warp_reduce_max(T* sum) {
#pragma unroll
for (int offset = WARP_SIZE_SOFTMAX / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
sum[i] = max(sum[i], max_val);
}
}
}
template <typename T, typename AccT, int Log2Elements>
__global__ void WarpSoftmaxForward(T* dst, const T* src, const int batch_size,
const int stride, const int element_count) {
constexpr int next_power_of_two = 1 << Log2Elements;
constexpr int warp_size_softmax =
(next_power_of_two < 32) ? next_power_of_two : 32;
constexpr int WARP_ITERATIONS = next_power_of_two / warp_size_softmax;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH) {
local_batches = WARP_BATCH;
}
int local_idx = threadIdx.x;
src += first_batch * stride + local_idx;
dst += first_batch * stride + local_idx;
// load data from global memory
AccT elements[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * warp_size_softmax;
if (element_index < batch_element_count) {
elements[i][it] =
static_cast<float>(src[i * element_count + it * warp_size_softmax]);
} else {
elements[i][it] = -std::numeric_limits<AccT>::infinity();
}
}
}
// compute max_value
AccT max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce_max<AccT, WARP_BATCH, warp_size_softmax>(max_value);
AccT sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = (std::exp((elements[i][it] - max_value[i])));
sum[i] += elements[i][it];
}
}
warp_reduce_sum<AccT, WARP_BATCH, warp_size_softmax>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * warp_size_softmax;
if (element_index < element_count) {
dst[i * element_count + it * warp_size_softmax] =
elements[i][it] / sum[i];
} else {
break;
}
}
}
}
template <typename T, int VPT, int WARP_PER_BLOCK> template <typename T, int VPT, int WARP_PER_BLOCK>
__global__ void VecSoftmaxBackward(T* dst, const T* grad, const T* src, __global__ void VecSoftmaxBackward(T* dst, const T* grad, const T* src,
const int batch_size, const int batch_size,
...@@ -130,26 +245,61 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> { ...@@ -130,26 +245,61 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
const int N = SizeToAxis(axis, dims); const int N = SizeToAxis(axis, dims);
const int D = SizeOutAxis(axis, dims); const int D = SizeOutAxis(axis, dims);
constexpr int max_dim = 320;
bool optimize = false;
constexpr int warps_per_block = 4; constexpr int warps_per_block = 4;
if (D == 1 && dim == 128 && N % warps_per_block == 0 && sizeof(T) <= 4) { if (D == 1 && dim <= max_dim && sizeof(T) <= 4) {
if (dim == 128 && N % warps_per_block == 0) {
optimize = true;
// a warp for a batch, 4 elements for a thread, only support the softmax // a warp for a batch, 4 elements for a thread, only support the softmax
// dim size = 128 currently // dim size = 128 currently
if (sizeof(T) == 2) { if (sizeof(T) == 2) {
VecSoftmaxForward< VecSoftmaxForward<T, int2, 4, warps_per_block><<<
T, int2, 4, N / warps_per_block, warps_per_block * WARP_SIZE, 0,
warps_per_block><<<N / warps_per_block, warps_per_block * WARP_SIZE, ctx.cuda_device_context().stream()>>>(out_data, x->data<T>(), N,
0, ctx.cuda_device_context().stream()>>>( dim);
out_data, x->data<T>(), N, dim);
} else if (sizeof(T) == 4) { } else if (sizeof(T) == 4) {
VecSoftmaxForward< VecSoftmaxForward<T, int4, 4, warps_per_block><<<
T, int4, 4, N / warps_per_block, warps_per_block * WARP_SIZE, 0,
warps_per_block><<<N / warps_per_block, warps_per_block * WARP_SIZE, ctx.cuda_device_context().stream()>>>(out_data, x->data<T>(), N,
0, ctx.cuda_device_context().stream()>>>( dim);
out_data, x->data<T>(), N, dim);
} else { } else {
assert(false && "not support"); assert(false && "not support");
} }
} else { } else if (dim < max_dim) {
optimize = true;
int log2_elements = static_cast<int>(log2_ceil(dim));
const int next_power_of_two = 1 << log2_elements;
int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
switch (log2_elements) {
LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1
LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2
LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4
LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8
LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16
LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32
LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64
LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128
LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256
LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512
default:
break;
}
}
}
if (!optimize) {
ScopedTensorDescriptor desc; ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1}; std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW; DataLayout layout = DataLayout::kNCHW;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册