未验证 提交 5ac8c040 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle Inference] Fix transfer_layout when input size if too big (#53881)

* fix transfer_layout when input size if too big
* do not add TransferLayoutKernelGPU
* add int64 and add check
上级 934d8b89
......@@ -31,8 +31,12 @@ namespace funcs {
// Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
// reserved. SPDX-License-Identifier: BSD-3-Clause
template <typename T>
__global__ void batch_transpose_kernel(
T* output, const T* input, const int batch, const int M, const int N) {
__global__ void batch_transpose_kernel(T* output,
const T* input,
const int batch,
const int M,
const int N,
int swizzle) {
const int num = M * N;
// "+1" to avoid smem bank conflict
__shared__ T shbuf[32 * (32 + 1)];
......@@ -40,8 +44,8 @@ __global__ void batch_transpose_kernel(
const int32_t wid = tid / 32;
const int32_t lid = tid % 32;
const int32_t batch_i = blockIdx.z;
const int32_t mi0 = blockIdx.y * 32;
const int32_t ni0 = blockIdx.x * 32;
const int32_t mi0 = (blockIdx.y * swizzle + blockIdx.x % swizzle) * 32;
const int32_t ni0 = blockIdx.x / swizzle * 32;
const size_t input_idx = batch_i * num + (mi0 + wid) * N + ni0;
const T* A = input + input_idx;
......@@ -87,19 +91,55 @@ __global__ void batch_transpose_kernel(
}
template <typename T>
void BatchTranspose(T* output, const T* input, int batch, int m, int n) {
dim3 grid((n + 31) / 32, (m + 31) / 32, batch);
void BatchTranspose(T* output,
const T* input,
int64_t batch,
int64_t m,
int64_t n,
const phi::GPUContext* dev_ctx) {
int64_t device_id = dev_ctx->GetPlace().GetDeviceId();
const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id);
int max_grid_y = prop.maxGridSize[1];
int64_t input_num = batch * m * n;
if (input_num >= std::numeric_limits<int>::max()) {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported input size, batch: %ld,m: %ld, n: %ld", batch, m, n));
}
dim3 logical_grid((n + 31) / 32, (m + 31) / 32, batch);
dim3 block(32, 8);
batch_transpose_kernel<<<grid, block>>>(output, input, batch, m, n);
// we set swizzle to 2 default.
int swizzle = (logical_grid.y + max_grid_y - 1) / max_grid_y;
swizzle = std::max(swizzle, 2);
dim3 physical_grid(logical_grid.x * swizzle,
(logical_grid.y + swizzle - 1) / swizzle,
batch);
batch_transpose_kernel<<<physical_grid, block>>>(
output, input, batch, m, n, swizzle);
}
using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
template void BatchTranspose(
float16* output, const float16* input, int batch, int m, int n);
template void BatchTranspose(
float* output, const float* input, int batch, int m, int n);
template void BatchTranspose(float16* output,
const float16* input,
int64_t batch,
int64_t m,
int64_t n,
const phi::GPUContext* dev_ctx);
template void BatchTranspose(float* output,
const float* input,
int64_t batch,
int64_t m,
int64_t n,
const phi::GPUContext* dev_ctx);
template void BatchTranspose(bfloat16* output,
const bfloat16* input,
int64_t batch,
int64_t m,
int64_t n,
const phi::GPUContext* dev_ctx);
template struct SetConstant<phi::GPUContext, float16>;
template struct SetConstant<phi::GPUContext, bfloat16>;
......
......@@ -26,7 +26,12 @@ namespace phi {
namespace funcs {
template <typename T>
void BatchTranspose(T* output, const T* input, int batch, int m, int n);
void BatchTranspose(T* output,
const T* input,
int64_t batch,
int64_t m,
int64_t n,
const phi::GPUContext* dev_ctx);
template <typename DeviceContext, typename T>
struct TransposeNormal {
......
......@@ -71,15 +71,16 @@ void TransferLayoutGeneral(const Context& dev_ctx,
out->Resize(phi::make_ddim(dst_dim));
dev_ctx.Alloc(out, x.dtype());
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// In GPU fp16 model, we will insert many transfer_layout ops in
// conv2d_fusion_layout_transfer_pass, so we optimize this kernel on GPU
if (std::is_same<Context, phi::GPUContext>::value) {
std::vector<int> axis_nchw_nhwc = {0, 2, 3, 1};
std::vector<int> axis_nhwc_nchw = {0, 3, 1, 2};
const int batch = src_dim[0];
int row_len = src_dim[1];
int col_len = src_dim[2] * src_dim[3];
auto* gpu_ctx = reinterpret_cast<const phi::GPUContext*>(&dev_ctx);
const int64_t batch = src_dim[0];
int64_t row_len = src_dim[1];
int64_t col_len = src_dim[2] * src_dim[3];
if (axis == axis_nhwc_nchw) {
row_len = src_dim[1] * src_dim[2];
col_len = src_dim[3];
......@@ -89,15 +90,28 @@ void TransferLayoutGeneral(const Context& dev_ctx,
x.data<phi::dtype::float16>(),
batch,
row_len,
col_len);
col_len,
gpu_ctx);
return;
} else if (x.dtype() == phi::DataType::FLOAT32) {
funcs::BatchTranspose(
out->data<float>(), x.data<float>(), batch, row_len, col_len);
funcs::BatchTranspose(out->data<float>(),
x.data<float>(),
batch,
row_len,
col_len,
gpu_ctx);
return;
} else if (x.dtype() == phi::DataType::BFLOAT16) {
funcs::BatchTranspose(out->data<phi::dtype::bfloat16>(),
x.data<phi::dtype::bfloat16>(),
batch,
row_len,
col_len,
gpu_ctx);
return;
}
}
#endif
PD_VISIT_ALL_TYPES(x.dtype(), "CastDataLayout", ([&] {
CastDataLayout<data_t, Context>(dev_ctx, x, axis, out);
}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册