提交 bc385b53 编写于 作者: M Megvii Engine Team

feat(cuda): support float16 depthwise large kernel conv

GitOrigin-RevId: fdc1b15fbcb3968e695601bff6b6a953bf66f115
上级 7d2063e3
......@@ -98,9 +98,11 @@ struct ConvTrait {
static int const smem_buff_h = FilterTileConfig::unroll_h;
static int const smem_load_h = smem_src_h + smem_buff_h;
static int const smem_h = smem_load_h + smem_buff_h;
static int const smem_w = OutTileConfig::block_w +
FilterTileConfig::unroll_w * ThreadConfig::thread_x -
1;
static int const smem_w =
DIVUP(OutTileConfig::block_w +
FilterTileConfig::unroll_w * ThreadConfig::thread_x - 1,
2) *
2;
static int const smem_size = smem_h * smem_w;
static int const load_w =
smem_w > ThreadConfig::nr_threads ? ThreadConfig::nr_threads : smem_w;
......@@ -266,9 +268,11 @@ __device__ __forceinline__ void Global2SharedMem<
// one each in the lower and upper half of a tile.
// Backprop input direction is the same as forward direction with the filter
// rotated by 180°.
template <typename T, typename ConvTrait, DepthwiseConv2dDirection kDirection>
template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
__global__ void DepthwiseConv2dGPUKernelNCHWSmall(
const Param param, const T* input, const T* filter, T* output) {
const Param param, const __half* input, const __half* filter, __half* output) {
using T = __half;
using T2 = __half2;
using ThreadConfig = typename ConvTrait::ThreadConfig;
using SrcTileConfig = typename ConvTrait::SrcTileConfig;
using FilterTileConfig = typename ConvTrait::FilterTileConfig;
......@@ -282,6 +286,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
const int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2;
const int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2;
const int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2;
extern __shared__ __align__(8) unsigned char smem[];
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
......@@ -315,10 +323,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
__syncthreads();
T reg_src[SrcTileConfig::unroll_h * SrcTileConfig::unroll_w],
reg_flt[FilterTileConfig::unroll_h * FilterTileConfig::unroll_w];
T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w],
reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w];
T sum[OutTileConfig::unroll_size] = {0.0};
T2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}};
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) {
gl2sh_src.copy();
......@@ -326,23 +334,34 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) {
reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
[(off_oh + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w];
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) {
int src_offset = (off_oh + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w * 2;
reg_src[s_h * t2_src_unroll_w + s_w] =
*reinterpret_cast<T2*>(smem_src_ptr + src_offset);
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(fh + f_h) % FilterTileCount::smem_h *
FilterTileCount::smem_w +
f_w];
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) {
int flt_offset =
(fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w +
f_w * 2;
reg_flt[0][f_h * t2_flt_unroll_w + f_w] =
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset);
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = {
f_w > 0 ? reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y
: static_cast<T>(0.0),
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
}
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {
static_cast<T>(0.0), static_cast<T>(0.0)};
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y,
static_cast<T>(0.0)};
}
#pragma unroll
......@@ -350,13 +369,14 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) {
for (int fw = 0; fw < t2_flt_unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] *
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw +
ow];
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2(
reg_flt[ow % 2][inner_fh * t2_flt_unroll_w + fw],
reg_src[(inner_fh + oh) * t2_src_unroll_w + fw +
ow / 2],
sum[oh * t2_out_unroll_w + ow]);
}
}
}
......@@ -387,7 +407,156 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
if (out_w_idx >= param.out_w)
return;
out_base_ptr[out_h_idx * param.out_w + out_w_idx] =
sum[i * OutTileConfig::unroll_w + j];
sum[i * OutTileConfig::unroll_w + j].x +
sum[i * OutTileConfig::unroll_w + j].y;
}
}
}
}
}
template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
__global__ void DepthwiseConv2dGPUKernelNCHWSmall(
const Param param, const float* input, const float* filter, float* output) {
using T = float;
using T2 = float2;
using ThreadConfig = typename ConvTrait::ThreadConfig;
using SrcTileConfig = typename ConvTrait::SrcTileConfig;
using FilterTileConfig = typename ConvTrait::FilterTileConfig;
using OutTileConfig = typename ConvTrait::OutTileConfig;
using SrcTileCount = typename ConvTrait::SrcTileCount;
using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
const int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2;
const int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2;
const int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2;
extern __shared__ __align__(8) unsigned char smem[];
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
out_start_h = off_obh * OutTileConfig::block_h,
out_start_w = off_obw * OutTileConfig::block_w,
src_start_h = out_start_h - param.pad_h,
src_start_w = out_start_w - param.pad_w,
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h;
T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w;
T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w;
T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w;
SrcGlobal2ShareVisitor gl2sh_src(
smem_src, param.src_w, src_start_h, src_start_w, param.src_h, param.src_w);
FilterGlobal2ShareVisitor gl2sh_flt = {
smem_flt, param.flt_w, is_fwd ? 0 : param.flt_h - 2,
0, param.flt_h, param.flt_w};
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w;
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w;
gl2sh_src.first_copy();
gl2sh_flt.first_copy();
__syncthreads();
T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w],
reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w];
T2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}};
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) {
gl2sh_src.copy();
gl2sh_flt.copy();
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) {
int src_offset = (off_oh + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w * 2;
reg_src[s_h * t2_src_unroll_w + s_w] =
*reinterpret_cast<T2*>(smem_src_ptr + src_offset);
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) {
int flt_offset =
(fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w +
f_w * 2;
reg_flt[0][f_h * t2_flt_unroll_w + f_w] =
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset);
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = {
f_w > 0 ? reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y
: static_cast<T>(0.0),
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
}
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {
static_cast<T>(0.0), static_cast<T>(0.0)};
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y,
static_cast<T>(0.0)};
}
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < t2_flt_unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2(
reg_flt[ow % 2][inner_fh * t2_flt_unroll_w + fw],
reg_src[(inner_fh + oh) * t2_src_unroll_w + fw +
ow / 2],
sum[oh * t2_out_unroll_w + ow]);
}
}
}
}
__syncthreads();
gl2sh_src.commit();
gl2sh_flt.commit();
gl2sh_src.iter_forward();
gl2sh_flt.iter_forward();
__syncthreads();
}
for (int o = 0; o < OutTileConfig::unroll_size; ++o) {
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) {
sum[o].x += __shfl_xor(sum[o].x, i, 32);
sum[o].y += __shfl_xor(sum[o].y, i, 32);
}
}
if (threadIdx.x == 0) {
#pragma unroll
for (int i = 0; i < OutTileConfig::unroll_h; ++i) {
int out_h_idx = out_base_h_idx + i;
if (out_h_idx < param.out_h) {
#pragma unroll
for (int j = 0; j < OutTileConfig::unroll_w; ++j) {
int out_w_idx = out_start_w + j;
if (out_w_idx >= param.out_w)
return;
out_base_ptr[out_h_idx * param.out_w + out_w_idx] =
sum[i * OutTileConfig::unroll_w + j].x +
sum[i * OutTileConfig::unroll_w + j].y;
}
}
}
......@@ -419,28 +588,27 @@ void LaunchDepthwiseConv2dGPUSmall(
(SrcTileCount::smem_size + FilterTileCount::smem_size) * sizeof(T);
void (*kernel)(const Param, const T*, const T*, T*);
kernel = DepthwiseConv2dGPUKernelNCHWSmall<T, IConvTrait, kDirection>;
kernel = DepthwiseConv2dGPUKernelNCHWSmall<IConvTrait, kDirection>;
kernel<<<grid, block, shared_storage, stream>>>(param, input, filter, output);
after_kernel_launch();
}
#define INSTANCE_AB(a, b, direction) \
if (param.out_w > b * 4) { \
LaunchDepthwiseConv2dGPUSmall<float, float2, direction, a + 1, b + 1>( \
param, src, flt, dst, stream); \
#define INSTANCE_AB(type1, type2, a, b, direction) \
if (param.out_w > b * 4) { \
LaunchDepthwiseConv2dGPUSmall<type1, type2, direction, a + 2, b + 1>( \
param, src, flt, dst, stream); \
}
#define INSTANCE_A(a, direction) \
if (param.flt_w > 0) { \
INSTANCE_AB(a, 15, direction) \
else INSTANCE_AB(a, 14, direction) else INSTANCE_AB(a, 13, direction) else INSTANCE_AB( \
a, 12, direction) else INSTANCE_AB(a, 11, direction) else INSTANCE_AB(a, 10, direction) else INSTANCE_AB(a, 9, direction) else INSTANCE_AB(a, 8, direction) else INSTANCE_AB(a, 7, direction) else INSTANCE_AB(a, 6, direction) else INSTANCE_AB(a, 5, direction) else INSTANCE_AB(a, 4, direction) else INSTANCE_AB(a, 3, direction) else INSTANCE_AB(a, 2, direction) else INSTANCE_AB(a, 1, direction) else INSTANCE_AB(a, 0, direction) \
#define INSTANCE_A(type1, type2, a, direction) \
if (param.flt_w > a * 4) { \
INSTANCE_AB(type1, type2, a, 15, direction) \
else INSTANCE_AB(type1, type2, a, 14, direction) else INSTANCE_AB(type1, type2, a, 13, direction) else INSTANCE_AB(type1, type2, a, 12, direction) else INSTANCE_AB(type1, type2, a, 11, direction) else INSTANCE_AB(type1, type2, a, 10, direction) else INSTANCE_AB( \
type1, type2, \
a, 9, direction) else INSTANCE_AB(type1, type2, a, 8, direction) else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB(type1, type2, a, 6, direction) else INSTANCE_AB(type1, type2, a, 5, direction) else INSTANCE_AB(type1, type2, a, 4, direction) else INSTANCE_AB(type1, type2, a, 3, direction) else INSTANCE_AB(type1, type2, a, 2, direction) else INSTANCE_AB(type1, type2, a, 1, direction) else INSTANCE_AB(type1, type2, a, 0, direction) \
}
#define INSTANCE(direction) \
INSTANCE_A(7, direction) \
else INSTANCE_A(6, direction) else INSTANCE_A(5, direction) else INSTANCE_A(4, direction) else INSTANCE_A( \
3, \
direction) else INSTANCE_A(2, direction) else INSTANCE_A(1, direction) else INSTANCE_A(0, direction)
#define INSTANCE(type1, type2, direction) \
INSTANCE_A(type1, type2, 6, direction) \
else INSTANCE_A(type1, type2, 4, direction) else INSTANCE_A( \
type1, type2, 2, direction) else INSTANCE_A(type1, type2, 0, direction)
} // anonymous namespace
......@@ -37,9 +37,18 @@ template <>
void run_fwd_depthwise_large_filter(
float* dst, const float* src, const float* flt, const Param& param,
cudaStream_t stream) {
INSTANCE(DepthwiseConv2dDirection::DIRECTION_FORWARD)
INSTANCE(float, float2, DepthwiseConv2dDirection::DIRECTION_FORWARD)
}
#if CUDA_VERSION >= 9000
template <>
void run_fwd_depthwise_large_filter(
__half* dst, const __half* src, const __half* flt, const Param& param,
cudaStream_t stream) {
INSTANCE(__half, __half2, DepthwiseConv2dDirection::DIRECTION_FORWARD)
}
#endif
} // namespace chanwise
} // namespace conv_bias
} // namespace cuda
......
......@@ -50,7 +50,11 @@ bool ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::is_available(
return false;
}
if (args.src_layout->dtype != args.filter_layout->dtype &&
args.src_layout->dtype != dtype::Float32()) {
(args.src_layout->dtype != dtype::Float32()
#if CUDA_VERSION >= 9000
|| args.src_layout->dtype != dtype::Float16()
#endif
)) {
return false;
}
if (args.z_layout->ndim > 0)
......@@ -97,6 +101,15 @@ void ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::exec(const ExecArgs& args) c
conv_dst_tensor.ptr<float>(), args.src_tensor->ptr<float>(),
args.filter_tensor->ptr<float>(), kparam, stream);
break;
#if CUDA_VERSION >= 9000
case DTypeEnum::Float16:
chanwise::run_fwd_depthwise_large_filter(
static_cast<half*>(conv_dst_tensor.raw_ptr()),
static_cast<half*>(args.src_tensor->raw_ptr()),
static_cast<half*>(args.filter_tensor->raw_ptr()), kparam,
stream);
break;
#endif
default:
megdnn_assert_internal(0);
}
......
......@@ -49,7 +49,11 @@ bool ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::is_available(
return false;
}
if (args.diff_layout->dtype != args.filter_layout->dtype &&
args.diff_layout->dtype != dtype::Float32()) {
(args.diff_layout->dtype != dtype::Float32()
#if CUDA_VERSION >= 9000
|| args.diff_layout->dtype != dtype::Float16()
#endif
)) {
return false;
}
......@@ -78,6 +82,14 @@ void ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::exec(
args.grad_tensor->ptr<float>(), args.diff_tensor->ptr<float>(),
args.filter_tensor->ptr<float>(), kparam, stream);
break;
#if CUDA_VERSION >= 9000
case DTypeEnum::Float16:
chanwise::run_bwd_depthwise_large_filter(
static_cast<half*>(args.grad_tensor->raw_ptr()),
static_cast<half*>(args.diff_tensor->raw_ptr()),
static_cast<half*>(args.filter_tensor->raw_ptr()), kparam, stream);
break;
#endif
default:
megdnn_assert_internal(0);
}
......
......@@ -34,9 +34,18 @@ template <>
void run_bwd_depthwise_large_filter(
float* dst, const float* src, const float* flt, const Param& param,
cudaStream_t stream) {
INSTANCE(DepthwiseConv2dDirection::DIRECTION_BACKWARD)
INSTANCE(float, float2, DepthwiseConv2dDirection::DIRECTION_BACKWARD)
}
#if CUDA_VERSION >= 9000
template <>
void run_bwd_depthwise_large_filter(
__half* dst, const __half* src, const __half* flt, const Param& param,
cudaStream_t stream) {
INSTANCE(__half, __half2, DepthwiseConv2dDirection::DIRECTION_BACKWARD)
}
#endif
} // namespace chanwise
} // namespace convolution
} // namespace cuda
......
......@@ -701,51 +701,53 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
ConvBiasForward::algo_name<ConvBias::DirectParam>(
"DEPTHWISE_LARGE_FILTER", {})
.c_str()));
auto run = [&checker](size_t n, size_t g, size_t h, size_t fh) {
param::ConvBias cur_param;
cur_param.mode = param::ConvBias::Mode::CROSS_CORRELATION;
cur_param.sparse = ConvBias::Param::Sparse::GROUP;
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(3, dtype::Float32())
.set_dtype(4, dtype::Float32());
for (auto dtype : std::vector<DType>{dtype::Float16()}) {
auto run = [&checker, &dtype](size_t n, size_t g, size_t h, size_t fh) {
param::ConvBias cur_param;
cur_param.mode = param::ConvBias::Mode::CROSS_CORRELATION;
cur_param.sparse = ConvBias::Param::Sparse::GROUP;
checker.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype)
.set_dtype(3, dtype)
.set_dtype(4, dtype);
cur_param.pad_h = cur_param.pad_w = fh / 2;
cur_param.stride_h = cur_param.stride_w = 1;
checker.set_param(cur_param).execs(
{{n, g, h, h}, {g, 1, 1, fh, fh}, {}, {}, {}});
};
run(4, 8, 32, 5);
run(4, 8, 32, 7);
run(4, 8, 32, 9);
run(4, 8, 32, 11);
run(4, 8, 32, 13);
run(4, 8, 32, 15);
run(4, 8, 32, 17);
run(4, 8, 32, 19);
run(4, 8, 32, 21);
run(4, 8, 32, 23);
run(4, 8, 32, 25);
run(4, 8, 32, 27);
run(4, 8, 32, 29);
run(4, 8, 32, 31);
run(4, 8, 64, 5);
run(4, 8, 64, 7);
run(4, 8, 64, 9);
run(4, 8, 64, 11);
run(4, 8, 64, 13);
run(4, 8, 64, 15);
run(4, 8, 64, 17);
run(4, 8, 64, 19);
run(4, 8, 64, 21);
run(4, 8, 64, 23);
run(4, 8, 64, 25);
run(4, 8, 64, 27);
run(4, 8, 64, 29);
run(4, 8, 64, 31);
run(1, 2, 128, 31);
run(1, 2, 256, 31);
cur_param.pad_h = cur_param.pad_w = fh / 2;
cur_param.stride_h = cur_param.stride_w = 1;
checker.set_param(cur_param).execs(
{{n, g, h, h}, {g, 1, 1, fh, fh}, {}, {}, {}});
};
run(4, 8, 32, 5);
run(4, 8, 32, 7);
run(4, 8, 32, 9);
run(4, 8, 32, 11);
run(4, 8, 32, 13);
run(4, 8, 32, 15);
run(4, 8, 32, 17);
run(4, 8, 32, 19);
run(4, 8, 32, 21);
run(4, 8, 32, 23);
run(4, 8, 32, 25);
run(4, 8, 32, 27);
run(4, 8, 32, 29);
run(4, 8, 32, 31);
run(4, 8, 64, 5);
run(4, 8, 64, 7);
run(4, 8, 64, 9);
run(4, 8, 64, 11);
run(4, 8, 64, 13);
run(4, 8, 64, 15);
run(4, 8, 64, 17);
run(4, 8, 64, 19);
run(4, 8, 64, 21);
run(4, 8, 64, 23);
run(4, 8, 64, 25);
run(4, 8, 64, 27);
run(4, 8, 64, 29);
run(4, 8, 64, 31);
run(1, 2, 128, 31);
run(1, 2, 256, 31);
}
}
TEST_F(CUDA, CONV_BIAS_FORWARD_CHANWISE_8x8x32) {
......@@ -1550,11 +1552,81 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
param.stride_h = sh;
param.stride_w = sw;
bencher.set_times(nr_times);
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h);
size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w);
TensorShape inp{batch, g, hi, wi}, kern{g, 1, 1, fh, fw}, out{batch, g, ho, wo};
float bandwith = static_cast<float>(
inp.total_nr_elems() + kern.total_nr_elems() +
out.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
bencher.set_param(param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32());
auto fp32_time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times;
bencher.set_param(param)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_dtype(2, dtype::Float16())
.set_dtype(4, dtype::Float16());
auto fp16_time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times;
printf("chanwise_depthwise_large_filter: inp=%s, kern=%s, out=%s, fp32_time: "
"%.2fms, fp16_time: %.2fms, speedup: %0.2f (fp16/fp32) "
"fp32_bandwidth: %.2fGB/s fp16_bandwidth: %.2fGB/s.\n",
inp.to_string().c_str(), kern.to_string().c_str(),
out.to_string().c_str(), fp32_time_in_ms, fp16_time_in_ms,
fp32_time_in_ms / fp16_time_in_ms, bandwith * 4 / fp32_time_in_ms,
bandwith * 2 / fp16_time_in_ms);
};
run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10);
run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10);
run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10);
run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10);
run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10);
run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10);
run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10);
run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10);
run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10);
run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10);
run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10);
run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10);
run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10);
run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10);
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10);
}
TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP16) {
require_compute_capability(7, 5);
Benchmarker<ConvBiasForward> bencher(handle_cuda());
bencher.set_display(false);
bencher.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
"DEPTHWISE_LARGE_FILTER", {})
.c_str()));
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW;
using NonlineMode = ConvBias::Param::NonlineMode;
param.nonlineMode = NonlineMode::IDENTITY;
param.sparse = ConvBias::Param::Sparse::GROUP;
auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh,
size_t fw, size_t sh, size_t sw, size_t nr_times) {
param.pad_h = fh / 2;
param.pad_w = fw / 2;
param.stride_h = sh;
param.stride_w = sw;
bencher.set_param(param)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_dtype(2, dtype::Float16())
.set_dtype(4, dtype::Float16());
bencher.set_times(nr_times);
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h);
size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w);
......
......@@ -728,7 +728,7 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) {
Checker<ConvolutionBackwardData> checker(handle_cuda());
checker.set_before_exec_callback(
AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER"));
for (auto dtype : std::vector<DType>{dtype::Float32()}) {
for (auto dtype : std::vector<DType>{dtype::Float16()}) {
auto run = [&checker, &dtype](size_t n, size_t g, size_t h, size_t fh) {
param::Convolution param;
param.stride_h = param.stride_w = 1;
......@@ -999,6 +999,55 @@ TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER) {
run(64, 384, 384, 32, 32, 31, 1, 10);
}
TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER_FP16) {
CUBenchmarker<ConvolutionBackwardData> bencher{handle_cuda()};
bencher.set_display(false);
bencher.set_before_exec_callback(
AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER"));
auto run = [&](size_t N, size_t OC, size_t g, size_t IH, size_t IW, size_t FH,
size_t SH, size_t nr_times) {
bencher.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_dtype(2, dtype::Float16());
param::Convolution param;
param.stride_h = param.stride_w = SH;
param.pad_h = param.pad_w = FH / 2;
param.sparse = param::Convolution::Sparse::GROUP;
bencher.set_param(param);
bencher.set_times(nr_times);
TensorLayout src{{N, g, IH, IW}, dtype::Float16()},
filter{{g, 1, 1, FH, FH}, dtype::Float16()};
TensorLayout dst;
{
auto&& opr = handle_cuda()->create_operator<Convolution>();
opr->param() = param;
opr->deduce_layout(src, filter, dst);
}
auto time_ms_fp16 = bencher.execl({filter, dst, src}) / nr_times;
float flo = 2.0 * N * g * dst[2] * dst[3] * FH * FH;
printf("inp=%s, kern=%s, dst=%s ", src.to_string().c_str(),
filter.to_string().c_str(), dst.to_string().c_str());
printf("time_fp16=%.2fms, flops=%.3fTFLOPS\n", time_ms_fp16,
(flo / (time_ms_fp16 * 1e9)));
};
run(64, 384, 384, 32, 32, 3, 1, 10);
run(64, 384, 384, 32, 32, 5, 1, 10);
run(64, 384, 384, 32, 32, 7, 1, 10);
run(64, 384, 384, 32, 32, 9, 1, 10);
run(64, 384, 384, 32, 32, 11, 1, 10);
run(64, 384, 384, 32, 32, 13, 1, 10);
run(64, 384, 384, 32, 32, 15, 1, 10);
run(64, 384, 384, 32, 32, 17, 1, 10);
run(64, 384, 384, 32, 32, 19, 1, 10);
run(64, 384, 384, 32, 32, 21, 1, 10);
run(64, 384, 384, 32, 32, 23, 1, 10);
run(64, 384, 384, 32, 32, 25, 1, 10);
run(64, 384, 384, 32, 32, 27, 1, 10);
run(64, 384, 384, 32, 32, 29, 1, 10);
run(64, 384, 384, 32, 32, 31, 1, 10);
}
TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_BF16) {
CUBenchmarker<ConvolutionBackwardData> bench{handle_cuda()};
std::unique_ptr<OprProxy<ConvolutionBackwardData>> proxy{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册