diff --git a/paddle/phi/kernels/funcs/segmented_array.h b/paddle/phi/kernels/funcs/segmented_array.h new file mode 100644 index 0000000000000000000000000000000000000000..51059c2ece7c91aa1b1dad9264e432b3f76676a9 --- /dev/null +++ b/paddle/phi/kernels/funcs/segmented_array.h @@ -0,0 +1,230 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/funcs/fast_divmod.h" + +namespace phi { +namespace funcs { + +template +struct GeneralDivMod { + public: + explicit GeneralDivMod(IndexT d) { divmoder = phi::funcs::FastDivMod(d); } + __device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) { + return divmoder.Divmod(val); + } + + phi::funcs::FastDivMod divmoder; +}; + +template <> +struct GeneralDivMod { + public: + using DivModT = phi::AlignedVector; + + explicit GeneralDivMod(int64_t d) { divisor = d; } + __device__ inline DivModT div_mod(int64_t val) { + DivModT data; + data[0] = val / divisor; + data[1] = val - data[0] * divisor; + return data; + } + + int64_t divisor; +}; + +#if !defined(_WIN32) +#define PADDLE_ALIGN(x) __attribute__((aligned(x))) +#else +#define PADDLE_ALIGN(x) +#endif + +enum class SegmentedArraySize { + kVariableLength = 0, + kFixed4 = 4, + kFixed8 = 8, + kFixed16 = 16, + kFixed32 = 32, + kFixed64 = 64, +}; + +template +struct PADDLE_ALIGN(256) ConstPointerArray { + public: + const T* data[static_cast(Size)]; + + void Set(const std::vector& ptrs, const T** dev_ptr = nullptr) { + for (auto i = 0; i < ptrs.size(); ++i) { + data[i] = ptrs[i]; + } + } +}; + +template +struct PADDLE_ALIGN(256) + ConstPointerArray { + public: + const T** data{nullptr}; + + void Set(const std::vector& ptrs, const T** dev_ptr = nullptr) { + data = dev_ptr; + } +}; + +template +struct PADDLE_ALIGN(256) PointerArray { + public: + T* data[static_cast(Size)]; + + void Set(const std::vector& ptrs, T** dev_ptr = nullptr) { + for (auto i = 0; i < ptrs.size(); ++i) { + data[i] = ptrs[i]; + } + } +}; + +template +struct PADDLE_ALIGN(256) PointerArray { + public: + T** data{nullptr}; + + void Set(const std::vector& ptrs, T** dev_ptr = nullptr) { + data = dev_ptr; + } +}; + +#undef PADDLE_ALIGN + +template +struct ArraySetterBase { + protected: + void* AllocAndCopy(const Context& ctx, void* src, size_t num_bytes) { + allocation = paddle::memory::Alloc( + ctx.GetPlace(), + num_bytes, + phi::Stream(reinterpret_cast(ctx.stream()))); + paddle::memory::Copy(ctx.GetPlace(), + allocation->ptr(), + phi::CPUPlace(), + src, + num_bytes, + ctx.stream()); + return allocation->ptr(); + } + + phi::Allocator::AllocationPtr allocation{nullptr}; +}; + +template +struct ConstPointerArraySetter : public ArraySetterBase { + public: + ConstPointerArray array; + + ConstPointerArraySetter(const Context& ctx, + const std::vector& t) { + ptrs.resize(t.size()); + for (int i = 0; i < t.size(); ++i) { + ptrs[i] = t[i]->data(); + } + + const T** dev_ptr = nullptr; + if (Size == SegmentedArraySize::kVariableLength) { + size_t num_bytes = t.size() * sizeof(T*); + dev_ptr = + reinterpret_cast(ArraySetterBase::AllocAndCopy( + ctx, reinterpret_cast(ptrs.data()), num_bytes)); + } + + array.Set(ptrs, dev_ptr); + } + + private: + std::vector ptrs; +}; + +template +struct PointerArraySetter : public ArraySetterBase { + public: + PointerArray array; + + PointerArraySetter(const Context& ctx, std::vector* t) { + ptrs.resize(t->size()); + for (int i = 0; i < t->size(); ++i) { + if (t->at(i) && (t->at(i)->numel() > 0)) { + ptrs[i] = ctx.template Alloc(t->at(i)); + } else { + ptrs[i] = nullptr; + } + } + + T** dev_ptr = nullptr; + if (Size == SegmentedArraySize::kVariableLength) { + size_t num_bytes = t->size() * sizeof(T*); + dev_ptr = reinterpret_cast(ArraySetterBase::AllocAndCopy( + ctx, reinterpret_cast(ptrs.data()), num_bytes)); + } + + array.Set(ptrs, dev_ptr); + } + + private: + std::vector ptrs; +}; + +inline SegmentedArraySize CalcArraySize(int n) { + if (n <= 4) { + return SegmentedArraySize::kFixed4; + } else if (n <= 8) { + return SegmentedArraySize::kFixed8; + } else if (n <= 16) { + return SegmentedArraySize::kFixed16; + } else if (n <= 32) { + return SegmentedArraySize::kFixed32; + } else if (n <= 64) { + return SegmentedArraySize::kFixed64; + } else { + return SegmentedArraySize::kVariableLength; + } +} +} // namespace funcs + +#define _SEGMENTED_ARRAY_KERNEL_CASE(size, ...) \ + case (size): { \ + constexpr auto kArraySize = (size); \ + __VA_ARGS__; \ + } break + +#define _SEGMENTED_ARRAY_KERNEL_DEFAULT(size, ...) \ + default: { \ + constexpr auto kArraySize = (size); \ + __VA_ARGS__; \ + } break + +#define SEGMENTED_ARRAY_KERNEL_HELPER(...) \ + _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed4, \ + ##__VA_ARGS__); \ + _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed8, \ + ##__VA_ARGS__); \ + _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed16, \ + ##__VA_ARGS__); \ + _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed32, \ + ##__VA_ARGS__); \ + _SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed64, \ + ##__VA_ARGS__); \ + _SEGMENTED_ARRAY_KERNEL_DEFAULT(funcs::SegmentedArraySize::kVariableLength, \ + ##__VA_ARGS__); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/stack_grad_kernel.cu b/paddle/phi/kernels/gpu/stack_grad_kernel.cu index ccbe2885334a8d50cd4b546094502391ccd8b96c..572ed4a361b4e6730a743c811d5292dd2728d50e 100644 --- a/paddle/phi/kernels/gpu/stack_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/stack_grad_kernel.cu @@ -16,16 +16,17 @@ #include "paddle/fluid/memory/memory.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/segmented_array.h" namespace phi { -template -__global__ void UnStackHelperCUDAKernel(const T* __restrict__ input, - int pre_dim_size, - int split_dim_size, - int suf_dim_size, - int num_split, - T** output_ptrs) { +template +__global__ void UnStackCudaKernel(const T* __restrict__ input, + IndexT pre_dim_size, + IndexT split_dim_size, + IndexT suf_dim_size, + IndexT num_split, + ArrayT array) { assert(blockDim.y == 1); assert(blockDim.z == 1); // In this case they are equal @@ -40,7 +41,7 @@ __global__ void UnStackHelperCUDAKernel(const T* __restrict__ input, IndexT j = (offset % (split_dim_size * suf_dim_size)) / suf_dim_size; IndexT k = offset % suf_dim_size; - T* output = output_ptrs[j / each_dim_size]; + T* output = array.data[j / each_dim_size]; if (output == nullptr) { return; } @@ -50,12 +51,12 @@ __global__ void UnStackHelperCUDAKernel(const T* __restrict__ input, } } -template -__global__ void StackGradKernelForLastDim(const T* __restrict__ in_data, - const IndexT cols, - const IndexT rows, - const IndexT tile_x_num, - T** out_datas) { +template +__global__ void UnStackCudaKernelForLastDim(const T* __restrict__ in_data, + const IndexT cols, + const IndexT rows, + const IndexT tile_x_num, + ArrayT array) { constexpr int buffer_size = 512; __shared__ T s_buf[buffer_size]; @@ -71,112 +72,112 @@ __global__ void StackGradKernelForLastDim(const T* __restrict__ in_data, } __syncthreads(); if (is_valid) { - if (out_datas[col_idx] != nullptr) { - out_datas[col_idx][row_idx] = s_buf[s_idx]; + if (array.data[col_idx]) { + array.data[col_idx][row_idx] = s_buf[s_idx]; } } } } -template -void LaunchStackGradCUDAKernel(const Context& ctx, - const DenseTensor& out, - std::vector* x_grad_ptr, - const int axis, - const int64_t dy_pre) { - auto x_grad = *x_grad_ptr; - int out_num = out.dims()[axis]; - PADDLE_ENFORCE_EQ( - out_num, - x_grad.size(), - phi::errors::InvalidArgument( - "Output x_grad size shall be equal to output num, but output num " - "received in stack_grad op is:%d, and x_grad size is:%d.", - out_num, - x_grad.size())); - std::vector outputs(out_num); - for (size_t j = 0; j < out_num; ++j) { - if (x_grad[j] == nullptr || x_grad[j]->numel() == 0UL) { - outputs[j] = nullptr; - } else { - outputs[j] = ctx.template Alloc(x_grad[j]); - } - } - - auto tmp_out_data = paddle::memory::Alloc( - ctx.GetPlace(), - out_num * sizeof(T*), - phi::Stream(reinterpret_cast(ctx.stream()))); - paddle::memory::Copy(ctx.GetPlace(), - tmp_out_data->ptr(), - phi::CPUPlace(), - reinterpret_cast(outputs.data()), - out_num * sizeof(T*), - ctx.stream()); - - if (axis == (out.dims().size() - 1)) { +template +void LaunchUnStackKernel(const Context& ctx, + const IndexT pre_dim, + const IndexT split_dim, + const IndexT suf_dim, + const IndexT num_splits, + const DenseTensor& out_grad, + std::vector* x_grad) { + // each x_grad should have same shape + auto dout_ptr = out_grad.data(); + funcs::PointerArraySetter setter(ctx, x_grad); + + if (suf_dim == 1) { + // For the case axis == (out_grad.dims().size() - 1) constexpr int kThreads = 512; constexpr int kWarpSize = 32; constexpr int kMaxOut = 16; - int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1; - bool is_small_num = out_num < kMaxOut; - if (is_small_num) { - tid_y = out_num; + int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1; + if (split_dim < kMaxOut) { + tid_y = split_dim; tid_x = - std::min(backends::gpu::RoundToNextHighPowOfTwo(dy_pre, kWarpSize), + std::min(backends::gpu::RoundToNextHighPowOfTwo(pre_dim, kWarpSize), kThreads / backends::gpu::RoundToNextHighPowOfTwo(tid_y)); } else { tid_y = kMaxOut; tid_x = kWarpSize; - bid_y = backends::gpu::DivUp(out_num, kMaxOut); + bid_y = backends::gpu::DivUp(split_dim, kMaxOut); } - int tile_x_num = backends::gpu::DivUp(dy_pre, tid_x); + int tile_x_num = backends::gpu::DivUp(pre_dim, tid_x); bid_x = std::min(tile_x_num, backends::gpu::kMultiDimslimit); dim3 blocks(tid_x, tid_y, 1); dim3 grids(bid_x, bid_y, 1); - StackGradKernelForLastDim<<>>( - out.data(), - out_num, - dy_pre, - tile_x_num, - reinterpret_cast(tmp_out_data->ptr())); + UnStackCudaKernelForLastDim + <<>>( + dout_ptr, split_dim, pre_dim, tile_x_num, setter.array); } else { - int dy_suf = out.numel() / (out_num * dy_pre); - auto config = - backends::gpu::GetGpuLaunchConfig1D(ctx, dy_pre * out_num * dy_suf); - - UnStackHelperCUDAKernel - <<>>( - out.data(), - dy_pre, - out_num, - dy_suf, - out_num, - reinterpret_cast(tmp_out_data->ptr())); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + ctx, pre_dim * split_dim * suf_dim); + + UnStackCudaKernel + <<>>( + dout_ptr, pre_dim, split_dim, suf_dim, num_splits, setter.array); } } template -void StackGradKernel(const Context& dev_ctx, - const DenseTensor& out, +void StackGradKernel(const Context& ctx, + const DenseTensor& out_grad, int axis, std::vector x_grad) { - const auto& dy_dims = out.dims(); - int actual_axis = axis < 0 ? axis + dy_dims.size() : axis; - bool use_int32 = out.numel() < std::numeric_limits::max(); + if (axis < 0) axis += out_grad.dims().size(); + + int64_t split_dim = out_grad.dims()[axis]; + PADDLE_ENFORCE_EQ( + split_dim, + x_grad.size(), + phi::errors::InvalidArgument( + "Output x_grad size should be equal to the split_dim, but" + " received split_dim is:%d x_grad size is:%d.", + split_dim, + x_grad.size())); - int64_t dy_pre = 1; - for (int i = 0; i < actual_axis; ++i) { - dy_pre *= dy_dims[i]; + auto dout_dims = out_grad.dims(); + int64_t dout_pre = 1; + for (int i = 0; i < axis; ++i) { + dout_pre *= dout_dims[i]; } - if (use_int32) { - LaunchStackGradCUDAKernel( - dev_ctx, out, &x_grad, actual_axis, dy_pre); + int64_t dout_suf = out_grad.numel() / (split_dim * dout_pre); + + if (out_grad.numel() < std::numeric_limits::max()) { + switch (funcs::CalcArraySize(split_dim)) { + SEGMENTED_ARRAY_KERNEL_HELPER( + LaunchUnStackKernel(ctx, + dout_pre, + split_dim, + dout_suf, + split_dim, + out_grad, + &x_grad)); + } } else { - LaunchStackGradCUDAKernel( - dev_ctx, out, &x_grad, actual_axis, dy_pre); + switch (funcs::CalcArraySize(split_dim)) { + SEGMENTED_ARRAY_KERNEL_HELPER( + LaunchUnStackKernel(ctx, + dout_pre, + split_dim, + dout_suf, + split_dim, + out_grad, + &x_grad)); + } } } diff --git a/paddle/phi/kernels/gpu/stack_kernel.cu b/paddle/phi/kernels/gpu/stack_kernel.cu index 92730d15f25a48f0eb884383ae73d955ca8afaac..a50396e7c97293e8a3b667f02544bcff93592c5f 100644 --- a/paddle/phi/kernels/gpu/stack_kernel.cu +++ b/paddle/phi/kernels/gpu/stack_kernel.cu @@ -15,86 +15,15 @@ #include "paddle/phi/kernels/stack_kernel.h" #include "paddle/fluid/memory/memory.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/fast_divmod.h" +#include "paddle/phi/kernels/funcs/segmented_array.h" namespace phi { -template -struct DivmodWarpper { - public: - void SetDivisor(IndexT divisor) { - divmoder = phi::funcs::FastDivMod(divisor); - } - __device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) { - return divmoder.Divmod(val); - } - - private: - phi::funcs::FastDivMod divmoder; -}; - -template <> -struct DivmodWarpper { - public: - using DivModT = phi::AlignedVector; - - void SetDivisor(int64_t divisor) { dividen_ = divisor; } - __device__ inline DivModT div_mod(int64_t val) { - DivModT data; - data[0] = val / dividen_; - data[1] = val - data[0] * dividen_; - return data; - } - - private: - int64_t dividen_; -}; - -template -struct PointerArray : public DivmodWarpper { - public: - const T* data[Size]; - PointerArray(const std::vector& x, - int num, - IndexT divisor) { - this->SetDivisor(divisor); - for (auto i = 0; i < num; ++i) { - data[i] = x[i]->data(); - } - } -}; - -template -struct PointerToPointer : public DivmodWarpper { - public: - T** data{nullptr}; - PointerToPointer(const Context& ctx, - const std::vector& x, - IndexT num, - IndexT divisor, - paddle::memory::AllocationPtr* dev_ins_ptr) { - this->SetDivisor(divisor); - std::vector x_datas(num); - for (int i = 0; i < num; ++i) { - x_datas[i] = x[i]->data(); - } - *dev_ins_ptr = paddle::memory::Alloc( - ctx.GetPlace(), - num * sizeof(T*), - phi::Stream(reinterpret_cast(ctx.stream()))); - paddle::memory::Copy(ctx.GetPlace(), - (*dev_ins_ptr)->ptr(), - phi::CPUPlace(), - reinterpret_cast(x_datas.data()), - num * sizeof(T*), - ctx.stream()); - data = reinterpret_cast((*dev_ins_ptr)->ptr()); - } -}; - -template -__global__ void StackCUDAKernel(WrapT input_warpper, +template +__global__ void StackCUDAKernel(ArrayT array, + funcs::GeneralDivMod divmoder, IndexT split_size, IndexT rows, IndexT cols, @@ -106,85 +35,69 @@ __global__ void StackCUDAKernel(WrapT input_warpper, for (; grid_x < cols; grid_x += grid_x_stride) { IndexT grid_y = static_cast(blockIdx.y) * blockDim.y + threadIdx.y; - auto divmod_rslt = input_warpper.div_mod(grid_x); - const T* input_ptr = input_warpper.data[divmod_rslt[0]]; + auto divmod_rslt = divmoder.div_mod(grid_x); + IndexT split = divmod_rslt[0]; // grid_x / split_size + IndexT col_offset = divmod_rslt[1]; // grid_x % split_size + const T* input_ptr = array.data[split]; #pragma unroll for (; grid_y < rows; grid_y += grid_y_stride) { output[grid_y * cols + grid_x] = - input_ptr[grid_y * split_size + divmod_rslt[1]]; + input_ptr[grid_y * split_size + col_offset]; } } } -template -void LaunchStackCUDAKernelWithIndexType( - const Context& ctx, - const IndexT x_col, - const IndexT x_row, - const IndexT out_col, - const phi::backends::gpu::GpuLaunchConfig& cfg, - const std::vector& x, - T* dst_data) { - int num = static_cast(x.size()); -#define IMPL_STACK_CUDA_KERNEL_CASE(size_, ...) \ - case size_: { \ - PointerArray ptr_array(x, num, x_col); \ - __VA_ARGS__; \ - } break; - -#define IMPL_STACK_CUDA_KERNEL_HELPER(...) \ - IMPL_STACK_CUDA_KERNEL_CASE(4, ##__VA_ARGS__); \ - IMPL_STACK_CUDA_KERNEL_CASE(8, ##__VA_ARGS__); \ - IMPL_STACK_CUDA_KERNEL_CASE(16, ##__VA_ARGS__); \ - IMPL_STACK_CUDA_KERNEL_CASE(32, ##__VA_ARGS__); \ - IMPL_STACK_CUDA_KERNEL_CASE(64, ##__VA_ARGS__); \ - IMPL_STACK_CUDA_KERNEL_CASE(128, ##__VA_ARGS__); - - switch (phi::backends::gpu::RoundToNextHighPowOfTwo(num, 4)) { - IMPL_STACK_CUDA_KERNEL_HELPER( - StackCUDAKernel - <<>>( - ptr_array, x_col, x_row, out_col, dst_data)); - default: { - paddle::memory::AllocationPtr dev_ins_ptr{nullptr}; - PointerToPointer ptr_array( - ctx, x, num, x_col, &dev_ins_ptr); - StackCUDAKernel - <<>>( - ptr_array, x_col, x_row, out_col, dst_data); - } - } -#undef IMPL_STACK_CUDA_KERNEL_HELPER -#undef IMPL_STACK_CUDA_KERNEL_CASE +template +void LaunchStackKernel(const Context& ctx, + const IndexT x_col, + const IndexT x_row, + const IndexT out_col, + const std::vector& x, + DenseTensor* out) { + T* out_ptr = ctx.template Alloc(out); + auto config = phi::backends::gpu::GetGpuLaunchConfig2D(ctx, out_col, x_row); + + funcs::ConstPointerArraySetter setter(ctx, x); + funcs::GeneralDivMod divmoder(x_col); + StackCUDAKernel + <<>>( + setter.array, divmoder, x_col, x_row, out_col, out_ptr); } template -void StackKernel(const Context& dev_ctx, +void StackKernel(const Context& ctx, const std::vector& x, int axis, DenseTensor* out) { if (axis < 0) axis += (x[0]->dims().size() + 1); int num = static_cast(x.size()); - T* dst_data = dev_ctx.template Alloc(out); // Split x dim from axis to matrix - int64_t x_row = 1, x_col = 1; + int64_t x_row = 1; for (int i = 0; i < axis; ++i) { x_row *= x[0]->dims()[i]; } - x_col = x[0]->numel() / x_row; + int64_t x_col = x[0]->numel() / x_row; int64_t out_col = x_col * num; - auto config = - phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, out_col, x_row); if (out->numel() < std::numeric_limits::max()) { - LaunchStackCUDAKernelWithIndexType( - dev_ctx, x_col, x_row, out_col, config, x, dst_data); + switch (funcs::CalcArraySize(num)) { + SEGMENTED_ARRAY_KERNEL_HELPER( + LaunchStackKernel( + ctx, x_col, x_row, out_col, x, out)); + } } else { - LaunchStackCUDAKernelWithIndexType( - dev_ctx, x_col, x_row, out_col, config, x, dst_data); + switch (funcs::CalcArraySize(num)) { + SEGMENTED_ARRAY_KERNEL_HELPER( + LaunchStackKernel( + ctx, x_col, x_row, out_col, x, out)); + } } } + } // namespace phi PD_REGISTER_KERNEL(stack,