Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8ba8c11d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
8ba8c11d
编写于
3月 26, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): add nchw44 layout
GitOrigin-RevId: d92672b88a48a2de396532ccbc6bd7e467d5eab9
上级
a744b3cb
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
458 addition
and
95 deletion
+458
-95
dnn/scripts/opr_param_defs.py
dnn/scripts/opr_param_defs.py
+2
-1
dnn/src/common/conv_bias.cpp
dnn/src/common/conv_bias.cpp
+16
-10
dnn/src/common/convolution.cpp
dnn/src/common/convolution.cpp
+72
-11
dnn/src/common/pooling.cpp
dnn/src/common/pooling.cpp
+6
-2
dnn/src/fallback/conv_bias/opr_impl.cpp
dnn/src/fallback/conv_bias/opr_impl.cpp
+8
-6
dnn/src/fallback/conv_bias/winograd/winograd.h
dnn/src/fallback/conv_bias/winograd/winograd.h
+29
-8
dnn/src/fallback/convolution/opr_impl.cpp
dnn/src/fallback/convolution/opr_impl.cpp
+4
-2
dnn/src/naive/convolution/helper.h
dnn/src/naive/convolution/helper.h
+59
-2
dnn/src/naive/pooling/opr_impl.cpp
dnn/src/naive/pooling/opr_impl.cpp
+14
-1
dnn/test/naive/conv_bias.cpp
dnn/test/naive/conv_bias.cpp
+248
-52
未找到文件。
dnn/scripts/opr_param_defs.py
浏览文件 @
8ba8c11d
...
@@ -35,9 +35,10 @@ pdef('Axis').add_fields('int32', 'axis', 0)
...
@@ -35,9 +35,10 @@ pdef('Axis').add_fields('int32', 'axis', 0)
).
).
add_enum
(
Doc
(
'Format'
,
'convolution data/filter/output format; see '
add_enum
(
Doc
(
'Format'
,
'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'
),
':class:`RelayoutFormat` for more details'
),
'NCHW'
,
'NHWC'
,
'NHWCD4'
,
'NCHW4'
,
'NCHW8'
,
'NCHW32'
,
'NCHW88'
,
'NCHW'
,
'NHWC'
,
'NHWCD4'
,
'NCHW4'
,
'NCHW8'
,
'NCHW32'
,
'NCHW88'
,
'NCHW44'
,
Doc
(
'NCHW_WINOGRAD'
,
'NCHW layout with weights tranformed by winograd'
),
Doc
(
'NCHW_WINOGRAD'
,
'NCHW layout with weights tranformed by winograd'
),
Doc
(
'NCHW88_WINOGRAD'
,
'NCHW88 layout with weights tranformed by winograd'
),
Doc
(
'NCHW88_WINOGRAD'
,
'NCHW88 layout with weights tranformed by winograd'
),
Doc
(
'NCHW44_WINOGRAD'
,
'NCHW44 layout with weights tranformed by winograd'
),
Doc
(
'CHWN4'
,
'CHWN4 is currently only used on Nvidia platform for fast implementation '
Doc
(
'CHWN4'
,
'CHWN4 is currently only used on Nvidia platform for fast implementation '
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'
))
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'
))
)
)
...
...
dnn/src/common/conv_bias.cpp
浏览文件 @
8ba8c11d
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "src/common/conv_bias.h"
#include "src/common/conv_bias.h"
...
@@ -33,7 +34,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
...
@@ -33,7 +34,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
)
{
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
)
{
if
((
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
||
if
((
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
||
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
)
&&
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
||
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW44_WINOGRAD
)
&&
src
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
)
{
src
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
)
{
megdnn_assert
(
filter
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS16
);
megdnn_assert
(
filter
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS16
);
megdnn_assert
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
megdnn_assert
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
...
@@ -45,7 +47,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
...
@@ -45,7 +47,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
float
scale_src
=
src
.
dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
float
scale_src
=
src
.
dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
float
scale_filter
=
0.
f
;
float
scale_filter
=
0.
f
;
if
(
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
||
if
(
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
||
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
)
{
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
||
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW44_WINOGRAD
)
{
scale_filter
=
filter
.
dtype
.
param
<
dtype
::
QuantizedS16
>
().
scale
;
scale_filter
=
filter
.
dtype
.
param
<
dtype
::
QuantizedS16
>
().
scale
;
}
else
{
}
else
{
scale_filter
=
filter
.
dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
scale_filter
=
filter
.
dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
...
@@ -58,7 +61,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
...
@@ -58,7 +61,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
float
scale_src
=
src
.
dtype
.
param
<
dtype
::
Quantized8Asymm
>
().
scale
;
float
scale_src
=
src
.
dtype
.
param
<
dtype
::
Quantized8Asymm
>
().
scale
;
float
scale_filter
=
0.
f
;
float
scale_filter
=
0.
f
;
if
(
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
||
if
(
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
||
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
)
{
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
||
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW44_WINOGRAD
)
{
scale_filter
=
filter
.
dtype
.
param
<
dtype
::
QuantizedS16
>
().
scale
;
scale_filter
=
filter
.
dtype
.
param
<
dtype
::
QuantizedS16
>
().
scale
;
}
else
{
}
else
{
scale_filter
=
filter
.
dtype
.
param
<
dtype
::
Quantized8Asymm
>
().
scale
;
scale_filter
=
filter
.
dtype
.
param
<
dtype
::
Quantized8Asymm
>
().
scale
;
...
@@ -98,7 +102,9 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
...
@@ -98,7 +102,9 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
megdnn_assert
(
bias
.
shape
[
2
]
==
1
);
megdnn_assert
(
bias
.
shape
[
2
]
==
1
);
megdnn_assert
(
bias
.
shape
[
3
]
==
dst
.
shape
[
3
],
"bias:%s, dst:%s"
,
megdnn_assert
(
bias
.
shape
[
3
]
==
dst
.
shape
[
3
],
"bias:%s, dst:%s"
,
bias
.
to_string
().
c_str
(),
dst
.
to_string
().
c_str
());
bias
.
to_string
().
c_str
(),
dst
.
to_string
().
c_str
());
}
else
if
(
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW4
)
{
}
else
if
(
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW4
||
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW44
||
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW44_WINOGRAD
)
{
megdnn_assert
(
bias
.
shape
[
0
]
==
1
);
megdnn_assert
(
bias
.
shape
[
0
]
==
1
);
megdnn_assert
(
bias
.
shape
[
1
]
==
dst
.
shape
[
1
],
"bias:%s, dst:%s"
,
megdnn_assert
(
bias
.
shape
[
1
]
==
dst
.
shape
[
1
],
"bias:%s, dst:%s"
,
bias
.
to_string
().
c_str
(),
dst
.
to_string
().
c_str
());
bias
.
to_string
().
c_str
(),
dst
.
to_string
().
c_str
());
...
@@ -141,7 +147,10 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
...
@@ -141,7 +147,10 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
if
(
z
.
ndim
!=
0
)
{
if
(
z
.
ndim
!=
0
)
{
megdnn_assert
(
param
().
format
!=
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
);
megdnn_assert
(
param
().
format
!=
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
);
megdnn_assert
(
param
().
format
!=
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
);
megdnn_assert
(
param
().
format
!=
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
);
megdnn_assert
(
param
().
format
!=
param
::
ConvBias
::
Format
::
NCHW44_WINOGRAD
);
megdnn_assert
(
z
.
dtype
.
enumv
()
==
dst
.
dtype
.
enumv
());
megdnn_assert
(
z
.
dtype
.
enumv
()
==
dst
.
dtype
.
enumv
());
megdnn_assert
(
z
.
eq_shape
(
dst
));
megdnn_assert
(
z
.
eq_shape
(
dst
));
}
}
...
@@ -163,10 +172,7 @@ std::string ConvBias::algo_name(const std::string& base, const T& p) {
...
@@ -163,10 +172,7 @@ std::string ConvBias::algo_name(const std::string& base, const T& p) {
}
}
#define FOREACH_CONV_BIAS_PARAM(cb) \
#define FOREACH_CONV_BIAS_PARAM(cb) \
cb(WinogradParam) \
cb(WinogradParam) cb(DirectParam) cb(MatmulParam) cb(DefaultParam)
cb(DirectParam) \
cb(MatmulParam) \
cb(DefaultParam)
#define cb(pt) \
#define cb(pt) \
template <> \
template <> \
...
...
dnn/src/common/convolution.cpp
浏览文件 @
8ba8c11d
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "megdnn/oprs/nn.h"
#include "megdnn/oprs/nn.h"
...
@@ -55,7 +56,13 @@ spatial_getter<param::ConvBias, param::ConvBias::Format::NCHW88_WINOGRAD>(
...
@@ -55,7 +56,13 @@ spatial_getter<param::ConvBias, param::ConvBias::Format::NCHW88_WINOGRAD>(
//! f = m + r - 1 -> r = f + 1 - m
//! f = m + r - 1 -> r = f + 1 - m
return
filter
-
param
.
output_block_size
+
1
;
return
filter
-
param
.
output_block_size
+
1
;
}
}
template
<
>
uint32_t
spatial_getter
<
param
::
ConvBias
,
param
::
ConvBias
::
Format
::
NCHW44_WINOGRAD
>
(
uint32_t
filter
,
const
param
::
ConvBias
&
param
)
{
//! f = m + r - 1 -> r = f + 1 - m
return
filter
-
param
.
output_block_size
+
1
;
}
template
<
typename
Parameter
,
typename
Param
>
template
<
typename
Parameter
,
typename
Param
>
void
make_canonized_filter_meta_nchw_nhwc
(
void
make_canonized_filter_meta_nchw_nhwc
(
...
@@ -273,7 +280,7 @@ void make_canonized_filter_meta_nchwxx(
...
@@ -273,7 +280,7 @@ void make_canonized_filter_meta_nchwxx(
/**
/**
* input: N IC/pack_size, H, W, pack_size
* input: N IC/pack_size, H, W, pack_size
*
*
* NCHW88 mode
* NCHW88
and NCHW44
mode
* filter:
* filter:
* {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)}
* {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)}
* [dense]
* [dense]
...
@@ -281,7 +288,7 @@ void make_canonized_filter_meta_nchwxx(
...
@@ -281,7 +288,7 @@ void make_canonized_filter_meta_nchwxx(
* FH, FW, pack_size(IC), pack_size(OC)} [group]
* FH, FW, pack_size(IC), pack_size(OC)} [group]
* {GROUP/pack_size, 1, 1, FH, FW, pack_size} [chan]
* {GROUP/pack_size, 1, 1, FH, FW, pack_size} [chan]
*
*
** NCHW88_WINOGRAD mode
** NCHW88_WINOGRAD
and NCHW44_WINOGRAD
mode
* filter:
* filter:
* {alpha, alpha, OC/pack_size, IC/pack_size, pack_size(IC),
* {alpha, alpha, OC/pack_size, IC/pack_size, pack_size(IC),
*pack_size(OC)} [dense]
*pack_size(OC)} [dense]
...
@@ -291,6 +298,7 @@ void make_canonized_filter_meta_nchwxx(
...
@@ -291,6 +298,7 @@ void make_canonized_filter_meta_nchwxx(
*/
*/
megdnn_assert
(
param
.
format
==
Param
::
Format
::
NCHW88
||
megdnn_assert
(
param
.
format
==
Param
::
Format
::
NCHW88
||
param
.
format
==
Param
::
Format
::
NCHW44
||
param
.
format
==
Param
::
Format
::
NCHW88_WINOGRAD
);
param
.
format
==
Param
::
Format
::
NCHW88_WINOGRAD
);
size_t
img_ndim
=
2
;
size_t
img_ndim
=
2
;
size_t
flt_start
=
0
;
size_t
flt_start
=
0
;
...
@@ -305,7 +313,8 @@ void make_canonized_filter_meta_nchwxx(
...
@@ -305,7 +313,8 @@ void make_canonized_filter_meta_nchwxx(
filter
[
filter
.
ndim
-
1
]);
filter
[
filter
.
ndim
-
1
]);
ret
.
group
=
1
;
ret
.
group
=
1
;
flt_start
=
0
;
flt_start
=
0
;
if
(
param
.
format
==
Param
::
Format
::
NCHW88_WINOGRAD
)
{
if
(
param
.
format
==
Param
::
Format
::
NCHW88_WINOGRAD
||
param
.
format
==
Param
::
Format
::
NCHW44_WINOGRAD
)
{
flt_start
=
2
;
flt_start
=
2
;
}
}
ret
.
ocpg
=
filter
[
flt_start
]
*
pack_size
;
ret
.
ocpg
=
filter
[
flt_start
]
*
pack_size
;
...
@@ -314,6 +323,8 @@ void make_canonized_filter_meta_nchwxx(
...
@@ -314,6 +323,8 @@ void make_canonized_filter_meta_nchwxx(
// ohwi8o
// ohwi8o
megdnn_assert
(
param
.
format
!=
Param
::
Format
::
NCHW88_WINOGRAD
,
megdnn_assert
(
param
.
format
!=
Param
::
Format
::
NCHW88_WINOGRAD
,
"Hybrid nchw88 mode in not support winograd"
);
"Hybrid nchw88 mode in not support winograd"
);
megdnn_assert
(
param
.
format
!=
Param
::
Format
::
NCHW44_WINOGRAD
,
"Hybrid nchw44 mode in not support winograd"
);
flt_start
=
0
;
flt_start
=
0
;
flt_spatial_start
=
1
;
flt_spatial_start
=
1
;
ret
.
group
=
1
;
ret
.
group
=
1
;
...
@@ -321,20 +332,22 @@ void make_canonized_filter_meta_nchwxx(
...
@@ -321,20 +332,22 @@ void make_canonized_filter_meta_nchwxx(
ret
.
icpg
=
filter
[
flt_start
+
3
];
ret
.
icpg
=
filter
[
flt_start
+
3
];
}
else
{
}
else
{
megdnn_assert
(
0
,
"not support nchw
88
filter dim = %zu"
,
megdnn_assert
(
0
,
"not support nchw
xx
filter dim = %zu"
,
filter
.
ndim
);
filter
.
ndim
);
}
}
}
else
{
}
else
{
megdnn_assert
(
param
.
sparse
==
Param
::
Sparse
::
GROUP
,
megdnn_assert
(
param
.
sparse
==
Param
::
Sparse
::
GROUP
,
"invalid convolution sparse type"
);
"invalid convolution sparse type"
);
flt_start
=
1
;
flt_start
=
1
;
if
(
param
.
format
==
Param
::
Format
::
NCHW88_WINOGRAD
)
{
if
(
param
.
format
==
Param
::
Format
::
NCHW88_WINOGRAD
||
param
.
format
==
Param
::
Format
::
NCHW44_WINOGRAD
)
{
flt_start
=
3
;
flt_start
=
3
;
}
}
auto
filter_oc
=
filter
[
flt_start
];
auto
filter_oc
=
filter
[
flt_start
];
auto
filter_ic
=
filter
[
flt_start
+
1
];
auto
filter_ic
=
filter
[
flt_start
+
1
];
if
(
filter_oc
==
1
&&
filter_ic
==
1
&&
filter
.
ndim
==
(
img_ndim
+
4
)
&&
if
(
filter_oc
==
1
&&
filter_ic
==
1
&&
filter
.
ndim
==
(
img_ndim
+
4
)
&&
param
.
format
!=
Param
::
Format
::
NCHW88_WINOGRAD
)
{
param
.
format
!=
Param
::
Format
::
NCHW88_WINOGRAD
&&
param
.
format
!=
Param
::
Format
::
NCHW44_WINOGRAD
)
{
// Depthwise case goihw8g
// Depthwise case goihw8g
megdnn_assert
(
filter
.
ndim
==
img_ndim
+
4
,
megdnn_assert
(
filter
.
ndim
==
img_ndim
+
4
,
"bad filter ndim for group convolution: "
"bad filter ndim for group convolution: "
...
@@ -343,7 +356,7 @@ void make_canonized_filter_meta_nchwxx(
...
@@ -343,7 +356,7 @@ void make_canonized_filter_meta_nchwxx(
megdnn_assert
(
filter
[
filter
.
ndim
-
1
]
==
pack_size
,
megdnn_assert
(
filter
[
filter
.
ndim
-
1
]
==
pack_size
,
"last dim of filter must be %zu, but %zu"
,
pack_size
,
"last dim of filter must be %zu, but %zu"
,
pack_size
,
filter
[
filter
.
ndim
-
1
]);
filter
[
filter
.
ndim
-
1
]);
ret
.
group
=
filter
[
0
]
*
8
;
ret
.
group
=
filter
[
0
]
*
pack_size
;
ret
.
ocpg
=
filter_oc
;
ret
.
ocpg
=
filter_oc
;
ret
.
icpg
=
filter_ic
;
ret
.
icpg
=
filter_ic
;
...
@@ -381,6 +394,10 @@ void make_canonized_filter_meta_nchwxx(
...
@@ -381,6 +394,10 @@ void make_canonized_filter_meta_nchwxx(
ret
.
spatial
[
i
]
=
ret
.
spatial
[
i
]
=
spatial_getter
<
Param
,
Param
::
Format
::
NCHW88_WINOGRAD
>
(
spatial_getter
<
Param
,
Param
::
Format
::
NCHW88_WINOGRAD
>
(
filter
[
i
+
flt_start
-
2
],
param
);
filter
[
i
+
flt_start
-
2
],
param
);
}
else
if
(
param
.
format
==
Param
::
Format
::
NCHW44_WINOGRAD
)
{
ret
.
spatial
[
i
]
=
spatial_getter
<
Param
,
Param
::
Format
::
NCHW44_WINOGRAD
>
(
filter
[
i
+
flt_start
-
2
],
param
);
}
else
{
}
else
{
ret
.
spatial
[
i
]
=
filter
[
i
+
flt_start
+
flt_spatial_start
];
ret
.
spatial
[
i
]
=
filter
[
i
+
flt_start
+
flt_spatial_start
];
}
}
...
@@ -535,6 +552,10 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta(
...
@@ -535,6 +552,10 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta(
param
().
format
==
Param
::
Format
::
NCHW88_WINOGRAD
)
{
param
().
format
==
Param
::
Format
::
NCHW88_WINOGRAD
)
{
make_canonized_filter_meta_nchwxx
<
8
,
Parameter
>
(
src_ndim
,
filter
,
make_canonized_filter_meta_nchwxx
<
8
,
Parameter
>
(
src_ndim
,
filter
,
param
(),
ret
);
param
(),
ret
);
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
NCHW44_WINOGRAD
)
{
make_canonized_filter_meta_nchwxx
<
4
,
Parameter
>
(
src_ndim
,
filter
,
param
(),
ret
);
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW32
)
{
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW32
)
{
make_canonized_filter_meta_nchwx
<
32
,
Parameter
>
(
src_ndim
,
filter
,
make_canonized_filter_meta_nchwx
<
32
,
Parameter
>
(
src_ndim
,
filter
,
param
(),
ret
);
param
(),
ret
);
...
@@ -629,18 +650,22 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
...
@@ -629,18 +650,22 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
}
else
{
}
else
{
megdnn_assert
(
param
().
format
==
Param
::
Format
::
NHWCD4
||
megdnn_assert
(
param
().
format
==
Param
::
Format
::
NHWCD4
||
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
NCHW8
||
param
().
format
==
Param
::
Format
::
NCHW8
||
param
().
format
==
Param
::
Format
::
NCHW32
||
param
().
format
==
Param
::
Format
::
NCHW32
||
param
().
format
==
Param
::
Format
::
NCHW88
||
param
().
format
==
Param
::
Format
::
NCHW88
||
param
().
format
==
Param
::
Format
::
NCHW88_WINOGRAD
||
param
().
format
==
Param
::
Format
::
NCHW88_WINOGRAD
||
param
().
format
==
Param
::
Format
::
CHWN4
);
param
().
format
==
Param
::
Format
::
CHWN4
);
img_dim
=
src
.
ndim
-
3
;
img_dim
=
src
.
ndim
-
3
;
if
(
param
().
format
==
Param
::
Format
::
NCHW88
&&
filter
.
ndim
==
5
)
{
if
((
param
().
format
==
Param
::
Format
::
NCHW88
||
param
().
format
==
Param
::
Format
::
NCHW44
)
&&
filter
.
ndim
==
5
)
{
img_dim
=
src
.
ndim
-
2
;
img_dim
=
src
.
ndim
-
2
;
}
}
megdnn_assert
(
filter
.
ndim
==
img_dim
+
3
||
megdnn_assert
(
filter
.
ndim
==
img_dim
+
3
||
(
filter
.
ndim
==
img_dim
+
2
&&
(
filter
.
ndim
==
img_dim
+
2
&&
param
().
format
==
Param
::
Format
::
NCHW88
)
||
(
param
().
format
==
Param
::
Format
::
NCHW88
||
param
().
format
==
Param
::
Format
::
NCHW44
))
||
filter
.
ndim
==
img_dim
+
4
||
filter
.
ndim
==
img_dim
+
4
||
filter
.
ndim
==
img_dim
+
5
,
filter
.
ndim
==
img_dim
+
5
,
"%s"
,
errmsg
().
c_str
());
"%s"
,
errmsg
().
c_str
());
...
@@ -691,6 +716,21 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
...
@@ -691,6 +716,21 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
", and last shape two is 8 but got src %s, filter %s"
,
", and last shape two is 8 but got src %s, filter %s"
,
src
.
to_string
().
c_str
(),
filter
.
to_string
().
c_str
());
src
.
to_string
().
c_str
(),
filter
.
to_string
().
c_str
());
}
}
if
(
param
().
format
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
NCHW44_WINOGRAD
)
{
megdnn_assert
((
src
.
ndim
==
4
&&
filter
.
ndim
==
5
&&
filter
[
filter
.
ndim
-
1
]
==
4
)
||
(
src
.
ndim
==
5
&&
((
filter
.
ndim
==
6
&&
filter
[
filter
.
ndim
-
1
]
==
4
)
||
(
filter
.
ndim
==
7
&&
filter
[
filter
.
ndim
-
1
]
==
4
&&
filter
[
filter
.
ndim
-
2
]
==
4
))
&&
src
[
src
.
ndim
-
1
]
==
4
),
"NCHW44 require src ndim is 5 and filter's ndim is 6 "
", and last shape two is 4 but got src %s, filter %s"
,
src
.
to_string
().
c_str
(),
filter
.
to_string
().
c_str
());
}
if
(
param
().
format
==
Param
::
Format
::
CHWN4
)
{
if
(
param
().
format
==
Param
::
Format
::
CHWN4
)
{
megdnn_assert
(
megdnn_assert
(
src
.
ndim
==
5
&&
(
filter
.
ndim
==
5
||
filter
.
ndim
==
6
)
&&
src
.
ndim
==
5
&&
(
filter
.
ndim
==
5
||
filter
.
ndim
==
6
)
&&
...
@@ -808,6 +848,27 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
...
@@ -808,6 +848,27 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
cflt
.
group
);
cflt
.
group
);
}
}
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
NCHW44_WINOGRAD
)
{
megdnn_assert
(
src
.
ndim
==
5
||
(
src
.
ndim
==
4
&&
src
[
1
]
<=
8
),
"invalid src ndim for NCHW44, expected=5 or 4, got=%zu"
,
src
.
ndim
);
dst
.
ndim
=
5
;
dst
[
0
]
=
src
[
0
];
auto
oc
=
cflt
.
ocpg
*
cflt
.
group
;
megdnn_assert
(
oc
%
4
==
0
);
dst
[
1
]
=
oc
/
4
;
dst
[
2
]
=
infer_conv_shape
(
src
[
2
],
cflt
.
dilated_spatial
[
0
],
cflt
.
stride
[
0
],
cflt
.
padding
[
0
]);
dst
[
3
]
=
infer_conv_shape
(
src
[
3
],
cflt
.
dilated_spatial
[
1
],
cflt
.
stride
[
1
],
cflt
.
padding
[
1
]);
dst
[
4
]
=
4
;
if
(
cflt
.
group
==
1
)
{
megdnn_assert
(
cflt
.
icpg
*
cflt
.
group
==
src
[
1
]
*
4
||
(
cflt
.
icpg
*
cflt
.
group
==
src
[
1
]),
"%s icpg=%u group=%u"
,
errmsg
().
c_str
(),
cflt
.
icpg
,
cflt
.
group
);
}
}
else
if
(
param
().
format
==
Param
::
Format
::
CHWN4
)
{
}
else
if
(
param
().
format
==
Param
::
Format
::
CHWN4
)
{
megdnn_assert
(
src
.
ndim
==
5
,
megdnn_assert
(
src
.
ndim
==
5
,
"invalid src ndim for CHWN4, expected=5, got=%zu"
,
"invalid src ndim for CHWN4, expected=5, got=%zu"
,
...
...
dnn/src/common/pooling.cpp
浏览文件 @
8ba8c11d
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "megdnn/oprs.h"
#include "megdnn/oprs.h"
...
@@ -47,6 +48,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
...
@@ -47,6 +48,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
spatial_pos
=
1
;
spatial_pos
=
1
;
c_pos
=
3
;
c_pos
=
3
;
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW4
||
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
NCHW88
||
param
().
format
==
Param
::
Format
::
NCHW88
||
param
().
format
==
Param
::
Format
::
NCHW32
)
{
param
().
format
==
Param
::
Format
::
NCHW32
)
{
megdnn_assert
(
src
.
ndim
==
5
_z
,
"%s"
,
errmsg_c
);
megdnn_assert
(
src
.
ndim
==
5
_z
,
"%s"
,
errmsg_c
);
...
@@ -73,6 +75,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
...
@@ -73,6 +75,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
iw
=
src
[
spatial_pos
+
2
];
iw
=
src
[
spatial_pos
+
2
];
}
}
if
(
param
().
format
==
Param
::
Format
::
NCHW4
||
if
(
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
CHWN4
)
{
param
().
format
==
Param
::
Format
::
CHWN4
)
{
c
*=
4
;
c
*=
4
;
}
}
...
@@ -96,7 +99,8 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
...
@@ -96,7 +99,8 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
megdnn_assert
(
param
().
format
==
Param
::
Format
::
NHWC
,
megdnn_assert
(
param
().
format
==
Param
::
Format
::
NHWC
,
"invalid pooling format"
);
"invalid pooling format"
);
dst
=
TensorLayout
({
n
,
oh
,
ow
,
c
},
src
.
dtype
,
src
.
format
);
dst
=
TensorLayout
({
n
,
oh
,
ow
,
c
},
src
.
dtype
,
src
.
format
);
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW4
)
{
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW44
)
{
dst
=
TensorLayout
{{
n
,
c
/
4
,
oh
,
ow
,
4
},
src
.
dtype
,
src
.
format
};
dst
=
TensorLayout
{{
n
,
c
/
4
,
oh
,
ow
,
4
},
src
.
dtype
,
src
.
format
};
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW88
)
{
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW88
)
{
dst
=
TensorLayout
{{
n
,
c
/
8
,
oh
,
ow
,
8
},
src
.
dtype
,
src
.
format
};
dst
=
TensorLayout
{{
n
,
c
/
8
,
oh
,
ow
,
8
},
src
.
dtype
,
src
.
format
};
...
...
dnn/src/fallback/conv_bias/opr_impl.cpp
浏览文件 @
8ba8c11d
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "src/fallback/convolution/opr_impl.h"
#include "src/fallback/convolution/opr_impl.h"
#include "src/common/algo_chooser.h"
#include "src/common/algo_chooser.h"
...
@@ -157,9 +158,11 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
...
@@ -157,9 +158,11 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
if
(
param
().
format
==
Param
::
Format
::
NCHW88
||
if
(
param
().
format
==
Param
::
Format
::
NCHW88
||
param
().
format
==
Param
::
Format
::
NCHW8
||
param
().
format
==
Param
::
Format
::
NCHW8
||
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
NCHW
||
param
().
format
==
Param
::
Format
::
NCHW
||
param
().
format
==
Param
::
Format
::
NCHW_WINOGRAD
||
param
().
format
==
Param
::
Format
::
NCHW_WINOGRAD
||
param
().
format
==
Param
::
Format
::
NCHW88_WINOGRAD
)
{
param
().
format
==
Param
::
Format
::
NCHW88_WINOGRAD
||
param
().
format
==
Param
::
Format
::
NCHW44_WINOGRAD
)
{
spatial_pos
=
2
;
spatial_pos
=
2
;
}
else
if
(
param
().
format
==
Param
::
Format
::
NHWC
)
{
}
else
if
(
param
().
format
==
Param
::
Format
::
NHWC
)
{
spatial_pos
=
1
;
spatial_pos
=
1
;
...
@@ -188,7 +191,8 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
...
@@ -188,7 +191,8 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
param
::
MatrixMul
::
Format
format
=
param
::
MatrixMul
::
Format
::
DEFAULT
;
param
::
MatrixMul
::
Format
format
=
param
::
MatrixMul
::
Format
::
DEFAULT
;
if
(
param
().
format
==
Param
::
Format
::
NCHW_WINOGRAD
||
if
(
param
().
format
==
Param
::
Format
::
NCHW_WINOGRAD
||
param
().
format
==
Param
::
Format
::
NCHW88_WINOGRAD
)
{
param
().
format
==
Param
::
Format
::
NCHW88_WINOGRAD
||
param
().
format
==
Param
::
Format
::
NCHW44_WINOGRAD
)
{
size_t
flt_start
=
0
;
size_t
flt_start
=
0
;
if
(
param
().
sparse
==
Param
::
Sparse
::
GROUP
)
{
if
(
param
().
sparse
==
Param
::
Sparse
::
GROUP
)
{
flt_start
=
1
;
flt_start
=
1
;
...
@@ -325,7 +329,7 @@ const char* ConvBiasImpl::get_algorithm_set_name() const {
...
@@ -325,7 +329,7 @@ const char* ConvBiasImpl::get_algorithm_set_name() const {
return
"F0"
;
return
"F0"
;
}
}
namespace
megdnn
{
namespace
megdnn
{
namespace
fallback
{
namespace
fallback
{
template
<
typename
T
>
template
<
typename
T
>
...
@@ -342,7 +346,6 @@ const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_pack_id,
...
@@ -342,7 +346,6 @@ const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_pack_id,
batch_offset
+
group_offset
+
channel_offset
);
batch_offset
+
group_offset
+
channel_offset
);
}
}
template
<
typename
T
>
template
<
typename
T
>
const
T
*
ConvBiasImpl
::
NCBKernParam
::
filter
(
size_t
group_pack_id
,
const
T
*
ConvBiasImpl
::
NCBKernParam
::
filter
(
size_t
group_pack_id
,
size_t
pack_group_size
)
const
{
size_t
pack_group_size
)
const
{
...
@@ -453,5 +456,4 @@ INST(void)
...
@@ -453,5 +456,4 @@ INST(void)
}
// namespace fallback
}
// namespace fallback
}
// namespace megdnn
}
// namespace megdnn
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/fallback/conv_bias/winograd/winograd.h
浏览文件 @
8ba8c11d
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#pragma once
#pragma once
...
@@ -87,7 +88,9 @@ class ConvBias {
...
@@ -87,7 +88,9 @@ class ConvBias {
if
(
param
.
filter_meta
.
format
!=
if
(
param
.
filter_meta
.
format
!=
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
&&
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
&&
param
.
filter_meta
.
format
!=
param
.
filter_meta
.
format
!=
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
)
{
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
&&
param
.
filter_meta
.
format
!=
param
::
ConvBias
::
Format
::
NCHW44_WINOGRAD
)
{
filter_transform_buf_size
=
Strategy
::
ALPHA
*
Strategy
::
ALPHA
*
OC
*
filter_transform_buf_size
=
Strategy
::
ALPHA
*
Strategy
::
ALPHA
*
OC
*
IC
*
sizeof
(
input_filter_compute_type
);
IC
*
sizeof
(
input_filter_compute_type
);
}
}
...
@@ -95,7 +98,8 @@ class ConvBias {
...
@@ -95,7 +98,8 @@ class ConvBias {
get_wbundle_compute
(
param
,
matmul_algo
).
total_size_in_bytes
()
*
get_wbundle_compute
(
param
,
matmul_algo
).
total_size_in_bytes
()
*
nr_threads
;
nr_threads
;
if
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW
||
if
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
)
{
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW44
)
{
return
WorkspaceBundle
(
return
WorkspaceBundle
(
nullptr
,
nullptr
,
{
winograd_comput_size
,
filter_transform_buf_size
*
GROUP
});
{
winograd_comput_size
,
filter_transform_buf_size
*
GROUP
});
...
@@ -103,7 +107,9 @@ class ConvBias {
...
@@ -103,7 +107,9 @@ class ConvBias {
megdnn_assert
(
param
.
filter_meta
.
format
==
megdnn_assert
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
||
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
||
param
.
filter_meta
.
format
==
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
);
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW44_WINOGRAD
);
return
WorkspaceBundle
(
nullptr
,
{
winograd_comput_size
});
return
WorkspaceBundle
(
nullptr
,
{
winograd_comput_size
});
}
}
}
}
...
@@ -210,11 +216,17 @@ public:
...
@@ -210,11 +216,17 @@ public:
reinterpret_cast
<
input_filter_compute_type
*>
(
reinterpret_cast
<
input_filter_compute_type
*>
(
reinterpret_cast
<
uintptr_t
>
(
bundle_compute
.
get
(
2
))
+
reinterpret_cast
<
uintptr_t
>
(
bundle_compute
.
get
(
2
))
+
compute_workspace_size_per_thread
*
thread_id
);
compute_workspace_size_per_thread
*
thread_id
);
const
stype
*
filter_ptr
=
kern_param
.
filter
<
stype
>
(
group_id
);
const
stype
*
filter_ptr
=
kern_param
.
filter
<
stype
>
(
group_id
);
size_t
oc_start
=
oc_id
,
oc_end
=
oc_id
+
1
;
size_t
oc_start
=
oc_id
,
oc_end
=
oc_id
+
1
;
if
(
kern_param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
)
{
if
(
kern_param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
)
{
oc_start
=
8
*
oc_id
;
oc_start
=
8
*
oc_id
;
oc_end
=
oc_start
+
8
;
oc_end
=
oc_start
+
8
;
}
else
if
(
kern_param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW44
)
{
oc_start
=
4
*
oc_id
;
oc_end
=
oc_start
+
4
;
}
}
strategy
.
filter
(
filter_ptr
,
filter_transform_buf
,
transform_mid_buf
,
OC
,
strategy
.
filter
(
filter_ptr
,
filter_transform_buf
,
transform_mid_buf
,
OC
,
IC
,
oc_start
,
oc_end
);
IC
,
oc_start
,
oc_end
);
...
@@ -279,7 +291,8 @@ public:
...
@@ -279,7 +291,8 @@ public:
static_cast
<
const
input_filter_compute_type
*>
(
static_cast
<
const
input_filter_compute_type
*>
(
ncb_param
.
filter
<
input_filter_compute_type
>
(
group_id
));
ncb_param
.
filter
<
input_filter_compute_type
>
(
group_id
));
if
(
ncb_param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW
||
if
(
ncb_param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW
||
ncb_param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
)
{
ncb_param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
||
ncb_param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW44
)
{
filter_transform_buf
=
reinterpret_cast
<
input_filter_compute_type
*>
(
filter_transform_buf
=
reinterpret_cast
<
input_filter_compute_type
*>
(
reinterpret_cast
<
uintptr_t
>
(
bundle_top
.
get
(
1
))
+
reinterpret_cast
<
uintptr_t
>
(
bundle_top
.
get
(
1
))
+
group_id
*
filter_group_size
);
group_id
*
filter_group_size
);
...
@@ -404,14 +417,18 @@ public:
...
@@ -404,14 +417,18 @@ public:
param
.
filter_meta
.
stride
[
1
]
==
1
&&
param
.
filter_meta
.
stride
[
1
]
==
1
&&
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW
||
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW44
||
param
.
filter_meta
.
format
==
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
||
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
||
param
.
filter_meta
.
format
==
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
));
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW44_WINOGRAD
));
SmallVector
<
NCBKern
>
kerns
;
SmallVector
<
NCBKern
>
kerns
;
if
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW
||
if
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
)
{
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW44
)
{
//! probably a gcc bug, labmda require capturing 'this' to call
//! probably a gcc bug, labmda require capturing 'this' to call
//! static member function
//! static member function
auto
filter_process_kern
=
[
this
,
strategy
,
bundle_top
,
auto
filter_process_kern
=
[
this
,
strategy
,
bundle_top
,
...
@@ -426,6 +443,10 @@ public:
...
@@ -426,6 +443,10 @@ public:
if
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
)
{
if
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
)
{
megdnn_assert
(
OC
%
8
==
0
);
megdnn_assert
(
OC
%
8
==
0
);
oc_parallelism
=
OC
/
8
;
oc_parallelism
=
OC
/
8
;
}
else
if
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW44
)
{
megdnn_assert
(
OC
%
4
==
0
);
oc_parallelism
=
OC
/
4
;
}
}
kerns
.
push_back
({
filter_process_kern
,
{
GROUP
,
1
,
oc_parallelism
}});
kerns
.
push_back
({
filter_process_kern
,
{
GROUP
,
1
,
oc_parallelism
}});
}
}
...
...
dnn/src/fallback/convolution/opr_impl.cpp
浏览文件 @
8ba8c11d
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "src/fallback/convolution/opr_impl.h"
#include "src/fallback/convolution/opr_impl.h"
...
@@ -142,7 +143,8 @@ ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param(
...
@@ -142,7 +143,8 @@ ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param(
size_t
spatial_pos
;
size_t
spatial_pos
;
if
(
param
().
format
==
Param
::
Format
::
NCHW88
||
if
(
param
().
format
==
Param
::
Format
::
NCHW88
||
param
().
format
==
Param
::
Format
::
NCHW8
||
param
().
format
==
Param
::
Format
::
NCHW8
||
param
().
format
==
Param
::
Format
::
NCHW4
)
{
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW44
)
{
spatial_pos
=
2
;
spatial_pos
=
2
;
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW
||
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW
||
param
().
format
==
Param
::
Format
::
NCHW_WINOGRAD
)
{
param
().
format
==
Param
::
Format
::
NCHW_WINOGRAD
)
{
...
...
dnn/src/naive/convolution/helper.h
浏览文件 @
8ba8c11d
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#pragma once
#pragma once
...
@@ -145,6 +146,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
...
@@ -145,6 +146,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
using
Format
=
param
::
Convolution
::
Format
;
using
Format
=
param
::
Convolution
::
Format
;
if
(
filter_meta
.
format
==
Format
::
NCHW
||
if
(
filter_meta
.
format
==
Format
::
NCHW
||
filter_meta
.
format
==
Format
::
NCHW88
||
filter_meta
.
format
==
Format
::
NCHW88
||
filter_meta
.
format
==
Format
::
NCHW44
||
filter_meta
.
format
==
Format
::
NCHW4
||
filter_meta
.
format
==
Format
::
NCHW4
||
filter_meta
.
format
==
Format
::
NCHW8
||
filter_meta
.
format
==
Format
::
NCHW8
||
filter_meta
.
format
==
Format
::
NCHW32
)
{
filter_meta
.
format
==
Format
::
NCHW32
)
{
...
@@ -171,7 +173,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
...
@@ -171,7 +173,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
OW
=
dst
.
layout
.
shape
[
spatial_start
+
1
];
OW
=
dst
.
layout
.
shape
[
spatial_start
+
1
];
if
(
filter_meta
.
format
==
Format
::
NCHW4
||
if
(
filter_meta
.
format
==
Format
::
NCHW4
||
filter_meta
.
format
==
Format
::
CHWN4
)
{
filter_meta
.
format
==
Format
::
CHWN4
||
filter_meta
.
format
==
Format
::
NCHW44
)
{
OC
*=
4
;
OC
*=
4
;
}
else
if
(
filter_meta
.
format
==
Format
::
NCHW8
||
}
else
if
(
filter_meta
.
format
==
Format
::
NCHW8
||
filter_meta
.
format
==
Format
::
NCHW88
)
{
filter_meta
.
format
==
Format
::
NCHW88
)
{
...
@@ -216,6 +219,26 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
...
@@ -216,6 +219,26 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
FS_G
=
FS_OC
*
filter_meta
.
ocpg
/
8
;
FS_G
=
FS_OC
*
filter_meta
.
ocpg
/
8
;
}
}
}
}
}
else
if
(
filter_meta
.
format
==
Format
::
NCHW44
)
{
if
(
filter_meta
.
group
>
1
&&
filter_meta
.
icpg
==
1
&&
src
.
layout
.
ndim
==
5
&&
filter_meta
.
ocpg
==
1
)
{
FS_SPATIAL
=
4
;
FS_IC
=
FH
*
FW
*
FS_SPATIAL
;
FS_OC
=
FS_IC
*
filter_meta
.
icpg
;
FS_G
=
FS_OC
*
filter_meta
.
ocpg
;
}
else
{
if
(
src
.
layout
.
ndim
==
4
&&
dst
.
layout
.
ndim
==
5
)
{
FS_IC
=
4
;
FS_SPATIAL
=
filter_meta
.
icpg
*
FS_IC
;
FS_OC
=
FH
*
FW
*
FS_SPATIAL
;
FS_G
=
FS_OC
*
filter_meta
.
ocpg
/
4
;
}
else
{
FS_SPATIAL
=
4
*
4
;
FS_IC
=
FH
*
FW
*
FS_SPATIAL
;
FS_OC
=
FS_IC
*
filter_meta
.
icpg
/
4
;
FS_G
=
FS_OC
*
filter_meta
.
ocpg
/
4
;
}
}
}
else
{
}
else
{
// g, oc, fh, fw, ic
// g, oc, fh, fw, ic
megdnn_assert
(
filter_meta
.
format
==
Format
::
NHWC
);
megdnn_assert
(
filter_meta
.
format
==
Format
::
NHWC
);
...
@@ -259,6 +282,16 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
...
@@ -259,6 +282,16 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
h
*
layout
.
stride
[
2
]
+
w
*
layout
.
stride
[
3
]
+
h
*
layout
.
stride
[
2
]
+
w
*
layout
.
stride
[
3
]
+
(
c
&
0b111
)
*
layout
.
stride
[
4
];
(
c
&
0b111
)
*
layout
.
stride
[
4
];
}
}
}
else
if
(
filter_meta
.
format
==
Format
::
NCHW44
)
{
if
(
filter_meta
.
format
==
Format
::
NCHW44
&&
!
is_output
&&
src
.
layout
.
ndim
==
4
)
{
return
n
*
layout
.
stride
[
0
]
+
c
*
layout
.
stride
[
1
]
+
h
*
layout
.
stride
[
2
]
+
w
*
layout
.
stride
[
3
];
}
else
{
return
n
*
layout
.
stride
[
0
]
+
(
c
/
4
)
*
layout
.
stride
[
1
]
+
h
*
layout
.
stride
[
2
]
+
w
*
layout
.
stride
[
3
]
+
(
c
%
4
)
*
layout
.
stride
[
4
];
}
}
else
if
(
filter_meta
.
format
==
Format
::
NCHW32
)
{
}
else
if
(
filter_meta
.
format
==
Format
::
NCHW32
)
{
return
n
*
layout
.
stride
[
0
]
+
(
c
>>
5
)
*
layout
.
stride
[
1
]
+
return
n
*
layout
.
stride
[
0
]
+
(
c
>>
5
)
*
layout
.
stride
[
1
]
+
h
*
layout
.
stride
[
2
]
+
w
*
layout
.
stride
[
3
]
+
h
*
layout
.
stride
[
2
]
+
w
*
layout
.
stride
[
3
]
+
...
@@ -315,6 +348,27 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
...
@@ -315,6 +348,27 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
megdnn_assert
(
megdnn_assert
(
0
,
"nchw88 naive not support this input and output
\n
"
);
0
,
"nchw88 naive not support this input and output
\n
"
);
}
}
}
else
if
(
filter_meta
.
format
==
Format
::
NCHW44
)
{
if
(
src
.
layout
.
ndim
==
4
)
{
// ic < 8, input is nchw
return
gc_out
.
cur_grp
*
FS_G
+
gc_out
.
cur_off
/
4
*
FS_OC
+
(
fh
*
FW
+
fw
)
*
FS_SPATIAL
+
(
ic
-
ic0
)
*
FS_IC
+
gc_out
.
cur_off
%
4
;
}
else
if
(
filter_meta
.
group
>
1
&&
filter_meta
.
icpg
==
1
&&
filter_meta
.
ocpg
==
1
&&
src
.
layout
.
ndim
==
5
)
{
// dw case
return
gc_out
.
cur_grp
/
4
*
FS_G
+
gc_out
.
cur_off
*
FS_OC
+
(
ic
-
ic0
)
*
FS_IC
+
(
fh
*
FW
+
fw
)
*
FS_SPATIAL
+
gc_out
.
cur_grp
%
4
;
}
else
if
(
src
.
layout
.
ndim
==
5
)
{
// normal case
return
gc_out
.
cur_grp
*
FS_G
+
gc_out
.
cur_off
/
4
*
FS_OC
+
(
ic
-
ic0
)
/
4
*
FS_IC
+
(
fh
*
FW
+
fw
)
*
FS_SPATIAL
+
((
ic
-
ic0
)
%
4
)
*
4
+
gc_out
.
cur_off
%
4
;
}
else
{
megdnn_assert
(
0
,
"nchw44 naive not support this input and output
\n
"
);
}
}
else
{
}
else
{
return
gc_out
.
cur_grp
*
FS_G
+
gc_out
.
cur_off
*
FS_OC
+
return
gc_out
.
cur_grp
*
FS_G
+
gc_out
.
cur_off
*
FS_OC
+
(
ic
-
ic0
)
*
FS_IC
+
(
fh
*
FW
+
fw
)
*
FS_SPATIAL
;
(
ic
-
ic0
)
*
FS_IC
+
(
fh
*
FW
+
fw
)
*
FS_SPATIAL
;
...
@@ -504,6 +558,7 @@ void forward(_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst,
...
@@ -504,6 +558,7 @@ void forward(_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst,
megdnn_assert
(
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NCHW
||
megdnn_assert
(
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NCHW
||
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NHWC
||
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NHWC
||
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NCHW88
||
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NCHW88
||
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NCHW44
||
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NCHW4
);
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NCHW4
);
compute2d
<
stype
,
ftype
,
dtype
,
comp_type
,
StrategyFwd
>
(
compute2d
<
stype
,
ftype
,
dtype
,
comp_type
,
StrategyFwd
>
(
src
,
const_cast
<
ftype
*>
(
fptr
),
dst
,
filter_meta
);
src
,
const_cast
<
ftype
*>
(
fptr
),
dst
,
filter_meta
);
...
@@ -557,6 +612,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
...
@@ -557,6 +612,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
switch
(
filter_meta
.
format
)
{
switch
(
filter_meta
.
format
)
{
case
param
::
Convolution
::
Format
::
NCHW
:
case
param
::
Convolution
::
Format
::
NCHW
:
case
param
::
Convolution
::
Format
::
NCHW88
:
case
param
::
Convolution
::
Format
::
NCHW88
:
case
param
::
Convolution
::
Format
::
NCHW44
:
case
param
::
Convolution
::
Format
::
NHWC
:
case
param
::
Convolution
::
Format
::
NHWC
:
case
param
::
Convolution
::
Format
::
NCHW4
:
case
param
::
Convolution
::
Format
::
NCHW4
:
case
param
::
Convolution
::
Format
::
NCHW8
:
case
param
::
Convolution
::
Format
::
NCHW8
:
...
@@ -633,6 +689,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
...
@@ -633,6 +689,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
} \
} \
} \
} \
} while (0)
} while (0)
case
Format
::
NCHW44
:
case
Format
::
NCHW4
:
{
case
Format
::
NCHW4
:
{
BIAS_ADD_NCHWx
(
4
);
BIAS_ADD_NCHWx
(
4
);
break
;
break
;
...
...
dnn/src/naive/pooling/opr_impl.cpp
浏览文件 @
8ba8c11d
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "src/naive/pooling/opr_impl.h"
#include "src/naive/pooling/opr_impl.h"
...
@@ -168,6 +169,13 @@ struct NCHW88IdxGetter {
...
@@ -168,6 +169,13 @@ struct NCHW88IdxGetter {
return
id
;
return
id
;
}
}
};
};
struct
NCHW44IdxGetter
{
static
size_t
get_idx
(
size_t
n
,
size_t
c
,
size_t
h
,
size_t
w
,
size_t
,
size_t
C
,
size_t
H
,
size_t
W
)
{
size_t
id
=
(((
n
*
(
C
>>
2
)
+
(
c
>>
2
))
*
H
+
h
)
*
W
+
w
)
*
4
+
(
c
%
4
);
return
id
;
}
};
struct
CHWN4IdxGetter
{
struct
CHWN4IdxGetter
{
static
size_t
get_idx
(
size_t
n
,
size_t
c
,
size_t
h
,
size_t
w
,
size_t
N
,
static
size_t
get_idx
(
size_t
n
,
size_t
c
,
size_t
h
,
size_t
w
,
size_t
N
,
...
@@ -375,6 +383,7 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
...
@@ -375,6 +383,7 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
if
(
param
().
format
==
Param
::
Format
::
NCHW
||
if
(
param
().
format
==
Param
::
Format
::
NCHW
||
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW88
||
param
().
format
==
Param
::
Format
::
NCHW88
||
param
().
format
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
NCHW32
)
{
param
().
format
==
Param
::
Format
::
NCHW32
)
{
c_pos
=
1
;
c_pos
=
1
;
spatial_pos
=
2
;
spatial_pos
=
2
;
...
@@ -401,6 +410,7 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
...
@@ -401,6 +410,7 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
OW
=
dst
.
layout
.
shape
[
spatial_pos
+
2
];
OW
=
dst
.
layout
.
shape
[
spatial_pos
+
2
];
}
}
if
(
param
().
format
==
Param
::
Format
::
NCHW4
||
if
(
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
CHWN4
)
{
param
().
format
==
Param
::
Format
::
CHWN4
)
{
C
*=
4
;
C
*=
4
;
}
}
...
@@ -437,6 +447,9 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
...
@@ -437,6 +447,9 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
case Param::Format::NCHW88: \
case Param::Format::NCHW88: \
DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW88IdxGetter); \
DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW88IdxGetter); \
break; \
break; \
case Param::Format::NCHW44: \
DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW44IdxGetter); \
break; \
case Param::Format::NCHW32: \
case Param::Format::NCHW32: \
DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW32IdxGetter); \
DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW32IdxGetter); \
break; \
break; \
...
...
dnn/test/naive/conv_bias.cpp
浏览文件 @
8ba8c11d
...
@@ -6,13 +6,14 @@
...
@@ -6,13 +6,14 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "test/naive/fixture.h"
#include "megdnn/oprs/nn.h"
#include "megdnn/oprs/nn.h"
#include "test/common/checker.h"
#include "test/common/checker.h"
#include "test/common/workspace_wrapper.h"
#include "test/common/workspace_wrapper.h"
#include "test/naive/fixture.h"
using
namespace
megdnn
;
using
namespace
megdnn
;
using
namespace
test
;
using
namespace
test
;
...
@@ -35,55 +36,39 @@ private:
...
@@ -35,55 +36,39 @@ private:
}
// namespace
}
// namespace
TEST_F
(
NAIVE
,
CONV_BIAS_QUANTIZED8x8x32
)
{
TEST_F
(
NAIVE
,
CONV_BIAS_QUANTIZED8x8x32
)
{
Checker
<
ConvBias
>
checker
(
handle
(),
/* check_dispatch */
false
);
Checker
<
ConvBias
>
checker
(
handle
(),
/* check_dispatch */
false
);
ConvBias
::
Param
param
;
ConvBias
::
Param
param
;
param
.
format
=
ConvBias
::
Param
::
Format
::
NCHW
;
param
.
format
=
ConvBias
::
Param
::
Format
::
NCHW
;
checker
.
set_param
(
param
).
exect
(
checker
.
set_param
(
param
).
exect
(
Testcase
{
Testcase
{
TensorValue
({
1
,
1
,
4
,
4
},
dtype
::
QuantizedS8
(
0.1
f
),
TensorValue
({
1
,
1
,
4
,
4
},
dtype
::
QuantizedS8
(
0.1
f
),
{
90
-
128
,
136
-
128
,
85
-
128
,
204
-
128
,
{
90
-
128
,
136
-
128
,
85
-
128
,
204
-
128
,
48
-
128
,
9
-
128
,
226
-
128
,
25
-
128
,
48
-
128
,
9
-
128
,
226
-
128
,
25
-
128
,
118
-
128
,
109
-
128
,
87
-
128
,
132
-
128
,
118
-
128
,
109
-
128
,
87
-
128
,
132
-
128
,
104
-
128
,
163
-
128
,
25
-
128
,
90
-
128
}),
104
-
128
,
163
-
128
,
25
-
128
,
90
-
128
}),
TensorValue
({
3
,
1
,
3
,
3
},
dtype
::
QuantizedS8
(
0.2
f
),
TensorValue
({
3
,
1
,
3
,
3
},
dtype
::
QuantizedS8
(
0.2
f
),
{
153
-
124
,
170
-
124
,
102
-
124
,
103
-
124
,
{
153
-
124
,
170
-
124
,
102
-
124
,
23
-
124
,
213
-
124
,
116
-
124
,
195
-
124
,
103
-
124
,
23
-
124
,
213
-
124
,
191
-
124
,
44
-
124
,
50
-
124
,
247
-
124
,
116
-
124
,
195
-
124
,
191
-
124
,
172
-
124
,
42
-
124
,
32
-
124
,
233
-
124
,
163
-
124
,
247
-
124
,
120
-
124
,
241
-
124
,
44
-
124
,
50
-
124
,
247
-
124
,
209
-
124
,
83
-
124
,
201
-
124
,
115
-
124
,
172
-
124
,
42
-
124
,
32
-
124
,
32
-
124
,
140
-
124
,
147
-
124
}),
233
-
124
,
163
-
124
,
247
-
124
,
TensorValue
({
1
,
3
,
1
,
1
},
dtype
::
QuantizedS32
(
0.02
f
),
{
0
,
0
,
0
}),
120
-
124
,
241
-
124
,
209
-
124
,
TensorValue
({
1
,
3
,
2
,
2
},
dtype
::
QuantizedS32
(
0.3
f
),
83
-
124
,
201
-
124
,
115
-
124
,
{
1234
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
-
234
,
0
,
0
}),
32
-
124
,
140
-
124
,
147
-
124
}),
{}},
TensorValue
({
1
,
3
,
1
,
1
},
dtype
::
QuantizedS32
(
0.02
f
),
Testcase
{{},
{
0
,
0
,
0
}),
{},
TensorValue
({
1
,
3
,
2
,
2
},
dtype
::
QuantizedS32
(
0.3
f
),
{},
{
1234
,
0
,
{},
0
,
0
,
TensorValue
({
1
,
3
,
2
,
2
},
dtype
::
QuantizedS32
(
0.1
f
*
0.2
f
),
{
37127
,
-
22475
,
-
15694
,
-
1920
,
0
,
0
,
0
,
0
,
-
12813
,
4440
,
18190
,
-
13195
,
0
,
-
234
,
-
9659
,
12423
,
-
5558
,
-
4969
})});
0
,
0
}),
{}},
Testcase
{
{},
{},
{},
{},
TensorValue
({
1
,
3
,
2
,
2
},
dtype
::
QuantizedS32
(
0.1
f
*
0.2
f
),
{
37127
,
-
22475
,
-
15694
,
-
1920
,
-
12813
,
4440
,
18190
,
-
13195
,
-
9659
,
12423
,
-
5558
,
-
4969
})});
}
}
TEST_F
(
NAIVE
,
CONV_BIAS_QUANTIZED4x4x32
)
{
TEST_F
(
NAIVE
,
CONV_BIAS_QUANTIZED4x4x32
)
{
...
@@ -175,10 +160,8 @@ TEST_F(NAIVE, CONV_BIAS_QUANTIZED4x4x32) {
...
@@ -175,10 +160,8 @@ TEST_F(NAIVE, CONV_BIAS_QUANTIZED4x4x32) {
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
}),
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
}),
TensorValue
(
TensorValue
(
{
1
,
1
,
2
,
2
,
8
},
dtype
::
QuantizedS32
(
0.3
f
),
{
1
,
1
,
2
,
2
,
8
},
dtype
::
QuantizedS32
(
0.3
f
),
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
-
87
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
}),
0
,
0
,
0
,
0
,
0
,
0
,
-
87
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
}),
{}},
{}},
Testcase
{
Testcase
{
{},
{},
...
@@ -316,8 +299,221 @@ TEST_F(NAIVE, CONV_BIAS_QUANTIZED8x8x32_NCHW32) {
...
@@ -316,8 +299,221 @@ TEST_F(NAIVE, CONV_BIAS_QUANTIZED8x8x32_NCHW32) {
TensorNDArray
{
src_ts_32
.
tensornd
(),
TensorNDArray
{
src_ts_32
.
tensornd
(),
filter_ts_32
.
tensornd
(),
filter_ts_32
.
tensornd
(),
bias_ts_32
.
tensornd
(),
bias_ts_32
.
tensornd
(),
z_ts_32
.
tensornd
(),
{}},
z_ts_32
.
tensornd
(),
{}},
TensorNDArray
{{},
{},
{},
{},
dst_ts_32
.
tensornd
()});
TensorNDArray
{{},
{},
{},
{},
dst_ts_32
.
tensornd
()});
}
}
TEST_F
(
NAIVE
,
CONV_BIAS_NCHW44
)
{
Checker
<
ConvBias
>
checker
(
handle
(),
/* check_dispatch */
false
);
ConvBias
::
Param
param
;
param
.
format
=
ConvBias
::
Param
::
Format
::
NCHW44
;
size_t
n
=
1
;
size_t
ic
=
4
;
size_t
oc
=
8
;
size_t
h
=
2
;
size_t
w
=
2
;
size_t
filter_size
=
3
;
size_t
pad
=
1
;
auto
src_tensor_shape
=
TensorShape
{
n
,
ic
/
4
,
h
,
w
,
4
};
auto
weight_tensor_shape
=
TensorShape
{
oc
/
4
,
ic
/
4
,
filter_size
,
filter_size
,
4
,
4
};
auto
bias_tensor_shape
=
TensorShape
{
1
,
oc
/
4
,
1
,
1
,
4
};
param
.
pad_h
=
pad
;
param
.
pad_w
=
pad
;
UniformIntRNG
rng
{
-
127
,
127
};
checker
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_dtype
(
2
,
dtype
::
Float32
())
.
set_dtype
(
4
,
dtype
::
Float32
())
.
set_rng
(
0
,
&
rng
)
.
set_rng
(
1
,
&
rng
)
.
set_rng
(
2
,
&
rng
)
.
set_epsilon
(
1e-3
)
.
set_param
(
param
)
.
execs
({
src_tensor_shape
,
weight_tensor_shape
,
bias_tensor_shape
,
{},
{}});
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
2.
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
3.
f
))
.
set_dtype
(
2
,
dtype
::
QuantizedS32
(
6.
f
))
.
set_dtype
(
4
,
dtype
::
QuantizedS32
(
6.
f
))
.
set_rng
(
0
,
&
rng
)
.
set_rng
(
1
,
&
rng
)
.
set_rng
(
2
,
&
rng
)
.
set_epsilon
(
1e-3
)
.
set_param
(
param
)
.
execs
({
src_tensor_shape
,
weight_tensor_shape
,
bias_tensor_shape
,
{},
{}});
{
// test normal conv
ConvBias
::
Param
param
;
param
.
format
=
ConvBias
::
Param
::
Format
::
NCHW44
;
param
.
sparse
=
ConvBias
::
Param
::
Sparse
::
DENSE
;
param
.
pad_h
=
1
;
param
.
pad_w
=
1
;
checker
.
set_param
(
param
).
exect
(
Testcase
{
TensorValue
({
1
,
1
,
2
,
2
,
4
},
dtype
::
Float32
(),
{
7
,
2
,
2
,
1
,
7
,
5
,
6
,
3
,
1
,
2
,
8
,
3
,
7
,
7
,
6
,
4
}),
TensorValue
(
{
1
,
1
,
3
,
3
,
4
,
4
},
dtype
::
Float32
(),
{
3
,
5
,
5
,
2
,
0
,
1
,
4
,
8
,
3
,
5
,
0
,
7
,
1
,
7
,
0
,
7
,
6
,
4
,
7
,
7
,
5
,
2
,
2
,
4
,
7
,
6
,
6
,
3
,
3
,
2
,
2
,
8
,
5
,
0
,
4
,
4
,
0
,
5
,
1
,
0
,
0
,
4
,
8
,
4
,
7
,
7
,
2
,
0
,
4
,
8
,
7
,
3
,
6
,
2
,
3
,
0
,
0
,
6
,
4
,
4
,
1
,
4
,
3
,
8
,
8
,
8
,
7
,
2
,
2
,
5
,
5
,
1
,
3
,
2
,
8
,
1
,
7
,
0
,
2
,
7
,
1
,
6
,
1
,
5
,
0
,
6
,
3
,
0
,
2
,
4
,
1
,
1
,
4
,
2
,
7
,
5
,
7
,
8
,
4
,
5
,
5
,
7
,
0
,
3
,
3
,
2
,
8
,
6
,
0
,
1
,
4
,
6
,
6
,
6
,
0
,
1
,
2
,
4
,
4
,
1
,
1
,
7
,
8
,
2
,
5
,
2
,
8
,
3
,
8
,
3
,
5
,
0
,
6
,
3
,
4
,
3
,
3
,
7
,
2
,
8
,
1
,
1
,
1
,
4
}),
TensorValue
({
1
,
1
,
1
,
1
,
4
},
dtype
::
Float32
(),
{
7
,
2
,
8
,
1
}),
TensorValue
({
1
,
1
,
2
,
2
,
4
},
dtype
::
Float32
(),
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
}),
{}},
Testcase
{
{},
{},
{},
{},
TensorValue
({
1
,
1
,
2
,
2
,
4
},
dtype
::
Float32
(),
{
264
,
338
,
309
,
195
,
276
,
332
,
390
,
199
,
224
,
268
,
311
,
218
,
288
,
311
,
346
,
277
})});
}
{
// test dw conv
ConvBias
::
Param
param
;
param
.
format
=
ConvBias
::
Param
::
Format
::
NCHW44
;
param
.
sparse
=
ConvBias
::
Param
::
Sparse
::
GROUP
;
param
.
pad_h
=
1
;
param
.
pad_w
=
1
;
checker
.
set_param
(
param
).
exect
(
Testcase
{
TensorValue
({
1
,
1
,
2
,
2
,
4
},
dtype
::
Float32
(),
{
5
,
8
,
3
,
2
,
4
,
6
,
1
,
5
,
0
,
8
,
2
,
6
,
8
,
6
,
5
,
7
}),
TensorValue
({
1
,
1
,
1
,
3
,
3
,
4
},
dtype
::
Float32
(),
{
3
,
0
,
3
,
1
,
6
,
5
,
7
,
3
,
5
,
0
,
0
,
7
,
4
,
6
,
0
,
1
,
8
,
2
,
3
,
7
,
1
,
0
,
2
,
4
,
7
,
5
,
3
,
0
,
6
,
2
,
1
,
5
,
8
,
6
,
3
,
1
}),
TensorValue
({
1
,
1
,
1
,
1
,
4
},
dtype
::
Float32
(),
{
4
,
3
,
5
,
6
}),
TensorValue
({
1
,
1
,
2
,
2
,
4
},
dtype
::
Float32
(),
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
}),
{}},
Testcase
{{},
{},
{},
{},
TensorValue
({
1
,
1
,
2
,
2
,
4
},
dtype
::
Float32
(),
{
112
,
71
,
33
,
77
,
104
,
115
,
19
,
78
,
62
,
59
,
42
,
117
,
107
,
93
,
36
,
78
})});
}
{
// test group conv
ConvBias
::
Param
param
;
param
.
format
=
ConvBias
::
Param
::
Format
::
NCHW44
;
param
.
sparse
=
ConvBias
::
Param
::
Sparse
::
GROUP
;
param
.
pad_h
=
1
;
param
.
pad_w
=
1
;
checker
.
set_param
(
param
).
exect
(
Testcase
{
TensorValue
({
1
,
2
,
2
,
2
,
4
},
dtype
::
Float32
(),
{
6
,
3
,
2
,
7
,
7
,
6
,
4
,
5
,
8
,
6
,
3
,
1
,
1
,
2
,
8
,
3
,
1
,
0
,
6
,
1
,
3
,
3
,
6
,
0
,
0
,
5
,
6
,
7
,
2
,
2
,
4
,
4
}),
TensorValue
(
{
2
,
1
,
1
,
3
,
3
,
4
,
4
},
dtype
::
Float32
(),
{
3
,
5
,
5
,
2
,
0
,
1
,
4
,
8
,
3
,
5
,
0
,
7
,
1
,
7
,
0
,
7
,
6
,
4
,
7
,
7
,
5
,
2
,
2
,
4
,
7
,
6
,
6
,
3
,
3
,
2
,
2
,
8
,
5
,
0
,
4
,
4
,
0
,
5
,
1
,
0
,
0
,
4
,
8
,
4
,
7
,
7
,
2
,
0
,
4
,
8
,
7
,
3
,
6
,
2
,
3
,
0
,
0
,
6
,
4
,
4
,
1
,
4
,
3
,
8
,
8
,
8
,
7
,
2
,
2
,
5
,
5
,
1
,
3
,
2
,
8
,
1
,
7
,
0
,
2
,
7
,
1
,
6
,
1
,
5
,
0
,
6
,
3
,
0
,
2
,
4
,
1
,
1
,
4
,
2
,
7
,
5
,
7
,
8
,
4
,
5
,
5
,
7
,
0
,
3
,
3
,
2
,
8
,
6
,
0
,
1
,
4
,
6
,
6
,
6
,
0
,
1
,
2
,
4
,
4
,
1
,
1
,
7
,
8
,
2
,
5
,
2
,
8
,
3
,
8
,
3
,
5
,
0
,
6
,
3
,
4
,
3
,
3
,
7
,
2
,
8
,
1
,
1
,
1
,
4
,
7
,
4
,
5
,
0
,
6
,
8
,
7
,
4
,
8
,
1
,
3
,
5
,
3
,
0
,
0
,
3
,
7
,
7
,
7
,
3
,
8
,
1
,
2
,
0
,
1
,
1
,
2
,
1
,
3
,
0
,
0
,
1
,
1
,
3
,
0
,
5
,
6
,
3
,
0
,
5
,
4
,
1
,
4
,
7
,
0
,
2
,
1
,
6
,
7
,
8
,
0
,
2
,
1
,
6
,
7
,
6
,
3
,
2
,
7
,
6
,
5
,
1
,
1
,
1
,
2
,
4
,
6
,
3
,
3
,
8
,
0
,
7
,
1
,
3
,
7
,
3
,
2
,
2
,
4
,
3
,
5
,
5
,
6
,
3
,
3
,
1
,
2
,
3
,
0
,
4
,
0
,
3
,
3
,
5
,
5
,
5
,
2
,
3
,
1
,
5
,
4
,
5
,
8
,
1
,
7
,
2
,
1
,
0
,
1
,
8
,
2
,
6
,
7
,
8
,
4
,
4
,
7
,
8
,
4
,
5
,
8
,
1
,
1
,
0
,
7
,
8
,
4
,
2
,
2
,
8
,
6
,
5
,
2
,
4
,
8
,
4
,
0
,
4
,
0
,
2
,
1
,
7
,
1
,
6
}),
TensorValue
({
1
,
2
,
1
,
1
,
4
},
dtype
::
Float32
(),
{
1
,
8
,
5
,
6
,
2
,
8
,
7
,
7
}),
TensorValue
({
1
,
2
,
2
,
2
,
4
},
dtype
::
Float32
(),
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
}),
{}},
Testcase
{
{},
{},
{},
{},
TensorValue
({
1
,
2
,
2
,
2
,
4
},
dtype
::
Float32
(),
{
260
,
342
,
244
,
241
,
293
,
385
,
362
,
257
,
278
,
301
,
303
,
226
,
273
,
306
,
318
,
307
,
180
,
244
,
169
,
156
,
210
,
244
,
206
,
167
,
126
,
165
,
156
,
207
,
191
,
141
,
209
,
172
})});
}
{
// test normal conv
ConvBias
::
Param
param
;
param
.
format
=
ConvBias
::
Param
::
Format
::
NCHW44
;
param
.
sparse
=
ConvBias
::
Param
::
Sparse
::
DENSE
;
param
.
pad_h
=
1
;
param
.
pad_w
=
1
;
checker
.
set_param
(
param
).
exect
(
Testcase
{
TensorValue
({
1
,
1
,
2
,
2
,
4
},
dtype
::
Int8
(),
{
7
,
2
,
2
,
1
,
7
,
5
,
6
,
3
,
1
,
2
,
8
,
3
,
7
,
7
,
6
,
4
}),
TensorValue
(
{
1
,
1
,
3
,
3
,
4
,
4
},
dtype
::
Int8
(),
{
3
,
5
,
5
,
2
,
0
,
1
,
4
,
8
,
3
,
5
,
0
,
7
,
1
,
7
,
0
,
7
,
6
,
4
,
7
,
7
,
5
,
2
,
2
,
4
,
7
,
6
,
6
,
3
,
3
,
2
,
2
,
8
,
5
,
0
,
4
,
4
,
0
,
5
,
1
,
0
,
0
,
4
,
8
,
4
,
7
,
7
,
2
,
0
,
4
,
8
,
7
,
3
,
6
,
2
,
3
,
0
,
0
,
6
,
4
,
4
,
1
,
4
,
3
,
8
,
8
,
8
,
7
,
2
,
2
,
5
,
5
,
1
,
3
,
2
,
8
,
1
,
7
,
0
,
2
,
7
,
1
,
6
,
1
,
5
,
0
,
6
,
3
,
0
,
2
,
4
,
1
,
1
,
4
,
2
,
7
,
5
,
7
,
8
,
4
,
5
,
5
,
7
,
0
,
3
,
3
,
2
,
8
,
6
,
0
,
1
,
4
,
6
,
6
,
6
,
0
,
1
,
2
,
4
,
4
,
1
,
1
,
7
,
8
,
2
,
5
,
2
,
8
,
3
,
8
,
3
,
5
,
0
,
6
,
3
,
4
,
3
,
3
,
7
,
2
,
8
,
1
,
1
,
1
,
4
}),
TensorValue
({
1
,
1
,
1
,
1
,
4
},
dtype
::
Int32
(),
{
7
,
2
,
8
,
1
}),
TensorValue
({
1
,
1
,
2
,
2
,
4
},
dtype
::
Int32
(),
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
}),
{}},
Testcase
{
{},
{},
{},
{},
TensorValue
({
1
,
1
,
2
,
2
,
4
},
dtype
::
Int32
(),
{
264
,
338
,
309
,
195
,
276
,
332
,
390
,
199
,
224
,
268
,
311
,
218
,
288
,
311
,
346
,
277
})});
}
}
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录