提交 91c3d5fe 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(dnn/cuda): add sharemem to prefer and refactor code

GitOrigin-RevId: 0dba39f56ba0bf61ecb591f866641e7f51937ab9
上级 abf82cfb
/**
* \file dnn/src/cuda/group_local/cuda_interface.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./cuda_interface.h"
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
// src layout is (N, G, IC, IH, IW)
// filter layout is (G, OH, OW, IC, FH, FW, OC)
// dst layout is (N, G, OC, OH, OW)
// NR_THREADS is 256
// gridDim.z is G
// gridDim.y is OC*OH*OW/NR_THREADS
// gridDim.x is N/NB
// blockDim.x is NR_THREADS
// INs and ONs are the stride on the src/dst batch size dim
// IC and OC are nr. channels per group
// Each thread tackles with NB (actually NB_cur if non-multiple-of-NB N is considered).
// Let oid = blockIdx.y*NR_THREADS + threadIdx.x (global thread ID along block
// axis y), and we flatten (OC, OH, OW) into one dimension, then each thread
// calculates the answer at dst position (n, blockIdx.z, oid), where n ranges
// from blockDim.x*NB + 0 to blockDim.x*NB + (NB-1).
// IC is processed at stride of ICB. On entrance of each iteration of the loop,
// NB * ICB spatial src planes are loaded into shared memory (presumably src
// spatial size is small).
template <uint32_t NB, uint32_t ICB, bool is_xcorr>
__global__ void forward_kernel(const float * __restrict__ src,
const float * __restrict__ filter,
float * __restrict__ dst,
uint32_t N,
uint32_t IC, uint32_t IH, uint32_t IW,
uint32_t OC, uint32_t OH, uint32_t OW,
uint32_t FH, uint32_t FW,
uint32_t G,
uint32_t INs, uint32_t ONs,
uint32_t PH, uint32_t PW,
uint32_t SH, uint32_t SW)
{
// NB * ICB * sizeof(float) * IH * IW
extern __shared__ float shared_mem[];
float *src_cache = shared_mem;
uint32_t tid = threadIdx.x;
uint32_t tstride = blockDim.x;
uint32_t oid = tid + blockIdx.y * tstride;
src += blockIdx.x*NB * INs + blockIdx.z*IC*IH*IW;
dst += blockIdx.x*NB * ONs + blockIdx.z*OC*OH*OW;
filter += blockIdx.z*OH*OW*IC*FH*FW*OC;
uint32_t op = oid / OC;
uint32_t oc = oid % OC;
uint32_t oh = op / OW;
uint32_t ow = op % OW;
float dst_reg[NB];
for (uint32_t nb = 0; nb < NB; ++nb) dst_reg[nb] = 0.0f;
uint32_t NB_cur = min(N-blockIdx.x*NB, NB);
for (uint32_t ic = 0; ic < IC; ic += ICB) {
// read ICB-channel src
// (NB, ICB, IHs, IWs)
uint32_t ICB_cur = min(ICB, IC-ic);
for (uint32_t i = tid; i < NB_cur*ICB*IH*IW; i += tstride) {
uint32_t ip = i % (IH*IW);
uint32_t icb = i / (IH*IW) % ICB;
uint32_t nb = i / (IH*IW) / ICB;
src_cache[i] =
(icb < ICB_cur) * src[nb*INs + min(IC-1, (ic+icb))*IH*IW + ip];
}
__syncthreads();
if (oid < OC*OH*OW)
for (uint32_t fh = 0; fh < FH; ++fh)
{
uint32_t ih;
if (is_xcorr) ih = oh*SH + fh - PH; else ih = oh*SH + (FH-fh-1) - PH;
if (ih < IH)
for (uint32_t fw = 0; fw < FW; ++fw)
{
uint32_t iw;
if (is_xcorr) iw = ow*SW + fw - PW; else iw = ow*SW + (FW-fw-1) - PW;
if (iw < IW)
for (uint32_t icb = 0; icb < ICB_cur; ++icb) {
uint32_t fid = op*IC*FH*FW*OC + (ic+icb)*FH*FW*OC +
fh*FW*OC + fw*OC + oc;
float fval = filter[fid];
float src_reg[NB];
#pragma unroll
for (uint32_t nb = 0; nb < NB; ++nb) {
src_reg[nb] = src_cache[nb*ICB*IH*IW + icb*IH*IW + ih*IW + iw];
}
#pragma unroll
for (uint32_t nb = 0; nb < NB; ++nb) {
dst_reg[nb] += src_reg[nb]*fval;
}
}
}
}
__syncthreads();
}
if (oid < OC*OH*OW) {
for (uint32_t nb = 0; nb < NB_cur; ++nb) {
dst[nb*ONs + oc*OH*OW + op] = dst_reg[nb];
}
}
}
void run_inference_kernel(const float *src, const float *filter, float *dst,
float *wptr,
uint32_t N, uint32_t IC, uint32_t IH, uint32_t IW,
uint32_t OC, uint32_t OH, uint32_t OW,
uint32_t FH, uint32_t FW,
uint32_t G,
uint32_t PH, uint32_t PW,
uint32_t SH, uint32_t SW,
cudaStream_t stream)
{
MEGDNN_MARK_USED_VAR(wptr);
size_t threads = 256;
const size_t NB = 4, ICB = 4;
dim3 blocks = dim3(DIVUP(N, NB), DIVUP(OC*OH*OW, threads), G);
uint32_t INs = G*IC*IH*IW, ONs = G*OC*OH*OW;
forward_kernel<NB, ICB, true><<<blocks, threads,
NB*ICB*sizeof(float)*IH*IW, stream>>>(src, filter, dst,
N,
IC, IH, IW,
OC, OH, OW,
FH, FW,
G,
INs, ONs,
PH, PW,
SH, SW);
after_kernel_launch();
}
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/group_local/forward/kern.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/group_local/forward/kern.cuh"
#include "src/cuda/utils.cuh"
using namespace megdnn;
using namespace cuda;
namespace {
constexpr size_t NB = 4, ICB = 4;
// src layout is (N, G, IC, IH, IW)
// filter layout is (G, OH, OW, IC, FH, FW, OC)
// dst layout is (N, G, OC, OH, OW)
// NR_THREADS is 256
// gridDim.z is G
// gridDim.y is OC*OH*OW/NR_THREADS
// gridDim.x is N/NB
// blockDim.x is NR_THREADS
// INs and ONs are the stride on the src/dst batch size dim
// IC and OC are nr. channels per group
// Each thread tackles with NB (actually NB_cur if non-multiple-of-NB N is
// considered). Let oid = blockIdx.y*NR_THREADS + threadIdx.x (global thread ID
// along block axis y), and we flatten (OC, OH, OW) into one dimension, then
// each thread calculates the answer at dst position (n, blockIdx.z, oid), where
// n ranges from blockDim.x*NB + 0 to blockDim.x*NB + (NB-1). IC is processed at
// stride of ICB. On entrance of each iteration of the loop, NB * ICB spatial
// src planes are loaded into shared memory (presumably src spatial size is
// small).
template <uint32_t NB, uint32_t ICB, bool is_xcorr>
__global__ void forward_kernel(const float* __restrict__ src,
const float* __restrict__ filter,
float* __restrict__ dst, uint32_t N, uint32_t IC,
uint32_t IH, uint32_t IW, uint32_t OC,
uint32_t OH, uint32_t OW, uint32_t FH,
uint32_t FW, uint32_t INs, uint32_t ONs,
uint32_t PH, uint32_t PW, uint32_t SH,
uint32_t SW) {
// NB * ICB * sizeof(float) * IH * IW
extern __shared__ float shared_mem[];
float* src_cache = shared_mem;
uint32_t tid = threadIdx.x;
uint32_t tstride = blockDim.x;
uint32_t oid = tid + blockIdx.y * tstride;
src += blockIdx.x * NB * INs + blockIdx.z * IC * IH * IW;
dst += blockIdx.x * NB * ONs + blockIdx.z * OC * OH * OW;
filter += blockIdx.z * OH * OW * IC * FH * FW * OC;
uint32_t op = oid / OC;
uint32_t oc = oid % OC;
uint32_t oh = op / OW;
uint32_t ow = op % OW;
float dst_reg[NB];
for (uint32_t nb = 0; nb < NB; ++nb)
dst_reg[nb] = 0.0f;
uint32_t NB_cur = min(N - blockIdx.x * NB, NB);
for (uint32_t ic = 0; ic < IC; ic += ICB) {
// read ICB-channel src
// (NB, ICB, IHs, IWs)
uint32_t ICB_cur = min(ICB, IC - ic);
for (uint32_t i = tid; i < NB_cur * ICB * IH * IW; i += tstride) {
uint32_t ip = i % (IH * IW);
uint32_t icb = i / (IH * IW) % ICB;
uint32_t nb = i / (IH * IW) / ICB;
src_cache[i] =
(icb < ICB_cur) *
src[nb * INs + min(IC - 1, (ic + icb)) * IH * IW + ip];
}
__syncthreads();
if (oid < OC * OH * OW)
for (uint32_t fh = 0; fh < FH; ++fh) {
uint32_t ih;
if (is_xcorr)
ih = oh * SH + fh - PH;
else
ih = oh * SH + (FH - fh - 1) - PH;
if (ih < IH)
for (uint32_t fw = 0; fw < FW; ++fw) {
uint32_t iw;
if (is_xcorr)
iw = ow * SW + fw - PW;
else
iw = ow * SW + (FW - fw - 1) - PW;
if (iw < IW)
for (uint32_t icb = 0; icb < ICB_cur; ++icb) {
uint32_t fid = op * IC * FH * FW * OC +
(ic + icb) * FH * FW * OC +
fh * FW * OC + fw * OC + oc;
float fval = filter[fid];
float src_reg[NB];
#pragma unroll
for (uint32_t nb = 0; nb < NB; ++nb) {
src_reg[nb] = src_cache[nb * ICB * IH * IW +
icb * IH * IW +
ih * IW + iw];
}
#pragma unroll
for (uint32_t nb = 0; nb < NB; ++nb) {
dst_reg[nb] += src_reg[nb] * fval;
}
}
}
}
__syncthreads();
}
if (oid < OC * OH * OW) {
for (uint32_t nb = 0; nb < NB_cur; ++nb) {
dst[nb * ONs + oc * OH * OW + op] = dst_reg[nb];
}
}
}
}
void group_local::exec(const float* src, const float* filter, float* dst,
float* wptr, uint32_t N, uint32_t IC, uint32_t IH,
uint32_t IW, uint32_t OC, uint32_t OH, uint32_t OW,
uint32_t FH, uint32_t FW, uint32_t G, uint32_t PH,
uint32_t PW, uint32_t SH, uint32_t SW,
cudaStream_t stream) {
MEGDNN_MARK_USED_VAR(wptr);
size_t threads = 256;
dim3 blocks = dim3(DIVUP(N, NB), DIVUP(OC * OH * OW, threads), G);
uint32_t INs = G * IC * IH * IW, ONs = G * OC * OH * OW;
forward_kernel<NB, ICB, true>
<<<blocks, threads, NB * ICB * sizeof(float) * IH * IW, stream>>>(
src, filter, dst, N, IC, IH, IW, OC, OH, OW, FH, FW, INs,
ONs, PH, PW, SH, SW);
after_kernel_launch();
}
size_t group_local::get_share_mem_in_bytes(uint32_t IH, uint32_t IW) {
return NB * ICB * sizeof(float) * IH * IW;
}
......@@ -14,8 +14,9 @@
namespace megdnn {
namespace cuda {
namespace group_local {
void run_inference_kernel(const float *src, const float *filter, float *dst,
void exec(const float *src, const float *filter, float *dst,
float *wptr,
uint32_t N, uint32_t IC, uint32_t IH, uint32_t IW,
uint32_t OC, uint32_t OH, uint32_t OW,
......@@ -25,6 +26,10 @@ void run_inference_kernel(const float *src, const float *filter, float *dst,
uint32_t SH, uint32_t SW,
cudaStream_t stream);
size_t get_share_mem_in_bytes(uint32_t IH, uint32_t IW);
} // namespace group_local
} // namespace cuda
} // namespace megdnn
......
/**
* \file dnn/src/cuda/group_local/fwd.cpp
* \file dnn/src/cuda/group_local/forward/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -14,7 +14,7 @@
#include "src/cuda/local/local.cuh"
#include "src/cuda/utils.h"
#include "./cuda_interface.h"
#include "src/cuda/group_local/forward/kern.cuh"
namespace megdnn {
namespace cuda {
......@@ -46,7 +46,7 @@ void GroupLocalForwardImpl::exec(_megdnn_tensor_in src,
auto one = handle->one_device();
auto zero = handle->zero_device();
if (prefer_inference_kernel(src.layout, filter.layout, dst.layout)) {
run_inference_kernel(sptr, fptr, dptr, wptr,
group_local::exec(sptr, fptr, dptr, wptr,
N, IC, IH, IW,
OC, OH, OW,
FH, FW,
......@@ -141,11 +141,14 @@ bool GroupLocalForwardImpl::prefer_inference_kernel(const TensorLayout &src,
const TensorLayout &filter,
const TensorLayout &dst)
{
megdnn_ignore(filter);
megdnn_ignore(dst);
return src.shape[0] <= 8;
MEGDNN_MARK_USED_VAR(filter);
MEGDNN_MARK_USED_VAR(dst);
auto handle = concrete_handle(this->handle());
size_t N = src.shape[0], IH = src.shape[2], IW = src.shape[3];
return N <= 8 && handle->device_prop().sharedMemPerBlock >=
group_local::get_share_mem_in_bytes(IH, IW);
}
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册