diff --git a/paddle/phi/backends/gpu/gpu_launch_config.h b/paddle/phi/backends/gpu/gpu_launch_config.h index 05d97dc45a4cda93dfa77f294b34564df6e01fed..a7a7ad03ad66479d5d550e73590ff8dc65bb7176 100644 --- a/paddle/phi/backends/gpu/gpu_launch_config.h +++ b/paddle/phi/backends/gpu/gpu_launch_config.h @@ -46,6 +46,9 @@ namespace phi { namespace backends { namespace gpu { +// Limitation of the setting in one dimension of cuda grid. +constexpr int kMultiDimslimit = 65536; + template inline T DivUp(T a, T b) { return (a + b - 1) / b; diff --git a/paddle/phi/kernels/gpu/stack_grad_kernel.cu b/paddle/phi/kernels/gpu/stack_grad_kernel.cu index ea61be0abf1a56ca40c6c52d72b20357f6c2e9e7..ccbe2885334a8d50cd4b546094502391ccd8b96c 100644 --- a/paddle/phi/kernels/gpu/stack_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/stack_grad_kernel.cu @@ -13,15 +13,13 @@ // limitations under the License. #include "paddle/phi/kernels/stack_grad_kernel.h" - #include "paddle/fluid/memory/memory.h" -#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { -template +template __global__ void UnStackHelperCUDAKernel(const T* __restrict__ input, int pre_dim_size, int split_dim_size, @@ -33,101 +31,152 @@ __global__ void UnStackHelperCUDAKernel(const T* __restrict__ input, // In this case they are equal assert(split_dim_size % num_split == 0); - IntType size = pre_dim_size * split_dim_size * suf_dim_size; - IntType each_dim_size = split_dim_size / num_split; + IndexT size = pre_dim_size * split_dim_size * suf_dim_size; + IndexT each_dim_size = split_dim_size / num_split; - for (IntType offset = blockIdx.x * blockDim.x + threadIdx.x; offset < size; + for (IndexT offset = blockIdx.x * blockDim.x + threadIdx.x; offset < size; offset += blockDim.x * gridDim.x) { - IntType i = offset / (split_dim_size * suf_dim_size); - IntType j = (offset % (split_dim_size * suf_dim_size)) / suf_dim_size; - IntType k = offset % suf_dim_size; + IndexT i = offset / (split_dim_size * suf_dim_size); + IndexT j = (offset % (split_dim_size * suf_dim_size)) / suf_dim_size; + IndexT k = offset % suf_dim_size; T* output = output_ptrs[j / each_dim_size]; if (output == nullptr) { return; } - IntType output_ind = i * each_dim_size * suf_dim_size + - (j % each_dim_size) * suf_dim_size + k; + IndexT output_ind = i * each_dim_size * suf_dim_size + + (j % each_dim_size) * suf_dim_size + k; *(output + output_ind) = input[offset]; } } -template -void StackGradKernel(const Context& dev_ctx, - const DenseTensor& out, - int axis, - std::vector x_grad) { - if (axis < 0) axis += out.dims().size(); - - int n = out.dims()[axis]; - PADDLE_ENFORCE_EQ(n, - x_grad.size(), - phi::errors::InvalidArgument( - "Output x_grad size should be equal to n, but" - " received n is:%d x_grad size is:%d.", - n, - x_grad.size())); - - // x_grad is output, so save each data address, then copy each dy into dx_data - std::vector outputs(n); - for (size_t j = 0; j < x_grad.size(); ++j) { - if (x_grad[j] == nullptr) { - outputs[j] = nullptr; - continue; +template +__global__ void StackGradKernelForLastDim(const T* __restrict__ in_data, + const IndexT cols, + const IndexT rows, + const IndexT tile_x_num, + T** out_datas) { + constexpr int buffer_size = 512; + __shared__ T s_buf[buffer_size]; + + for (IndexT tile_x = blockIdx.x; tile_x < tile_x_num; tile_x += gridDim.x) { + IndexT row_idx = tile_x * blockDim.x + threadIdx.x; + IndexT col_idx = blockIdx.y * blockDim.y + threadIdx.y; + int s_idx = threadIdx.y * blockDim.x + threadIdx.x; + bool is_valid = (col_idx < cols && row_idx < rows); + + if (is_valid) { + T data = in_data[row_idx * cols + col_idx]; + s_buf[s_idx] = data; } - if (x_grad[j]->numel() != 0UL) { - T* ptr = dev_ctx.template Alloc(x_grad[j]); - outputs[j] = ptr; - } else { - outputs[j] = nullptr; + __syncthreads(); + if (is_valid) { + if (out_datas[col_idx] != nullptr) { + out_datas[col_idx][row_idx] = s_buf[s_idx]; + } } } - auto dy_data = out.data(); - // each x_grad should have same shape - int dy_pre = 1, dy_suf = 1; - auto dy_dims = out.dims(); - int split_dim = n; - for (int i = 0; i < axis; ++i) { - dy_pre *= dy_dims[i]; +} + +template +void LaunchStackGradCUDAKernel(const Context& ctx, + const DenseTensor& out, + std::vector* x_grad_ptr, + const int axis, + const int64_t dy_pre) { + auto x_grad = *x_grad_ptr; + int out_num = out.dims()[axis]; + PADDLE_ENFORCE_EQ( + out_num, + x_grad.size(), + phi::errors::InvalidArgument( + "Output x_grad size shall be equal to output num, but output num " + "received in stack_grad op is:%d, and x_grad size is:%d.", + out_num, + x_grad.size())); + std::vector outputs(out_num); + for (size_t j = 0; j < out_num; ++j) { + if (x_grad[j] == nullptr || x_grad[j]->numel() == 0UL) { + outputs[j] = nullptr; + } else { + outputs[j] = ctx.template Alloc(x_grad[j]); + } } - dy_suf = out.numel() / (split_dim * dy_pre); auto tmp_out_data = paddle::memory::Alloc( - dev_ctx.GetPlace(), - outputs.size() * sizeof(T*), - phi::Stream(reinterpret_cast(dev_ctx.stream()))); - paddle::memory::Copy(dev_ctx.GetPlace(), + ctx.GetPlace(), + out_num * sizeof(T*), + phi::Stream(reinterpret_cast(ctx.stream()))); + paddle::memory::Copy(ctx.GetPlace(), tmp_out_data->ptr(), phi::CPUPlace(), reinterpret_cast(outputs.data()), - outputs.size() * sizeof(T*), - dev_ctx.stream()); - - auto config = phi::backends::gpu::GetGpuLaunchConfig1D( - dev_ctx, dy_pre * split_dim * dy_suf); - - if (out.numel() < std::numeric_limits::max()) { - UnStackHelperCUDAKernel - <<>>(dy_data, - dy_pre, - split_dim, - dy_suf, - split_dim, - reinterpret_cast(tmp_out_data->ptr())); + out_num * sizeof(T*), + ctx.stream()); + + if (axis == (out.dims().size() - 1)) { + constexpr int kThreads = 512; + constexpr int kWarpSize = 32; + constexpr int kMaxOut = 16; + int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1; + bool is_small_num = out_num < kMaxOut; + + if (is_small_num) { + tid_y = out_num; + tid_x = + std::min(backends::gpu::RoundToNextHighPowOfTwo(dy_pre, kWarpSize), + kThreads / backends::gpu::RoundToNextHighPowOfTwo(tid_y)); + } else { + tid_y = kMaxOut; + tid_x = kWarpSize; + bid_y = backends::gpu::DivUp(out_num, kMaxOut); + } + int tile_x_num = backends::gpu::DivUp(dy_pre, tid_x); + bid_x = std::min(tile_x_num, backends::gpu::kMultiDimslimit); + dim3 blocks(tid_x, tid_y, 1); + dim3 grids(bid_x, bid_y, 1); + + StackGradKernelForLastDim<<>>( + out.data(), + out_num, + dy_pre, + tile_x_num, + reinterpret_cast(tmp_out_data->ptr())); + } else { + int dy_suf = out.numel() / (out_num * dy_pre); + auto config = + backends::gpu::GetGpuLaunchConfig1D(ctx, dy_pre * out_num * dy_suf); + + UnStackHelperCUDAKernel + <<>>( + out.data(), + dy_pre, + out_num, + dy_suf, + out_num, + reinterpret_cast(tmp_out_data->ptr())); + } +} + +template +void StackGradKernel(const Context& dev_ctx, + const DenseTensor& out, + int axis, + std::vector x_grad) { + const auto& dy_dims = out.dims(); + int actual_axis = axis < 0 ? axis + dy_dims.size() : axis; + bool use_int32 = out.numel() < std::numeric_limits::max(); + + int64_t dy_pre = 1; + for (int i = 0; i < actual_axis; ++i) { + dy_pre *= dy_dims[i]; + } + if (use_int32) { + LaunchStackGradCUDAKernel( + dev_ctx, out, &x_grad, actual_axis, dy_pre); } else { - UnStackHelperCUDAKernel - <<>>(dy_data, - dy_pre, - split_dim, - dy_suf, - split_dim, - reinterpret_cast(tmp_out_data->ptr())); + LaunchStackGradCUDAKernel( + dev_ctx, out, &x_grad, actual_axis, dy_pre); } } diff --git a/paddle/phi/kernels/gpu/stack_kernel.cu b/paddle/phi/kernels/gpu/stack_kernel.cu index d3f24095069555498325fc1bb445f16c77515eb1..92730d15f25a48f0eb884383ae73d955ca8afaac 100644 --- a/paddle/phi/kernels/gpu/stack_kernel.cu +++ b/paddle/phi/kernels/gpu/stack_kernel.cu @@ -13,9 +13,7 @@ // limitations under the License. #include "paddle/phi/kernels/stack_kernel.h" - #include "paddle/fluid/memory/memory.h" -#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/fast_divmod.h" @@ -135,7 +133,7 @@ void LaunchStackCUDAKernelWithIndexType( } break; #define IMPL_STACK_CUDA_KERNEL_HELPER(...) \ - IMPL_STACK_CUDA_KERNEL_CASE(2, ##__VA_ARGS__); \ + IMPL_STACK_CUDA_KERNEL_CASE(4, ##__VA_ARGS__); \ IMPL_STACK_CUDA_KERNEL_CASE(8, ##__VA_ARGS__); \ IMPL_STACK_CUDA_KERNEL_CASE(16, ##__VA_ARGS__); \ IMPL_STACK_CUDA_KERNEL_CASE(32, ##__VA_ARGS__); \