提交 ac26bdce 编写于 作者: M Megvii Engine Team

fix(cuda): fix direct conv speed and memory problem

GitOrigin-RevId: 6faeeff3b80b9cd2245268bfaa9c017b1d3bac58
上级 f7994683
......@@ -142,7 +142,7 @@ struct ConvTraitInner {
}
#define CHECK_AB_BWD(a, b) \
if (param.out_w > b * 4) { \
if (param.out_w > b * 4 || b == 3) { \
using FilterTileConfig_ = FilterTileConfig<unroll_fh, a + 2>; \
using ThreadConfig_ = ThreadConfig<4, 32>; \
using OutTileConfig_ = OutTileConfig<ThreadConfig_, unroll_oh, b + 1>; \
......@@ -165,11 +165,9 @@ struct ConvTraitInner {
return true; \
}
#define CHECK_A(a, cb) \
if (param.flt_w > a * 4) { \
CHECK_AB_##cb( \
a, \
15) else CHECK_AB_##cb(a, 14) else CHECK_AB_##cb(a, 13) else CHECK_AB_##cb(a, 12) else CHECK_AB_##cb(a, 11) else CHECK_AB_##cb(a, 10) else CHECK_AB_##cb(a, 9) else CHECK_AB_##cb(a, 8) else CHECK_AB_##cb(a, 7) else CHECK_AB_##cb(a, 6) else CHECK_AB_##cb(a, 5) else CHECK_AB_##cb(a, 4) else CHECK_AB_##cb(a, 3) else CHECK_AB_##cb(a, 2) else CHECK_AB_##cb(a, 1) else CHECK_AB_##cb(a, 0) \
#define CHECK_A(a, cb) \
if (param.flt_w > a * 4) { \
CHECK_AB_##cb(a, 15) else CHECK_AB_##cb(a, 7) else CHECK_AB_##cb(a, 3) \
}
#define CHECK(cb) \
......
......@@ -217,7 +217,7 @@ __device__ __forceinline__ void Global2SharedMem<
// Backprop input direction is the same as forward direction with the filter
// rotated by 180°.
#if CUDA_VERSION >= 9000
template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride>
__global__ void DepthwiseConv2dGPUKernelNCHW(
const Param param, const __half* input, const __half* filter, __half* output) {
using T = __half;
......@@ -230,7 +230,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
......@@ -243,8 +243,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
constexpr int stride_h = is_fwd ? stride : 1;
constexpr int stride_w = is_fwd ? stride : 1;
int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
......@@ -385,7 +385,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
}
}
template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride>
__global__ void DepthwiseConv2dGPUKernelNCHWC32(
const Param param, const __half* input, const __half* filter, __half* output) {
using T = __half;
......@@ -398,7 +398,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
......@@ -411,8 +411,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
constexpr int stride_h = is_fwd ? stride : 1;
constexpr int stride_w = is_fwd ? stride : 1;
int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
......@@ -555,7 +555,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
}
#endif
template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride>
__global__ void DepthwiseConv2dGPUKernelNCHW(
const Param param, const float* input, const float* filter, float* output) {
using T = float;
......@@ -568,7 +568,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
......@@ -577,8 +577,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
constexpr int stride_h = is_fwd ? stride : 1;
constexpr int stride_w = is_fwd ? stride : 1;
int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
......@@ -703,7 +703,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
}
}
template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride>
__global__ void DepthwiseConv2dGPUKernelNCHWC32(
const Param param, const float* input, const float* filter, float* output) {
using T = float;
......@@ -716,7 +716,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
......@@ -725,8 +725,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
constexpr int stride_h = is_fwd ? stride : 1;
constexpr int stride_w = is_fwd ? stride : 1;
int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
......@@ -879,16 +879,16 @@ void LaunchDepthwiseConv2dGPU(
void (*kernel)(const Param, const T*, const T*, T*);
if (param.is_compute_deafult) {
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection>;
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection, stride>;
} else {
kernel = DepthwiseConv2dGPUKernelNCHWC32<IConvTrait, kDirection>;
kernel = DepthwiseConv2dGPUKernelNCHWC32<IConvTrait, kDirection, stride>;
}
kernel<<<grid, block, shared_storage, stream>>>(param, input, filter, output);
after_kernel_launch();
}
#define INSTANCE_AB(type1, type2, a, b, direction) \
if (param.out_w > b * 4) { \
if (param.out_w > b * 4 || b == 3) { \
if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \
(param.stride_h == 1 && param.stride_w == 1)) { \
LaunchDepthwiseConv2dGPU<type1, type2, direction, a + 2, b + 1, 1>( \
......@@ -899,12 +899,11 @@ void LaunchDepthwiseConv2dGPU(
} \
}
#define INSTANCE_A(type1, type2, a, direction) \
if (param.flt_w > a * 4) { \
INSTANCE_AB(type1, type2, a, 15, direction) \
else INSTANCE_AB(type1, type2, a, 14, direction) else INSTANCE_AB(type1, type2, a, 13, direction) else INSTANCE_AB(type1, type2, a, 12, direction) else INSTANCE_AB(type1, type2, a, 11, direction) else INSTANCE_AB(type1, type2, a, 10, direction) else INSTANCE_AB( \
type1, type2, \
a, 9, direction) else INSTANCE_AB(type1, type2, a, 8, direction) else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB(type1, type2, a, 6, direction) else INSTANCE_AB(type1, type2, a, 5, direction) else INSTANCE_AB(type1, type2, a, 4, direction) else INSTANCE_AB(type1, type2, a, 3, direction) else INSTANCE_AB(type1, type2, a, 2, direction) else INSTANCE_AB(type1, type2, a, 1, direction) else INSTANCE_AB(type1, type2, a, 0, direction) \
#define INSTANCE_A(type1, type2, a, direction) \
if (param.flt_w > a * 4) { \
INSTANCE_AB(type1, type2, a, 15, direction) \
else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB( \
type1, type2, a, 3, direction) \
}
#define INSTANCE(type1, type2, direction) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册