未验证 提交 0cae5c7f 编写于 作者: L limingshu 提交者: GitHub

Optimization for StackGradCUDAKernel for last dimension stack case. (#48992)

* add stack grad kernel optimization

* add basic optimization kernel for stack_grad_kernel

* optimization of stack_grad_kernel for last dim stack and change code format with pre-commit
上级 05df6973
...@@ -46,6 +46,9 @@ namespace phi { ...@@ -46,6 +46,9 @@ namespace phi {
namespace backends { namespace backends {
namespace gpu { namespace gpu {
// Limitation of the setting in one dimension of cuda grid.
constexpr int kMultiDimslimit = 65536;
template <typename T = int64_t> template <typename T = int64_t>
inline T DivUp(T a, T b) { inline T DivUp(T a, T b) {
return (a + b - 1) / b; return (a + b - 1) / b;
......
...@@ -13,15 +13,13 @@ ...@@ -13,15 +13,13 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/stack_grad_kernel.h" #include "paddle/phi/kernels/stack_grad_kernel.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
template <typename T, typename IntType> template <typename T, typename IndexT>
__global__ void UnStackHelperCUDAKernel(const T* __restrict__ input, __global__ void UnStackHelperCUDAKernel(const T* __restrict__ input,
int pre_dim_size, int pre_dim_size,
int split_dim_size, int split_dim_size,
...@@ -33,101 +31,152 @@ __global__ void UnStackHelperCUDAKernel(const T* __restrict__ input, ...@@ -33,101 +31,152 @@ __global__ void UnStackHelperCUDAKernel(const T* __restrict__ input,
// In this case they are equal // In this case they are equal
assert(split_dim_size % num_split == 0); assert(split_dim_size % num_split == 0);
IntType size = pre_dim_size * split_dim_size * suf_dim_size; IndexT size = pre_dim_size * split_dim_size * suf_dim_size;
IntType each_dim_size = split_dim_size / num_split; IndexT each_dim_size = split_dim_size / num_split;
for (IntType offset = blockIdx.x * blockDim.x + threadIdx.x; offset < size; for (IndexT offset = blockIdx.x * blockDim.x + threadIdx.x; offset < size;
offset += blockDim.x * gridDim.x) { offset += blockDim.x * gridDim.x) {
IntType i = offset / (split_dim_size * suf_dim_size); IndexT i = offset / (split_dim_size * suf_dim_size);
IntType j = (offset % (split_dim_size * suf_dim_size)) / suf_dim_size; IndexT j = (offset % (split_dim_size * suf_dim_size)) / suf_dim_size;
IntType k = offset % suf_dim_size; IndexT k = offset % suf_dim_size;
T* output = output_ptrs[j / each_dim_size]; T* output = output_ptrs[j / each_dim_size];
if (output == nullptr) { if (output == nullptr) {
return; return;
} }
IntType output_ind = i * each_dim_size * suf_dim_size + IndexT output_ind = i * each_dim_size * suf_dim_size +
(j % each_dim_size) * suf_dim_size + k; (j % each_dim_size) * suf_dim_size + k;
*(output + output_ind) = input[offset]; *(output + output_ind) = input[offset];
} }
} }
template <typename T, typename Context> template <typename T, typename IndexT>
void StackGradKernel(const Context& dev_ctx, __global__ void StackGradKernelForLastDim(const T* __restrict__ in_data,
const DenseTensor& out, const IndexT cols,
int axis, const IndexT rows,
std::vector<DenseTensor*> x_grad) { const IndexT tile_x_num,
if (axis < 0) axis += out.dims().size(); T** out_datas) {
constexpr int buffer_size = 512;
int n = out.dims()[axis]; __shared__ T s_buf[buffer_size];
PADDLE_ENFORCE_EQ(n,
x_grad.size(), for (IndexT tile_x = blockIdx.x; tile_x < tile_x_num; tile_x += gridDim.x) {
phi::errors::InvalidArgument( IndexT row_idx = tile_x * blockDim.x + threadIdx.x;
"Output x_grad size should be equal to n, but" IndexT col_idx = blockIdx.y * blockDim.y + threadIdx.y;
" received n is:%d x_grad size is:%d.", int s_idx = threadIdx.y * blockDim.x + threadIdx.x;
n, bool is_valid = (col_idx < cols && row_idx < rows);
x_grad.size()));
if (is_valid) {
// x_grad is output, so save each data address, then copy each dy into dx_data T data = in_data[row_idx * cols + col_idx];
std::vector<T*> outputs(n); s_buf[s_idx] = data;
for (size_t j = 0; j < x_grad.size(); ++j) {
if (x_grad[j] == nullptr) {
outputs[j] = nullptr;
continue;
} }
if (x_grad[j]->numel() != 0UL) { __syncthreads();
T* ptr = dev_ctx.template Alloc<T>(x_grad[j]); if (is_valid) {
outputs[j] = ptr; if (out_datas[col_idx] != nullptr) {
} else { out_datas[col_idx][row_idx] = s_buf[s_idx];
outputs[j] = nullptr; }
} }
} }
auto dy_data = out.data<T>(); }
// each x_grad should have same shape
int dy_pre = 1, dy_suf = 1; template <typename Context, typename T, typename IndexT>
auto dy_dims = out.dims(); void LaunchStackGradCUDAKernel(const Context& ctx,
int split_dim = n; const DenseTensor& out,
for (int i = 0; i < axis; ++i) { std::vector<DenseTensor*>* x_grad_ptr,
dy_pre *= dy_dims[i]; const int axis,
const int64_t dy_pre) {
auto x_grad = *x_grad_ptr;
int out_num = out.dims()[axis];
PADDLE_ENFORCE_EQ(
out_num,
x_grad.size(),
phi::errors::InvalidArgument(
"Output x_grad size shall be equal to output num, but output num "
"received in stack_grad op is:%d, and x_grad size is:%d.",
out_num,
x_grad.size()));
std::vector<T*> outputs(out_num);
for (size_t j = 0; j < out_num; ++j) {
if (x_grad[j] == nullptr || x_grad[j]->numel() == 0UL) {
outputs[j] = nullptr;
} else {
outputs[j] = ctx.template Alloc<T>(x_grad[j]);
}
} }
dy_suf = out.numel() / (split_dim * dy_pre);
auto tmp_out_data = paddle::memory::Alloc( auto tmp_out_data = paddle::memory::Alloc(
dev_ctx.GetPlace(), ctx.GetPlace(),
outputs.size() * sizeof(T*), out_num * sizeof(T*),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
paddle::memory::Copy(dev_ctx.GetPlace(), paddle::memory::Copy(ctx.GetPlace(),
tmp_out_data->ptr(), tmp_out_data->ptr(),
phi::CPUPlace(), phi::CPUPlace(),
reinterpret_cast<void*>(outputs.data()), reinterpret_cast<void*>(outputs.data()),
outputs.size() * sizeof(T*), out_num * sizeof(T*),
dev_ctx.stream()); ctx.stream());
auto config = phi::backends::gpu::GetGpuLaunchConfig1D( if (axis == (out.dims().size() - 1)) {
dev_ctx, dy_pre * split_dim * dy_suf); constexpr int kThreads = 512;
constexpr int kWarpSize = 32;
if (out.numel() < std::numeric_limits<int32_t>::max()) { constexpr int kMaxOut = 16;
UnStackHelperCUDAKernel<T, int32_t> int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1;
<<<config.block_per_grid.x, bool is_small_num = out_num < kMaxOut;
config.thread_per_block.x,
0, if (is_small_num) {
dev_ctx.stream()>>>(dy_data, tid_y = out_num;
dy_pre, tid_x =
split_dim, std::min(backends::gpu::RoundToNextHighPowOfTwo(dy_pre, kWarpSize),
dy_suf, kThreads / backends::gpu::RoundToNextHighPowOfTwo(tid_y));
split_dim, } else {
reinterpret_cast<T**>(tmp_out_data->ptr())); tid_y = kMaxOut;
tid_x = kWarpSize;
bid_y = backends::gpu::DivUp<int>(out_num, kMaxOut);
}
int tile_x_num = backends::gpu::DivUp<int>(dy_pre, tid_x);
bid_x = std::min(tile_x_num, backends::gpu::kMultiDimslimit);
dim3 blocks(tid_x, tid_y, 1);
dim3 grids(bid_x, bid_y, 1);
StackGradKernelForLastDim<T, IndexT><<<grids, blocks, 0, ctx.stream()>>>(
out.data<T>(),
out_num,
dy_pre,
tile_x_num,
reinterpret_cast<T**>(tmp_out_data->ptr()));
} else {
int dy_suf = out.numel() / (out_num * dy_pre);
auto config =
backends::gpu::GetGpuLaunchConfig1D(ctx, dy_pre * out_num * dy_suf);
UnStackHelperCUDAKernel<T, IndexT>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
out.data<T>(),
dy_pre,
out_num,
dy_suf,
out_num,
reinterpret_cast<T**>(tmp_out_data->ptr()));
}
}
template <typename T, typename Context>
void StackGradKernel(const Context& dev_ctx,
const DenseTensor& out,
int axis,
std::vector<DenseTensor*> x_grad) {
const auto& dy_dims = out.dims();
int actual_axis = axis < 0 ? axis + dy_dims.size() : axis;
bool use_int32 = out.numel() < std::numeric_limits<int32_t>::max();
int64_t dy_pre = 1;
for (int i = 0; i < actual_axis; ++i) {
dy_pre *= dy_dims[i];
}
if (use_int32) {
LaunchStackGradCUDAKernel<Context, T, int32_t>(
dev_ctx, out, &x_grad, actual_axis, dy_pre);
} else { } else {
UnStackHelperCUDAKernel<T, int64_t> LaunchStackGradCUDAKernel<Context, T, int64_t>(
<<<config.block_per_grid.x, dev_ctx, out, &x_grad, actual_axis, dy_pre);
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(dy_data,
dy_pre,
split_dim,
dy_suf,
split_dim,
reinterpret_cast<T**>(tmp_out_data->ptr()));
} }
} }
......
...@@ -13,9 +13,7 @@ ...@@ -13,9 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/stack_kernel.h" #include "paddle/phi/kernels/stack_kernel.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/fast_divmod.h" #include "paddle/phi/kernels/funcs/fast_divmod.h"
...@@ -135,7 +133,7 @@ void LaunchStackCUDAKernelWithIndexType( ...@@ -135,7 +133,7 @@ void LaunchStackCUDAKernelWithIndexType(
} break; } break;
#define IMPL_STACK_CUDA_KERNEL_HELPER(...) \ #define IMPL_STACK_CUDA_KERNEL_HELPER(...) \
IMPL_STACK_CUDA_KERNEL_CASE(2, ##__VA_ARGS__); \ IMPL_STACK_CUDA_KERNEL_CASE(4, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(8, ##__VA_ARGS__); \ IMPL_STACK_CUDA_KERNEL_CASE(8, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(16, ##__VA_ARGS__); \ IMPL_STACK_CUDA_KERNEL_CASE(16, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(32, ##__VA_ARGS__); \ IMPL_STACK_CUDA_KERNEL_CASE(32, ##__VA_ARGS__); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册