Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4a227083
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4a227083
编写于
8月 14, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
8月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb/fallback): fix conv1x1 and conv1x1_gemv nchw44 usable
GitOrigin-RevId: 90aa75d51e9ec52afd699a837b20ce605ae21971
上级
b778d225
变更
6
展开全部
隐藏空白更改
内联
并排
Showing
6 changed file
with
407 addition
and
321 deletion
+407
-321
dnn/src/arm_common/pooling/algo.cpp
dnn/src/arm_common/pooling/algo.cpp
+4
-4
dnn/src/fallback/conv_bias/conv1x1/algos.cpp
dnn/src/fallback/conv_bias/conv1x1/algos.cpp
+134
-236
dnn/src/fallback/conv_bias/conv1x1/algos.h
dnn/src/fallback/conv_bias/conv1x1/algos.h
+6
-1
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp
+46
-53
dnn/src/fallback/conv_bias/conv1x1/conv1x1_dispatcher.h
dnn/src/fallback/conv_bias/conv1x1/conv1x1_dispatcher.h
+201
-10
dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h
dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h
+16
-17
未找到文件。
dnn/src/arm_common/pooling/algo.cpp
浏览文件 @
4a227083
...
...
@@ -612,7 +612,7 @@ bool PoolingImpl::AlgoFilter3ModexStridexNCHW44::usable(
(
param
.
mode
==
Mode
::
MAX
||
param
.
mode
==
Mode
::
AVERAGE
)
&&
FH
==
3
&&
FW
==
3
&&
SW
==
SH
&&
(
SH
==
1
||
SW
==
2
);
//! Int8 not support average, because its round mode is different form
//! q
u
int8
//! qint8
avaible
&=
!
(
param
.
src_type
.
enumv
()
==
DTypeEnum
::
Int8
&&
param
.
mode
==
Mode
::
AVERAGE
);
return
avaible
;
...
...
@@ -705,7 +705,7 @@ bool PoolingImpl::AlgoFilter2ModexStridexNCHW44::usable(
(
param
.
mode
==
Mode
::
MAX
||
param
.
mode
==
Mode
::
AVERAGE
)
&&
FH
==
2
&&
FW
==
2
&&
SH
==
SW
&&
(
SW
==
1
||
SW
==
2
);
//! Int8 not support average, because its round mode is different form
//! q
u
int8
//! qint8
avaible
&=
!
(
param
.
src_type
.
enumv
()
==
DTypeEnum
::
Int8
&&
param
.
mode
==
Mode
::
AVERAGE
);
return
avaible
;
...
...
@@ -799,7 +799,7 @@ bool PoolingImpl::AlgoFilter4ModexStridexNCHW44::usable(
FH
==
4
&&
FW
==
4
&&
SH
==
SW
&&
(
SW
==
1
||
SW
==
2
);
//! Int8 not support average, because its round mode is different form
//! q
u
int8
//! qint8
avaible
&=
!
(
param
.
src_type
.
enumv
()
==
DTypeEnum
::
Int8
&&
param
.
mode
==
Mode
::
AVERAGE
);
return
avaible
;
...
...
@@ -892,7 +892,7 @@ bool PoolingImpl::AlgoFilter5ModexStridexNCHW44::usable(
(
param
.
mode
==
Mode
::
MAX
||
param
.
mode
==
Mode
::
AVERAGE
)
&&
FH
==
5
&&
FW
==
5
&&
SH
==
SW
&&
(
SW
==
1
||
SW
==
2
);
//! Int8 not support average, because its round mode is different form
//! q
u
int8
//! qint8
avaible
&=
!
(
param
.
src_type
.
enumv
()
==
DTypeEnum
::
Int8
&&
param
.
mode
==
Mode
::
AVERAGE
);
return
avaible
;
...
...
dnn/src/fallback/conv_bias/conv1x1/algos.cpp
浏览文件 @
4a227083
此差异已折叠。
点击以展开。
dnn/src/fallback/conv_bias/conv1x1/algos.h
浏览文件 @
4a227083
...
...
@@ -20,6 +20,11 @@ namespace megdnn {
namespace
fallback
{
class
ConvBiasImpl
::
AlgoConv1x1
final
:
public
AlgoBase
{
WorkspaceBundle
get_bundle_according_packmode
(
const
NCBKernSizeParam
&
param
)
const
;
SmallVector
<
NCBKern
>
get_kerns_according_packmode
(
const
NCBKernSizeParam
&
param
,
bool
weight_preprocess
)
const
;
public:
AlgoConv1x1
(
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
size_t
oc_block_size
)
:
m_matmul_algo
(
matmul_algo
),
m_oc_block_size
(
oc_block_size
)
{}
...
...
@@ -41,7 +46,7 @@ public:
const
NCBKernSizeParam
&
param
)
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
)
const
override
;
SmallVector
<
TensorLayout
>
deduce_preprocessed_filter_layout
(
const
NCBKernSizeParam
&
param
)
const
override
;
size_t
get_preprocess_workspace
(
...
...
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp
浏览文件 @
4a227083
...
...
@@ -360,23 +360,23 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
dt_uint8
,
PostprocessMode
::
QUANTIZED
,
"NCHW::GEMV::QUINT8x8x32_QUINT8"
_hash
);
break
;
//!no support nchw44 8x8x16
case
param
::
ConvBias
::
Format
::
NCHW44
:
cb1
(
param
::
ConvBias
::
Format
::
NCHW44
,
dt_float32
,
dt_float32
,
PostprocessMode
::
FLOAT
,
"NCHW44::GEMV::FLOAT"
_hash
);
cb
2
(
param
::
ConvBias
::
Format
::
NCHW44
,
dt_int8
,
dt_int32
,
dt_int32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
cb
3
(
param
::
ConvBias
::
Format
::
NCHW44
,
dt_int8
,
dt_int32
,
dt_int32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
ADD_BIA
S
,
"NCHW44::GEMV::INT8x8x32_INT32"
_hash
);
cb
2
(
param
::
ConvBias
::
Format
::
NCHW44
,
dtype
::
QuantizedS8
,
cb
3
(
param
::
ConvBias
::
Format
::
NCHW44
,
dtype
::
QuantizedS8
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
dt_int32
,
PostprocessMode
::
ADD_BIA
S
,
"NCHW44::GEMV::QINT8x8x32_QINT32"
_hash
);
cb2
(
param
::
ConvBias
::
Format
::
NCHW44
,
dtype
::
QuantizedS8
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS8
,
dt_int8
,
dt_int32
,
dt_int8
,
PostprocessMode
::
QUANTIZED
,
"NCHW44::GEMV::QINT8x8x32_QINT8"
_hash
);
break
;
//!no support nchw44-dot 8x8x16
case
param
::
ConvBias
::
Format
::
NCHW44_DOT
:
cb3
(
param
::
ConvBias
::
Format
::
NCHW44_DOT
,
dt_int8
,
dt_int32
,
dt_int32
,
dt_int8
,
dt_int32
,
dt_int32
,
...
...
@@ -420,81 +420,74 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(const NCBKernSizeParam& param,
MIDOUT_BEGIN
(
megdnn_fallback_conv1x1_gemv
,
midout_iv
(
"AlgoConv1x1Gemv::usable"
_hash
))
{
auto
format
=
param
.
filter_meta
.
format
;
#if MEGDNN_X86
if
(
format
!=
param
::
ConvBias
::
Format
::
NCHW
)
return
false
;
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
if
(
format
!=
param
::
ConvBias
::
Format
::
NCHW
&&
format
!=
param
::
ConvBias
::
Format
::
NCHW44
&&
format
!=
param
::
ConvBias
::
Format
::
NCHW44_DOT
)
return
false
;
#endif
//! whether 1x1
size_t
FH
=
param
.
filter_meta
.
spatial
[
0
],
FW
=
param
.
filter_meta
.
spatial
[
1
];
size_t
PH
=
param
.
filter_meta
.
padding
[
0
],
PW
=
param
.
filter_meta
.
padding
[
1
];
size_t
SH
=
param
.
filter_meta
.
stride
[
0
],
SW
=
param
.
filter_meta
.
stride
[
1
];
if
(
FH
!=
1
||
FW
!=
1
||
PH
||
PW
||
SH
!=
1
||
SW
!=
1
)
{
return
false
;
}
//! whether gemv
size_t
OH
=
param
.
osz
[
0
];
size_t
OW
=
param
.
osz
[
1
];
if
(
OH
*
OW
!=
1
)
{
//! whether gemv and 1x1
if
(
OH
*
OW
!=
1
||
FH
!=
1
||
FW
!=
1
||
PH
||
PW
||
SH
!=
1
||
SW
!=
1
)
{
return
false
;
}
//! even no naive support in gemv
if
((
param
.
src_type
.
enumv
()
==
param
.
filter_type
.
enumv
()
&&
param
.
src_type
.
enumv
()
==
DTypeEnum
::
Int16
)
&&
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int32
)
{
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
if
(
format
!=
param
::
ConvBias
::
Format
::
NCHW
&&
format
!=
param
::
ConvBias
::
Format
::
NCHW44
&&
format
!=
param
::
ConvBias
::
Format
::
NCHW44_DOT
)
{
return
false
;
}
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode
//! is identity otherwise return false mean that 8x8x32 and 8x8x16
//! not support PostProcess
if
(
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int16
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int32
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)
{
if
(
param
.
nonlineMode
!=
megdnn
::
NonlineMode
::
IDENTITY
)
{
return
false
;
}
}
//! supports a few dtypes
if
(
param
.
src_type
.
enumv
()
!=
param
.
filter_type
.
enumv
())
{
#else
if
(
format
!=
param
::
ConvBias
::
Format
::
NCHW
)
{
return
false
;
}
if
(
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
Int8
&&
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
QuantizedS8
&&
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
Quantized8Asymm
&&
#endif
//! supports a few dtypes
if
(
param
.
src_type
.
enumv
()
!=
param
.
filter_type
.
enumv
()
||
(
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
Int8
&&
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
QuantizedS8
&&
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
Quantized8Asymm
&&
#if !MEGDNN_DISABLE_FLOAT16
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
Float16
&&
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
Float16
&&
#endif
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
Float32
)
{
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
Float32
)
)
{
return
false
;
}
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
if
(
format
==
param
::
ConvBias
::
Format
::
NCHW44
)
{
if
(
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
Float32
&&
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
Int8
&&
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
QuantizedS8
)
{
return
false
;
}
//! 8x8x16 is not support nchw44
if
(
param
.
src_type
.
enumv
()
==
DTypeEnum
::
Int8
&&
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int16
)
{
return
false
;
}
}
else
if
(
format
==
param
::
ConvBias
::
Format
::
NCHW44_DOT
)
{
if
(
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
Int8
&&
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
QuantizedS8
)
{
if
((
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
Int8
&&
param
.
src_type
.
enumv
()
!=
DTypeEnum
::
QuantizedS8
)
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int16
)
{
return
false
;
}
}
#endif
//! make sure 8x8x16 and 8x8x32 biasmode nonlineMode is identity
//! otherwise return false
if
(
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int16
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int32
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)
{
if
(
param
.
nonlineMode
!=
megdnn
::
NonlineMode
::
IDENTITY
)
{
return
false
;
}
}
//! even no naive support in gemv
if
((
param
.
src_type
.
enumv
()
==
param
.
filter_type
.
enumv
()
&&
param
.
src_type
.
enumv
()
==
DTypeEnum
::
Int16
)
&&
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int32
)
{
return
false
;
}
return
(
param
.
filter_meta
.
dilation
[
0
]
==
param
.
filter_meta
.
dilation
[
1
]
&&
param
.
filter_meta
.
dilation
[
0
]
==
1
)
&&
...
...
dnn/src/fallback/conv_bias/conv1x1/conv1x1_dispatcher.h
浏览文件 @
4a227083
...
...
@@ -11,14 +11,19 @@
#pragma once
#include "src/fallback/conv_bias/conv1x1/conv1x1_strategy.h"
#include "src/fallback/conv_bias/conv1x1/conv1x1_utils.h"
#include "src/fallback/conv_bias/opr_impl.h"
namespace
megdnn
{
namespace
fallback
{
namespace
conv1x1
{
template
<
MatrixMulImpl
::
AlgoBase
::
PackMode
pack_mode
>
class
Conv1x1Kerns
{
class
Conv1x1Kerns
;
template
<
>
class
Conv1x1Kerns
<
MatrixMulImpl
::
AlgoBase
::
PackMode
::
DEFAULT
>
{
public:
//! get_bundle
WorkspaceBundle
get_bundle
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
...
...
@@ -28,13 +33,12 @@ public:
size_t
GROUP
=
param
.
filter_meta
.
group
;
size_t
OC
=
param
.
filter_meta
.
ocpg
;
size_t
BATCH
=
param
.
n
;
//! bundle per thread
//! matmul_param records a matmul with M = oc_tile_size, K = IC, N = OH
//! * OW this does not bother packb bytes
auto
matmul_bundle
=
matmul_algo
->
get_bundle
(
matmul_param
);
auto
thread_bundle
=
utils
::
get_thread_bundle
(
param
,
matmul_bundle
.
get_size
(
2
),
oc_tile_size
);
auto
thread_bundle
=
utils
::
get_thread_bundle
(
param
,
matmul_bundle
.
get_size
(
2
),
oc_tile_size
);
//! size per thread
size_t
all_threads_bytes
=
thread_bundle
.
total_size_in_bytes
()
*
param
.
nr_threads
;
...
...
@@ -46,11 +50,6 @@ public:
is_enable_filter_preprocess
(
param
)
?
0
:
packa_bytes_per_oc_tile
*
oc_tiles_per_group
*
GROUP
;
if
(
pack_mode
==
MatrixMulImpl
::
AlgoBase
::
PackMode
::
ONLY_PACKA
)
return
WorkspaceBundle
{
nullptr
,
{
all_packa_bytes
,
0
,
all_threads_bytes
}};
//! packb size = N * GROUP * packb_size_per_group
size_t
packb_bytes_per_group
=
matmul_bundle
.
get_size
(
1
);
size_t
all_packb_bytes
=
packb_bytes_per_group
*
GROUP
*
BATCH
;
...
...
@@ -58,6 +57,165 @@ public:
return
WorkspaceBundle
{
nullptr
,
{
all_packa_bytes
,
all_packb_bytes
,
all_threads_bytes
}};
}
SmallVector
<
ConvBiasImpl
::
NCBKern
>
get_kern
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
WorkspaceBundle
&
whole_bundle
,
WorkspaceBundle
&
matmul_bundle
,
WorkspaceBundle
&
thread_bundle
,
Conv1x1StrategyBase
*
conv1x1_strategy
,
const
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
size_t
oc_block_size
)
{
auto
kern_packA
=
[
whole_bundle
,
matmul_bundle
,
param
,
matmul_algo
,
oc_block_size
,
conv1x1_strategy
](
const
ConvBiasImpl
::
NCBKernParam
&
ncb_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
mutable
{
conv1x1_strategy
->
packA
(
whole_bundle
,
matmul_bundle
,
oc_block_size
,
matmul_algo
,
param
,
ncb_param
,
std
::
move
(
ncb_index
));
};
auto
kern_packB
=
[
whole_bundle
,
matmul_bundle
,
param
,
matmul_algo
,
conv1x1_strategy
](
const
ConvBiasImpl
::
NCBKernParam
&
ncb_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
mutable
{
conv1x1_strategy
->
packB
(
whole_bundle
,
matmul_bundle
,
matmul_algo
,
param
,
ncb_param
,
std
::
move
(
ncb_index
));
};
auto
kern_compt
=
[
whole_bundle
,
matmul_bundle
,
thread_bundle
,
matmul_algo
,
param
,
oc_block_size
,
conv1x1_strategy
](
const
ConvBiasImpl
::
NCBKernParam
&
ncb_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
mutable
{
conv1x1_strategy
->
exec
(
whole_bundle
,
matmul_bundle
,
thread_bundle
,
oc_block_size
,
matmul_algo
,
param
,
ncb_param
,
std
::
move
(
ncb_index
));
};
size_t
GROUP
=
param
.
filter_meta
.
group
;
size_t
BATCH
=
param
.
n
;
size_t
OC
=
param
.
filter_meta
.
ocpg
;
size_t
oc_blocks_per_group
=
div_ceil
(
OC
,
oc_block_size
);
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ret_kern
;
if
(
!
is_enable_filter_preprocess
(
param
))
{
ret_kern
.
push_back
({
kern_packA
,
{
GROUP
,
oc_blocks_per_group
}});
}
ret_kern
.
push_back
({
kern_packB
,
{
BATCH
}});
ret_kern
.
push_back
({
kern_compt
,
{
BATCH
,
GROUP
,
oc_blocks_per_group
}});
return
ret_kern
;
}
SmallVector
<
ConvBiasImpl
::
NCBKern
>
get_kern_preprocess
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
WorkspaceBundle
&
whole_bundle
,
WorkspaceBundle
&
matmul_bundle
,
Conv1x1StrategyBase
*
conv1x1_strategy
,
const
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
size_t
oc_block_size
)
{
auto
kern_packA
=
[
whole_bundle
,
matmul_bundle
,
param
,
matmul_algo
,
oc_block_size
,
conv1x1_strategy
](
const
ConvBiasImpl
::
NCBKernParam
&
ncb_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
mutable
{
conv1x1_strategy
->
packA
(
whole_bundle
,
matmul_bundle
,
oc_block_size
,
matmul_algo
,
param
,
ncb_param
,
std
::
move
(
ncb_index
));
};
size_t
GROUP
=
param
.
filter_meta
.
group
;
size_t
OC
=
param
.
filter_meta
.
ocpg
;
size_t
oc_blocks_per_group
=
div_ceil
(
OC
,
oc_block_size
);
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ret_kern
;
ret_kern
.
push_back
({
kern_packA
,
{
GROUP
,
oc_blocks_per_group
}});
return
ret_kern
;
}
};
template
<
>
class
Conv1x1Kerns
<
MatrixMulImpl
::
AlgoBase
::
PackMode
::
ONLY_PACKA
>
{
public:
//! get_bundle
WorkspaceBundle
get_bundle
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
const
MatrixMulImpl
::
KernSizeParam
&
matmul_param
,
const
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
size_t
oc_tile_size
)
{
size_t
GROUP
=
param
.
filter_meta
.
group
;
size_t
OC
=
param
.
filter_meta
.
ocpg
;
//! bundle per thread
//! matmul_param records a matmul with M = oc_tile_size, K = IC, N = OH
//! * OW this does not bother packb bytes
auto
matmul_bundle
=
matmul_algo
->
get_bundle
(
matmul_param
);
auto
thread_bundle
=
utils
::
get_thread_bundle
(
param
,
matmul_bundle
.
get_size
(
2
),
oc_tile_size
);
//! size per thread
size_t
all_threads_bytes
=
thread_bundle
.
total_size_in_bytes
()
*
param
.
nr_threads
;
//! packa size = GROUP * packa_size_each_group
size_t
packa_bytes_per_oc_tile
=
matmul_bundle
.
get_size
(
0
);
size_t
oc_tiles_per_group
=
div_ceil
(
OC
,
oc_tile_size
);
size_t
all_packa_bytes
=
is_enable_filter_preprocess
(
param
)
?
0
:
packa_bytes_per_oc_tile
*
oc_tiles_per_group
*
GROUP
;
return
WorkspaceBundle
{
nullptr
,
{
all_packa_bytes
,
0
,
all_threads_bytes
}};
}
SmallVector
<
ConvBiasImpl
::
NCBKern
>
get_kern
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
WorkspaceBundle
&
whole_bundle
,
WorkspaceBundle
&
matmul_bundle
,
WorkspaceBundle
&
thread_bundle
,
Conv1x1StrategyBase
*
conv1x1_strategy
,
const
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
size_t
oc_block_size
)
{
auto
kern_packA
=
[
whole_bundle
,
matmul_bundle
,
param
,
matmul_algo
,
oc_block_size
,
conv1x1_strategy
](
const
ConvBiasImpl
::
NCBKernParam
&
ncb_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
mutable
{
conv1x1_strategy
->
packA
(
whole_bundle
,
matmul_bundle
,
oc_block_size
,
matmul_algo
,
param
,
ncb_param
,
std
::
move
(
ncb_index
));
};
auto
kern_compt
=
[
whole_bundle
,
matmul_bundle
,
thread_bundle
,
matmul_algo
,
param
,
oc_block_size
,
conv1x1_strategy
](
const
ConvBiasImpl
::
NCBKernParam
&
ncb_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
mutable
{
conv1x1_strategy
->
exec
(
whole_bundle
,
matmul_bundle
,
thread_bundle
,
oc_block_size
,
matmul_algo
,
param
,
ncb_param
,
std
::
move
(
ncb_index
));
};
size_t
GROUP
=
param
.
filter_meta
.
group
;
size_t
BATCH
=
param
.
n
;
size_t
OC
=
param
.
filter_meta
.
ocpg
;
size_t
oc_blocks_per_group
=
div_ceil
(
OC
,
oc_block_size
);
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ret_kern
;
if
(
!
is_enable_filter_preprocess
(
param
))
{
ret_kern
.
push_back
({
kern_packA
,
{
GROUP
,
oc_blocks_per_group
}});
}
ret_kern
.
push_back
({
kern_compt
,
{
BATCH
,
GROUP
,
oc_blocks_per_group
}});
return
ret_kern
;
}
SmallVector
<
ConvBiasImpl
::
NCBKern
>
get_kern_preprocess
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
WorkspaceBundle
&
whole_bundle
,
WorkspaceBundle
&
matmul_bundle
,
Conv1x1StrategyBase
*
conv1x1_strategy
,
const
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
size_t
oc_block_size
)
{
auto
kern_packA
=
[
whole_bundle
,
matmul_bundle
,
param
,
matmul_algo
,
oc_block_size
,
conv1x1_strategy
](
const
ConvBiasImpl
::
NCBKernParam
&
ncb_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
mutable
{
conv1x1_strategy
->
packA
(
whole_bundle
,
matmul_bundle
,
oc_block_size
,
matmul_algo
,
param
,
ncb_param
,
std
::
move
(
ncb_index
));
};
size_t
GROUP
=
param
.
filter_meta
.
group
;
size_t
OC
=
param
.
filter_meta
.
ocpg
;
size_t
oc_blocks_per_group
=
div_ceil
(
OC
,
oc_block_size
);
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ret_kern
;
ret_kern
.
push_back
({
kern_packA
,
{
GROUP
,
oc_blocks_per_group
}});
return
ret_kern
;
}
};
template
<
>
...
...
@@ -69,14 +227,47 @@ public:
const
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
size_t
oc_tile_size
)
{
size_t
matmul_size
=
matmul_algo
->
get_workspace
(
matmul_param
);
auto
thread_bundle
=
utils
::
get_thread_bundle
(
param
,
matmul_size
,
oc_tile_size
);
auto
thread_bundle
=
utils
::
get_thread_bundle
(
param
,
matmul_size
,
oc_tile_size
);
//! size per thread
size_t
all_threads_bytes
=
thread_bundle
.
total_size_in_bytes
()
*
param
.
nr_threads
;
return
WorkspaceBundle
{
nullptr
,
{
0
,
0
,
all_threads_bytes
}};
}
SmallVector
<
ConvBiasImpl
::
NCBKern
>
get_kern
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
,
WorkspaceBundle
&
whole_bundle
,
WorkspaceBundle
&
matmul_bundle
,
WorkspaceBundle
&
thread_bundle
,
Conv1x1StrategyBase
*
conv1x1_strategy
,
const
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
size_t
oc_block_size
)
{
auto
kern_compt
=
[
whole_bundle
,
matmul_bundle
,
thread_bundle
,
matmul_algo
,
param
,
oc_block_size
,
conv1x1_strategy
](
const
ConvBiasImpl
::
NCBKernParam
&
ncb_param
,
const
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
mutable
{
conv1x1_strategy
->
exec
(
whole_bundle
,
matmul_bundle
,
thread_bundle
,
oc_block_size
,
matmul_algo
,
param
,
ncb_param
,
std
::
move
(
ncb_index
));
};
size_t
GROUP
=
param
.
filter_meta
.
group
;
size_t
BATCH
=
param
.
n
;
size_t
OC
=
param
.
filter_meta
.
ocpg
;
size_t
oc_blocks_per_group
=
div_ceil
(
OC
,
oc_block_size
);
SmallVector
<
ConvBiasImpl
::
NCBKern
>
ret_kern
;
ret_kern
.
push_back
({
kern_compt
,
{
BATCH
,
GROUP
,
oc_blocks_per_group
}});
return
ret_kern
;
}
SmallVector
<
ConvBiasImpl
::
NCBKern
>
get_kern_preprocess
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
,
WorkspaceBundle
&
,
WorkspaceBundle
&
,
Conv1x1StrategyBase
*
,
const
MatrixMulImpl
::
AlgoBase
*
,
size_t
)
{
return
{};
}
};
}
// namespace conv1x1
}
// namespace fallback
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h
浏览文件 @
4a227083
...
...
@@ -59,7 +59,8 @@ public:
template
<
typename
src_ctype
,
typename
bias_ctype
,
typename
dst_ctype
,
typename
op_ctype
,
typename
op_dtype
,
megdnn
::
PostprocessMode
postprocess_mode
,
MatrixMulImpl
::
AlgoBase
::
PackMode
pack_mode
>
megdnn
::
PostprocessMode
postprocess_mode
,
MatrixMulImpl
::
AlgoBase
::
PackMode
pack_mode
>
class
Conv1x1Strategy
:
public
Conv1x1StrategyBase
{
public:
explicit
Conv1x1Strategy
(
size_t
pack_size
=
1
)
:
m_pack_size
(
pack_size
)
{}
...
...
@@ -136,32 +137,30 @@ public:
size_t
packb_bytes_per_group
=
matmul_bundle
.
get_size
(
1
);
size_t
GROUP
=
param
.
filter_meta
.
group
;
size_t
BATCH
=
param
.
n
;
size_t
SH
=
param
.
filter_meta
.
stride
[
0
];
size_t
SW
=
param
.
filter_meta
.
stride
[
1
];
size_t
OH
=
param
.
osz
[
0
];
size_t
OW
=
param
.
osz
[
1
];
size_t
OC
=
param
.
filter_meta
.
ocpg
;
size_t
batch
=
ncb_index
.
ndrange_id
[
0
];
MatrixMulImpl
::
KernParam
matmul_kern_param
;
static_cast
<
MatrixMulImpl
::
KernSizeParam
&>
(
matmul_kern_param
)
=
utils
::
get_matmul_kern_param
(
param
,
OH
*
OW
,
OC
);
rep
(
batch
,
BATCH
)
{
rep
(
g
,
GROUP
)
{
if
(
SH
==
2
&&
SW
==
2
)
megdnn_throw
(
"no support for stride = 2"
);
size_t
bytes_offset_of_b_panel
=
batch
*
packb_bytes_per_group
*
GROUP
+
g
*
packb_bytes_per_group
;
src_ctype
*
b_panel
=
reinterpret_cast
<
src_ctype
*>
(
reinterpret_cast
<
int8_t
*>
(
whole_bundle
.
get
(
1
))
+
bytes_offset_of_b_panel
);
matmul_kern_param
.
B_ptr
=
const_cast
<
src_ctype
*>
(
ncb_param
.
src
<
src_ctype
>
(
batch
,
g
));
matmul_algo
->
pack_B
(
matmul_kern_param
,
b_panel
,
0
,
OH
*
OW
);
}
rep
(
g
,
GROUP
)
{
if
(
SH
==
2
&&
SW
==
2
)
megdnn_throw
(
"no support for stride = 2"
);
size_t
bytes_offset_of_b_panel
=
batch
*
packb_bytes_per_group
*
GROUP
+
g
*
packb_bytes_per_group
;
src_ctype
*
b_panel
=
reinterpret_cast
<
src_ctype
*>
(
reinterpret_cast
<
int8_t
*>
(
whole_bundle
.
get
(
1
))
+
bytes_offset_of_b_panel
);
matmul_kern_param
.
B_ptr
=
const_cast
<
src_ctype
*>
(
ncb_param
.
src
<
src_ctype
>
(
batch
,
g
));
matmul_algo
->
pack_B
(
matmul_kern_param
,
b_panel
,
0
,
OH
*
OW
);
}
}
else
{
megdnn_log_error
(
"OnlyPackA mode and NoPack mode has no packB kernel"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录