diff --git a/paddle/phi/kernels/funcs/segmented_array.h b/paddle/phi/kernels/funcs/segmented_array.h index 0f03dbac591ec745eb55408ad67f267200e74e3a..aa03eb4e9fcd21e1c5ac6901bb7a30d7fd1f895a 100644 --- a/paddle/phi/kernels/funcs/segmented_array.h +++ b/paddle/phi/kernels/funcs/segmented_array.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/phi/kernels/funcs/fast_divmod.h" +#include "paddle/phi/core/dense_tensor.h" namespace phi { namespace funcs { @@ -89,12 +89,11 @@ struct ArraySetterBase { ctx.GetPlace(), num_bytes, phi::Stream(reinterpret_cast(ctx.stream()))); - paddle::memory::Copy(ctx.GetPlace(), - allocation->ptr(), - phi::CPUPlace(), - src, - num_bytes, - ctx.stream()); + phi::backends::gpu::GpuMemcpyAsync(allocation->ptr(), + src, + num_bytes, + phi::gpuMemcpyHostToDevice, + ctx.stream()); return allocation->ptr(); } diff --git a/paddle/phi/kernels/funcs/stack_and_unstack.h b/paddle/phi/kernels/funcs/stack_and_unstack.h new file mode 100644 index 0000000000000000000000000000000000000000..c516d4892bf629226bed9bb8f93cd3436792d639 --- /dev/null +++ b/paddle/phi/kernels/funcs/stack_and_unstack.h @@ -0,0 +1,276 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/kernels/funcs/fast_divmod.h" +#include "paddle/phi/kernels/funcs/segmented_array.h" + +namespace phi { +namespace funcs { + +template +__global__ void StackCudaKernel(ArrayT array, + GeneralDivMod divmoder, + IndexT split_size, + IndexT rows, + IndexT cols, + T* __restrict__ output) { + IndexT grid_x = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + IndexT grid_x_stride = static_cast(blockDim.x) * gridDim.x; + IndexT grid_y_stride = static_cast(blockDim.y) * gridDim.y; + + for (; grid_x < cols; grid_x += grid_x_stride) { + IndexT grid_y = static_cast(blockIdx.y) * blockDim.y + threadIdx.y; + + auto divmod_rslt = divmoder.div_mod(grid_x); + IndexT split = divmod_rslt[0]; // grid_x / split_size + IndexT col_offset = divmod_rslt[1]; // grid_x % split_size + const T* input_ptr = array.data[split]; +#pragma unroll + for (; grid_y < rows; grid_y += grid_y_stride) { + output[grid_y * cols + grid_x] = + input_ptr[grid_y * split_size + col_offset]; + } + } +} + +template +void LaunchStackKernel(const Context& ctx, + const IndexT x_col, + const IndexT x_row, + const IndexT out_col, + const std::vector& x, + DenseTensor* out) { + T* out_ptr = ctx.template Alloc(out); + auto config = phi::backends::gpu::GetGpuLaunchConfig2D(ctx, out_col, x_row); + + ConstPointerArraySetter setter(ctx, x); + GeneralDivMod divmoder(x_col); + StackCudaKernel + <<>>( + setter.array, divmoder, x_col, x_row, out_col, out_ptr); +} + +template +void StackRawKernel(const Context& ctx, + const std::vector& x, + int axis, + DenseTensor* out) { + if (axis < 0) axis += (x[0]->dims().size() + 1); + int num = static_cast(x.size()); + + // Split x dim from axis to matrix of shape [x_row, x_col], and the output + // tensor's shape is [x_row, out_col]. + int64_t x_row = 1; + for (int i = 0; i < axis; ++i) { + x_row *= x[0]->dims()[i]; + } + int64_t x_col = x[0]->numel() / x_row; + int64_t out_col = x_col * num; + + if (out->numel() < std::numeric_limits::max()) { + switch (CalcArraySize(num)) { + SEGMENTED_ARRAY_KERNEL_HELPER( + LaunchStackKernel( + ctx, x_col, x_row, out_col, x, out)); + } + } else { + switch (CalcArraySize(num)) { + SEGMENTED_ARRAY_KERNEL_HELPER( + LaunchStackKernel( + ctx, x_col, x_row, out_col, x, out)); + } + } +} + +template +__global__ void UnStackCudaKernel(const T* __restrict__ input, + IndexT out_row, + IndexT split_dim, + IndexT out_col, + IndexT num_splits, + GeneralDivMod col_divmoder, + ArrayT array) { + assert(blockDim.y == 1); + assert(blockDim.z == 1); + // In this case they are equal + assert(split_dim % num_splits == 0); + + IndexT numel = out_row * split_dim * out_col; + IndexT each_dim_size = split_dim / num_splits; + IndexT split_dim_with_out_col = split_dim * out_col; + + IndexT offset = blockIdx.x * blockDim.x + threadIdx.x; + if (each_dim_size == 1) { + for (; offset < numel; offset += blockDim.x * gridDim.x) { + auto col_divmod_rslt = col_divmoder.div_mod(offset); + + IndexT i = offset / split_dim_with_out_col; + IndexT j = col_divmod_rslt[0] - i * split_dim; + IndexT k = col_divmod_rslt[1]; // offset % out_col + + T* output = array.data[j]; + if (output) { + IndexT output_idx = i * out_col + k; + *(output + output_idx) = input[offset]; + } + } + } else { + for (; offset < numel; offset += blockDim.x * gridDim.x) { + auto col_divmod_rslt = col_divmoder.div_mod(offset); + + IndexT i = offset / split_dim_with_out_col; + IndexT j = col_divmod_rslt[0] - i * split_dim; + IndexT k = col_divmod_rslt[1]; // offset % out_col + + T* output = array.data[j / each_dim_size]; + if (output) { + IndexT output_idx = (i + j % each_dim_size) * out_col + k; + *(output + output_idx) = input[offset]; + } + } + } +} + +template +__global__ void UnStackCudaKernelForLastDim(const T* __restrict__ in_data, + const IndexT cols, + const IndexT rows, + const IndexT tile_x_num, + ArrayT array) { + constexpr int buffer_size = 512; + __shared__ T s_buf[buffer_size]; + + for (IndexT tile_x = blockIdx.x; tile_x < tile_x_num; tile_x += gridDim.x) { + IndexT row_idx = tile_x * blockDim.x + threadIdx.x; + IndexT col_idx = blockIdx.y * blockDim.y + threadIdx.y; + int s_idx = threadIdx.y * blockDim.x + threadIdx.x; + bool is_valid = (col_idx < cols && row_idx < rows); + + if (is_valid) { + T data = in_data[row_idx * cols + col_idx]; + s_buf[s_idx] = data; + } + __syncthreads(); + if (is_valid) { + if (array.data[col_idx]) { + array.data[col_idx][row_idx] = s_buf[s_idx]; + } + } + } +} + +template +void LaunchUnStackKernel(const Context& ctx, + const IndexT out_row, + const IndexT split_dim, + const IndexT out_col, + const IndexT num_splits, + const DenseTensor& x, + std::vector* outs) { + // each tensor in outs should have same shape. + VLOG(6) << "out_row=" << out_row << ", split_dim=" << split_dim + << ", out_col=" << out_col << ", num_splits=" << num_splits; + + auto x_ptr = x.data(); + PointerArraySetter setter(ctx, outs); + + if (out_col == 1) { + // For the case axis == (x.dims().size() - 1) + constexpr int kThreads = 512; + constexpr int kWarpSize = 32; + constexpr int kMaxOut = 16; + + int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1; + if (split_dim < kMaxOut) { + tid_y = split_dim; + tid_x = + std::min(backends::gpu::RoundToNextHighPowOfTwo(out_row, kWarpSize), + kThreads / backends::gpu::RoundToNextHighPowOfTwo(tid_y)); + } else { + tid_y = kMaxOut; + tid_x = kWarpSize; + bid_y = backends::gpu::DivUp(split_dim, kMaxOut); + } + int tile_x_num = backends::gpu::DivUp(out_row, tid_x); + bid_x = std::min(tile_x_num, backends::gpu::kMultiDimslimit); + dim3 blocks(tid_x, tid_y, 1); + dim3 grids(bid_x, bid_y, 1); + + UnStackCudaKernelForLastDim + <<>>( + x_ptr, split_dim, out_row, tile_x_num, setter.array); + } else { + GeneralDivMod col_divmoder(out_col); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + ctx, out_row * split_dim * out_col); + + UnStackCudaKernel + <<>>(x_ptr, + out_row, + split_dim, + out_col, + num_splits, + col_divmoder, + setter.array); + } +} + +template +void UnStackRawKernel(const Context& ctx, + const DenseTensor& x, + int axis, + std::vector* outs) { + auto x_dims = x.dims(); + + // Input tensor is splited to split_dim tensors along split_dim dimension. + int64_t split_dim = x_dims[axis]; + + // Treat outs[i] as [out_row, out_col], and x as [out_row, split_dim, + // out_col]. + int64_t out_row = 1; + for (int i = 0; i < axis; ++i) { + out_row *= x_dims[i]; + } + + int64_t out_col = x.numel() / (split_dim * out_row); + + if (x.numel() < std::numeric_limits::max()) { + switch (CalcArraySize(split_dim)) { + SEGMENTED_ARRAY_KERNEL_HELPER( + LaunchUnStackKernel( + ctx, out_row, split_dim, out_col, split_dim, x, outs)); + } + } else { + switch (CalcArraySize(split_dim)) { + SEGMENTED_ARRAY_KERNEL_HELPER( + LaunchUnStackKernel( + ctx, out_row, split_dim, out_col, split_dim, x, outs)); + } + } +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/stack_grad_kernel.cu b/paddle/phi/kernels/gpu/stack_grad_kernel.cu index 572ed4a361b4e6730a743c811d5292dd2728d50e..6c72a3562e6a7521bbd436569414ba52bb2d00ae 100644 --- a/paddle/phi/kernels/gpu/stack_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/stack_grad_kernel.cu @@ -13,125 +13,13 @@ // limitations under the License. #include "paddle/phi/kernels/stack_grad_kernel.h" -#include "paddle/fluid/memory/memory.h" -#include "paddle/phi/backends/gpu/gpu_launch_config.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/segmented_array.h" +#include "paddle/phi/kernels/funcs/stack_and_unstack.h" namespace phi { -template -__global__ void UnStackCudaKernel(const T* __restrict__ input, - IndexT pre_dim_size, - IndexT split_dim_size, - IndexT suf_dim_size, - IndexT num_split, - ArrayT array) { - assert(blockDim.y == 1); - assert(blockDim.z == 1); - // In this case they are equal - assert(split_dim_size % num_split == 0); - - IndexT size = pre_dim_size * split_dim_size * suf_dim_size; - IndexT each_dim_size = split_dim_size / num_split; - - for (IndexT offset = blockIdx.x * blockDim.x + threadIdx.x; offset < size; - offset += blockDim.x * gridDim.x) { - IndexT i = offset / (split_dim_size * suf_dim_size); - IndexT j = (offset % (split_dim_size * suf_dim_size)) / suf_dim_size; - IndexT k = offset % suf_dim_size; - - T* output = array.data[j / each_dim_size]; - if (output == nullptr) { - return; - } - IndexT output_ind = i * each_dim_size * suf_dim_size + - (j % each_dim_size) * suf_dim_size + k; - *(output + output_ind) = input[offset]; - } -} - -template -__global__ void UnStackCudaKernelForLastDim(const T* __restrict__ in_data, - const IndexT cols, - const IndexT rows, - const IndexT tile_x_num, - ArrayT array) { - constexpr int buffer_size = 512; - __shared__ T s_buf[buffer_size]; - - for (IndexT tile_x = blockIdx.x; tile_x < tile_x_num; tile_x += gridDim.x) { - IndexT row_idx = tile_x * blockDim.x + threadIdx.x; - IndexT col_idx = blockIdx.y * blockDim.y + threadIdx.y; - int s_idx = threadIdx.y * blockDim.x + threadIdx.x; - bool is_valid = (col_idx < cols && row_idx < rows); - - if (is_valid) { - T data = in_data[row_idx * cols + col_idx]; - s_buf[s_idx] = data; - } - __syncthreads(); - if (is_valid) { - if (array.data[col_idx]) { - array.data[col_idx][row_idx] = s_buf[s_idx]; - } - } - } -} - -template -void LaunchUnStackKernel(const Context& ctx, - const IndexT pre_dim, - const IndexT split_dim, - const IndexT suf_dim, - const IndexT num_splits, - const DenseTensor& out_grad, - std::vector* x_grad) { - // each x_grad should have same shape - auto dout_ptr = out_grad.data(); - funcs::PointerArraySetter setter(ctx, x_grad); - - if (suf_dim == 1) { - // For the case axis == (out_grad.dims().size() - 1) - constexpr int kThreads = 512; - constexpr int kWarpSize = 32; - constexpr int kMaxOut = 16; - - int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1; - if (split_dim < kMaxOut) { - tid_y = split_dim; - tid_x = - std::min(backends::gpu::RoundToNextHighPowOfTwo(pre_dim, kWarpSize), - kThreads / backends::gpu::RoundToNextHighPowOfTwo(tid_y)); - } else { - tid_y = kMaxOut; - tid_x = kWarpSize; - bid_y = backends::gpu::DivUp(split_dim, kMaxOut); - } - int tile_x_num = backends::gpu::DivUp(pre_dim, tid_x); - bid_x = std::min(tile_x_num, backends::gpu::kMultiDimslimit); - dim3 blocks(tid_x, tid_y, 1); - dim3 grids(bid_x, bid_y, 1); - - UnStackCudaKernelForLastDim - <<>>( - dout_ptr, split_dim, pre_dim, tile_x_num, setter.array); - } else { - auto config = phi::backends::gpu::GetGpuLaunchConfig1D( - ctx, pre_dim * split_dim * suf_dim); - - UnStackCudaKernel - <<>>( - dout_ptr, pre_dim, split_dim, suf_dim, num_splits, setter.array); - } -} - template void StackGradKernel(const Context& ctx, const DenseTensor& out_grad, @@ -144,41 +32,12 @@ void StackGradKernel(const Context& ctx, split_dim, x_grad.size(), phi::errors::InvalidArgument( - "Output x_grad size should be equal to the split_dim, but" - " received split_dim is:%d x_grad size is:%d.", + "Output x_grad's size should be equal to the split_dim, but" + " received split_dim is:%d x_grad's size is:%d.", split_dim, x_grad.size())); - auto dout_dims = out_grad.dims(); - int64_t dout_pre = 1; - for (int i = 0; i < axis; ++i) { - dout_pre *= dout_dims[i]; - } - int64_t dout_suf = out_grad.numel() / (split_dim * dout_pre); - - if (out_grad.numel() < std::numeric_limits::max()) { - switch (funcs::CalcArraySize(split_dim)) { - SEGMENTED_ARRAY_KERNEL_HELPER( - LaunchUnStackKernel(ctx, - dout_pre, - split_dim, - dout_suf, - split_dim, - out_grad, - &x_grad)); - } - } else { - switch (funcs::CalcArraySize(split_dim)) { - SEGMENTED_ARRAY_KERNEL_HELPER( - LaunchUnStackKernel(ctx, - dout_pre, - split_dim, - dout_suf, - split_dim, - out_grad, - &x_grad)); - } - } + funcs::UnStackRawKernel(ctx, out_grad, axis, &x_grad); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/stack_kernel.cu b/paddle/phi/kernels/gpu/stack_kernel.cu index a50396e7c97293e8a3b667f02544bcff93592c5f..e1d7d4e6f389c895198a00ece36c7ecb1da1a5b2 100644 --- a/paddle/phi/kernels/gpu/stack_kernel.cu +++ b/paddle/phi/kernels/gpu/stack_kernel.cu @@ -13,89 +13,19 @@ // limitations under the License. #include "paddle/phi/kernels/stack_kernel.h" -#include "paddle/fluid/memory/memory.h" -#include "paddle/phi/backends/gpu/gpu_launch_config.h" -#include "paddle/phi/core/dense_tensor.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/segmented_array.h" +#include "paddle/phi/kernels/funcs/stack_and_unstack.h" namespace phi { -template -__global__ void StackCUDAKernel(ArrayT array, - funcs::GeneralDivMod divmoder, - IndexT split_size, - IndexT rows, - IndexT cols, - T* __restrict__ output) { - IndexT grid_x = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; - IndexT grid_x_stride = static_cast(blockDim.x) * gridDim.x; - IndexT grid_y_stride = static_cast(blockDim.y) * gridDim.y; - - for (; grid_x < cols; grid_x += grid_x_stride) { - IndexT grid_y = static_cast(blockIdx.y) * blockDim.y + threadIdx.y; - - auto divmod_rslt = divmoder.div_mod(grid_x); - IndexT split = divmod_rslt[0]; // grid_x / split_size - IndexT col_offset = divmod_rslt[1]; // grid_x % split_size - const T* input_ptr = array.data[split]; -#pragma unroll - for (; grid_y < rows; grid_y += grid_y_stride) { - output[grid_y * cols + grid_x] = - input_ptr[grid_y * split_size + col_offset]; - } - } -} - -template -void LaunchStackKernel(const Context& ctx, - const IndexT x_col, - const IndexT x_row, - const IndexT out_col, - const std::vector& x, - DenseTensor* out) { - T* out_ptr = ctx.template Alloc(out); - auto config = phi::backends::gpu::GetGpuLaunchConfig2D(ctx, out_col, x_row); - - funcs::ConstPointerArraySetter setter(ctx, x); - funcs::GeneralDivMod divmoder(x_col); - StackCUDAKernel - <<>>( - setter.array, divmoder, x_col, x_row, out_col, out_ptr); -} - template void StackKernel(const Context& ctx, const std::vector& x, int axis, DenseTensor* out) { - if (axis < 0) axis += (x[0]->dims().size() + 1); - int num = static_cast(x.size()); - - // Split x dim from axis to matrix - int64_t x_row = 1; - for (int i = 0; i < axis; ++i) { - x_row *= x[0]->dims()[i]; - } - int64_t x_col = x[0]->numel() / x_row; - int64_t out_col = x_col * num; - - if (out->numel() < std::numeric_limits::max()) { - switch (funcs::CalcArraySize(num)) { - SEGMENTED_ARRAY_KERNEL_HELPER( - LaunchStackKernel( - ctx, x_col, x_row, out_col, x, out)); - } - } else { - switch (funcs::CalcArraySize(num)) { - SEGMENTED_ARRAY_KERNEL_HELPER( - LaunchStackKernel( - ctx, x_col, x_row, out_col, x, out)); - } - } + funcs::StackRawKernel(ctx, x, axis, out); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/unstack_grad_kernel.cu b/paddle/phi/kernels/gpu/unstack_grad_kernel.cu index b7c349de0df32b0a0687264c7504e600306068fd..88bf155606c1b3b5437deac226752941bcbe6906 100644 --- a/paddle/phi/kernels/gpu/unstack_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/unstack_grad_kernel.cu @@ -16,7 +16,19 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/unstack_grad_kernel_impl.h" +#include "paddle/phi/kernels/funcs/stack_and_unstack.h" + +namespace phi { + +template +void UnStackGradKernel(const Context& ctx, + const std::vector& out_grad, + int axis, + DenseTensor* x_grad) { + funcs::StackRawKernel(ctx, out_grad, axis, x_grad); +} + +} // namespace phi PD_REGISTER_KERNEL(unstack_grad, GPU, @@ -26,4 +38,5 @@ PD_REGISTER_KERNEL(unstack_grad, double, int64_t, int, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/unstack_kernel.cu b/paddle/phi/kernels/gpu/unstack_kernel.cu index f147f4c0f0edfd4af38070f3c221cdd70aef7c9c..4331322bdc202c2c3641fe7653d88377050baaea 100644 --- a/paddle/phi/kernels/gpu/unstack_kernel.cu +++ b/paddle/phi/kernels/gpu/unstack_kernel.cu @@ -16,7 +16,33 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/unstack_kernel_impl.h" +#include "paddle/phi/kernels/funcs/stack_and_unstack.h" + +namespace phi { + +template +void UnStackKernel(const Context& ctx, + const DenseTensor& x, + int axis, + int num, + std::vector outs) { + if (x.numel() == 0) return; + if (axis < 0) axis += x.dims().size(); + + int64_t split_dim = x.dims()[axis]; + PADDLE_ENFORCE_EQ( + split_dim, + outs.size(), + phi::errors::InvalidArgument( + "Output outs's size should be equal to the split_dim, but" + " received split_dim is:%d outs's size is:%d.", + split_dim, + outs.size())); + + funcs::UnStackRawKernel(ctx, x, axis, &outs); +} + +} // namespace phi PD_REGISTER_KERNEL(unstack, GPU, @@ -26,4 +52,5 @@ PD_REGISTER_KERNEL(unstack, double, int64_t, int, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/stack_grad_kernel.h b/paddle/phi/kernels/stack_grad_kernel.h index 32451e606f26ab207c932e39846fc3ad91960bc7..1e8f2d68399f87b240fd530defe8fd18c1b92b04 100644 --- a/paddle/phi/kernels/stack_grad_kernel.h +++ b/paddle/phi/kernels/stack_grad_kernel.h @@ -20,7 +20,7 @@ namespace phi { template void StackGradKernel(const Context& dev_ctx, - const DenseTensor& out, + const DenseTensor& out_grad, int axis, std::vector x_grad); diff --git a/paddle/phi/kernels/unstack_grad_kernel.h b/paddle/phi/kernels/unstack_grad_kernel.h index de0e3004d8038dea3114920739ee099546fbcb68..cb50f5ec9240c0de0f4c2f44672d9ca8c7818019 100644 --- a/paddle/phi/kernels/unstack_grad_kernel.h +++ b/paddle/phi/kernels/unstack_grad_kernel.h @@ -20,7 +20,7 @@ namespace phi { template void UnStackGradKernel(const Context& dev_ctx, - const std::vector& x, + const std::vector& out_grad, int axis, DenseTensor* x_grad);