Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8e9fa80c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
8e9fa80c
编写于
6月 16, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/fallback): add matmul description for im2col
GitOrigin-RevId: 5bde0b60f0b8102cd8bad14457cf123bc7e6dafa
上级
af3de7e1
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
190 addition
and
132 deletion
+190
-132
dnn/src/aarch64/matrix_mul/algos.h
dnn/src/aarch64/matrix_mul/algos.h
+4
-0
dnn/src/arm_common/matrix_mul/algos.h
dnn/src/arm_common/matrix_mul/algos.h
+6
-0
dnn/src/armv7/matrix_mul/algos.h
dnn/src/armv7/matrix_mul/algos.h
+3
-0
dnn/src/fallback/conv_bias/im2col/algos.cpp
dnn/src/fallback/conv_bias/im2col/algos.cpp
+59
-70
dnn/src/fallback/conv_bias/im2col/strategy_base.h
dnn/src/fallback/conv_bias/im2col/strategy_base.h
+50
-35
dnn/src/fallback/conv_bias/im2col/strategy_default.cpp
dnn/src/fallback/conv_bias/im2col/strategy_default.cpp
+14
-10
dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp
...src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp
+1
-1
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44.cpp
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44.cpp
+1
-1
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp
...rc/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp
+1
-1
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp
...allback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp
+1
-1
dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp
dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp
+9
-4
dnn/src/fallback/conv_bias/im2col/strategy_onlypacka.cpp
dnn/src/fallback/conv_bias/im2col/strategy_onlypacka.cpp
+9
-4
dnn/src/fallback/matrix_mul/algos.h
dnn/src/fallback/matrix_mul/algos.h
+1
-0
dnn/src/fallback/matrix_mul/gemm_common.h
dnn/src/fallback/matrix_mul/gemm_common.h
+18
-3
dnn/src/fallback/matrix_mul/opr_impl.h
dnn/src/fallback/matrix_mul/opr_impl.h
+7
-1
dnn/src/x86/matrix_mul/algos.h
dnn/src/x86/matrix_mul/algos.h
+6
-1
未找到文件。
dnn/src/aarch64/matrix_mul/algos.h
浏览文件 @
8e9fa80c
...
@@ -60,6 +60,7 @@ public:
...
@@ -60,6 +60,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
16
,
4
,
4
)
};
};
class
MatrixMulImpl
::
AlgoF32Gemv
final
class
MatrixMulImpl
::
AlgoF32Gemv
final
...
@@ -86,6 +87,7 @@ public:
...
@@ -86,6 +87,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
8
,
8
,
2
)
};
};
#endif
#endif
...
@@ -207,6 +209,7 @@ public:
...
@@ -207,6 +209,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
8
,
8
,
2
)
};
};
#if __ARM_FEATURE_DOTPROD
#if __ARM_FEATURE_DOTPROD
...
@@ -234,6 +237,7 @@ public:
...
@@ -234,6 +237,7 @@ public:
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
)
};
};
#else
#else
...
...
dnn/src/arm_common/matrix_mul/algos.h
浏览文件 @
8e9fa80c
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#pragma once
#pragma once
#include "src/arm_common/matrix_mul/opr_impl.h"
#include "src/arm_common/matrix_mul/opr_impl.h"
#include "src/fallback/matrix_mul/gemm_common.h"
namespace
megdnn
{
namespace
megdnn
{
namespace
arm_common
{
namespace
arm_common
{
...
@@ -25,6 +26,7 @@ public:
...
@@ -25,6 +26,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
)
};
};
class
MatrixMulImpl
::
AlgoInt8x8x32Gemv
:
public
AlgoBase
{
class
MatrixMulImpl
::
AlgoInt8x8x32Gemv
:
public
AlgoBase
{
...
@@ -38,6 +40,7 @@ public:
...
@@ -38,6 +40,7 @@ public:
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
)
};
};
class
MatrixMulImpl
::
AlgoF32Gemv
:
public
AlgoBase
{
class
MatrixMulImpl
::
AlgoF32Gemv
:
public
AlgoBase
{
...
@@ -54,6 +57,7 @@ public:
...
@@ -54,6 +57,7 @@ public:
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
)
};
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
...
@@ -68,6 +72,7 @@ public:
...
@@ -68,6 +72,7 @@ public:
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
)
};
};
#endif
#endif
...
@@ -82,6 +87,7 @@ public:
...
@@ -82,6 +87,7 @@ public:
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
)
};
};
...
...
dnn/src/armv7/matrix_mul/algos.h
浏览文件 @
8e9fa80c
...
@@ -49,6 +49,7 @@ public:
...
@@ -49,6 +49,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
8
,
4
,
4
)
};
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
...
@@ -71,6 +72,7 @@ public:
...
@@ -71,6 +72,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
8
,
8
,
2
)
};
};
#endif
#endif
#if __ARM_FEATURE_DOTPROD
#if __ARM_FEATURE_DOTPROD
...
@@ -190,6 +192,7 @@ public:
...
@@ -190,6 +192,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
8
,
8
,
2
)
};
};
class
MatrixMulImpl
::
AlgoInt8x8x32MK4_4x2x16
final
:
public
AlgoBase
{
class
MatrixMulImpl
::
AlgoInt8x8x32MK4_4x2x16
final
:
public
AlgoBase
{
...
...
dnn/src/fallback/conv_bias/im2col/algos.cpp
浏览文件 @
8e9fa80c
...
@@ -47,14 +47,17 @@ static void copy_padding_kern(WorkspaceBundle bundle,
...
@@ -47,14 +47,17 @@ static void copy_padding_kern(WorkspaceBundle bundle,
}
}
//! packA_kern
//! packA_kern
static
void
packA_kern
(
WorkspaceBundle
bundle
,
static
void
packA_kern
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
WorkspaceBundle
bundle
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
StrategyBase
*
im2colstrategy
,
size_t
pack_oc_size
)
{
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
StrategyBase
*
im2colstrategy
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
matmul_desc
,
size_t
pack_oc_size
)
{
im2colstrategy
->
packA_kern
(
bundle
,
param
,
matmulparam
,
matmul_algo
,
im2colstrategy
->
packA_kern
(
bundle
,
param
,
matmulparam
,
matmul_algo
,
ncb_index
,
pack_oc_size
);
ncb_index
,
matmul_desc
,
pack_oc_size
);
}
}
/*!
/*!
...
@@ -72,7 +75,8 @@ public:
...
@@ -72,7 +75,8 @@ public:
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
const
ConvBiasImpl
::
NCBKernParam
&
param
,
const
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmul_kernsize_param
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmul_kernsize_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
matmul_desc
,
StrategyParam
strategyparam
,
StrategyParam
strategyparam
,
fallback
::
ConvBiasImpl
::
NCBKernIndex
ncb_index
,
fallback
::
ConvBiasImpl
::
NCBKernIndex
ncb_index
,
size_t
ohw_tile_size
,
StrategyBase
*
im2colstrategy
)
{
size_t
ohw_tile_size
,
StrategyBase
*
im2colstrategy
)
{
...
@@ -111,7 +115,8 @@ public:
...
@@ -111,7 +115,8 @@ public:
//! 2.packb and matmul compute
//! 2.packb and matmul compute
im2colstrategy
->
exec_matmul
(
param
,
strategyparam
,
bundle
,
bundle_thread
,
im2colstrategy
->
exec_matmul
(
param
,
strategyparam
,
bundle
,
bundle_thread
,
matmul_param
,
matmul_algo
,
ncb_index
);
matmul_param
,
matmul_algo
,
ncb_index
,
matmul_desc
);
//! 3.postprocess and copy dst if need
//! 3.postprocess and copy dst if need
im2colstrategy
->
exec_postprocess
(
param
,
strategyparam
,
bundle_thread
);
im2colstrategy
->
exec_postprocess
(
param
,
strategyparam
,
bundle_thread
);
...
@@ -151,7 +156,8 @@ public:
...
@@ -151,7 +156,8 @@ public:
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
const
ConvBiasImpl
::
NCBKernParam
&
param
,
const
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmul_kernsize_param
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmul_kernsize_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
matmul_desc
,
StrategyParam
strategyparam
,
StrategyParam
strategyparam
,
fallback
::
ConvBiasImpl
::
NCBKernIndex
ncb_index
,
fallback
::
ConvBiasImpl
::
NCBKernIndex
ncb_index
,
size_t
ohw_tile_size
,
StrategyBase
*
im2colstrategy
)
{
size_t
ohw_tile_size
,
StrategyBase
*
im2colstrategy
)
{
...
@@ -191,7 +197,8 @@ public:
...
@@ -191,7 +197,8 @@ public:
//! 2.packb and matmul compute
//! 2.packb and matmul compute
im2colstrategy
->
exec_matmul
(
param
,
strategyparam
,
bundle
,
bundle_thread
,
im2colstrategy
->
exec_matmul
(
param
,
strategyparam
,
bundle
,
bundle_thread
,
matmul_param
,
matmul_algo
,
ncb_index
);
matmul_param
,
matmul_algo
,
ncb_index
,
matmul_desc
);
//! 3.postprocess and copy dst if need
//! 3.postprocess and copy dst if need
im2colstrategy
->
exec_postprocess
(
param
,
strategyparam
,
bundle_thread
);
im2colstrategy
->
exec_postprocess
(
param
,
strategyparam
,
bundle_thread
);
...
@@ -232,7 +239,8 @@ public:
...
@@ -232,7 +239,8 @@ public:
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
const
ConvBiasImpl
::
NCBKernParam
&
param
,
const
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmul_kernsize_param
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmul_kernsize_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
matmul_desc
,
StrategyParam
strategyparam
,
StrategyParam
strategyparam
,
fallback
::
ConvBiasImpl
::
NCBKernIndex
ncb_index
,
fallback
::
ConvBiasImpl
::
NCBKernIndex
ncb_index
,
size_t
ohw_tile_size
,
StrategyBase
*
im2colstrategy
)
{
size_t
ohw_tile_size
,
StrategyBase
*
im2colstrategy
)
{
...
@@ -272,7 +280,8 @@ public:
...
@@ -272,7 +280,8 @@ public:
//! 2.packb and matmul compute
//! 2.packb and matmul compute
im2colstrategy
->
exec_matmul
(
param
,
strategyparam
,
bundle
,
bundle_thread
,
im2colstrategy
->
exec_matmul
(
param
,
strategyparam
,
bundle
,
bundle_thread
,
matmul_param
,
matmul_algo
,
ncb_index
);
matmul_param
,
matmul_algo
,
ncb_index
,
matmul_desc
);
//! 3.postprocess and copy dst if need
//! 3.postprocess and copy dst if need
im2colstrategy
->
exec_postprocess
(
param
,
strategyparam
,
bundle_thread
);
im2colstrategy
->
exec_postprocess
(
param
,
strategyparam
,
bundle_thread
);
...
@@ -401,13 +410,15 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
...
@@ -401,13 +410,15 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
size_t
padding
=
0
,
packa_size
=
0
,
packa_group_size
=
0
;
size_t
padding
=
0
,
packa_size
=
0
,
packa_group_size
=
0
;
size_t
nr_threads
=
param
.
nr_threads
;
size_t
nr_threads
=
param
.
nr_threads
;
size_t
GROUP
=
param
.
filter_meta
.
group
;
size_t
GROUP
=
param
.
filter_meta
.
group
;
bool
need_pack
=
m_matmul_algo
->
packmode
()
==
Pack_Mode
::
DEFAULT
;
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
mdesc
=
bool
only_packA
=
m_matmul_algo
->
packmode
()
==
Pack_Mode
::
ONLY_PACKA
;
m_matmul_algo
->
matmul_description
();
bool
need_pack
=
mdesc
.
packmode
==
Pack_Mode
::
DEFAULT
;
bool
only_packA
=
mdesc
.
packmode
==
Pack_Mode
::
ONLY_PACKA
;
size_t
oc_tile_size
=
0
,
ohw_tile_size
=
0
;
size_t
oc_tile_size
=
0
,
ohw_tile_size
=
0
;
choice_ohw_oc_block
(
param
,
oc_tile_size
,
ohw_tile_size
,
mdesc
.
innerblocksize
.
m
,
mdesc
.
innerblocksize
.
n
,
mdesc
.
packmode
);
if
(
need_pack
||
only_packA
)
{
if
(
need_pack
||
only_packA
)
{
auto
inner_block
=
m_matmul_algo
->
get_inner_block_size
();
choice_ohw_oc_block
(
param
,
oc_tile_size
,
ohw_tile_size
,
inner_block
.
m
,
inner_block
.
n
,
m_matmul_algo
->
packmode
());
auto
im2col_kern_param
=
get_matmul_kern_param
(
auto
im2col_kern_param
=
get_matmul_kern_param
(
param
,
ohw_tile_size
,
only_packA
?
oc_tile_size
:
OC
);
param
,
ohw_tile_size
,
only_packA
?
oc_tile_size
:
OC
);
size_t
oc_parallel_times
=
div_ceil
<
size_t
>
(
OC
,
oc_tile_size
);
size_t
oc_parallel_times
=
div_ceil
<
size_t
>
(
OC
,
oc_tile_size
);
...
@@ -415,11 +426,6 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
...
@@ -415,11 +426,6 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
packa_group_size
=
only_packA
?
oc_parallel_times
*
wb
.
get_size
(
0
)
packa_group_size
=
only_packA
?
oc_parallel_times
*
wb
.
get_size
(
0
)
:
wb
.
get_size
(
0
);
:
wb
.
get_size
(
0
);
}
else
{
//! not support pack,not need pack
}
else
{
//! not support pack,not need pack
size_t
nopack_default_blockm
=
8
;
size_t
nopack_default_blockn
=
16
;
choice_ohw_oc_block
(
param
,
oc_tile_size
,
ohw_tile_size
,
nopack_default_blockm
,
nopack_default_blockn
,
m_matmul_algo
->
packmode
());
packa_group_size
=
0
;
packa_group_size
=
0
;
}
}
...
@@ -481,23 +487,18 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
...
@@ -481,23 +487,18 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
WorkspaceBundle
bundle
=
get_bundle
(
param
);
WorkspaceBundle
bundle
=
get_bundle
(
param
);
WorkspaceBundle
bundle_thread
=
{
nullptr
,
{}};
WorkspaceBundle
bundle_thread
=
{
nullptr
,
{}};
bool
need_padding
=
(
PH
!=
0
||
PW
!=
0
);
bool
need_padding
=
(
PH
!=
0
||
PW
!=
0
);
Pack_Mode
packmode
=
m_matmul_algo
->
packmode
();
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
mdesc
=
m_matmul_algo
->
matmul_description
();
Pack_Mode
packmode
=
mdesc
.
packmode
;
bool
default_pack
=
packmode
==
Pack_Mode
::
DEFAULT
;
bool
default_pack
=
packmode
==
Pack_Mode
::
DEFAULT
;
bool
no_pack
=
packmode
==
Pack_Mode
::
NO_PACK
;
bool
no_pack
=
packmode
==
Pack_Mode
::
NO_PACK
;
bool
only_packA
=
packmode
==
Pack_Mode
::
ONLY_PACKA
;
bool
only_packA
=
packmode
==
Pack_Mode
::
ONLY_PACKA
;
if
(
default_pack
||
only_packA
)
{
choice_ohw_oc_block
(
param
,
oc_tile_size
,
ohw_tile_size
,
auto
inner_block
=
m_matmul_algo
->
get_inner_block_size
();
mdesc
.
innerblocksize
.
m
,
mdesc
.
innerblocksize
.
n
,
choice_ohw_oc_block
(
param
,
oc_tile_size
,
ohw_tile_size
,
mdesc
.
packmode
);
inner_block
.
m
,
inner_block
.
n
,
m_matmul_algo
->
packmode
());
}
else
{
//! nopack_mode
size_t
nopack_default_blockm
=
8
;
size_t
nopack_default_blockn
=
16
;
choice_ohw_oc_block
(
param
,
oc_tile_size
,
ohw_tile_size
,
nopack_default_blockm
,
nopack_default_blockn
,
m_matmul_algo
->
packmode
());
}
size_t
ohw_parallel_times
=
div_ceil
(
ohw
,
ohw_tile_size
);
size_t
ohw_parallel_times
=
div_ceil
(
ohw
,
ohw_tile_size
);
size_t
oc_parallel_times
=
div_ceil
<
size_t
>
(
OC
,
oc_tile_size
);
size_t
oc_parallel_times
=
div_ceil
<
size_t
>
(
OC
,
oc_tile_size
);
...
@@ -507,18 +508,17 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
...
@@ -507,18 +508,17 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
if
(
only_packA
)
{
if
(
only_packA
)
{
packa_parallel_times
=
div_ceil
<
size_t
>
(
OC
,
oc_tile_size
);
packa_parallel_times
=
div_ceil
<
size_t
>
(
OC
,
oc_tile_size
);
}
else
if
(
default_pack
)
{
}
else
if
(
default_pack
)
{
packa_parallel_times
=
div_ceil
<
size_t
>
(
packa_parallel_times
=
div_ceil
<
size_t
>
(
OC
,
mdesc
.
innerblocksize
.
m
);
OC
,
m_matmul_algo
->
get_inner_block_size
().
m
);
}
}
auto
matmul_param
=
get_matmul_kern_param
(
auto
matmul_param
=
get_matmul_kern_param
(
param
,
ohw_tile_size
,
only_packA
?
oc_tile_size
:
OC
);
param
,
ohw_tile_size
,
only_packA
?
oc_tile_size
:
OC
);
if
(
m
_matmul_algo
->
packmode
()
==
Pack_Mode
::
DEFAULT
)
{
if
(
m
desc
.
packmode
==
Pack_Mode
::
DEFAULT
)
{
Im2colKerns
<
Pack_Mode
::
DEFAULT
>
defaultkern
;
Im2colKerns
<
Pack_Mode
::
DEFAULT
>
defaultkern
;
bundle_thread
=
defaultkern
.
get_thread_bundle
(
bundle_thread
=
defaultkern
.
get_thread_bundle
(
param
,
matmul_param
,
m_matmul_algo
,
ohw_tile_size
,
param
,
matmul_param
,
m_matmul_algo
,
ohw_tile_size
,
oc_tile_size
);
oc_tile_size
);
}
else
if
(
m
_matmul_algo
->
packmode
()
==
Pack_Mode
::
ONLY_PACKA
)
{
}
else
if
(
m
desc
.
packmode
==
Pack_Mode
::
ONLY_PACKA
)
{
Im2colKerns
<
Pack_Mode
::
ONLY_PACKA
>
onlypackakern
;
Im2colKerns
<
Pack_Mode
::
ONLY_PACKA
>
onlypackakern
;
bundle_thread
=
onlypackakern
.
get_thread_bundle
(
bundle_thread
=
onlypackakern
.
get_thread_bundle
(
param
,
matmul_param
,
m_matmul_algo
,
ohw_tile_size
,
param
,
matmul_param
,
m_matmul_algo
,
ohw_tile_size
,
...
@@ -559,24 +559,24 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
...
@@ -559,24 +559,24 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
auto
kern_packA
=
[
bundle
,
matmul_algo
=
m_matmul_algo
,
auto
kern_packA
=
[
bundle
,
matmul_algo
=
m_matmul_algo
,
matmul_param
,
im2colstrategy
,
matmul_param
,
im2colstrategy
,
pack_oc_size
=
pack_oc_size
](
pack_oc_size
=
pack_oc_size
,
const
NCBKernParam
&
param
,
mdesc
=
mdesc
](
const
NCBKernParam
&
param
,
const
NCBKernIndex
&
ncb_index
)
{
const
NCBKernIndex
&
ncb_index
)
{
packA_kern
(
bundle
,
param
,
matmul_param
,
matmul_algo
,
ncb_index
,
packA_kern
(
bundle
,
param
,
matmul_param
,
matmul_algo
,
ncb_index
,
im2colstrategy
,
pack_oc_size
);
im2colstrategy
,
mdesc
,
pack_oc_size
);
};
};
if
(
default_pack
)
{
if
(
default_pack
)
{
auto
kern_compute_default
=
auto
kern_compute_default
=
[
bundle
,
bundle_thread
,
matmul_param
,
[
bundle
,
bundle_thread
,
matmul_param
,
matmul_algo
=
m_matmul_algo
,
matmul_algo
=
m_matmul_algo
,
ohw_tile_size
=
ohw_tile_size
,
ohw_tile_size
=
ohw_tile_size
,
strategyparam
=
strategyparam
,
strategyparam
=
strategyparam
,
matmul_desc
=
mdesc
,
im2colstrategy
](
const
NCBKernParam
&
param
,
im2colstrategy
](
const
NCBKernParam
&
param
,
const
NCBKernIndex
&
ncb_index
)
{
const
NCBKernIndex
&
ncb_index
)
{
Im2colKerns
<
Pack_Mode
::
DEFAULT
>::
kerns
(
Im2colKerns
<
Pack_Mode
::
DEFAULT
>::
kerns
(
bundle
,
bundle_thread
,
param
,
matmul_param
,
bundle
,
bundle_thread
,
param
,
matmul_param
,
matmul_algo
,
strategyparam
,
ncb_index
,
matmul_algo
,
matmul_desc
,
strategyparam
,
ohw_tile_size
,
im2colstrategy
);
ncb_index
,
ohw_tile_size
,
im2colstrategy
);
};
};
ret_kern
.
push_back
({
kern_packA
,
{
GROUP
,
packa_parallel_times
}});
ret_kern
.
push_back
({
kern_packA
,
{
GROUP
,
packa_parallel_times
}});
...
@@ -592,13 +592,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
...
@@ -592,13 +592,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
[
bundle
,
bundle_thread
,
matmul_param
,
[
bundle
,
bundle_thread
,
matmul_param
,
matmul_algo
=
m_matmul_algo
,
matmul_algo
=
m_matmul_algo
,
strategyparam
=
strategyparam
,
strategyparam
=
strategyparam
,
ohw_tile_size
=
ohw_tile_size
,
ohw_tile_size
=
ohw_tile_size
,
matmul_desc
=
mdesc
,
im2colstrategy
](
const
NCBKernParam
&
param
,
im2colstrategy
](
const
NCBKernParam
&
param
,
const
NCBKernIndex
&
ncb_index
)
{
const
NCBKernIndex
&
ncb_index
)
{
Im2colKerns
<
Pack_Mode
::
ONLY_PACKA
>::
kerns
(
Im2colKerns
<
Pack_Mode
::
ONLY_PACKA
>::
kerns
(
bundle
,
bundle_thread
,
param
,
matmul_param
,
bundle
,
bundle_thread
,
param
,
matmul_param
,
matmul_algo
,
strategyparam
,
ncb_index
,
matmul_algo
,
matmul_desc
,
strategyparam
,
ohw_tile_size
,
im2colstrategy
);
ncb_index
,
ohw_tile_size
,
im2colstrategy
);
};
};
ret_kern
.
push_back
({
kern_packA
,
{
GROUP
,
packa_parallel_times
}});
ret_kern
.
push_back
({
kern_packA
,
{
GROUP
,
packa_parallel_times
}});
if
(
need_padding
)
{
if
(
need_padding
)
{
...
@@ -612,13 +612,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
...
@@ -612,13 +612,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
[
bundle
,
bundle_thread
,
matmul_param
,
[
bundle
,
bundle_thread
,
matmul_param
,
matmul_algo
=
m_matmul_algo
,
matmul_algo
=
m_matmul_algo
,
strategyparam
=
strategyparam
,
strategyparam
=
strategyparam
,
ohw_tile_size
=
ohw_tile_size
,
ohw_tile_size
=
ohw_tile_size
,
matmul_desc
=
mdesc
,
im2colstrategy
](
const
NCBKernParam
&
param
,
im2colstrategy
](
const
NCBKernParam
&
param
,
const
NCBKernIndex
&
ncb_index
)
{
const
NCBKernIndex
&
ncb_index
)
{
Im2colKerns
<
Pack_Mode
::
NO_PACK
>::
kerns
(
Im2colKerns
<
Pack_Mode
::
NO_PACK
>::
kerns
(
bundle
,
bundle_thread
,
param
,
matmul_param
,
bundle
,
bundle_thread
,
param
,
matmul_param
,
matmul_algo
,
strategyparam
,
ncb_index
,
matmul_algo
,
matmul_desc
,
strategyparam
,
ohw_tile_size
,
im2colstrategy
);
ncb_index
,
ohw_tile_size
,
im2colstrategy
);
};
};
if
(
need_padding
)
{
if
(
need_padding
)
{
...
@@ -668,10 +668,12 @@ bool ConvBiasImpl::AlgoIm2col::usable(
...
@@ -668,10 +668,12 @@ bool ConvBiasImpl::AlgoIm2col::usable(
return
false
;
return
false
;
}
}
}
}
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
mdesc
=
m_matmul_algo
->
matmul_description
();
if
(
opr
->
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW44
||
if
(
opr
->
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW44
||
opr
->
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW44_DOT
)
{
opr
->
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW44_DOT
)
{
//! current NCHW44 im2col only support DEFAULT mode matmul
//! current NCHW44 im2col only support DEFAULT mode matmul
if
(
m
_matmul_algo
->
packmode
()
!=
Pack_Mode
::
DEFAULT
)
{
if
(
m
desc
.
packmode
!=
Pack_Mode
::
DEFAULT
)
{
return
false
;
return
false
;
//! nchw44 hybird mode and channel wise is not support
//! nchw44 hybird mode and channel wise is not support
}
else
if
(
param
.
filter_meta
.
icpg
<
4
_z
||
}
else
if
(
param
.
filter_meta
.
icpg
<
4
_z
||
...
@@ -682,22 +684,9 @@ bool ConvBiasImpl::AlgoIm2col::usable(
...
@@ -682,22 +684,9 @@ bool ConvBiasImpl::AlgoIm2col::usable(
}
}
size_t
oc_tile_size
=
0
,
ohw_tile_size
=
0
;
size_t
oc_tile_size
=
0
,
ohw_tile_size
=
0
;
Pack_Mode
packmode
=
m_matmul_algo
->
packmode
();
choice_ohw_oc_block
(
param
,
oc_tile_size
,
ohw_tile_size
,
bool
default_pack
=
packmode
==
Pack_Mode
::
DEFAULT
;
mdesc
.
innerblocksize
.
m
,
mdesc
.
innerblocksize
.
n
,
bool
only_packA
=
packmode
==
Pack_Mode
::
ONLY_PACKA
;
m_matmul_algo
->
packmode
());
if
(
default_pack
||
only_packA
)
{
auto
inner_block
=
m_matmul_algo
->
get_inner_block_size
();
choice_ohw_oc_block
(
param
,
oc_tile_size
,
ohw_tile_size
,
inner_block
.
m
,
inner_block
.
n
,
m_matmul_algo
->
packmode
());
}
else
{
//! not support pack,not need pack
size_t
nopack_default_blockm
=
8
;
size_t
nopack_default_blockn
=
16
;
choice_ohw_oc_block
(
param
,
oc_tile_size
,
ohw_tile_size
,
nopack_default_blockm
,
nopack_default_blockn
,
m_matmul_algo
->
packmode
());
}
fallback
::
MatrixMulImpl
::
KernSizeParam
matmul_param
=
fallback
::
MatrixMulImpl
::
KernSizeParam
matmul_param
=
get_matmul_kern_param
(
param
,
ohw_tile_size
,
oc_tile_size
);
get_matmul_kern_param
(
param
,
ohw_tile_size
,
oc_tile_size
);
bool
matmulusable
=
m_matmul_algo
->
usable
(
matmul_param
);
bool
matmulusable
=
m_matmul_algo
->
usable
(
matmul_param
);
...
...
dnn/src/fallback/conv_bias/im2col/strategy_base.h
浏览文件 @
8e9fa80c
...
@@ -58,8 +58,9 @@ public:
...
@@ -58,8 +58,9 @@ public:
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
matmul_desec
,
size_t
pack_size
)
=
0
;
size_t
pack_size
)
=
0
;
virtual
void
exec_im2col
(
virtual
void
exec_im2col
(
...
@@ -67,15 +68,17 @@ public:
...
@@ -67,15 +68,17 @@ public:
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
=
0
;
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
=
0
;
virtual
void
exec_matmul
(
virtual
void
exec_matmul
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle
,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
WorkspaceBundle
bundle_thread
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
=
0
;
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
matmul_desc
)
=
0
;
virtual
void
exec_postprocess
(
virtual
void
exec_postprocess
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
...
@@ -284,26 +287,30 @@ public:
...
@@ -284,26 +287,30 @@ public:
Strategy
()
=
default
;
Strategy
()
=
default
;
virtual
void
packA_kern
(
WorkspaceBundle
bundle
,
virtual
void
packA_kern
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
WorkspaceBundle
bundle
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
size_t
pack_size
)
override
;
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
matmul_desc
,
size_t
pack_size
)
override
;
virtual
void
exec_im2col
(
virtual
void
exec_im2col
(
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
void
exec_matmul
(
void
exec_matmul
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle
,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
WorkspaceBundle
bundle_thread
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
override
;
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
matmul_desc
)
override
;
void
exec_postprocess
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
void
exec_postprocess
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle_thread
)
override
{
WorkspaceBundle
bundle_thread
)
override
{
...
@@ -338,7 +345,7 @@ public:
...
@@ -338,7 +345,7 @@ public:
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
};
};
template
<
typename
src_ctype
,
typename
bias_ctype
,
typename
dst_ctype
,
template
<
typename
src_ctype
,
typename
bias_ctype
,
typename
dst_ctype
,
...
@@ -359,20 +366,24 @@ public:
...
@@ -359,20 +366,24 @@ public:
Strategy
()
=
default
;
Strategy
()
=
default
;
void
packA_kern
(
WorkspaceBundle
bundle
,
void
packA_kern
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
WorkspaceBundle
bundle
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
size_t
pack_size
)
override
;
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
MDsec
,
size_t
pack_size
)
override
;
void
exec_matmul
(
void
exec_matmul
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle
,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
WorkspaceBundle
bundle_thread
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
override
;
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
matmul_desc
)
override
;
void
*
get_matmul_dst_ptr
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
void
*
get_matmul_dst_ptr
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
WorkspaceBundle
&
bundle_thread
,
const
WorkspaceBundle
&
bundle_thread
,
...
@@ -382,7 +393,7 @@ public:
...
@@ -382,7 +393,7 @@ public:
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
void
exec_postprocess
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
void
exec_postprocess
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle_thread
)
override
{
WorkspaceBundle
bundle_thread
)
override
{
...
@@ -411,26 +422,30 @@ public:
...
@@ -411,26 +422,30 @@ public:
Strategy
()
=
default
;
Strategy
()
=
default
;
void
packA_kern
(
WorkspaceBundle
bundle
,
void
packA_kern
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
WorkspaceBundle
bundle
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
size_t
pack_size
)
override
;
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
MDsec
,
size_t
pack_size
)
override
;
void
exec_im2col
(
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
void
exec_im2col
(
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
void
exec_matmul
(
void
exec_matmul
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle
,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
WorkspaceBundle
bundle_thread
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
override
;
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
matmul_desc
)
override
;
void
*
get_matmul_dst_ptr
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
void
*
get_matmul_dst_ptr
(
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
WorkspaceBundle
&
bundle_thread
,
const
WorkspaceBundle
&
bundle_thread
,
...
@@ -465,7 +480,7 @@ public:
...
@@ -465,7 +480,7 @@ public:
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
};
};
template
<
typename
op_ctype
,
typename
op_dtype
,
template
<
typename
op_ctype
,
typename
op_dtype
,
...
@@ -487,7 +502,7 @@ public:
...
@@ -487,7 +502,7 @@ public:
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
};
};
...
@@ -510,7 +525,7 @@ public:
...
@@ -510,7 +525,7 @@ public:
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
override
;
};
};
#endif
#endif
...
...
dnn/src/fallback/conv_bias/im2col/strategy_default.cpp
浏览文件 @
8e9fa80c
...
@@ -21,8 +21,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
...
@@ -21,8 +21,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
packA_kern
(
WorkspaceBundle
bundle
,
packA_kern
(
WorkspaceBundle
bundle
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
matmul_desc
,
size_t
)
{
size_t
)
{
bundle
.
set
(
param
.
workspace_ptr
);
bundle
.
set
(
param
.
workspace_ptr
);
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
;
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
;
...
@@ -31,16 +33,16 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
...
@@ -31,16 +33,16 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
matmulparam
;
matmulparam
;
size_t
packA_group_size
=
matmul_algo
->
get_bundle
(
matmul_param
).
get_size
(
0
);
size_t
packA_group_size
=
matmul_algo
->
get_bundle
(
matmul_param
).
get_size
(
0
);
size_t
packed_per_oc_block_size
=
size_t
packed_per_oc_block_size
=
round_up
(
matmul_param
.
K
,
matmul_
algo
->
get_inner_block_size
()
.
k
)
*
round_up
(
matmul_param
.
K
,
matmul_
desc
.
innerblocksize
.
k
)
*
matmul_
algo
->
get_inner_block_size
().
m
*
matmul_
desc
.
innerblocksize
.
m
*
matmul_desc
.
packa_type_size
;
matmul_algo
->
get_packA_type_size
();
size_t
a_panel_offset
=
ncb_index
.
ndrange_id
[
1
]
*
packed_per_oc_block_size
;
size_t
a_panel_offset
=
ncb_index
.
ndrange_id
[
1
]
*
packed_per_oc_block_size
;
int8_t
*
a_panel
=
static_cast
<
int8_t
*>
(
bundle
.
get
(
BUNDLE_PACKA_INDEX
))
+
int8_t
*
a_panel
=
static_cast
<
int8_t
*>
(
bundle
.
get
(
BUNDLE_PACKA_INDEX
))
+
group_id
*
packA_group_size
+
a_panel_offset
;
group_id
*
packA_group_size
+
a_panel_offset
;
matmul_param
.
A_ptr
=
matmul_param
.
A_ptr
=
const_cast
<
src_ctype
*>
(
param
.
filter
<
src_ctype
>
(
group_id
));
const_cast
<
src_ctype
*>
(
param
.
filter
<
src_ctype
>
(
group_id
));
matmul_algo
->
pack_A
(
matmul_param
,
a_panel
,
ncb_index
.
ndrange_id
[
1
],
matmul_algo
->
pack_A
(
matmul_param
,
a_panel
,
ncb_index
.
ndrange_id
[
1
],
matmul_
algo
->
get_inner_block_size
()
.
m
);
matmul_
desc
.
innerblocksize
.
m
);
}
}
template
<
typename
src_ctype
,
typename
bias_ctype
,
typename
dst_ctype
,
template
<
typename
src_ctype
,
typename
bias_ctype
,
typename
dst_ctype
,
...
@@ -52,7 +54,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
...
@@ -52,7 +54,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
{
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
{
size_t
sh
=
param
.
filter_meta
.
stride
[
0
];
size_t
sh
=
param
.
filter_meta
.
stride
[
0
];
size_t
sw
=
param
.
filter_meta
.
stride
[
1
];
size_t
sw
=
param
.
filter_meta
.
stride
[
1
];
size_t
oc
=
param
.
filter_meta
.
ocpg
;
size_t
oc
=
param
.
filter_meta
.
ocpg
;
...
@@ -140,11 +142,13 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
...
@@ -140,11 +142,13 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle
,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
WorkspaceBundle
bundle_thread
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
{
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
matmul_desc
)
{
size_t
packA_per_oc_block_size
=
size_t
packA_per_oc_block_size
=
round_up
(
matmul_param
.
K
,
matmul_
algo
->
get_inner_block_size
()
.
k
)
*
round_up
(
matmul_param
.
K
,
matmul_
desc
.
innerblocksize
.
k
)
*
sparam
.
oc_tile_size
*
matmul_
algo
->
get_packA_type_size
()
;
sparam
.
oc_tile_size
*
matmul_
desc
.
packa_type_size
;
size_t
packA_group_size
=
matmul_algo
->
get_bundle
(
matmul_param
).
get_size
(
0
);
size_t
packA_group_size
=
matmul_algo
->
get_bundle
(
matmul_param
).
get_size
(
0
);
size_t
a_panel_offset
=
ncb_index
.
ndrange_id
[
1
]
*
packA_group_size
+
size_t
a_panel_offset
=
ncb_index
.
ndrange_id
[
1
]
*
packA_group_size
+
ncb_index
.
ndrange_id
[
3
]
*
packA_per_oc_block_size
;
ncb_index
.
ndrange_id
[
3
]
*
packA_per_oc_block_size
;
...
...
dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp
浏览文件 @
8e9fa80c
...
@@ -33,7 +33,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
...
@@ -33,7 +33,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
{
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
{
size_t
sh
=
param
.
filter_meta
.
stride
[
0
];
size_t
sh
=
param
.
filter_meta
.
stride
[
0
];
size_t
sw
=
param
.
filter_meta
.
stride
[
1
];
size_t
sw
=
param
.
filter_meta
.
stride
[
1
];
size_t
oc
=
param
.
filter_meta
.
ocpg
;
size_t
oc
=
param
.
filter_meta
.
ocpg
;
...
...
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44.cpp
浏览文件 @
8e9fa80c
...
@@ -173,7 +173,7 @@ void StrategyFuse4x4x16Nchw44<op_ctype, op_dtype, postprocess_mode>::
...
@@ -173,7 +173,7 @@ void StrategyFuse4x4x16Nchw44<op_ctype, op_dtype, postprocess_mode>::
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
,
fallback
::
MatrixMulImpl
::
KernParam
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
)
{
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
)
{
size_t
ow
=
param
.
osz
[
1
];
size_t
ow
=
param
.
osz
[
1
];
size_t
ic
=
param
.
filter_meta
.
icpg
;
size_t
ic
=
param
.
filter_meta
.
icpg
;
size_t
ih
=
param
.
isz
[
0
]
+
param
.
filter_meta
.
padding
[
0
]
*
2
;
size_t
ih
=
param
.
isz
[
0
]
+
param
.
filter_meta
.
padding
[
0
]
*
2
;
...
...
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp
浏览文件 @
8e9fa80c
...
@@ -176,7 +176,7 @@ void StrategyFuse8x12x4Nchw44Dot<op_ctype, op_dtype, postprocess_mode>::
...
@@ -176,7 +176,7 @@ void StrategyFuse8x12x4Nchw44Dot<op_ctype, op_dtype, postprocess_mode>::
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
/*matmul_param*/
,
fallback
::
MatrixMulImpl
::
KernParam
/*matmul_param*/
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
/*matmul_algo*/
)
{
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
/*matmul_algo*/
)
{
size_t
ow
=
param
.
osz
[
1
];
size_t
ow
=
param
.
osz
[
1
];
size_t
ic
=
param
.
filter_meta
.
icpg
;
size_t
ic
=
param
.
filter_meta
.
icpg
;
size_t
ih
=
param
.
isz
[
0
]
+
param
.
filter_meta
.
padding
[
0
]
*
2
;
size_t
ih
=
param
.
isz
[
0
]
+
param
.
filter_meta
.
padding
[
0
]
*
2
;
...
...
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp
浏览文件 @
8e9fa80c
...
@@ -168,7 +168,7 @@ void StrategyFuse8x12x1Nchw44K3x3S2<op_ctype, op_dtype, postprocess_mode>::
...
@@ -168,7 +168,7 @@ void StrategyFuse8x12x1Nchw44K3x3S2<op_ctype, op_dtype, postprocess_mode>::
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
/*matmul_param*/
,
fallback
::
MatrixMulImpl
::
KernParam
/*matmul_param*/
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
/*matmul_algo*/
)
{
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
/*matmul_algo*/
)
{
size_t
ow
=
param
.
osz
[
1
];
size_t
ow
=
param
.
osz
[
1
];
size_t
ic
=
param
.
filter_meta
.
icpg
;
size_t
ic
=
param
.
filter_meta
.
icpg
;
size_t
ih
=
param
.
isz
[
0
]
+
param
.
filter_meta
.
padding
[
0
]
*
2
;
size_t
ih
=
param
.
isz
[
0
]
+
param
.
filter_meta
.
padding
[
0
]
*
2
;
...
...
dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp
浏览文件 @
8e9fa80c
...
@@ -22,8 +22,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
...
@@ -22,8 +22,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
packA_kern
(
WorkspaceBundle
bundle
,
packA_kern
(
WorkspaceBundle
bundle
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
/*matmul_dsec*/
,
size_t
)
{
size_t
)
{
MEGDNN_MARK_USED_VAR
(
bundle
);
MEGDNN_MARK_USED_VAR
(
bundle
);
MEGDNN_MARK_USED_VAR
(
param
);
MEGDNN_MARK_USED_VAR
(
param
);
...
@@ -62,8 +64,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
...
@@ -62,8 +64,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle
,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
WorkspaceBundle
bundle_thread
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
{
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
/*matmul_desc*/
)
{
MEGDNN_MARK_USED_VAR
(
bundle
);
MEGDNN_MARK_USED_VAR
(
bundle
);
MEGDNN_MARK_USED_VAR
(
ncb_index
);
MEGDNN_MARK_USED_VAR
(
ncb_index
);
matmul_param
.
workspace_ptr
=
bundle_thread
.
get
(
THREAD_BUNDLE_MATCOMP_INDEX
);
matmul_param
.
workspace_ptr
=
bundle_thread
.
get
(
THREAD_BUNDLE_MATCOMP_INDEX
);
...
@@ -95,7 +100,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
...
@@ -95,7 +100,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
{
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
{
MEGDNN_MARK_USED_VAR
(
matmul_param
);
MEGDNN_MARK_USED_VAR
(
matmul_param
);
MEGDNN_MARK_USED_VAR
(
matmul_algo
);
MEGDNN_MARK_USED_VAR
(
matmul_algo
);
size_t
sh
=
param
.
filter_meta
.
stride
[
0
];
size_t
sh
=
param
.
filter_meta
.
stride
[
0
];
...
...
dnn/src/fallback/conv_bias/im2col/strategy_onlypacka.cpp
浏览文件 @
8e9fa80c
...
@@ -22,8 +22,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
...
@@ -22,8 +22,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
packA_kern
(
WorkspaceBundle
bundle
,
packA_kern
(
WorkspaceBundle
bundle
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
fallback
::
MatrixMulImpl
::
KernSizeParam
matmulparam
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
/*matmul_desc*/
,
size_t
)
{
size_t
)
{
bundle
.
set
(
param
.
workspace_ptr
);
bundle
.
set
(
param
.
workspace_ptr
);
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
;
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
;
...
@@ -57,8 +59,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
...
@@ -57,8 +59,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle
,
const
StrategyParam
&
sparam
,
WorkspaceBundle
bundle
,
WorkspaceBundle
bundle_thread
,
WorkspaceBundle
bundle_thread
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
)
{
const
fallback
::
ConvBiasImpl
::
NCBKernIndex
&
ncb_index
,
const
fallback
::
MatrixMulImpl
::
AlgoBase
::
MatmulDescription
&
/*matmul_desc*/
)
{
size_t
packA_group_size
=
size_t
packA_group_size
=
bundle
.
get_size
(
BUNDLE_PACKA_INDEX
)
/
param
.
filter_meta
.
group
;
bundle
.
get_size
(
BUNDLE_PACKA_INDEX
)
/
param
.
filter_meta
.
group
;
size_t
a_panel_offset
=
ncb_index
.
ndrange_id
[
3
]
*
size_t
a_panel_offset
=
ncb_index
.
ndrange_id
[
3
]
*
...
@@ -95,7 +100,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
...
@@ -95,7 +100,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
const
StrategyParam
&
sparam
,
const
StrategyParam
&
sparam
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
const
fallback
::
ConvBiasImpl
::
NCBKernParam
&
param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
KernParam
matmul_param
,
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
{
const
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
)
{
MEGDNN_MARK_USED_VAR
(
matmul_param
);
MEGDNN_MARK_USED_VAR
(
matmul_param
);
MEGDNN_MARK_USED_VAR
(
matmul_algo
);
MEGDNN_MARK_USED_VAR
(
matmul_algo
);
size_t
sh
=
param
.
filter_meta
.
stride
[
0
];
size_t
sh
=
param
.
filter_meta
.
stride
[
0
];
...
...
dnn/src/fallback/matrix_mul/algos.h
浏览文件 @
8e9fa80c
...
@@ -37,6 +37,7 @@ public:
...
@@ -37,6 +37,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
)
};
};
}
// namespace fallback
}
// namespace fallback
...
...
dnn/src/fallback/matrix_mul/gemm_common.h
浏览文件 @
8e9fa80c
...
@@ -352,6 +352,15 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
...
@@ -352,6 +352,15 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
DType dtype_c) \
DType dtype_c) \
: A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {}
: A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {}
#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size) \
MatmulDescription matmul_description() const override { \
MatmulDescription mdesc; \
mdesc.packmode = packmode(); \
mdesc.innerblocksize = {_m, _n, _k}; \
mdesc.packa_type_size = _packa_type_size; \
return mdesc; \
}
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \
WorkspaceBundle get_bundle(const KernSizeParam&) const override; \
WorkspaceBundle get_bundle(const KernSizeParam&) const override; \
kern_naked_t get_kern_naked(const KernSizeParam&) const override; \
kern_naked_t get_kern_naked(const KernSizeParam&) const override; \
...
@@ -360,7 +369,7 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
...
@@ -360,7 +369,7 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
void pack_B(const KernParam& kern_param, void* out, size_t x0, \
void pack_B(const KernParam& kern_param, void* out, size_t x0, \
size_t xmax) const override; \
size_t xmax) const override; \
InnerBlockSize get_inner_block_size() const override; \
InnerBlockSize get_inner_block_size() const override; \
size_t get_packA_type_size
() const override;
MatmulDescription matmul_description
() const override;
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \
...
@@ -458,8 +467,14 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
...
@@ -458,8 +467,14 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
_strategy::UNROLL_K}; \
_strategy::UNROLL_K}; \
} \
} \
\
\
size_t MatrixMulImpl::_algo_name::get_packA_type_size() const { \
MatrixMulImpl::_algo_name::MatmulDescription \
return sizeof(_packa_type); \
MatrixMulImpl::_algo_name::matmul_description() const { \
MatmulDescription mdesc; \
mdesc.packmode = PackMode(); \
mdesc.innerblocksize = {_strategy::KERNEL_H, _strategy::KERNEL_W, \
_strategy::UNROLL_K}; \
mdesc.packa_type_size = sizeof(_packa_type); \
return mdesc; \
}
}
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \
...
...
dnn/src/fallback/matrix_mul/opr_impl.h
浏览文件 @
8e9fa80c
...
@@ -104,6 +104,12 @@ public:
...
@@ -104,6 +104,12 @@ public:
size_t
m
,
n
,
k
;
size_t
m
,
n
,
k
;
};
};
struct
MatmulDescription
{
PackMode
packmode
;
InnerBlockSize
innerblocksize
;
size_t
packa_type_size
;
};
virtual
bool
usable
(
const
KernSizeParam
&
)
const
=
0
;
virtual
bool
usable
(
const
KernSizeParam
&
)
const
=
0
;
virtual
bool
preferred
(
const
KernSizeParam
&
)
const
{
return
true
;
}
virtual
bool
preferred
(
const
KernSizeParam
&
)
const
{
return
true
;
}
virtual
size_t
get_workspace
(
const
KernSizeParam
&
)
const
=
0
;
virtual
size_t
get_workspace
(
const
KernSizeParam
&
)
const
=
0
;
...
@@ -125,11 +131,11 @@ public:
...
@@ -125,11 +131,11 @@ public:
virtual
InnerBlockSize
get_inner_block_size
()
const
{
virtual
InnerBlockSize
get_inner_block_size
()
const
{
megdnn_assert
(
0
);
megdnn_assert
(
0
);
};
};
virtual
size_t
get_packA_type_size
()
const
{
megdnn_assert
(
0
);
};
bool
preferred_reproducible
(
const
KernSizeParam
&
param
,
bool
preferred_reproducible
(
const
KernSizeParam
&
param
,
bool
reproducible
=
true
)
{
bool
reproducible
=
true
)
{
return
(
!
reproducible
||
is_reproducible
())
&&
preferred
(
param
);
return
(
!
reproducible
||
is_reproducible
())
&&
preferred
(
param
);
};
};
virtual
MatmulDescription
matmul_description
()
const
=
0
;
};
};
/**
/**
...
...
dnn/src/x86/matrix_mul/algos.h
浏览文件 @
8e9fa80c
...
@@ -27,6 +27,7 @@ public:
...
@@ -27,6 +27,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
)
};
};
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
...
@@ -46,7 +47,9 @@ public:
...
@@ -46,7 +47,9 @@ public:
megdnn_assert
(
0
);
megdnn_assert
(
0
);
};
};
WorkspaceBundle
get_bundle
(
const
KernSizeParam
&
param
)
const
override
;
WorkspaceBundle
get_bundle
(
const
KernSizeParam
&
param
)
const
override
;
InnerBlockSize
get_inner_block_size
()
const
override
{
return
{
8
,
16
,
1
};
};
InnerBlockSize
get_inner_block_size
()
const
override
{
return
{
8
,
16
,
1
};
};
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
)
};
};
#endif
#endif
...
@@ -124,6 +127,7 @@ public:
...
@@ -124,6 +127,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
8
,
8
,
4
)
};
};
#if MEGDNN_X86_WITH_VNNI
#if MEGDNN_X86_WITH_VNNI
...
@@ -149,6 +153,7 @@ public:
...
@@ -149,6 +153,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
)
};
};
#endif
#endif
}
// namespace x86
}
// namespace x86
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录