From 17879045f17c762592efaf29d09b565422d2b130 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Wed, 7 Dec 2022 13:09:50 +0800 Subject: [PATCH] optimize nchw<->nhwc kernel in fp16 model (#48692) --- paddle/phi/kernels/funcs/math_function.cu | 81 ++++++++++++++++++-- paddle/phi/kernels/funcs/math_function.h | 3 + paddle/phi/kernels/transfer_layout_kernel.cc | 26 +++++++ 3 files changed, 105 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/funcs/math_function.cu b/paddle/phi/kernels/funcs/math_function.cu index a0e59f8f3f..e1ab8922fd 100644 --- a/paddle/phi/kernels/funcs/math_function.cu +++ b/paddle/phi/kernels/funcs/math_function.cu @@ -27,11 +27,83 @@ limitations under the License. */ namespace phi { namespace funcs { +// The following part of the code refers to NVIDIA-cutlass +// https://github.com/NVIDIA/cutlass/blob/master/tools/util/include/cutlass/util/device_nchw_to_nhwc.h +// Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights +// reserved. SPDX-License-Identifier: BSD-3-Clause +template +__global__ void batch_transpose_kernel( + T* output, const T* input, const int batch, const int M, const int N) { + const int num = M * N; + // "+1" to avoid smem bank conflict + __shared__ T shbuf[32 * (32 + 1)]; + const int32_t tid = threadIdx.y * blockDim.x + threadIdx.x; + const int32_t wid = tid / 32; + const int32_t lid = tid % 32; + const int32_t batch_i = blockIdx.z; + const int32_t mi0 = blockIdx.y * 32; + const int32_t ni0 = blockIdx.x * 32; + + const size_t input_idx = batch_i * num + (mi0 + wid) * N + ni0; + const T* A = input + input_idx; + if (ni0 + lid < N) { + const int lid_x_33 = lid * 33; + if ((mi0 + 32) <= M) { + int mi = wid; // between 0 and 7 +#pragma unroll + for (int mLoopIdx = 0; mLoopIdx < 4; mLoopIdx++) { + shbuf[lid_x_33 + mi] = A[lid]; + A = &A[8 * N]; + mi += 8; + } + } else { + for (int mi = wid; mi < 32; mi += 8) { + if ((mi + mi0) < M) { + shbuf[lid_x_33 + mi] = A[lid]; + } + A = &A[8 * N]; + } + } + } + __syncthreads(); + + const int32_t miOut = mi0 + lid; + output = &output[batch_i * num + miOut]; + if (miOut < M) { + if (ni0 + 32 < N) { + int nI = wid; +#pragma unroll + for (int nLoopIdx = 0; nLoopIdx < 4; ++nLoopIdx) { + output[(ni0 + nI) * M] = shbuf[(nI)*33 + lid]; + nI += 8; + } + } else { + for (int nI = wid; nI < 32; nI += 8) { + if (ni0 + nI < N) { + output[(ni0 + nI) * M] = shbuf[(nI)*33 + lid]; + } + } + } + } +} + +template +void BatchTranspose(T* output, const T* input, int batch, int m, int n) { + dim3 grid((n + 31) / 32, (m + 31) / 32, batch); + dim3 block(32, 8); + batch_transpose_kernel<<>>(output, input, batch, m, n); +} + using float16 = phi::dtype::float16; using bfloat16 = phi::dtype::bfloat16; -template struct SetConstant; -template struct SetConstant; +template void BatchTranspose( + float16* output, const float16* input, int batch, int m, int n); +template void BatchTranspose( + float* output, const float* input, int batch, int m, int n); + +template struct SetConstant; +template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; @@ -42,10 +114,9 @@ template struct SetConstant; template struct SetConstant>; template struct SetConstant>; +template struct SetConstant; template struct SetConstant; -template struct SetConstant; + bfloat16>; template struct SetConstant; template struct SetConstant; template struct SetConstant; diff --git a/paddle/phi/kernels/funcs/math_function.h b/paddle/phi/kernels/funcs/math_function.h index 48649a454a..6f1cac4935 100644 --- a/paddle/phi/kernels/funcs/math_function.h +++ b/paddle/phi/kernels/funcs/math_function.h @@ -29,6 +29,9 @@ limitations under the License. */ namespace phi { namespace funcs { +template +void BatchTranspose(T* output, const T* input, int batch, int m, int n); + template struct TransposeNormal { // for dims >= 7 situation diff --git a/paddle/phi/kernels/transfer_layout_kernel.cc b/paddle/phi/kernels/transfer_layout_kernel.cc index d7b8d55707..f2c57150c6 100644 --- a/paddle/phi/kernels/transfer_layout_kernel.cc +++ b/paddle/phi/kernels/transfer_layout_kernel.cc @@ -70,6 +70,32 @@ void TransferLayoutGeneral(const Context& dev_ctx, out->Resize(phi::make_ddim(dst_dim)); dev_ctx.Alloc(out, x.dtype()); + // In GPU fp16 model, we will insert many transfer_layout ops in + // conv2d_fusion_layout_transfer_pass, so we optimize this kernel on GPU + if (std::is_same::value) { + std::vector axis_nchw_nhwc = {0, 2, 3, 1}; + std::vector axis_nhwc_nchw = {0, 3, 1, 2}; + const int batch = src_dim[0]; + int row_len = src_dim[1]; + int col_len = src_dim[2] * src_dim[3]; + if (axis == axis_nhwc_nchw) { + row_len = src_dim[1] * src_dim[2]; + col_len = src_dim[3]; + } + if (x.dtype() == phi::DataType::FLOAT16) { + funcs::BatchTranspose(out->data(), + x.data(), + batch, + row_len, + col_len); + return; + } else if (x.dtype() == phi::DataType::FLOAT32) { + funcs::BatchTranspose( + out->data(), x.data(), batch, row_len, col_len); + return; + } + } + PD_VISIT_ALL_TYPES(x.dtype(), "CastDataLayout", ([&] { CastDataLayout(dev_ctx, x, axis, out); })); -- GitLab