Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e698ec20
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e698ec20
编写于
2月 15, 2022
作者:
M
Megvii Engine Team
提交者:
王彪
2月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(cuda): float16 depthwise large kernel conv compute fp32
GitOrigin-RevId: 3050d48f2691faeeda4fb054134041cc620b5a35
上级
48406382
变更
8
展开全部
显示空白变更内容
内联
并排
Showing
8 changed file
with
264 addition
and
191 deletion
+264
-191
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl
...c/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl
+137
-85
dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu
dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu
+0
-1
dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
+4
-5
dnn/src/cuda/convolution/backward_data/algo.h
dnn/src/cuda/convolution/backward_data/algo.h
+1
-1
dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp
...cuda/convolution/backward_data/depthwise_large_filter.cpp
+7
-5
dnn/src/cuda/fp16_help.cuh
dnn/src/cuda/fp16_help.cuh
+6
-0
dnn/test/cuda/conv_bias.cpp
dnn/test/cuda/conv_bias.cpp
+63
-58
dnn/test/cuda/convolution.cpp
dnn/test/cuda/convolution.cpp
+46
-36
未找到文件。
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl
浏览文件 @
e698ec20
此差异已折叠。
点击以展开。
dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu
浏览文件 @
e698ec20
...
...
@@ -11,7 +11,6 @@
#include "cuda.h"
#include "cuda_fp16.h"
// #include "src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cuh"
#include "src/cuda/conv_bias/chanwise/kern.cuh"
#include "src/cuda/conv_bias/chanwise/kern_helper.cuh"
#include "src/cuda/conv_bias/chanwise/launch_config.cuh"
...
...
dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
浏览文件 @
e698ec20
...
...
@@ -32,15 +32,14 @@ inline bool is_available_depthwise_large_filter(const chanwise::Param& param) {
:
1
+
(
ow
+
3
)
/
4
+
flt_smem_w
/
4
-
1
;
int
out_reg_per_thread
=
(
ow
+
3
)
/
4
*
4
;
if
(
device_prop
.
regsPerBlock
<
4
*
32
*
(
flt_reg_per_thread
+
src_reg_per_thread
+
out_reg_per_thread
)
||
(
flt_reg_per_thread
*
2
+
src_reg_per_thread
+
out_reg_per_thread
)
||
device_prop
.
sharedMemPerBlock
<
static_cast
<
size_t
>
(
flt_smem_w
*
flt_smem_h
+
src_smem_w
*
src_smem_h
))
{
flt_smem_w
*
flt_smem_h
*
2
+
src_smem_w
*
src_smem_h
))
{
return
false
;
}
return
param
.
stride_h
==
1
&&
param
.
stride_w
==
1
&&
param
.
src_h
==
param
.
out_h
&&
param
.
src_w
==
param
.
out_w
;
return
true
;
}
}
// anonymous namespace
...
...
dnn/src/cuda/convolution/backward_data/algo.h
浏览文件 @
e698ec20
...
...
@@ -68,7 +68,7 @@ public:
const
TensorLayout
&
grad
);
convolution
::
ForwardSizeArgs
as_fwd_args
()
const
{
return
{
handle
,
grad_layout
,
filter_layout
,
filter_meta
,
diff
_layout
};
return
{
handle
,
diff_layout
,
filter_layout
,
filter_meta
,
grad
_layout
};
}
};
struct
ExecArgs
:
public
SizeArgs
{
...
...
dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp
浏览文件 @
e698ec20
...
...
@@ -31,15 +31,17 @@ inline bool is_available_depthwise_large_filter(const chanwise::Param& param) {
:
1
+
(
ow
+
3
)
/
4
+
flt_smem_w
/
4
-
1
;
int
out_reg_per_thread
=
(
ow
+
3
)
/
4
*
4
;
if
(
device_prop
.
regsPerBlock
<
4
*
32
*
(
flt_reg_per_thread
+
src_reg_per_thread
+
out_reg_per_thread
)
||
(
flt_reg_per_thread
*
2
+
src_reg_per_thread
+
out_reg_per_thread
)
||
device_prop
.
sharedMemPerBlock
<
static_cast
<
size_t
>
(
flt_smem_w
*
flt_smem_h
+
src_smem_w
*
src_smem_h
))
{
flt_smem_w
*
flt_smem_h
*
2
+
src_smem_w
*
src_smem_h
))
{
return
false
;
}
return
param
.
stride_h
==
1
&&
param
.
stride_w
==
1
&&
param
.
src_h
==
param
.
out_h
&&
param
.
src_w
==
param
.
out_w
;
printf
(
"param.src_w = %d, param.src_h = %d, param.out_w = %d, param.out_h = %d
\n
"
,
param
.
src_w
,
param
.
src_h
,
param
.
out_w
,
param
.
out_h
);
return
(
param
.
stride_h
==
1
&&
param
.
stride_w
==
1
)
||
(
param
.
stride_h
==
2
&&
param
.
stride_w
==
2
);
}
}
// anonymous namespace
...
...
dnn/src/cuda/fp16_help.cuh
浏览文件 @
e698ec20
...
...
@@ -45,6 +45,12 @@ fma2(const __half2 a, const __half2 b, const __half2 c) {
#endif
}
__device__
__forceinline__
float2
fma2
(
const
__half2
a
,
const
__half2
b
,
const
float2
c
)
{
return
{
__half2float
(
a
.
x
)
*
__half2float
(
b
.
x
)
+
c
.
x
,
__half2float
(
a
.
y
)
*
__half2float
(
b
.
y
)
+
c
.
y
};
}
#endif // CUDA_VERSION >= 9000
}
// namespace cuda
...
...
dnn/test/cuda/conv_bias.cpp
浏览文件 @
e698ec20
...
...
@@ -701,8 +701,10 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
ConvBiasForward
::
algo_name
<
ConvBias
::
DirectParam
>
(
"DEPTHWISE_LARGE_FILTER"
,
{})
.
c_str
()));
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float16
()})
{
auto
run
=
[
&
checker
,
&
dtype
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
)
{
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
auto
run
=
[
&
checker
,
&
dtype
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
,
size_t
padding
,
size_t
stride
)
{
param
::
ConvBias
cur_param
;
cur_param
.
mode
=
param
::
ConvBias
::
Mode
::
CROSS_CORRELATION
;
cur_param
.
sparse
=
ConvBias
::
Param
::
Sparse
::
GROUP
;
...
...
@@ -711,42 +713,52 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
.
set_dtype
(
2
,
dtype
)
.
set_dtype
(
3
,
dtype
)
.
set_dtype
(
4
,
dtype
);
float
scale
=
64.
f
/
sqrt
(
fh
*
fh
);
UniformFloatRNG
rng
(
scale
,
2
*
scale
);
checker
.
set_rng
(
0
,
&
rng
)
.
set_rng
(
1
,
&
rng
)
.
set_rng
(
2
,
&
rng
)
.
set_rng
(
3
,
&
rng
)
.
set_rng
(
4
,
&
rng
);
if
(
dtype
.
enumv
()
==
DTypeEnum
::
Float16
)
{
checker
.
set_epsilon
(
1e-1
);
}
cur_param
.
pad_h
=
cur_param
.
pad_w
=
fh
/
2
;
cur_param
.
stride_h
=
cur_param
.
stride_w
=
1
;
cur_param
.
pad_h
=
cur_param
.
pad_w
=
padding
;
cur_param
.
stride_h
=
cur_param
.
stride_w
=
stride
;
checker
.
set_param
(
cur_param
).
execs
(
{{
n
,
g
,
h
,
h
},
{
g
,
1
,
1
,
fh
,
fh
},
{},
{},
{}});
};
run
(
4
,
8
,
32
,
5
);
run
(
4
,
8
,
32
,
7
);
run
(
4
,
8
,
32
,
9
);
run
(
4
,
8
,
32
,
11
);
run
(
4
,
8
,
32
,
13
);
run
(
4
,
8
,
32
,
15
);
run
(
4
,
8
,
32
,
17
);
run
(
4
,
8
,
32
,
19
);
run
(
4
,
8
,
32
,
21
);
run
(
4
,
8
,
32
,
23
);
run
(
4
,
8
,
32
,
25
);
run
(
4
,
8
,
32
,
27
);
run
(
4
,
8
,
32
,
29
);
run
(
4
,
8
,
32
,
31
);
run
(
4
,
8
,
64
,
5
);
run
(
4
,
8
,
64
,
7
);
run
(
4
,
8
,
64
,
9
);
run
(
4
,
8
,
64
,
11
);
run
(
4
,
8
,
64
,
13
);
run
(
4
,
8
,
64
,
15
);
run
(
4
,
8
,
64
,
17
);
run
(
4
,
8
,
64
,
19
);
run
(
4
,
8
,
64
,
21
);
run
(
4
,
8
,
64
,
23
);
run
(
4
,
8
,
64
,
25
);
run
(
4
,
8
,
64
,
27
);
run
(
4
,
8
,
64
,
29
);
run
(
4
,
8
,
64
,
31
);
run
(
1
,
2
,
128
,
31
);
run
(
1
,
2
,
256
,
31
);
run
(
4
,
8
,
32
,
5
,
5
/
2
,
1
);
run
(
4
,
8
,
32
,
7
,
7
/
2
,
1
);
run
(
4
,
8
,
32
,
9
,
9
/
2
,
1
);
run
(
4
,
8
,
32
,
11
,
11
/
2
,
1
);
run
(
4
,
8
,
32
,
13
,
13
/
2
,
1
);
run
(
4
,
8
,
32
,
15
,
15
/
2
,
1
);
run
(
4
,
8
,
32
,
17
,
17
/
2
,
1
);
run
(
4
,
8
,
32
,
19
,
19
/
2
,
1
);
run
(
4
,
8
,
32
,
21
,
21
/
2
,
1
);
run
(
4
,
8
,
32
,
23
,
23
/
2
,
1
);
run
(
4
,
8
,
32
,
25
,
25
/
2
,
1
);
run
(
4
,
8
,
32
,
27
,
27
/
2
,
1
);
run
(
4
,
8
,
32
,
29
,
29
/
2
,
1
);
run
(
4
,
8
,
32
,
31
,
31
/
2
,
1
);
run
(
4
,
8
,
64
,
5
,
5
/
3
,
2
);
run
(
4
,
8
,
64
,
7
,
7
/
3
,
2
);
run
(
4
,
8
,
64
,
9
,
9
/
3
,
2
);
run
(
4
,
8
,
64
,
11
,
11
/
3
,
2
);
run
(
4
,
8
,
64
,
13
,
13
/
3
,
2
);
run
(
4
,
8
,
64
,
15
,
15
/
3
,
2
);
run
(
4
,
8
,
64
,
17
,
17
/
3
,
2
);
run
(
4
,
8
,
64
,
19
,
19
/
3
,
2
);
run
(
4
,
8
,
64
,
21
,
21
/
3
,
2
);
run
(
4
,
8
,
64
,
23
,
23
/
3
,
2
);
run
(
4
,
8
,
64
,
25
,
25
/
3
,
2
);
run
(
4
,
8
,
64
,
27
,
27
/
3
,
2
);
run
(
4
,
8
,
64
,
29
,
29
/
3
,
2
);
run
(
4
,
8
,
64
,
31
,
31
/
3
,
2
);
run
(
1
,
2
,
128
,
31
,
10
,
2
);
run
(
1
,
2
,
256
,
31
,
10
,
2
);
}
}
...
...
@@ -1530,7 +1542,7 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_TENSORCORE_INT8) {
run_bench
(
256
,
512
,
7
,
7
,
2048
,
1
,
1
,
1
,
1
,
1000
);
}
TEST_F
(
CUDA
,
BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER
)
{
TEST_F
(
CUDA
,
BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER
_FP16
)
{
require_compute_capability
(
7
,
5
);
Benchmarker
<
ConvBiasForward
>
bencher
(
handle_cuda
());
bencher
.
set_display
(
false
);
...
...
@@ -1552,6 +1564,11 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
param
.
stride_h
=
sh
;
param
.
stride_w
=
sw
;
bencher
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
Float16
())
.
set_dtype
(
1
,
dtype
::
Float16
())
.
set_dtype
(
2
,
dtype
::
Float16
())
.
set_dtype
(
4
,
dtype
::
Float16
());
bencher
.
set_times
(
nr_times
);
size_t
ho
=
infer_conv_shape
(
hi
,
fh
,
sh
,
param
.
pad_h
);
size_t
wo
=
infer_conv_shape
(
wi
,
fw
,
sw
,
param
.
pad_w
);
...
...
@@ -1562,25 +1579,13 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
out
.
total_nr_elems
())
/
(
1024
*
1024
*
1024
)
*
1e3
;
bencher
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_dtype
(
2
,
dtype
::
Float32
())
.
set_dtype
(
4
,
dtype
::
Float32
());
auto
fp32_time_in_ms
=
bencher
.
execs
({
inp
,
kern
,
{},
{},
out
})
/
nr_times
;
bencher
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
Float16
())
.
set_dtype
(
1
,
dtype
::
Float16
())
.
set_dtype
(
2
,
dtype
::
Float16
())
.
set_dtype
(
4
,
dtype
::
Float16
());
auto
fp16_time_in_ms
=
bencher
.
execs
({
inp
,
kern
,
{},
{},
out
})
/
nr_times
;
printf
(
"chanwise_depthwise_large_filter: inp=%s, kern=%s, out=%s, fp32_time: "
"%.2fms, fp16_time: %.2fms, speedup: %0.2f (fp16/fp32) "
"fp32_bandwidth: %.2fGB/s fp16_bandwidth: %.2fGB/s.
\n
"
,
auto
time_in_ms
=
bencher
.
execs
({
inp
,
kern
,
{},
{},
out
})
/
nr_times
;
auto
ops
=
2.0
*
batch
*
g
*
ho
*
wo
*
fh
*
fw
/
(
time_in_ms
*
1e-3
)
*
1e-12
;
printf
(
"chanwise_depthwise_large_filter: inp=%s, kern=%s, out=%s, time: "
"%.2fms, "
"perf: %.2f Tops bandwidth: %.2fGB/s.
\n
"
,
inp
.
to_string
().
c_str
(),
kern
.
to_string
().
c_str
(),
out
.
to_string
().
c_str
(),
fp32_time_in_ms
,
fp16_time_in_ms
,
fp32_time_in_ms
/
fp16_time_in_ms
,
bandwith
*
4
/
fp32_time_in_ms
,
bandwith
*
2
/
fp16_time_in_ms
);
out
.
to_string
().
c_str
(),
time_in_ms
,
ops
,
bandwith
*
4
/
time_in_ms
);
};
run_bench
(
64
,
384
,
32
,
32
,
3
,
3
,
1
,
1
,
10
);
...
...
@@ -1600,7 +1605,7 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
run_bench
(
64
,
384
,
32
,
32
,
31
,
31
,
1
,
1
,
10
);
}
TEST_F
(
CUDA
,
BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP
16
)
{
TEST_F
(
CUDA
,
BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP
32
)
{
require_compute_capability
(
7
,
5
);
Benchmarker
<
ConvBiasForward
>
bencher
(
handle_cuda
());
bencher
.
set_display
(
false
);
...
...
@@ -1623,10 +1628,10 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP16) {
param
.
stride_w
=
sw
;
bencher
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
Float
16
())
.
set_dtype
(
1
,
dtype
::
Float
16
())
.
set_dtype
(
2
,
dtype
::
Float
16
())
.
set_dtype
(
4
,
dtype
::
Float
16
());
.
set_dtype
(
0
,
dtype
::
Float
32
())
.
set_dtype
(
1
,
dtype
::
Float
32
())
.
set_dtype
(
2
,
dtype
::
Float
32
())
.
set_dtype
(
4
,
dtype
::
Float
32
());
bencher
.
set_times
(
nr_times
);
size_t
ho
=
infer_conv_shape
(
hi
,
fh
,
sh
,
param
.
pad_h
);
size_t
wo
=
infer_conv_shape
(
wi
,
fw
,
sw
,
param
.
pad_w
);
...
...
dnn/test/cuda/convolution.cpp
浏览文件 @
e698ec20
...
...
@@ -728,48 +728,58 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) {
Checker
<
ConvolutionBackwardData
>
checker
(
handle_cuda
());
checker
.
set_before_exec_callback
(
AlgoChecker
<
ConvolutionBackwardData
>
(
"DEPTHWISE_LARGE_FILTER"
));
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float16
()})
{
auto
run
=
[
&
checker
,
&
dtype
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
)
{
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
auto
run
=
[
&
checker
,
&
dtype
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
,
size_t
padding
,
size_t
stride
)
{
param
::
Convolution
param
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
pad_h
=
param
.
pad_w
=
fh
/
2
;
param
.
stride_h
=
param
.
stride_w
=
stride
;
param
.
pad_h
=
param
.
pad_w
=
padding
;
param
.
mode
=
Convolution
::
Mode
::
CROSS_CORRELATION
;
param
.
sparse
=
param
::
Convolution
::
Sparse
::
GROUP
;
checker
.
set_dtype
(
0
,
dtype
).
set_dtype
(
1
,
dtype
).
set_dtype
(
2
,
dtype
);
float
scale
=
64.
f
/
sqrt
(
fh
*
fh
);
UniformFloatRNG
rng
(
1.0
,
1.0
);
checker
.
set_rng
(
0
,
&
rng
).
set_rng
(
1
,
&
rng
).
set_rng
(
2
,
&
rng
);
if
(
dtype
.
enumv
()
==
DTypeEnum
::
Float16
)
checker
.
set_epsilon
(
1e-1
);
checker
.
set_param
(
param
).
execs
(
{{
g
,
1
,
1
,
fh
,
fh
},
{
n
,
g
,
h
,
h
},
{
n
,
g
,
h
,
h
}});
{{
g
,
1
,
1
,
fh
,
fh
},
{
n
,
g
,
(
h
+
2
*
padding
-
fh
+
1
)
/
stride
,
(
h
+
2
*
padding
-
fh
+
1
)
/
stride
},
{
n
,
g
,
h
,
h
}});
};
run
(
4
,
8
,
32
,
5
);
run
(
4
,
8
,
32
,
7
);
run
(
4
,
8
,
32
,
9
);
run
(
4
,
8
,
32
,
11
);
run
(
4
,
8
,
32
,
13
);
run
(
4
,
8
,
32
,
15
);
run
(
4
,
8
,
32
,
17
);
run
(
4
,
8
,
32
,
19
);
run
(
4
,
8
,
32
,
21
);
run
(
4
,
8
,
32
,
23
);
run
(
4
,
8
,
32
,
25
);
run
(
4
,
8
,
32
,
27
);
run
(
4
,
8
,
32
,
29
);
run
(
4
,
8
,
32
,
31
);
run
(
4
,
8
,
64
,
7
);
run
(
4
,
8
,
64
,
5
);
run
(
4
,
8
,
64
,
9
);
run
(
4
,
8
,
64
,
11
);
run
(
4
,
8
,
64
,
13
);
run
(
4
,
8
,
64
,
15
);
run
(
4
,
8
,
64
,
17
);
run
(
4
,
8
,
64
,
19
);
run
(
4
,
8
,
64
,
21
);
run
(
4
,
8
,
64
,
23
);
run
(
4
,
8
,
64
,
25
);
run
(
4
,
8
,
64
,
27
);
run
(
4
,
8
,
64
,
29
);
run
(
4
,
8
,
64
,
31
);
run
(
1
,
2
,
128
,
31
);
run
(
1
,
2
,
256
,
31
);
run
(
4
,
8
,
32
,
5
,
5
/
2
,
1
);
run
(
4
,
8
,
32
,
7
,
7
/
2
,
1
);
run
(
4
,
8
,
32
,
9
,
9
/
2
,
1
);
run
(
4
,
8
,
32
,
11
,
11
/
2
,
1
);
run
(
4
,
8
,
32
,
13
,
13
/
2
,
1
);
run
(
4
,
8
,
32
,
15
,
15
/
2
,
1
);
run
(
4
,
8
,
32
,
17
,
17
/
2
,
1
);
run
(
4
,
8
,
32
,
19
,
19
/
2
,
1
);
run
(
4
,
8
,
32
,
21
,
21
/
2
,
1
);
run
(
4
,
8
,
32
,
23
,
23
/
2
,
1
);
run
(
4
,
8
,
32
,
25
,
25
/
2
,
1
);
run
(
4
,
8
,
32
,
27
,
27
/
2
,
1
);
run
(
4
,
8
,
32
,
29
,
29
/
2
,
1
);
run
(
4
,
8
,
32
,
31
,
31
/
2
,
1
);
run
(
4
,
8
,
64
,
5
,
5
/
2
,
2
);
run
(
4
,
8
,
64
,
7
,
7
/
3
,
2
);
run
(
4
,
8
,
64
,
9
,
9
/
3
,
2
);
run
(
4
,
8
,
64
,
11
,
11
/
3
,
2
);
run
(
4
,
8
,
64
,
13
,
13
/
3
,
2
);
run
(
4
,
8
,
64
,
15
,
15
/
3
,
2
);
run
(
4
,
8
,
64
,
17
,
17
/
3
,
2
);
run
(
4
,
8
,
64
,
19
,
19
/
3
,
2
);
run
(
4
,
8
,
64
,
21
,
21
/
3
,
2
);
run
(
4
,
8
,
64
,
23
,
23
/
3
,
2
);
run
(
4
,
8
,
64
,
25
,
25
/
3
,
2
);
run
(
4
,
8
,
64
,
27
,
27
/
3
,
2
);
run
(
4
,
8
,
64
,
29
,
29
/
3
,
2
);
run
(
4
,
8
,
64
,
31
,
31
/
3
,
2
);
run
(
1
,
2
,
128
,
31
,
31
/
3
,
2
);
run
(
1
,
2
,
256
,
31
,
31
/
3
,
2
);
}
}
...
...
@@ -950,7 +960,7 @@ TEST_F(CUDA, CONVOLUTION_BWD_DATA_BENCHMARK) {
run
(
32
,
64
,
64
,
56
,
56
,
1
,
1
,
0
);
}
TEST_F
(
CUDA
,
BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER
)
{
TEST_F
(
CUDA
,
BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER
_FP32
)
{
CUBenchmarker
<
ConvolutionBackwardData
>
bencher
{
handle_cuda
()};
bencher
.
set_display
(
false
);
bencher
.
set_before_exec_callback
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录