Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e698ec20
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e698ec20
编写于
2月 15, 2022
作者:
M
Megvii Engine Team
提交者:
王彪
2月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(cuda): float16 depthwise large kernel conv compute fp32
GitOrigin-RevId: 3050d48f2691faeeda4fb054134041cc620b5a35
上级
48406382
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
264 addition
and
191 deletion
+264
-191
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl
...c/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl
+137
-85
dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu
dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu
+0
-1
dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
+4
-5
dnn/src/cuda/convolution/backward_data/algo.h
dnn/src/cuda/convolution/backward_data/algo.h
+1
-1
dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp
...cuda/convolution/backward_data/depthwise_large_filter.cpp
+7
-5
dnn/src/cuda/fp16_help.cuh
dnn/src/cuda/fp16_help.cuh
+6
-0
dnn/test/cuda/conv_bias.cpp
dnn/test/cuda/conv_bias.cpp
+63
-58
dnn/test/cuda/convolution.cpp
dnn/test/cuda/convolution.cpp
+46
-36
未找到文件。
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl
浏览文件 @
e698ec20
...
...
@@ -57,10 +57,13 @@ struct Global2SharedMem {
T* smem;
int stride;
int start_h, start_w, bound_h, bound_w, ring_smem_h, ring_src_h;
// just used in backward src data
int stride_h, stride_w;
const T* g_ptr;
__device__ __forceinline__
Global2SharedMem(T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w);
__device__ __forceinline__ Global2SharedMem(
T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w, int stride_h_,
int stride_w_);
__device__ __forceinline__ void first_copy();
__device__ __forceinline__ void copy();
...
...
@@ -77,7 +80,7 @@ struct Global2SharedMem {
template <
typename ldg_dtype, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename OutTileConfig_, typename FilterTileConfig_>
typename OutTileConfig_, typename FilterTileConfig_
, int stride_w, int stride_h
>
struct ConvTrait {
using ThreadConfig = ThreadConfig_;
using OutTileConfig = OutTileConfig_;
...
...
@@ -88,19 +91,19 @@ struct ConvTrait {
static int const unroll_h =
OutTileConfig::unroll_h + FilterTileConfig::unroll_h - 1;
static int const unroll_w =
OutTileConfig::unroll_w + FilterTileConfig::unroll_w - 1
;
(OutTileConfig::unroll_w - 1) * stride_w + FilterTileConfig::unroll_w
;
static int const unroll_size = unroll_h * unroll_w;
};
struct SrcTileCount {
static int const smem_src_h =
OutTileConfig::block_h + FilterTileConfig::unroll_h - 1
;
(OutTileConfig::block_h - 1) * stride_h + FilterTileConfig::unroll_h
;
static int const smem_buff_h = FilterTileConfig::unroll_h;
static int const smem_load_h = smem_src_h + smem_buff_h;
static int const smem_h = smem_load_h + smem_buff_h;
static int const smem_w =
DIVUP(
OutTileConfig::block
_w +
FilterTileConfig::unroll_w * ThreadConfig::thread_x
- 1
,
DIVUP(
(OutTileConfig::block_w - 1) * stride
_w +
FilterTileConfig::unroll_w * ThreadConfig::thread_x,
2) *
2;
static int const smem_size = smem_h * smem_w;
...
...
@@ -140,20 +143,25 @@ template <
typename TileCount_>
__device__ __forceinline__
Global2SharedMem<T, kDirection, ThreadConfig_, TileCount_>::Global2SharedMem(
T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w)
T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w, int stride_h_,
int stride_w_)
: smem(smem_),
stride(stride_),
start_h(s_h),
start_w(s_w),
bound_h(b_h),
bound_w(b_w),
ring_smem_h(TileCount::smem_load_h) {
ring_smem_h(TileCount::smem_load_h),
stride_h(stride_h_),
stride_w(stride_w_) {
if (is_fwd) {
ring_src_h = s_h + TileCount::smem_load_h;
w_offset = 0;
} else {
ring_src_h = s_h - 1;
w_offset = TileCount::smem_w - b_w;
// stride_h and stride_w just used in backward src data.
stride_h = stride_w = 1;
}
}
...
...
@@ -195,9 +203,10 @@ __device__ __forceinline__ void Global2SharedMem<
T val = 0.0f;
if (src_h_idx >= 0 && src_h_idx < bound_h && src_w_idx >= 0 &&
src_w_idx < bound_w &&
(is_fwd || (TileCount::smem_load_h - smem_h_idx - 1 >= 0 &&
TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) {
val = g_ptr[src_h_idx * stride + src_w_idx];
((is_fwd && src_h_idx % stride_h == 0 && src_w_idx % stride_w == 0) ||
(!is_fwd && TileCount::smem_load_h - smem_h_idx - 1 >= 0 &&
TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) {
val = g_ptr[src_h_idx / stride_h * stride + src_w_idx / stride_w];
}
*(sh_ptr_as_copy_t(smem_h_idx, smem_w_idx)) = val;
}
...
...
@@ -223,8 +232,9 @@ __device__ __forceinline__ void Global2SharedMem<
T val = 0.0f;
if (ring_src_h >= 0 && ring_src_h < bound_h && src_w_idx >= 0 &&
src_w_idx < bound_w &&
(is_fwd || TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0)) {
val = g_ptr[ring_src_h * stride + src_w_idx];
((is_fwd && ring_src_h % stride_h == 0 && src_w_idx % stride_w == 0) ||
(!is_fwd && TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) {
val = g_ptr[ring_src_h / stride_h * stride + src_w_idx / stride_w];
}
reg[j] = val;
}
...
...
@@ -286,21 +296,23 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
const int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2;
const int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2;
const int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2;
const
expr
int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2;
const
expr
int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2;
const
expr
int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2;
extern __shared__ __align__(8) unsigned char smem[];
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
out_start_h = off_obh * OutTileConfig::block_h,
out_start_w = off_obw * OutTileConfig::block_w,
src_start_h = out_start_h - param.pad_h,
src_start_w = out_start_w - param.pad_w,
src_start_h = out_start_h
* stride_h
- param.pad_h,
src_start_w = out_start_w
* stride_w
- param.pad_w,
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h;
T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w;
...
...
@@ -308,12 +320,28 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w;
SrcGlobal2ShareVisitor gl2sh_src(
smem_src, param.src_w, src_start_h, src_start_w, param.src_h, param.src_w);
FilterGlobal2ShareVisitor gl2sh_flt = {
smem_flt, param.flt_w, is_fwd ? 0 : param.flt_h - 2,
0, param.flt_h, param.flt_w};
SrcGlobal2ShareVisitor gl2sh_src = {
smem_src,
param.src_w,
is_fwd ? src_start_h
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2),
is_fwd ? src_start_w
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2),
is_fwd ? param.src_h : param.src_h * param.stride_h,
is_fwd ? param.src_w : param.src_w * param.stride_w,
is_fwd ? 1 : param.stride_h,
is_fwd ? 1 : param.stride_w};
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt,
param.flt_w,
is_fwd ? 0 : param.flt_h - 2,
0,
param.flt_h,
param.flt_w,
1,
1};
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w;
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w;
...
...
@@ -326,7 +354,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w],
reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w];
T
2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}};
float
2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}};
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) {
gl2sh_src.copy();
...
...
@@ -335,7 +363,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) {
int src_offset = (off_oh + fh + s_h) % SrcTileCount::smem_h *
int src_offset = (off_oh
* stride_h
+ fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w * 2;
reg_src[s_h * t2_src_unroll_w + s_w] =
...
...
@@ -373,9 +401,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2(
reg_flt[ow % 2][inner_fh * t2_flt_unroll_w + fw],
reg_flt[ow * stride_w % 2]
[inner_fh * t2_flt_unroll_w + fw],
reg_src[(inner_fh + oh) * t2_src_unroll_w + fw +
ow / 2],
ow
* stride_w
/ 2],
sum[oh * t2_out_unroll_w + ow]);
}
}
...
...
@@ -392,7 +421,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
for (int o = 0; o < OutTileConfig::unroll_size; ++o) {
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) {
sum[o] += __shfl_xor(sum[o], i, 32);
sum[o].x += __shfl_xor(sum[o].x, i, 32);
sum[o].y += __shfl_xor(sum[o].y, i, 32);
}
}
...
...
@@ -406,9 +436,9 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
int out_w_idx = out_start_w + j;
if (out_w_idx >= param.out_w)
return;
out_base_ptr[out_h_idx * param.out_w + out_w_idx] =
out_base_ptr[out_h_idx * param.out_w + out_w_idx] =
__float2half(
sum[i * OutTileConfig::unroll_w + j].x +
sum[i * OutTileConfig::unroll_w + j].y;
sum[i * OutTileConfig::unroll_w + j].y
)
;
}
}
}
...
...
@@ -433,21 +463,19 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
const int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2;
const int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2;
const int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2;
extern __shared__ __align__(8) unsigned char smem[];
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
out_start_h = off_obh * OutTileConfig::block_h,
out_start_w = off_obw * OutTileConfig::block_w,
src_start_h = out_start_h - param.pad_h,
src_start_w = out_start_w - param.pad_w,
src_start_h = out_start_h
* stride_h
- param.pad_h,
src_start_w = out_start_w
* stride_w
- param.pad_w,
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h;
T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w;
...
...
@@ -455,12 +483,28 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w;
SrcGlobal2ShareVisitor gl2sh_src(
smem_src, param.src_w, src_start_h, src_start_w, param.src_h, param.src_w);
FilterGlobal2ShareVisitor gl2sh_flt = {
smem_flt, param.flt_w, is_fwd ? 0 : param.flt_h - 2,
0, param.flt_h, param.flt_w};
SrcGlobal2ShareVisitor gl2sh_src = {
smem_src,
param.src_w,
is_fwd ? src_start_h
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2),
is_fwd ? src_start_w
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2),
is_fwd ? param.src_h : param.src_h * param.stride_h,
is_fwd ? param.src_w : param.src_w * param.stride_w,
is_fwd ? 1 : param.stride_h,
is_fwd ? 1 : param.stride_w};
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt,
param.flt_w,
is_fwd ? 0 : param.flt_h - 2,
0,
param.flt_h,
param.flt_w,
1,
1};
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w;
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w;
...
...
@@ -470,10 +514,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
__syncthreads();
T
2 reg_src[SrcTileConfig::unroll_h * t2_src_
unroll_w],
reg_flt[
2][FilterTileConfig::unroll_h * t2_flt_
unroll_w];
T
reg_src[SrcTileConfig::unroll_h * SrcTileConfig::
unroll_w],
reg_flt[
FilterTileConfig::unroll_h * FilterTileConfig::
unroll_w];
T
2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}
};
T
sum[OutTileConfig::unroll_size] = {0.0
};
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) {
gl2sh_src.copy();
...
...
@@ -481,34 +525,28 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) {
int src_offset = (off_oh + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w * 2;
reg_src[s_h * t2_src_unroll_w + s_w] =
*reinterpret_cast<T2*>(smem_src_ptr + src_offset);
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) {
reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
[(off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w];
if (off_ochannel == 0 && off_obw == 0 && off_obh == 0 && off_oh == 30 &&
off_ow == 0) {
printf("reg_src[%d] = %f\n", s_h * SrcTileConfig::unroll_w + s_w,
reg_src[s_h * SrcTileConfig::unroll_w + s_w]);
}
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) {
int flt_offset =
(fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w +
f_w * 2;
reg_flt[0][f_h * t2_flt_unroll_w + f_w] =
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset);
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = {
f_w > 0 ? reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y
: static_cast<T>(0.0),
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(fh + f_h) % FilterTileCount::smem_h *
FilterTileCount::smem_w +
f_w];
}
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {
static_cast<T>(0.0), static_cast<T>(0.0)};
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y,
static_cast<T>(0.0)};
}
#pragma unroll
...
...
@@ -516,14 +554,22 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw <
t2_flt_
unroll_w; ++fw) {
for (int fw = 0; fw <
FilterTileConfig::
unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2(
reg_flt[ow % 2][inner_fh * t2_flt_unroll_w + fw],
reg_src[(inner_fh + oh) * t2_src_unroll_w + fw +
ow / 2],
sum[oh * t2_out_unroll_w + ow]);
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] *
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw +
ow * stride_w];
if (off_ochannel == 0 && off_obw == 0 && off_obh == 0 &&
off_oh == 30) {
printf("sum[%d] += %f * %f\nsum = %f\n",
oh * OutTileConfig::unroll_w + ow,
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw],
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w +
fw + ow * stride_w],
sum[oh * OutTileConfig::unroll_w + ow]);
}
}
}
}
...
...
@@ -539,8 +585,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
for (int o = 0; o < OutTileConfig::unroll_size; ++o) {
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) {
sum[o].x += __shfl_xor(sum[o].x, i, 32);
sum[o].y += __shfl_xor(sum[o].y, i, 32);
sum[o] += __shfl_xor(sum[o], i, 32);
}
}
...
...
@@ -555,8 +600,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
if (out_w_idx >= param.out_w)
return;
out_base_ptr[out_h_idx * param.out_w + out_w_idx] =
sum[i * OutTileConfig::unroll_w + j].x +
sum[i * OutTileConfig::unroll_w + j].y;
sum[i * OutTileConfig::unroll_w + j];
}
}
}
...
...
@@ -565,7 +609,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
template <
typename T, typename T2, DepthwiseConv2dDirection kDirection, int unroll_fw,
int unroll_ow>
int unroll_ow
, int stride
>
void LaunchDepthwiseConv2dGPUSmall(
const Param& param, const T* input, const T* filter, T* output,
cudaStream_t stream) {
...
...
@@ -574,8 +618,9 @@ void LaunchDepthwiseConv2dGPUSmall(
using FilterTileConfig = FilterTileConfig<unroll_fh, unroll_fw>;
using ThreadConfig = ThreadConfig<4, 32>;
using OutTileConfig = OutTileConfig<ThreadConfig, unroll_oh, unroll_ow>;
using IConvTrait =
ConvTrait<T, kDirection, ThreadConfig, OutTileConfig, FilterTileConfig>;
using IConvTrait = ConvTrait<
T, kDirection, ThreadConfig, OutTileConfig, FilterTileConfig, stride,
stride>;
using SrcTileCount = typename IConvTrait::SrcTileCount;
using FilterTileCount = typename IConvTrait::FilterTileCount;
...
...
@@ -593,10 +638,17 @@ void LaunchDepthwiseConv2dGPUSmall(
after_kernel_launch();
}
#define INSTANCE_AB(type1, type2, a, b, direction) \
if (param.out_w > b * 4) { \
LaunchDepthwiseConv2dGPUSmall<type1, type2, direction, a + 2, b + 1>( \
param, src, flt, dst, stream); \
#define INSTANCE_AB(type1, type2, a, b, direction) \
if (param.out_w > b * 4) { \
printf("param.out_w = %d, b = %d\n", param.out_w, b); \
if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \
(param.stride_h == 1 && param.stride_w == 1)) { \
LaunchDepthwiseConv2dGPUSmall<type1, type2, direction, a + 2, b + 1, 1>( \
param, src, flt, dst, stream); \
} else if (param.stride_h == 2 && param.stride_w == 2) { \
LaunchDepthwiseConv2dGPUSmall<type1, type2, direction, a + 2, b + 1, 2>( \
param, src, flt, dst, stream); \
} \
}
#define INSTANCE_A(type1, type2, a, direction) \
...
...
dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu
浏览文件 @
e698ec20
...
...
@@ -11,7 +11,6 @@
#include "cuda.h"
#include "cuda_fp16.h"
// #include "src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cuh"
#include "src/cuda/conv_bias/chanwise/kern.cuh"
#include "src/cuda/conv_bias/chanwise/kern_helper.cuh"
#include "src/cuda/conv_bias/chanwise/launch_config.cuh"
...
...
dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
浏览文件 @
e698ec20
...
...
@@ -32,15 +32,14 @@ inline bool is_available_depthwise_large_filter(const chanwise::Param& param) {
:
1
+
(
ow
+
3
)
/
4
+
flt_smem_w
/
4
-
1
;
int
out_reg_per_thread
=
(
ow
+
3
)
/
4
*
4
;
if
(
device_prop
.
regsPerBlock
<
4
*
32
*
(
flt_reg_per_thread
+
src_reg_per_thread
+
out_reg_per_thread
)
||
(
flt_reg_per_thread
*
2
+
src_reg_per_thread
+
out_reg_per_thread
)
||
device_prop
.
sharedMemPerBlock
<
static_cast
<
size_t
>
(
flt_smem_w
*
flt_smem_h
+
src_smem_w
*
src_smem_h
))
{
flt_smem_w
*
flt_smem_h
*
2
+
src_smem_w
*
src_smem_h
))
{
return
false
;
}
return
param
.
stride_h
==
1
&&
param
.
stride_w
==
1
&&
param
.
src_h
==
param
.
out_h
&&
param
.
src_w
==
param
.
out_w
;
return
true
;
}
}
// anonymous namespace
...
...
dnn/src/cuda/convolution/backward_data/algo.h
浏览文件 @
e698ec20
...
...
@@ -68,7 +68,7 @@ public:
const
TensorLayout
&
grad
);
convolution
::
ForwardSizeArgs
as_fwd_args
()
const
{
return
{
handle
,
grad_layout
,
filter_layout
,
filter_meta
,
diff
_layout
};
return
{
handle
,
diff_layout
,
filter_layout
,
filter_meta
,
grad
_layout
};
}
};
struct
ExecArgs
:
public
SizeArgs
{
...
...
dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp
浏览文件 @
e698ec20
...
...
@@ -31,15 +31,17 @@ inline bool is_available_depthwise_large_filter(const chanwise::Param& param) {
:
1
+
(
ow
+
3
)
/
4
+
flt_smem_w
/
4
-
1
;
int
out_reg_per_thread
=
(
ow
+
3
)
/
4
*
4
;
if
(
device_prop
.
regsPerBlock
<
4
*
32
*
(
flt_reg_per_thread
+
src_reg_per_thread
+
out_reg_per_thread
)
||
(
flt_reg_per_thread
*
2
+
src_reg_per_thread
+
out_reg_per_thread
)
||
device_prop
.
sharedMemPerBlock
<
static_cast
<
size_t
>
(
flt_smem_w
*
flt_smem_h
+
src_smem_w
*
src_smem_h
))
{
flt_smem_w
*
flt_smem_h
*
2
+
src_smem_w
*
src_smem_h
))
{
return
false
;
}
return
param
.
stride_h
==
1
&&
param
.
stride_w
==
1
&&
param
.
src_h
==
param
.
out_h
&&
param
.
src_w
==
param
.
out_w
;
printf
(
"param.src_w = %d, param.src_h = %d, param.out_w = %d, param.out_h = %d
\n
"
,
param
.
src_w
,
param
.
src_h
,
param
.
out_w
,
param
.
out_h
);
return
(
param
.
stride_h
==
1
&&
param
.
stride_w
==
1
)
||
(
param
.
stride_h
==
2
&&
param
.
stride_w
==
2
);
}
}
// anonymous namespace
...
...
dnn/src/cuda/fp16_help.cuh
浏览文件 @
e698ec20
...
...
@@ -45,6 +45,12 @@ fma2(const __half2 a, const __half2 b, const __half2 c) {
#endif
}
__device__
__forceinline__
float2
fma2
(
const
__half2
a
,
const
__half2
b
,
const
float2
c
)
{
return
{
__half2float
(
a
.
x
)
*
__half2float
(
b
.
x
)
+
c
.
x
,
__half2float
(
a
.
y
)
*
__half2float
(
b
.
y
)
+
c
.
y
};
}
#endif // CUDA_VERSION >= 9000
}
// namespace cuda
...
...
dnn/test/cuda/conv_bias.cpp
浏览文件 @
e698ec20
...
...
@@ -701,8 +701,10 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
ConvBiasForward
::
algo_name
<
ConvBias
::
DirectParam
>
(
"DEPTHWISE_LARGE_FILTER"
,
{})
.
c_str
()));
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float16
()})
{
auto
run
=
[
&
checker
,
&
dtype
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
)
{
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
auto
run
=
[
&
checker
,
&
dtype
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
,
size_t
padding
,
size_t
stride
)
{
param
::
ConvBias
cur_param
;
cur_param
.
mode
=
param
::
ConvBias
::
Mode
::
CROSS_CORRELATION
;
cur_param
.
sparse
=
ConvBias
::
Param
::
Sparse
::
GROUP
;
...
...
@@ -711,42 +713,52 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
.
set_dtype
(
2
,
dtype
)
.
set_dtype
(
3
,
dtype
)
.
set_dtype
(
4
,
dtype
);
float
scale
=
64.
f
/
sqrt
(
fh
*
fh
);
UniformFloatRNG
rng
(
scale
,
2
*
scale
);
checker
.
set_rng
(
0
,
&
rng
)
.
set_rng
(
1
,
&
rng
)
.
set_rng
(
2
,
&
rng
)
.
set_rng
(
3
,
&
rng
)
.
set_rng
(
4
,
&
rng
);
if
(
dtype
.
enumv
()
==
DTypeEnum
::
Float16
)
{
checker
.
set_epsilon
(
1e-1
);
}
cur_param
.
pad_h
=
cur_param
.
pad_w
=
fh
/
2
;
cur_param
.
stride_h
=
cur_param
.
stride_w
=
1
;
cur_param
.
pad_h
=
cur_param
.
pad_w
=
padding
;
cur_param
.
stride_h
=
cur_param
.
stride_w
=
stride
;
checker
.
set_param
(
cur_param
).
execs
(
{{
n
,
g
,
h
,
h
},
{
g
,
1
,
1
,
fh
,
fh
},
{},
{},
{}});
};
run
(
4
,
8
,
32
,
5
);
run
(
4
,
8
,
32
,
7
);
run
(
4
,
8
,
32
,
9
);
run
(
4
,
8
,
32
,
11
);
run
(
4
,
8
,
32
,
13
);
run
(
4
,
8
,
32
,
15
);
run
(
4
,
8
,
32
,
17
);
run
(
4
,
8
,
32
,
19
);
run
(
4
,
8
,
32
,
21
);
run
(
4
,
8
,
32
,
23
);
run
(
4
,
8
,
32
,
25
);
run
(
4
,
8
,
32
,
27
);
run
(
4
,
8
,
32
,
29
);
run
(
4
,
8
,
32
,
31
);
run
(
4
,
8
,
64
,
5
);
run
(
4
,
8
,
64
,
7
);
run
(
4
,
8
,
64
,
9
);
run
(
4
,
8
,
64
,
11
);
run
(
4
,
8
,
64
,
13
);
run
(
4
,
8
,
64
,
15
);
run
(
4
,
8
,
64
,
17
);
run
(
4
,
8
,
64
,
19
);
run
(
4
,
8
,
64
,
21
);
run
(
4
,
8
,
64
,
23
);
run
(
4
,
8
,
64
,
25
);
run
(
4
,
8
,
64
,
27
);
run
(
4
,
8
,
64
,
29
);
run
(
4
,
8
,
64
,
31
);
run
(
1
,
2
,
128
,
31
);
run
(
1
,
2
,
256
,
31
);
run
(
4
,
8
,
32
,
5
,
5
/
2
,
1
);
run
(
4
,
8
,
32
,
7
,
7
/
2
,
1
);
run
(
4
,
8
,
32
,
9
,
9
/
2
,
1
);
run
(
4
,
8
,
32
,
11
,
11
/
2
,
1
);
run
(
4
,
8
,
32
,
13
,
13
/
2
,
1
);
run
(
4
,
8
,
32
,
15
,
15
/
2
,
1
);
run
(
4
,
8
,
32
,
17
,
17
/
2
,
1
);
run
(
4
,
8
,
32
,
19
,
19
/
2
,
1
);
run
(
4
,
8
,
32
,
21
,
21
/
2
,
1
);
run
(
4
,
8
,
32
,
23
,
23
/
2
,
1
);
run
(
4
,
8
,
32
,
25
,
25
/
2
,
1
);
run
(
4
,
8
,
32
,
27
,
27
/
2
,
1
);
run
(
4
,
8
,
32
,
29
,
29
/
2
,
1
);
run
(
4
,
8
,
32
,
31
,
31
/
2
,
1
);
run
(
4
,
8
,
64
,
5
,
5
/
3
,
2
);
run
(
4
,
8
,
64
,
7
,
7
/
3
,
2
);
run
(
4
,
8
,
64
,
9
,
9
/
3
,
2
);
run
(
4
,
8
,
64
,
11
,
11
/
3
,
2
);
run
(
4
,
8
,
64
,
13
,
13
/
3
,
2
);
run
(
4
,
8
,
64
,
15
,
15
/
3
,
2
);
run
(
4
,
8
,
64
,
17
,
17
/
3
,
2
);
run
(
4
,
8
,
64
,
19
,
19
/
3
,
2
);
run
(
4
,
8
,
64
,
21
,
21
/
3
,
2
);
run
(
4
,
8
,
64
,
23
,
23
/
3
,
2
);
run
(
4
,
8
,
64
,
25
,
25
/
3
,
2
);
run
(
4
,
8
,
64
,
27
,
27
/
3
,
2
);
run
(
4
,
8
,
64
,
29
,
29
/
3
,
2
);
run
(
4
,
8
,
64
,
31
,
31
/
3
,
2
);
run
(
1
,
2
,
128
,
31
,
10
,
2
);
run
(
1
,
2
,
256
,
31
,
10
,
2
);
}
}
...
...
@@ -1530,7 +1542,7 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_TENSORCORE_INT8) {
run_bench
(
256
,
512
,
7
,
7
,
2048
,
1
,
1
,
1
,
1
,
1000
);
}
TEST_F
(
CUDA
,
BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER
)
{
TEST_F
(
CUDA
,
BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER
_FP16
)
{
require_compute_capability
(
7
,
5
);
Benchmarker
<
ConvBiasForward
>
bencher
(
handle_cuda
());
bencher
.
set_display
(
false
);
...
...
@@ -1552,6 +1564,11 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
param
.
stride_h
=
sh
;
param
.
stride_w
=
sw
;
bencher
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
Float16
())
.
set_dtype
(
1
,
dtype
::
Float16
())
.
set_dtype
(
2
,
dtype
::
Float16
())
.
set_dtype
(
4
,
dtype
::
Float16
());
bencher
.
set_times
(
nr_times
);
size_t
ho
=
infer_conv_shape
(
hi
,
fh
,
sh
,
param
.
pad_h
);
size_t
wo
=
infer_conv_shape
(
wi
,
fw
,
sw
,
param
.
pad_w
);
...
...
@@ -1562,25 +1579,13 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
out
.
total_nr_elems
())
/
(
1024
*
1024
*
1024
)
*
1e3
;
bencher
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_dtype
(
2
,
dtype
::
Float32
())
.
set_dtype
(
4
,
dtype
::
Float32
());
auto
fp32_time_in_ms
=
bencher
.
execs
({
inp
,
kern
,
{},
{},
out
})
/
nr_times
;
bencher
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
Float16
())
.
set_dtype
(
1
,
dtype
::
Float16
())
.
set_dtype
(
2
,
dtype
::
Float16
())
.
set_dtype
(
4
,
dtype
::
Float16
());
auto
fp16_time_in_ms
=
bencher
.
execs
({
inp
,
kern
,
{},
{},
out
})
/
nr_times
;
printf
(
"chanwise_depthwise_large_filter: inp=%s, kern=%s, out=%s, fp32_time: "
"%.2fms, fp16_time: %.2fms, speedup: %0.2f (fp16/fp32) "
"fp32_bandwidth: %.2fGB/s fp16_bandwidth: %.2fGB/s.
\n
"
,
auto
time_in_ms
=
bencher
.
execs
({
inp
,
kern
,
{},
{},
out
})
/
nr_times
;
auto
ops
=
2.0
*
batch
*
g
*
ho
*
wo
*
fh
*
fw
/
(
time_in_ms
*
1e-3
)
*
1e-12
;
printf
(
"chanwise_depthwise_large_filter: inp=%s, kern=%s, out=%s, time: "
"%.2fms, "
"perf: %.2f Tops bandwidth: %.2fGB/s.
\n
"
,
inp
.
to_string
().
c_str
(),
kern
.
to_string
().
c_str
(),
out
.
to_string
().
c_str
(),
fp32_time_in_ms
,
fp16_time_in_ms
,
fp32_time_in_ms
/
fp16_time_in_ms
,
bandwith
*
4
/
fp32_time_in_ms
,
bandwith
*
2
/
fp16_time_in_ms
);
out
.
to_string
().
c_str
(),
time_in_ms
,
ops
,
bandwith
*
4
/
time_in_ms
);
};
run_bench
(
64
,
384
,
32
,
32
,
3
,
3
,
1
,
1
,
10
);
...
...
@@ -1600,7 +1605,7 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
run_bench
(
64
,
384
,
32
,
32
,
31
,
31
,
1
,
1
,
10
);
}
TEST_F
(
CUDA
,
BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP
16
)
{
TEST_F
(
CUDA
,
BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP
32
)
{
require_compute_capability
(
7
,
5
);
Benchmarker
<
ConvBiasForward
>
bencher
(
handle_cuda
());
bencher
.
set_display
(
false
);
...
...
@@ -1623,10 +1628,10 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP16) {
param
.
stride_w
=
sw
;
bencher
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
Float
16
())
.
set_dtype
(
1
,
dtype
::
Float
16
())
.
set_dtype
(
2
,
dtype
::
Float
16
())
.
set_dtype
(
4
,
dtype
::
Float
16
());
.
set_dtype
(
0
,
dtype
::
Float
32
())
.
set_dtype
(
1
,
dtype
::
Float
32
())
.
set_dtype
(
2
,
dtype
::
Float
32
())
.
set_dtype
(
4
,
dtype
::
Float
32
());
bencher
.
set_times
(
nr_times
);
size_t
ho
=
infer_conv_shape
(
hi
,
fh
,
sh
,
param
.
pad_h
);
size_t
wo
=
infer_conv_shape
(
wi
,
fw
,
sw
,
param
.
pad_w
);
...
...
dnn/test/cuda/convolution.cpp
浏览文件 @
e698ec20
...
...
@@ -728,48 +728,58 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) {
Checker
<
ConvolutionBackwardData
>
checker
(
handle_cuda
());
checker
.
set_before_exec_callback
(
AlgoChecker
<
ConvolutionBackwardData
>
(
"DEPTHWISE_LARGE_FILTER"
));
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float16
()})
{
auto
run
=
[
&
checker
,
&
dtype
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
)
{
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
auto
run
=
[
&
checker
,
&
dtype
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
,
size_t
padding
,
size_t
stride
)
{
param
::
Convolution
param
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
pad_h
=
param
.
pad_w
=
fh
/
2
;
param
.
stride_h
=
param
.
stride_w
=
stride
;
param
.
pad_h
=
param
.
pad_w
=
padding
;
param
.
mode
=
Convolution
::
Mode
::
CROSS_CORRELATION
;
param
.
sparse
=
param
::
Convolution
::
Sparse
::
GROUP
;
checker
.
set_dtype
(
0
,
dtype
).
set_dtype
(
1
,
dtype
).
set_dtype
(
2
,
dtype
);
float
scale
=
64.
f
/
sqrt
(
fh
*
fh
);
UniformFloatRNG
rng
(
1.0
,
1.0
);
checker
.
set_rng
(
0
,
&
rng
).
set_rng
(
1
,
&
rng
).
set_rng
(
2
,
&
rng
);
if
(
dtype
.
enumv
()
==
DTypeEnum
::
Float16
)
checker
.
set_epsilon
(
1e-1
);
checker
.
set_param
(
param
).
execs
(
{{
g
,
1
,
1
,
fh
,
fh
},
{
n
,
g
,
h
,
h
},
{
n
,
g
,
h
,
h
}});
{{
g
,
1
,
1
,
fh
,
fh
},
{
n
,
g
,
(
h
+
2
*
padding
-
fh
+
1
)
/
stride
,
(
h
+
2
*
padding
-
fh
+
1
)
/
stride
},
{
n
,
g
,
h
,
h
}});
};
run
(
4
,
8
,
32
,
5
);
run
(
4
,
8
,
32
,
7
);
run
(
4
,
8
,
32
,
9
);
run
(
4
,
8
,
32
,
11
);
run
(
4
,
8
,
32
,
13
);
run
(
4
,
8
,
32
,
15
);
run
(
4
,
8
,
32
,
17
);
run
(
4
,
8
,
32
,
19
);
run
(
4
,
8
,
32
,
21
);
run
(
4
,
8
,
32
,
23
);
run
(
4
,
8
,
32
,
25
);
run
(
4
,
8
,
32
,
27
);
run
(
4
,
8
,
32
,
29
);
run
(
4
,
8
,
32
,
31
);
run
(
4
,
8
,
64
,
7
);
run
(
4
,
8
,
64
,
5
);
run
(
4
,
8
,
64
,
9
);
run
(
4
,
8
,
64
,
11
);
run
(
4
,
8
,
64
,
13
);
run
(
4
,
8
,
64
,
15
);
run
(
4
,
8
,
64
,
17
);
run
(
4
,
8
,
64
,
19
);
run
(
4
,
8
,
64
,
21
);
run
(
4
,
8
,
64
,
23
);
run
(
4
,
8
,
64
,
25
);
run
(
4
,
8
,
64
,
27
);
run
(
4
,
8
,
64
,
29
);
run
(
4
,
8
,
64
,
31
);
run
(
1
,
2
,
128
,
31
);
run
(
1
,
2
,
256
,
31
);
run
(
4
,
8
,
32
,
5
,
5
/
2
,
1
);
run
(
4
,
8
,
32
,
7
,
7
/
2
,
1
);
run
(
4
,
8
,
32
,
9
,
9
/
2
,
1
);
run
(
4
,
8
,
32
,
11
,
11
/
2
,
1
);
run
(
4
,
8
,
32
,
13
,
13
/
2
,
1
);
run
(
4
,
8
,
32
,
15
,
15
/
2
,
1
);
run
(
4
,
8
,
32
,
17
,
17
/
2
,
1
);
run
(
4
,
8
,
32
,
19
,
19
/
2
,
1
);
run
(
4
,
8
,
32
,
21
,
21
/
2
,
1
);
run
(
4
,
8
,
32
,
23
,
23
/
2
,
1
);
run
(
4
,
8
,
32
,
25
,
25
/
2
,
1
);
run
(
4
,
8
,
32
,
27
,
27
/
2
,
1
);
run
(
4
,
8
,
32
,
29
,
29
/
2
,
1
);
run
(
4
,
8
,
32
,
31
,
31
/
2
,
1
);
run
(
4
,
8
,
64
,
5
,
5
/
2
,
2
);
run
(
4
,
8
,
64
,
7
,
7
/
3
,
2
);
run
(
4
,
8
,
64
,
9
,
9
/
3
,
2
);
run
(
4
,
8
,
64
,
11
,
11
/
3
,
2
);
run
(
4
,
8
,
64
,
13
,
13
/
3
,
2
);
run
(
4
,
8
,
64
,
15
,
15
/
3
,
2
);
run
(
4
,
8
,
64
,
17
,
17
/
3
,
2
);
run
(
4
,
8
,
64
,
19
,
19
/
3
,
2
);
run
(
4
,
8
,
64
,
21
,
21
/
3
,
2
);
run
(
4
,
8
,
64
,
23
,
23
/
3
,
2
);
run
(
4
,
8
,
64
,
25
,
25
/
3
,
2
);
run
(
4
,
8
,
64
,
27
,
27
/
3
,
2
);
run
(
4
,
8
,
64
,
29
,
29
/
3
,
2
);
run
(
4
,
8
,
64
,
31
,
31
/
3
,
2
);
run
(
1
,
2
,
128
,
31
,
31
/
3
,
2
);
run
(
1
,
2
,
256
,
31
,
31
/
3
,
2
);
}
}
...
...
@@ -950,7 +960,7 @@ TEST_F(CUDA, CONVOLUTION_BWD_DATA_BENCHMARK) {
run
(
32
,
64
,
64
,
56
,
56
,
1
,
1
,
0
);
}
TEST_F
(
CUDA
,
BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER
)
{
TEST_F
(
CUDA
,
BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER
_FP32
)
{
CUBenchmarker
<
ConvolutionBackwardData
>
bencher
{
handle_cuda
()};
bencher
.
set_display
(
false
);
bencher
.
set_before_exec_callback
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录