未验证 提交 3586e856 编写于 作者: Y Yiqun Liu 提交者: GitHub

Unify the gpu implementation of stack and unstack to reuse the optimization. (#49748)

* Unify the gpu implementation of stack and unstack to reuse the optimization.

* Optimize the cuda implementation of unstack.

* Use GpuMemcpyAsync instead of memory::Copy.

* Fix error of calculating the index.

* Use FastDivMod to further imporve the performance of unstack.
上级 e7deae21
......@@ -14,7 +14,7 @@
#pragma once
#include "paddle/phi/kernels/funcs/fast_divmod.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
namespace funcs {
......@@ -89,11 +89,10 @@ struct ArraySetterBase {
ctx.GetPlace(),
num_bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
paddle::memory::Copy(ctx.GetPlace(),
allocation->ptr(),
phi::CPUPlace(),
phi::backends::gpu::GpuMemcpyAsync(allocation->ptr(),
src,
num_bytes,
phi::gpuMemcpyHostToDevice,
ctx.stream());
return allocation->ptr();
}
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/fast_divmod.h"
#include "paddle/phi/kernels/funcs/segmented_array.h"
namespace phi {
namespace funcs {
template <typename T, typename IndexT, typename ArrayT>
__global__ void StackCudaKernel(ArrayT array,
GeneralDivMod<IndexT> divmoder,
IndexT split_size,
IndexT rows,
IndexT cols,
T* __restrict__ output) {
IndexT grid_x = static_cast<IndexT>(blockIdx.x) * blockDim.x + threadIdx.x;
IndexT grid_x_stride = static_cast<IndexT>(blockDim.x) * gridDim.x;
IndexT grid_y_stride = static_cast<IndexT>(blockDim.y) * gridDim.y;
for (; grid_x < cols; grid_x += grid_x_stride) {
IndexT grid_y = static_cast<IndexT>(blockIdx.y) * blockDim.y + threadIdx.y;
auto divmod_rslt = divmoder.div_mod(grid_x);
IndexT split = divmod_rslt[0]; // grid_x / split_size
IndexT col_offset = divmod_rslt[1]; // grid_x % split_size
const T* input_ptr = array.data[split];
#pragma unroll
for (; grid_y < rows; grid_y += grid_y_stride) {
output[grid_y * cols + grid_x] =
input_ptr[grid_y * split_size + col_offset];
}
}
}
template <typename Context,
typename T,
typename IndexT,
SegmentedArraySize Size>
void LaunchStackKernel(const Context& ctx,
const IndexT x_col,
const IndexT x_row,
const IndexT out_col,
const std::vector<const DenseTensor*>& x,
DenseTensor* out) {
T* out_ptr = ctx.template Alloc<T>(out);
auto config = phi::backends::gpu::GetGpuLaunchConfig2D(ctx, out_col, x_row);
ConstPointerArraySetter<Context, T, Size> setter(ctx, x);
GeneralDivMod<IndexT> divmoder(x_col);
StackCudaKernel<T, IndexT, decltype(setter.array)>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
setter.array, divmoder, x_col, x_row, out_col, out_ptr);
}
template <typename T, typename Context>
void StackRawKernel(const Context& ctx,
const std::vector<const DenseTensor*>& x,
int axis,
DenseTensor* out) {
if (axis < 0) axis += (x[0]->dims().size() + 1);
int num = static_cast<int>(x.size());
// Split x dim from axis to matrix of shape [x_row, x_col], and the output
// tensor's shape is [x_row, out_col].
int64_t x_row = 1;
for (int i = 0; i < axis; ++i) {
x_row *= x[0]->dims()[i];
}
int64_t x_col = x[0]->numel() / x_row;
int64_t out_col = x_col * num;
if (out->numel() < std::numeric_limits<int32_t>::max()) {
switch (CalcArraySize(num)) {
SEGMENTED_ARRAY_KERNEL_HELPER(
LaunchStackKernel<Context, T, int32_t, kArraySize>(
ctx, x_col, x_row, out_col, x, out));
}
} else {
switch (CalcArraySize(num)) {
SEGMENTED_ARRAY_KERNEL_HELPER(
LaunchStackKernel<Context, T, int64_t, kArraySize>(
ctx, x_col, x_row, out_col, x, out));
}
}
}
template <typename T, typename IndexT, typename ArrayT>
__global__ void UnStackCudaKernel(const T* __restrict__ input,
IndexT out_row,
IndexT split_dim,
IndexT out_col,
IndexT num_splits,
GeneralDivMod<IndexT> col_divmoder,
ArrayT array) {
assert(blockDim.y == 1);
assert(blockDim.z == 1);
// In this case they are equal
assert(split_dim % num_splits == 0);
IndexT numel = out_row * split_dim * out_col;
IndexT each_dim_size = split_dim / num_splits;
IndexT split_dim_with_out_col = split_dim * out_col;
IndexT offset = blockIdx.x * blockDim.x + threadIdx.x;
if (each_dim_size == 1) {
for (; offset < numel; offset += blockDim.x * gridDim.x) {
auto col_divmod_rslt = col_divmoder.div_mod(offset);
IndexT i = offset / split_dim_with_out_col;
IndexT j = col_divmod_rslt[0] - i * split_dim;
IndexT k = col_divmod_rslt[1]; // offset % out_col
T* output = array.data[j];
if (output) {
IndexT output_idx = i * out_col + k;
*(output + output_idx) = input[offset];
}
}
} else {
for (; offset < numel; offset += blockDim.x * gridDim.x) {
auto col_divmod_rslt = col_divmoder.div_mod(offset);
IndexT i = offset / split_dim_with_out_col;
IndexT j = col_divmod_rslt[0] - i * split_dim;
IndexT k = col_divmod_rslt[1]; // offset % out_col
T* output = array.data[j / each_dim_size];
if (output) {
IndexT output_idx = (i + j % each_dim_size) * out_col + k;
*(output + output_idx) = input[offset];
}
}
}
}
template <typename T, typename IndexT, typename ArrayT>
__global__ void UnStackCudaKernelForLastDim(const T* __restrict__ in_data,
const IndexT cols,
const IndexT rows,
const IndexT tile_x_num,
ArrayT array) {
constexpr int buffer_size = 512;
__shared__ T s_buf[buffer_size];
for (IndexT tile_x = blockIdx.x; tile_x < tile_x_num; tile_x += gridDim.x) {
IndexT row_idx = tile_x * blockDim.x + threadIdx.x;
IndexT col_idx = blockIdx.y * blockDim.y + threadIdx.y;
int s_idx = threadIdx.y * blockDim.x + threadIdx.x;
bool is_valid = (col_idx < cols && row_idx < rows);
if (is_valid) {
T data = in_data[row_idx * cols + col_idx];
s_buf[s_idx] = data;
}
__syncthreads();
if (is_valid) {
if (array.data[col_idx]) {
array.data[col_idx][row_idx] = s_buf[s_idx];
}
}
}
}
template <typename Context,
typename T,
typename IndexT,
SegmentedArraySize Size>
void LaunchUnStackKernel(const Context& ctx,
const IndexT out_row,
const IndexT split_dim,
const IndexT out_col,
const IndexT num_splits,
const DenseTensor& x,
std::vector<DenseTensor*>* outs) {
// each tensor in outs should have same shape.
VLOG(6) << "out_row=" << out_row << ", split_dim=" << split_dim
<< ", out_col=" << out_col << ", num_splits=" << num_splits;
auto x_ptr = x.data<T>();
PointerArraySetter<Context, T, Size> setter(ctx, outs);
if (out_col == 1) {
// For the case axis == (x.dims().size() - 1)
constexpr int kThreads = 512;
constexpr int kWarpSize = 32;
constexpr int kMaxOut = 16;
int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1;
if (split_dim < kMaxOut) {
tid_y = split_dim;
tid_x =
std::min(backends::gpu::RoundToNextHighPowOfTwo(out_row, kWarpSize),
kThreads / backends::gpu::RoundToNextHighPowOfTwo(tid_y));
} else {
tid_y = kMaxOut;
tid_x = kWarpSize;
bid_y = backends::gpu::DivUp<int>(split_dim, kMaxOut);
}
int tile_x_num = backends::gpu::DivUp<int>(out_row, tid_x);
bid_x = std::min(tile_x_num, backends::gpu::kMultiDimslimit);
dim3 blocks(tid_x, tid_y, 1);
dim3 grids(bid_x, bid_y, 1);
UnStackCudaKernelForLastDim<T, IndexT, decltype(setter.array)>
<<<grids, blocks, 0, ctx.stream()>>>(
x_ptr, split_dim, out_row, tile_x_num, setter.array);
} else {
GeneralDivMod<IndexT> col_divmoder(out_col);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
ctx, out_row * split_dim * out_col);
UnStackCudaKernel<T, IndexT, decltype(setter.array)>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
ctx.stream()>>>(x_ptr,
out_row,
split_dim,
out_col,
num_splits,
col_divmoder,
setter.array);
}
}
template <typename T, typename Context>
void UnStackRawKernel(const Context& ctx,
const DenseTensor& x,
int axis,
std::vector<DenseTensor*>* outs) {
auto x_dims = x.dims();
// Input tensor is splited to split_dim tensors along split_dim dimension.
int64_t split_dim = x_dims[axis];
// Treat outs[i] as [out_row, out_col], and x as [out_row, split_dim,
// out_col].
int64_t out_row = 1;
for (int i = 0; i < axis; ++i) {
out_row *= x_dims[i];
}
int64_t out_col = x.numel() / (split_dim * out_row);
if (x.numel() < std::numeric_limits<int32_t>::max()) {
switch (CalcArraySize(split_dim)) {
SEGMENTED_ARRAY_KERNEL_HELPER(
LaunchUnStackKernel<Context, T, int32_t, kArraySize>(
ctx, out_row, split_dim, out_col, split_dim, x, outs));
}
} else {
switch (CalcArraySize(split_dim)) {
SEGMENTED_ARRAY_KERNEL_HELPER(
LaunchUnStackKernel<Context, T, int64_t, kArraySize>(
ctx, out_row, split_dim, out_col, split_dim, x, outs));
}
}
}
} // namespace funcs
} // namespace phi
......@@ -13,125 +13,13 @@
// limitations under the License.
#include "paddle/phi/kernels/stack_grad_kernel.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/segmented_array.h"
#include "paddle/phi/kernels/funcs/stack_and_unstack.h"
namespace phi {
template <typename T, typename IndexT, typename ArrayT>
__global__ void UnStackCudaKernel(const T* __restrict__ input,
IndexT pre_dim_size,
IndexT split_dim_size,
IndexT suf_dim_size,
IndexT num_split,
ArrayT array) {
assert(blockDim.y == 1);
assert(blockDim.z == 1);
// In this case they are equal
assert(split_dim_size % num_split == 0);
IndexT size = pre_dim_size * split_dim_size * suf_dim_size;
IndexT each_dim_size = split_dim_size / num_split;
for (IndexT offset = blockIdx.x * blockDim.x + threadIdx.x; offset < size;
offset += blockDim.x * gridDim.x) {
IndexT i = offset / (split_dim_size * suf_dim_size);
IndexT j = (offset % (split_dim_size * suf_dim_size)) / suf_dim_size;
IndexT k = offset % suf_dim_size;
T* output = array.data[j / each_dim_size];
if (output == nullptr) {
return;
}
IndexT output_ind = i * each_dim_size * suf_dim_size +
(j % each_dim_size) * suf_dim_size + k;
*(output + output_ind) = input[offset];
}
}
template <typename T, typename IndexT, typename ArrayT>
__global__ void UnStackCudaKernelForLastDim(const T* __restrict__ in_data,
const IndexT cols,
const IndexT rows,
const IndexT tile_x_num,
ArrayT array) {
constexpr int buffer_size = 512;
__shared__ T s_buf[buffer_size];
for (IndexT tile_x = blockIdx.x; tile_x < tile_x_num; tile_x += gridDim.x) {
IndexT row_idx = tile_x * blockDim.x + threadIdx.x;
IndexT col_idx = blockIdx.y * blockDim.y + threadIdx.y;
int s_idx = threadIdx.y * blockDim.x + threadIdx.x;
bool is_valid = (col_idx < cols && row_idx < rows);
if (is_valid) {
T data = in_data[row_idx * cols + col_idx];
s_buf[s_idx] = data;
}
__syncthreads();
if (is_valid) {
if (array.data[col_idx]) {
array.data[col_idx][row_idx] = s_buf[s_idx];
}
}
}
}
template <typename Context,
typename T,
typename IndexT,
funcs::SegmentedArraySize Size>
void LaunchUnStackKernel(const Context& ctx,
const IndexT pre_dim,
const IndexT split_dim,
const IndexT suf_dim,
const IndexT num_splits,
const DenseTensor& out_grad,
std::vector<DenseTensor*>* x_grad) {
// each x_grad should have same shape
auto dout_ptr = out_grad.data<T>();
funcs::PointerArraySetter<Context, T, Size> setter(ctx, x_grad);
if (suf_dim == 1) {
// For the case axis == (out_grad.dims().size() - 1)
constexpr int kThreads = 512;
constexpr int kWarpSize = 32;
constexpr int kMaxOut = 16;
int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1;
if (split_dim < kMaxOut) {
tid_y = split_dim;
tid_x =
std::min(backends::gpu::RoundToNextHighPowOfTwo(pre_dim, kWarpSize),
kThreads / backends::gpu::RoundToNextHighPowOfTwo(tid_y));
} else {
tid_y = kMaxOut;
tid_x = kWarpSize;
bid_y = backends::gpu::DivUp<int>(split_dim, kMaxOut);
}
int tile_x_num = backends::gpu::DivUp<int>(pre_dim, tid_x);
bid_x = std::min(tile_x_num, backends::gpu::kMultiDimslimit);
dim3 blocks(tid_x, tid_y, 1);
dim3 grids(bid_x, bid_y, 1);
UnStackCudaKernelForLastDim<T, IndexT, decltype(setter.array)>
<<<grids, blocks, 0, ctx.stream()>>>(
dout_ptr, split_dim, pre_dim, tile_x_num, setter.array);
} else {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
ctx, pre_dim * split_dim * suf_dim);
UnStackCudaKernel<T, IndexT, decltype(setter.array)>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
ctx.stream()>>>(
dout_ptr, pre_dim, split_dim, suf_dim, num_splits, setter.array);
}
}
template <typename T, typename Context>
void StackGradKernel(const Context& ctx,
const DenseTensor& out_grad,
......@@ -144,41 +32,12 @@ void StackGradKernel(const Context& ctx,
split_dim,
x_grad.size(),
phi::errors::InvalidArgument(
"Output x_grad size should be equal to the split_dim, but"
" received split_dim is:%d x_grad size is:%d.",
"Output x_grad's size should be equal to the split_dim, but"
" received split_dim is:%d x_grad's size is:%d.",
split_dim,
x_grad.size()));
auto dout_dims = out_grad.dims();
int64_t dout_pre = 1;
for (int i = 0; i < axis; ++i) {
dout_pre *= dout_dims[i];
}
int64_t dout_suf = out_grad.numel() / (split_dim * dout_pre);
if (out_grad.numel() < std::numeric_limits<int32_t>::max()) {
switch (funcs::CalcArraySize(split_dim)) {
SEGMENTED_ARRAY_KERNEL_HELPER(
LaunchUnStackKernel<Context, T, int32_t, kArraySize>(ctx,
dout_pre,
split_dim,
dout_suf,
split_dim,
out_grad,
&x_grad));
}
} else {
switch (funcs::CalcArraySize(split_dim)) {
SEGMENTED_ARRAY_KERNEL_HELPER(
LaunchUnStackKernel<Context, T, int64_t, kArraySize>(ctx,
dout_pre,
split_dim,
dout_suf,
split_dim,
out_grad,
&x_grad));
}
}
funcs::UnStackRawKernel<T, Context>(ctx, out_grad, axis, &x_grad);
}
} // namespace phi
......
......@@ -13,89 +13,19 @@
// limitations under the License.
#include "paddle/phi/kernels/stack_kernel.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/segmented_array.h"
#include "paddle/phi/kernels/funcs/stack_and_unstack.h"
namespace phi {
template <typename T, typename IndexT, typename ArrayT>
__global__ void StackCUDAKernel(ArrayT array,
funcs::GeneralDivMod<IndexT> divmoder,
IndexT split_size,
IndexT rows,
IndexT cols,
T* __restrict__ output) {
IndexT grid_x = static_cast<IndexT>(blockIdx.x) * blockDim.x + threadIdx.x;
IndexT grid_x_stride = static_cast<IndexT>(blockDim.x) * gridDim.x;
IndexT grid_y_stride = static_cast<IndexT>(blockDim.y) * gridDim.y;
for (; grid_x < cols; grid_x += grid_x_stride) {
IndexT grid_y = static_cast<IndexT>(blockIdx.y) * blockDim.y + threadIdx.y;
auto divmod_rslt = divmoder.div_mod(grid_x);
IndexT split = divmod_rslt[0]; // grid_x / split_size
IndexT col_offset = divmod_rslt[1]; // grid_x % split_size
const T* input_ptr = array.data[split];
#pragma unroll
for (; grid_y < rows; grid_y += grid_y_stride) {
output[grid_y * cols + grid_x] =
input_ptr[grid_y * split_size + col_offset];
}
}
}
template <typename Context,
typename T,
typename IndexT,
funcs::SegmentedArraySize Size>
void LaunchStackKernel(const Context& ctx,
const IndexT x_col,
const IndexT x_row,
const IndexT out_col,
const std::vector<const DenseTensor*>& x,
DenseTensor* out) {
T* out_ptr = ctx.template Alloc<T>(out);
auto config = phi::backends::gpu::GetGpuLaunchConfig2D(ctx, out_col, x_row);
funcs::ConstPointerArraySetter<Context, T, Size> setter(ctx, x);
funcs::GeneralDivMod<IndexT> divmoder(x_col);
StackCUDAKernel<T, IndexT, decltype(setter.array)>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
setter.array, divmoder, x_col, x_row, out_col, out_ptr);
}
template <typename T, typename Context>
void StackKernel(const Context& ctx,
const std::vector<const DenseTensor*>& x,
int axis,
DenseTensor* out) {
if (axis < 0) axis += (x[0]->dims().size() + 1);
int num = static_cast<int>(x.size());
// Split x dim from axis to matrix
int64_t x_row = 1;
for (int i = 0; i < axis; ++i) {
x_row *= x[0]->dims()[i];
}
int64_t x_col = x[0]->numel() / x_row;
int64_t out_col = x_col * num;
if (out->numel() < std::numeric_limits<int32_t>::max()) {
switch (funcs::CalcArraySize(num)) {
SEGMENTED_ARRAY_KERNEL_HELPER(
LaunchStackKernel<Context, T, int32_t, kArraySize>(
ctx, x_col, x_row, out_col, x, out));
}
} else {
switch (funcs::CalcArraySize(num)) {
SEGMENTED_ARRAY_KERNEL_HELPER(
LaunchStackKernel<Context, T, int64_t, kArraySize>(
ctx, x_col, x_row, out_col, x, out));
}
}
funcs::StackRawKernel<T, Context>(ctx, x, axis, out);
}
} // namespace phi
......
......@@ -16,7 +16,19 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unstack_grad_kernel_impl.h"
#include "paddle/phi/kernels/funcs/stack_and_unstack.h"
namespace phi {
template <typename T, typename Context>
void UnStackGradKernel(const Context& ctx,
const std::vector<const DenseTensor*>& out_grad,
int axis,
DenseTensor* x_grad) {
funcs::StackRawKernel<T, Context>(ctx, out_grad, axis, x_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(unstack_grad,
GPU,
......@@ -26,4 +38,5 @@ PD_REGISTER_KERNEL(unstack_grad,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -16,7 +16,33 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unstack_kernel_impl.h"
#include "paddle/phi/kernels/funcs/stack_and_unstack.h"
namespace phi {
template <typename T, typename Context>
void UnStackKernel(const Context& ctx,
const DenseTensor& x,
int axis,
int num,
std::vector<DenseTensor*> outs) {
if (x.numel() == 0) return;
if (axis < 0) axis += x.dims().size();
int64_t split_dim = x.dims()[axis];
PADDLE_ENFORCE_EQ(
split_dim,
outs.size(),
phi::errors::InvalidArgument(
"Output outs's size should be equal to the split_dim, but"
" received split_dim is:%d outs's size is:%d.",
split_dim,
outs.size()));
funcs::UnStackRawKernel<T, Context>(ctx, x, axis, &outs);
}
} // namespace phi
PD_REGISTER_KERNEL(unstack,
GPU,
......@@ -26,4 +52,5 @@ PD_REGISTER_KERNEL(unstack,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -20,7 +20,7 @@ namespace phi {
template <typename T, typename Context>
void StackGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
int axis,
std::vector<DenseTensor*> x_grad);
......
......@@ -20,7 +20,7 @@ namespace phi {
template <typename T, typename Context>
void UnStackGradKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<const DenseTensor*>& out_grad,
int axis,
DenseTensor* x_grad);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册