未验证 提交 17879045 编写于 作者: Z zhoutianzi666 提交者: GitHub

optimize nchw<->nhwc kernel in fp16 model (#48692)

上级 e5bc2eec
......@@ -27,11 +27,83 @@ limitations under the License. */
namespace phi {
namespace funcs {
// The following part of the code refers to NVIDIA-cutlass
// https://github.com/NVIDIA/cutlass/blob/master/tools/util/include/cutlass/util/device_nchw_to_nhwc.h
// Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
// reserved. SPDX-License-Identifier: BSD-3-Clause
template <typename T>
__global__ void batch_transpose_kernel(
T* output, const T* input, const int batch, const int M, const int N) {
const int num = M * N;
// "+1" to avoid smem bank conflict
__shared__ T shbuf[32 * (32 + 1)];
const int32_t tid = threadIdx.y * blockDim.x + threadIdx.x;
const int32_t wid = tid / 32;
const int32_t lid = tid % 32;
const int32_t batch_i = blockIdx.z;
const int32_t mi0 = blockIdx.y * 32;
const int32_t ni0 = blockIdx.x * 32;
const size_t input_idx = batch_i * num + (mi0 + wid) * N + ni0;
const T* A = input + input_idx;
if (ni0 + lid < N) {
const int lid_x_33 = lid * 33;
if ((mi0 + 32) <= M) {
int mi = wid; // between 0 and 7
#pragma unroll
for (int mLoopIdx = 0; mLoopIdx < 4; mLoopIdx++) {
shbuf[lid_x_33 + mi] = A[lid];
A = &A[8 * N];
mi += 8;
}
} else {
for (int mi = wid; mi < 32; mi += 8) {
if ((mi + mi0) < M) {
shbuf[lid_x_33 + mi] = A[lid];
}
A = &A[8 * N];
}
}
}
__syncthreads();
const int32_t miOut = mi0 + lid;
output = &output[batch_i * num + miOut];
if (miOut < M) {
if (ni0 + 32 < N) {
int nI = wid;
#pragma unroll
for (int nLoopIdx = 0; nLoopIdx < 4; ++nLoopIdx) {
output[(ni0 + nI) * M] = shbuf[(nI)*33 + lid];
nI += 8;
}
} else {
for (int nI = wid; nI < 32; nI += 8) {
if (ni0 + nI < N) {
output[(ni0 + nI) * M] = shbuf[(nI)*33 + lid];
}
}
}
}
}
template <typename T>
void BatchTranspose(T* output, const T* input, int batch, int m, int n) {
dim3 grid((n + 31) / 32, (m + 31) / 32, batch);
dim3 block(32, 8);
batch_transpose_kernel<<<grid, block>>>(output, input, batch, m, n);
}
using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
template struct SetConstant<phi::GPUContext, phi::dtype::float16>;
template struct SetConstant<phi::GPUContext, phi::dtype::bfloat16>;
template void BatchTranspose(
float16* output, const float16* input, int batch, int m, int n);
template void BatchTranspose(
float* output, const float* input, int batch, int m, int n);
template struct SetConstant<phi::GPUContext, float16>;
template struct SetConstant<phi::GPUContext, bfloat16>;
template struct SetConstant<phi::GPUContext, float>;
template struct SetConstant<phi::GPUContext, double>;
template struct SetConstant<phi::GPUContext, uint8_t>;
......@@ -42,10 +114,9 @@ template struct SetConstant<phi::GPUContext, bool>;
template struct SetConstant<phi::GPUContext, phi::dtype::complex<float>>;
template struct SetConstant<phi::GPUContext, phi::dtype::complex<double>>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
phi::dtype::float16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
phi::dtype::bfloat16>;
bfloat16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, double>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, uint8_t>;
......
......@@ -29,6 +29,9 @@ limitations under the License. */
namespace phi {
namespace funcs {
template <typename T>
void BatchTranspose(T* output, const T* input, int batch, int m, int n);
template <typename DeviceContext, typename T>
struct TransposeNormal {
// for dims >= 7 situation
......
......@@ -70,6 +70,32 @@ void TransferLayoutGeneral(const Context& dev_ctx,
out->Resize(phi::make_ddim(dst_dim));
dev_ctx.Alloc(out, x.dtype());
// In GPU fp16 model, we will insert many transfer_layout ops in
// conv2d_fusion_layout_transfer_pass, so we optimize this kernel on GPU
if (std::is_same<Context, phi::GPUContext>::value) {
std::vector<int> axis_nchw_nhwc = {0, 2, 3, 1};
std::vector<int> axis_nhwc_nchw = {0, 3, 1, 2};
const int batch = src_dim[0];
int row_len = src_dim[1];
int col_len = src_dim[2] * src_dim[3];
if (axis == axis_nhwc_nchw) {
row_len = src_dim[1] * src_dim[2];
col_len = src_dim[3];
}
if (x.dtype() == phi::DataType::FLOAT16) {
funcs::BatchTranspose(out->data<phi::dtype::float16>(),
x.data<phi::dtype::float16>(),
batch,
row_len,
col_len);
return;
} else if (x.dtype() == phi::DataType::FLOAT32) {
funcs::BatchTranspose(
out->data<float>(), x.data<float>(), batch, row_len, col_len);
return;
}
}
PD_VISIT_ALL_TYPES(x.dtype(), "CastDataLayout", ([&] {
CastDataLayout<data_t, Context>(dev_ctx, x, axis, out);
}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册