Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f7b2bdae
MegEngine
项目概览
MegEngine 天元
/
MegEngine
9 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
f7b2bdae
编写于
11月 07, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dnn): refactor algorithm type interface
GitOrigin-RevId: 843d885f82a42456c8b0f1018290a3b5c04a3f00
上级
d793c87c
变更
53
隐藏空白更改
内联
并排
Showing
53 changed file
with
230 addition
and
277 deletion
+230
-277
dnn/include/megdnn/oprs/base.h
dnn/include/megdnn/oprs/base.h
+3
-2
dnn/src/aarch64/conv_bias/opr_impl.cpp
dnn/src/aarch64/conv_bias/opr_impl.cpp
+1
-1
dnn/src/aarch64/conv_bias/opr_impl.h
dnn/src/aarch64/conv_bias/opr_impl.h
+7
-2
dnn/src/aarch64/matrix_mul/algos.h
dnn/src/aarch64/matrix_mul/algos.h
+0
-21
dnn/src/aarch64/matrix_mul/opr_impl.cpp
dnn/src/aarch64/matrix_mul/opr_impl.cpp
+2
-2
dnn/src/aarch64/matrix_mul/opr_impl.h
dnn/src/aarch64/matrix_mul/opr_impl.h
+8
-2
dnn/src/arm_common/conv_bias/opr_impl.cpp
dnn/src/arm_common/conv_bias/opr_impl.cpp
+16
-11
dnn/src/arm_common/conv_bias/opr_impl.h
dnn/src/arm_common/conv_bias/opr_impl.h
+10
-8
dnn/src/arm_common/convolution/int8x8x32/algos.cpp
dnn/src/arm_common/convolution/int8x8x32/algos.cpp
+10
-6
dnn/src/arm_common/convolution/int8x8x32/algos.h
dnn/src/arm_common/convolution/int8x8x32/algos.h
+18
-15
dnn/src/arm_common/convolution/opr_impl.cpp
dnn/src/arm_common/convolution/opr_impl.cpp
+20
-25
dnn/src/arm_common/convolution/opr_impl.h
dnn/src/arm_common/convolution/opr_impl.h
+12
-9
dnn/src/arm_common/convolution/quint8/algos.cpp
dnn/src/arm_common/convolution/quint8/algos.cpp
+10
-6
dnn/src/arm_common/convolution/quint8/algos.h
dnn/src/arm_common/convolution/quint8/algos.h
+18
-13
dnn/src/arm_common/matrix_mul/algos.h
dnn/src/arm_common/matrix_mul/algos.h
+0
-8
dnn/src/arm_common/matrix_mul/opr_impl.cpp
dnn/src/arm_common/matrix_mul/opr_impl.cpp
+2
-9
dnn/src/arm_common/matrix_mul/opr_impl.h
dnn/src/arm_common/matrix_mul/opr_impl.h
+8
-3
dnn/src/armv7/conv_bias/opr_impl.cpp
dnn/src/armv7/conv_bias/opr_impl.cpp
+1
-1
dnn/src/armv7/conv_bias/opr_impl.h
dnn/src/armv7/conv_bias/opr_impl.h
+7
-2
dnn/src/armv7/matrix_mul/algos.h
dnn/src/armv7/matrix_mul/algos.h
+0
-14
dnn/src/armv7/matrix_mul/opr_impl.cpp
dnn/src/armv7/matrix_mul/opr_impl.cpp
+2
-2
dnn/src/armv7/matrix_mul/opr_impl.h
dnn/src/armv7/matrix_mul/opr_impl.h
+8
-1
dnn/src/cuda/batch_conv_bias/algo.h
dnn/src/cuda/batch_conv_bias/algo.h
+1
-0
dnn/src/cuda/batched_matrix_mul/algo.h
dnn/src/cuda/batched_matrix_mul/algo.h
+1
-0
dnn/src/cuda/conv_bias/algo.h
dnn/src/cuda/conv_bias/algo.h
+1
-0
dnn/src/cuda/convolution/backward_data/algo.h
dnn/src/cuda/convolution/backward_data/algo.h
+1
-0
dnn/src/cuda/convolution/backward_filter/algo.h
dnn/src/cuda/convolution/backward_filter/algo.h
+1
-0
dnn/src/cuda/convolution3d/backward_data/algo.h
dnn/src/cuda/convolution3d/backward_data/algo.h
+1
-0
dnn/src/cuda/convolution3d/backward_filter/algo.h
dnn/src/cuda/convolution3d/backward_filter/algo.h
+3
-2
dnn/src/cuda/convolution3d/forward/algo.h
dnn/src/cuda/convolution3d/forward/algo.h
+3
-2
dnn/src/cuda/deformable_conv/bwd_data/algo.h
dnn/src/cuda/deformable_conv/bwd_data/algo.h
+1
-0
dnn/src/cuda/deformable_conv/bwd_flt/algo.h
dnn/src/cuda/deformable_conv/bwd_flt/algo.h
+1
-0
dnn/src/cuda/deformable_conv/fwd/algo.h
dnn/src/cuda/deformable_conv/fwd/algo.h
+1
-0
dnn/src/cuda/local_share/backward_data/algo.h
dnn/src/cuda/local_share/backward_data/algo.h
+1
-0
dnn/src/cuda/local_share/backward_filter/algo.h
dnn/src/cuda/local_share/backward_filter/algo.h
+1
-0
dnn/src/cuda/local_share/forward/algo.h
dnn/src/cuda/local_share/forward/algo.h
+1
-0
dnn/src/cuda/matrix_mul/algos.h
dnn/src/cuda/matrix_mul/algos.h
+3
-2
dnn/src/fallback/conv_bias/opr_impl.h
dnn/src/fallback/conv_bias/opr_impl.h
+3
-0
dnn/src/fallback/convolution/algos.h
dnn/src/fallback/convolution/algos.h
+0
-4
dnn/src/fallback/convolution/opr_impl.cpp
dnn/src/fallback/convolution/opr_impl.cpp
+2
-10
dnn/src/fallback/convolution/opr_impl.h
dnn/src/fallback/convolution/opr_impl.h
+6
-4
dnn/src/fallback/matrix_mul/opr_impl.h
dnn/src/fallback/matrix_mul/opr_impl.h
+1
-0
dnn/src/rocm/batched_matrix_mul/opr_impl.cpp
dnn/src/rocm/batched_matrix_mul/opr_impl.cpp
+5
-4
dnn/src/rocm/convolution/backward_data/algo.h
dnn/src/rocm/convolution/backward_data/algo.h
+1
-0
dnn/src/rocm/convolution/backward_filter/algo.h
dnn/src/rocm/convolution/backward_filter/algo.h
+1
-0
dnn/src/rocm/convolution/forward/algo.h
dnn/src/rocm/convolution/forward/algo.h
+1
-0
dnn/src/x86/conv_bias/f32/algos.h
dnn/src/x86/conv_bias/f32/algos.h
+0
-7
dnn/src/x86/conv_bias/int8/algos.h
dnn/src/x86/conv_bias/int8/algos.h
+0
-7
dnn/src/x86/conv_bias/opr_impl.cpp
dnn/src/x86/conv_bias/opr_impl.cpp
+10
-50
dnn/src/x86/conv_bias/opr_impl.h
dnn/src/x86/conv_bias/opr_impl.h
+7
-2
dnn/src/x86/matrix_mul/algos.h
dnn/src/x86/matrix_mul/algos.h
+0
-10
dnn/src/x86/matrix_mul/opr_impl.cpp
dnn/src/x86/matrix_mul/opr_impl.cpp
+2
-8
dnn/src/x86/matrix_mul/opr_impl.h
dnn/src/x86/matrix_mul/opr_impl.h
+7
-2
未找到文件。
dnn/include/megdnn/oprs/base.h
浏览文件 @
f7b2bdae
...
...
@@ -11,6 +11,7 @@
#pragma once
#include "megdnn/basic_types.h"
#include "megdnn/handle.h"
#include "megdnn/internal/visibility_prologue.h"
namespace
megdnn
{
...
...
@@ -105,11 +106,11 @@ public:
virtual
bool
is_reproducible
()
const
=
0
;
virtual
const
char
*
name
()
const
=
0
;
//! a pointer to represent class type
virtual
void
*
type
()
const
{
return
nullptr
;
}
Handle
::
HandleType
handle_type
()
const
{
return
m_handle_type
;
}
protected:
~
Algorithm
()
=
default
;
Handle
::
HandleType
m_handle_type
=
Handle
::
HandleType
::
NAIVE
;
};
/*!
...
...
dnn/src/aarch64/conv_bias/opr_impl.cpp
浏览文件 @
f7b2bdae
...
...
@@ -45,7 +45,7 @@ public:
SmallVector
<
AlgoBase
*>
matmul_algos
;
};
SmallVector
<
ConvBiasImpl
::
AlgoBase
*>
ConvBiasImpl
::
algo_pack
()
{
SmallVector
<
fallback
::
ConvBiasImpl
::
AlgoBase
*>
ConvBiasImpl
::
algo_pack
()
{
static
AlgoPack
sl_algo_pack
;
auto
&&
algos
=
arm_common
::
ConvBiasImpl
::
algo_pack
();
algos
.
insert
(
algos
.
begin
(),
sl_algo_pack
.
direct_algos
.
begin
(),
...
...
dnn/src/aarch64/conv_bias/opr_impl.h
浏览文件 @
f7b2bdae
...
...
@@ -18,11 +18,16 @@ namespace aarch64 {
class
ConvBiasImpl
:
public
arm_common
::
ConvBiasImpl
{
public:
using
arm_common
::
ConvBiasImpl
::
ConvBiasImpl
;
class
AlgoBase
:
public
arm_common
::
ConvBiasImpl
::
AlgoBase
{
public:
AlgoBase
()
:
arm_common
::
ConvBiasImpl
::
AlgoBase
()
{
m_handle_type
=
Handle
::
HandleType
::
AARCH64
;
}
};
SmallVector
<
AlgoBase
*>
algo_pack
()
override
;
SmallVector
<
fallback
::
ConvBiasImpl
::
AlgoBase
*>
algo_pack
()
override
;
protected:
const
char
*
get_algorithm_set_name
()
const
override
;
private:
...
...
dnn/src/aarch64/matrix_mul/algos.h
浏览文件 @
f7b2bdae
...
...
@@ -26,7 +26,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -37,7 +36,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -48,7 +46,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -59,7 +56,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
16
,
4
,
4
,
AlgoDataType
::
FLOAT32
,
MK4
)
};
...
...
@@ -75,7 +71,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -86,7 +81,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
8
,
8
,
2
,
AlgoDataType
::
FLOAT16
,
MK8
)
};
...
...
@@ -103,7 +97,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -116,7 +109,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
#else
...
...
@@ -129,7 +121,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
DEFAULT
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
...
...
@@ -143,7 +134,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -156,7 +146,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
#endif
...
...
@@ -169,7 +158,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -182,7 +170,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -196,7 +183,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
DEFAULT
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
...
...
@@ -212,7 +198,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
DEFAULT
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
...
...
@@ -226,7 +211,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
DEFAULT
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
...
...
@@ -240,7 +224,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -251,7 +234,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
8
,
8
,
2
,
AlgoDataType
::
INT16X16X32
,
MK8
)
};
...
...
@@ -266,7 +248,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -278,7 +259,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
{
return
0
;
}
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
,
AlgoDataType
::
QUINT8X8X32
,
DEFAULT
)
...
...
@@ -292,7 +272,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
#endif
...
...
dnn/src/aarch64/matrix_mul/opr_impl.cpp
浏览文件 @
f7b2bdae
...
...
@@ -52,7 +52,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#endif
public:
SmallVector
<
MatrixMulImpl
::
AlgoBase
*>
all_algos
;
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
all_algos
;
AlgoPack
()
{
all_algos
.
emplace_back
(
&
f32_gemv
);
...
...
@@ -89,7 +89,7 @@ public:
}
};
SmallVector
<
MatrixMulImpl
::
AlgoBase
*>
MatrixMulImpl
::
algo_pack
()
{
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
MatrixMulImpl
::
algo_pack
()
{
static
AlgoPack
s_algo_pack
;
auto
&&
algos
=
arm_common
::
MatrixMulImpl
::
algo_pack
();
algos
.
insert
(
algos
.
begin
(),
s_algo_pack
.
all_algos
.
begin
(),
...
...
dnn/src/aarch64/matrix_mul/opr_impl.h
浏览文件 @
f7b2bdae
...
...
@@ -18,8 +18,14 @@ namespace aarch64 {
class
MatrixMulImpl
:
public
arm_common
::
MatrixMulImpl
{
public:
using
arm_common
::
MatrixMulImpl
::
MatrixMulImpl
;
class
AlgoBase
:
public
arm_common
::
MatrixMulImpl
::
AlgoBase
{
public:
AlgoBase
()
:
arm_common
::
MatrixMulImpl
::
AlgoBase
()
{
m_handle_type
=
Handle
::
HandleType
::
AARCH64
;
}
};
SmallVector
<
AlgoBase
*>
algo_pack
()
override
;
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
algo_pack
()
override
;
private:
class
AlgoF32K8x12x1
;
// Aarch64 F32 Kernel 8X12X1
...
...
@@ -57,7 +63,7 @@ private:
#else
class
AlgoQuint8K8x8x8
;
// Aarch64 Quint8 Kernel 8x8x8
#endif
class
AlgoInt8x8x16MK4_K8x8x8
;
// Aarch64 Int4x4x16 Kernel 4x4x16
class
AlgoInt8x8x16MK4_K8x8x8
;
// Aarch64 Int4x4x16 Kernel 4x4x16
class
AlgoPack
;
};
...
...
dnn/src/arm_common/conv_bias/opr_impl.cpp
浏览文件 @
f7b2bdae
...
...
@@ -11,6 +11,7 @@
*/
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/base.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8x8x16/algos.h"
#include "src/arm_common/conv_bias/quint8/algos.h"
...
...
@@ -18,6 +19,7 @@
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/naive/handle.h"
#include "src/arm_common/convolution/opr_impl.h"
...
...
@@ -37,7 +39,12 @@ using namespace megdnn;
using
namespace
arm_common
;
namespace
{
uint8_t
arm_common_algo_type_storage
;
bool
is_fallback_or_naive
(
const
detail
::
Algorithm
*
algo
)
{
return
algo
->
handle_type
()
==
Handle
::
HandleType
::
NAIVE
||
algo
->
handle_type
()
==
Handle
::
HandleType
::
FALLBACK
;
}
}
// anonymous namespace
class
ConvBiasImpl
::
AlgoPack
:
NonCopyableObj
{
...
...
@@ -50,7 +57,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8DirectStride1
s8_direct_stride1
;
AlgoS8ChanWiseStride1NCHW44
s8_channel_wise_stride1_nchw44
;
AlgoS8ChanWiseStride2NCHW44
s8_channel_wise_stride2_nchw44
;
AlgoS8x8x16ChanWiseStride1Stride2NCHW44
s8x8x16_channel_wise_stride1_stride2_nchw44
;
AlgoS8x8x16ChanWiseStride1Stride2NCHW44
s8x8x16_channel_wise_stride1_stride2_nchw44
;
#if __ARM_FEATURE_DOTPROD
AlgoDotS8DirectStride1
ds8_direct_stride1
;
...
...
@@ -129,7 +137,7 @@ public:
->
select_algo_type
(
{
AlgoDataType
::
FLOAT32
,
MatmulFormat
::
MK4
});
for
(
auto
&&
algo
:
matmul_algos
)
{
if
(
algo
->
type
()
==
nullptr
)
if
(
is_fallback_or_naive
(
algo
)
)
continue
;
for
(
uint32_t
tile_size
:
{
16
,
8
,
24
,
32
})
{
refhold
.
emplace_back
(
new
AlgoFP32WinogradF23_4x4
(
...
...
@@ -166,7 +174,7 @@ public:
->
select_algo_type
({
AlgoDataType
::
FLOAT32
,
MatmulFormat
::
DEFAULT
});
for
(
auto
&&
algo
:
matmul_algos
)
{
if
(
algo
->
type
()
==
nullptr
)
if
(
is_fallback_or_naive
(
algo
)
)
continue
;
for
(
uint32_t
tile_size
:
{
16
,
8
,
24
,
32
})
{
refhold
.
emplace_back
(
new
AlgoFP32WinogradF63
(
...
...
@@ -189,7 +197,7 @@ public:
->
select_algo_type
({
AlgoDataType
::
FLOAT16
,
MatmulFormat
::
DEFAULT
});
for
(
auto
&&
algo
:
matmul_algos
)
{
if
(
algo
->
type
()
==
nullptr
)
if
(
is_fallback_or_naive
(
algo
)
)
continue
;
for
(
uint32_t
tile_size
:
{
16
,
8
,
24
,
32
})
{
refhold
.
emplace_back
(
new
AlgoFP16WinogradF23
(
...
...
@@ -210,7 +218,7 @@ public:
->
select_algo_type
({
AlgoDataType
::
FLOAT16
,
MatmulFormat
::
MK8
});
for
(
auto
&&
algo
:
matmul_algos
)
{
if
(
algo
->
type
()
==
nullptr
)
if
(
is_fallback_or_naive
(
algo
)
)
continue
;
for
(
uint32_t
tile_size
:
{
16
,
8
,
24
,
32
})
{
refhold
.
emplace_back
(
new
AlgoFP16WinogradF23_8x8
(
...
...
@@ -224,7 +232,7 @@ public:
->
select_algo_type
({
AlgoDataType
::
INT16X16X32
,
MatmulFormat
::
MK8
});
for
(
auto
&&
algo
:
matmul_algos
)
{
if
(
algo
->
type
()
==
nullptr
)
if
(
is_fallback_or_naive
(
algo
)
)
continue
;
for
(
uint32_t
tile_size
:
{
16
,
8
,
24
,
32
})
{
refhold
.
emplace_back
(
new
AlgoS8WinogradF23_8x8
(
...
...
@@ -242,7 +250,7 @@ public:
SmallVector
<
AlgoBase
*>
winograd_algos
;
};
SmallVector
<
ConvBiasImpl
::
AlgoBase
*>
ConvBiasImpl
::
algo_pack
()
{
SmallVector
<
fallback
::
ConvBiasImpl
::
AlgoBase
*>
ConvBiasImpl
::
algo_pack
()
{
static
AlgoPack
sl_algo_pack
;
auto
&&
algos
=
fallback
::
ConvBiasImpl
::
algo_pack
();
algos
.
insert
(
algos
.
begin
(),
sl_algo_pack
.
direct_algos
.
begin
(),
...
...
@@ -252,9 +260,6 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
return
std
::
move
(
algos
);
}
void
*
const
ConvBiasImpl
::
sm_arm_common_algo_type
=
&
arm_common_algo_type_storage
;
bool
ConvBiasImpl
::
is_matmul_quantized_prefer
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
)
const
{
fallback
::
ConvBiasImpl
::
NCBKernSizeParam
conv_ncb_param
(
...
...
dnn/src/arm_common/conv_bias/opr_impl.h
浏览文件 @
f7b2bdae
...
...
@@ -19,23 +19,25 @@ namespace arm_common {
class
ConvBiasImpl
:
public
fallback
::
ConvBiasImpl
{
public:
using
fallback
::
ConvBiasImpl
::
ConvBiasImpl
;
using
FallbackConvBiasImpl
=
fallback
::
ConvBiasImpl
;
using
NCBKernIndex
=
fallback
::
ConvBiasImpl
::
NCBKernIndex
;
bool
is_thread_safe
()
const
override
{
return
true
;
}
class
AlgoBase
:
public
fallback
::
ConvBiasImpl
::
AlgoBase
{
public:
AlgoBase
()
:
fallback
::
ConvBiasImpl
::
AlgoBase
()
{
m_handle_type
=
Handle
::
HandleType
::
ARM_COMMON
;
}
};
SmallVector
<
AlgoBase
*>
algo_pack
()
override
;
SmallVector
<
fallback
::
ConvBiasImpl
::
AlgoBase
*>
algo_pack
()
override
;
bool
is_matmul_quantized_prefer
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
ncb_param
)
const
override
;
const
fallback
::
ConvBiasImpl
::
NCBKernSizeParam
&
ncb_param
)
const
override
;
SmallVector
<
AlgoCategory
>
suggest_algo_category_order
(
const
NCBKernSizeParam
&
param
)
const
override
;
class
AlgoPack
;
protected:
static
void
*
const
sm_arm_common_algo_type
;
const
char
*
get_algorithm_set_name
()
const
override
;
private:
...
...
@@ -93,7 +95,7 @@ private:
class
AlgoF16Direct
;
class
AlgoF16DirectStride1
;
#endif
};
};
}
// namespace arm_common
}
// namespace megdnn
...
...
dnn/src/arm_common/convolution/int8x8x32/algos.cpp
浏览文件 @
f7b2bdae
...
...
@@ -26,12 +26,14 @@ using namespace arm_common;
/* ===================== ConvolutionBackwardData ===================== */
/* ===================== direct stride 1 algo ===================== */
bool
ConvolutionBackwardDataImpl
::
AlgoSdot8DirectStride1
::
usable
(
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
return
deconv
::
can_stride1_int8x8x32_dot
(
param
);
}
size_t
ConvolutionBackwardDataImpl
::
AlgoSdot8DirectStride1
::
get_workspace
(
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
MIDOUT_BEGIN
(
megdnn_arm_conv_int8832_kimpl
,
midout_iv
(
"AlgoSdot8DirectStride1::get_workspace"
_hash
))
{
return
deconv
::
get_workspace_in_bytes_stride1_int8x8x32_dot
(
param
);
...
...
@@ -42,7 +44,7 @@ size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace(
ConvolutionBackwardDataImpl
::
ncb_kern_t
ConvolutionBackwardDataImpl
::
AlgoSdot8DirectStride1
::
dispatch_kern
(
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
)
const
{
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
)
const
{
MIDOUT_BEGIN
(
megdnn_arm_conv_int8832_kimpl
,
midout_iv
(
"AlgoSdot8DirectStride1::dispatch_kern"
_hash
))
{
return
deconv
::
stride1_int8x8x32_dot
;
...
...
@@ -53,12 +55,14 @@ ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern(
/* ===================== direct stride 2 algo ===================== */
bool
ConvolutionBackwardDataImpl
::
AlgoSdot8DirectStride2
::
usable
(
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
return
deconv
::
can_stride2_int8x8x32_dot
(
param
);
}
size_t
ConvolutionBackwardDataImpl
::
AlgoSdot8DirectStride2
::
get_workspace
(
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
MIDOUT_BEGIN
(
megdnn_arm_conv_int8832_kimpl
,
midout_iv
(
"AlgoSdot8DirectStride2::get_workspace"
_hash
))
{
return
deconv
::
get_workspace_in_bytes_stride2_int8x8x32_dot
(
param
);
...
...
@@ -69,7 +73,7 @@ size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace(
ConvolutionBackwardDataImpl
::
ncb_kern_t
ConvolutionBackwardDataImpl
::
AlgoSdot8DirectStride2
::
dispatch_kern
(
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
)
const
{
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
)
const
{
MIDOUT_BEGIN
(
megdnn_arm_conv_int8832_kimpl
,
midout_iv
(
"AlgoSdot8DirectStride2::dispatch_kern"
_hash
))
{
return
deconv
::
stride2_int8x8x32_dot
;
...
...
dnn/src/arm_common/convolution/int8x8x32/algos.h
浏览文件 @
f7b2bdae
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
...
...
@@ -19,38 +20,40 @@ namespace arm_common {
#if __ARM_FEATURE_DOTPROD
/* ===================== ConvolutionBackwardData ===================== */
class
ConvolutionBackwardDataImpl
::
AlgoSdot8DirectStride1
final
:
public
AlgoBase
{
class
ConvolutionBackwardDataImpl
::
AlgoSdot8DirectStride1
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"AARCH32_I8x8x32_DECONV_STRIDE1"
;
}
const
char
*
name
()
const
override
{
return
"AARCH32_I8x8x32_DECONV_STRIDE1"
;
}
bool
usable
(
ConvolutionBackwardDataImpl
*
,
bool
usable
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
override
;
size_t
get_workspace
(
ConvolutionBackwardDataImpl
*
,
size_t
get_workspace
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
override
;
ncb_kern_t
dispatch_kern
(
ConvolutionBackwardDataImpl
*
,
ncb_kern_t
dispatch_kern
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
};
class
ConvolutionBackwardDataImpl
::
AlgoSdot8DirectStride2
final
:
public
AlgoBase
{
class
ConvolutionBackwardDataImpl
::
AlgoSdot8DirectStride2
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"AARCH32_I8x8x32_DECONV_STRIDE2"
;
}
const
char
*
name
()
const
override
{
return
"AARCH32_I8x8x32_DECONV_STRIDE2"
;
}
bool
usable
(
ConvolutionBackwardDataImpl
*
,
bool
usable
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
override
;
size_t
get_workspace
(
ConvolutionBackwardDataImpl
*
,
size_t
get_workspace
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
override
;
ncb_kern_t
dispatch_kern
(
ConvolutionBackwardDataImpl
*
,
ncb_kern_t
dispatch_kern
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
};
#endif
...
...
dnn/src/arm_common/convolution/opr_impl.cpp
浏览文件 @
f7b2bdae
...
...
@@ -21,9 +21,6 @@
using
namespace
megdnn
;
using
namespace
arm_common
;
namespace
{
uint8_t
arm_common_algo_type_storage
;
}
// anonymous namespace
/* ===================== ConvolutionBackwardData ===================== */
struct
ConvolutionBackwardDataImpl
::
AlgoPack
{
...
...
@@ -36,46 +33,44 @@ struct ConvolutionBackwardDataImpl::AlgoPack {
};
ConvolutionBackwardDataImpl
::
AlgoPack
ConvolutionBackwardDataImpl
::
sm_algo_pack
;
void
*
const
ConvolutionBackwardDataImpl
::
sm_arm_common_algo_type
=
&
arm_common_algo_type_storage
;
ConvolutionBackwardDataImpl
::
ncb_kern_t
ConvolutionBackwardDataImpl
::
ncb_1g_dispatch_kern
(
ConvolutionBackwardDataImpl
::
ncb_kern_t
ConvolutionBackwardDataImpl
::
ncb_1g_dispatch_kern
(
Algorithm
*
algo
,
const
NCBKernSizeParam
&
param
)
{
if
(
algo
->
type
()
==
sm_arm_common_algo_type
)
{
if
(
algo
->
handle_type
()
==
Handle
::
HandleType
::
ARM_COMMON
)
{
return
static_cast
<
AlgoBase
*>
(
algo
)
->
dispatch_kern
(
this
,
param
);
}
return
fallback
::
ConvolutionBackwardDataImpl
::
ncb_1g_dispatch_kern
(
algo
,
param
);
return
fallback
::
ConvolutionBackwardDataImpl
::
ncb_1g_dispatch_kern
(
algo
,
param
);
}
size_t
ConvolutionBackwardDataImpl
::
ncb_1g_get_workspace
(
Algorithm
*
algo
,
const
NCBKernSizeParam
&
param
)
{
if
(
algo
->
type
()
==
sm_arm_common_algo_type
)
{
size_t
ConvolutionBackwardDataImpl
::
ncb_1g_get_workspace
(
Algorithm
*
algo
,
const
NCBKernSizeParam
&
param
)
{
if
(
algo
->
handle_type
()
==
Handle
::
HandleType
::
ARM_COMMON
)
{
return
static_cast
<
AlgoBase
*>
(
algo
)
->
get_workspace
(
this
,
param
);
}
return
fallback
::
ConvolutionBackwardDataImpl
::
ncb_1g_get_workspace
(
algo
,
param
);
return
fallback
::
ConvolutionBackwardDataImpl
::
ncb_1g_get_workspace
(
algo
,
param
);
}
std
::
vector
<
ConvolutionBackwardDataImpl
::
Algorithm
*>
ConvolutionBackwardDataImpl
::
ncb_1g_get_all_algorithms
(
const
NCBKernSizeParam
&
param
)
{
auto
ret
=
fallback
::
ConvolutionBackwardDataImpl
::
ncb_1g_get_all_algorithms
(
param
);
ConvolutionBackwardDataImpl
::
ncb_1g_get_all_algorithms
(
const
NCBKernSizeParam
&
param
)
{
auto
ret
=
fallback
::
ConvolutionBackwardDataImpl
::
ncb_1g_get_all_algorithms
(
param
);
#if __ARM_FEATURE_DOTPROD
if
((
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
Int8
)
&&
(
param
.
grad_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
||
param
.
grad_type
.
enumv
()
==
DTypeEnum
::
Int32
))
{
if
((
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
Int8
)
&&
(
param
.
grad_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
||
param
.
grad_type
.
enumv
()
==
DTypeEnum
::
Int32
))
{
if
(
sm_algo_pack
.
i8x8x32_direct_stride1_sdot
.
usable
(
this
,
param
))
{
ret
.
insert
(
ret
.
begin
(),
&
sm_algo_pack
.
i8x8x32_direct_stride1_sdot
);
}
if
(
sm_algo_pack
.
i8x8x32_direct_stride2_sdot
.
usable
(
this
,
param
))
{
ret
.
insert
(
ret
.
begin
(),
&
sm_algo_pack
.
i8x8x32_direct_stride2_sdot
);
}
}
else
if
(
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
&&
param
.
grad_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)
{
}
else
if
(
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
&&
param
.
grad_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)
{
if
(
sm_algo_pack
.
quint8_direct_stride1_udot
.
usable
(
this
,
param
))
{
ret
.
insert
(
ret
.
begin
(),
&
sm_algo_pack
.
quint8_direct_stride1_udot
);
}
...
...
dnn/src/arm_common/convolution/opr_impl.h
浏览文件 @
f7b2bdae
...
...
@@ -18,24 +18,27 @@ namespace arm_common {
class
ConvBiasImpl
;
class
ConvolutionBackwardDataImpl
:
public
fallback
::
ConvolutionBackwardDataImpl
{
class
ConvolutionBackwardDataImpl
:
public
fallback
::
ConvolutionBackwardDataImpl
{
public:
using
fallback
::
ConvolutionBackwardDataImpl
::
ConvolutionBackwardDataImpl
;
protected:
static
void
*
const
sm_arm_common_algo_type
;
class
AlgoBase
:
public
Algorithm
{
class
AlgoBase
:
public
fallback
::
ConvolutionBackwardDataImpl
::
AlgoBase
{
protected:
~
AlgoBase
()
=
default
;
public:
virtual
bool
usable
(
ConvolutionBackwardDataImpl
*
opr
,
AlgoBase
()
:
fallback
::
ConvolutionBackwardDataImpl
::
AlgoBase
()
{
m_handle_type
=
Handle
::
HandleType
::
ARM_COMMON
;
}
virtual
bool
usable
(
fallback
::
ConvolutionBackwardDataImpl
*
opr
,
const
NCBKernSizeParam
&
param
)
const
=
0
;
virtual
size_t
get_workspace
(
ConvolutionBackwardDataImpl
*
opr
,
virtual
size_t
get_workspace
(
fallback
::
ConvolutionBackwardDataImpl
*
opr
,
const
NCBKernSizeParam
&
param
)
const
=
0
;
virtual
ncb_kern_t
dispatch_kern
(
ConvolutionBackwardDataImpl
*
opr
,
const
NCBKernSizeParam
&
param
)
const
=
0
;
fallback
::
ConvolutionBackwardDataImpl
*
opr
,
const
NCBKernSizeParam
&
param
)
const
=
0
;
};
ncb_kern_t
ncb_1g_dispatch_kern
(
Algorithm
*
algo
,
...
...
@@ -49,7 +52,7 @@ protected:
const
char
*
get_algorithm_set_name
()
const
override
;
private:
private:
#if __ARM_FEATURE_DOTPROD
class
AlgoSdot8DirectStride1
;
class
AlgoSdot8DirectStride2
;
...
...
@@ -62,4 +65,4 @@ protected:
}
// namespace arm_common
}
// namespace megdnn
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/arm_common/convolution/quint8/algos.cpp
浏览文件 @
f7b2bdae
...
...
@@ -27,12 +27,14 @@ using namespace arm_common;
/* ===================== direct stride 1 algo ===================== */
bool
ConvolutionBackwardDataImpl
::
AlgoUdot8DirectStride1
::
usable
(
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
return
deconv
::
can_stride1_quint8_dot
(
param
);
}
size_t
ConvolutionBackwardDataImpl
::
AlgoUdot8DirectStride1
::
get_workspace
(
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
MIDOUT_BEGIN
(
megdnn_arm_conv_quint8_kimpl
,
midout_iv
(
"AlgoUdot8DirectStride1::get_workspace"
_hash
))
{
return
deconv
::
get_workspace_in_bytes_stride1_quint8_dot
(
param
);
...
...
@@ -43,7 +45,7 @@ size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace(
ConvolutionBackwardDataImpl
::
ncb_kern_t
ConvolutionBackwardDataImpl
::
AlgoUdot8DirectStride1
::
dispatch_kern
(
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
)
const
{
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
)
const
{
MIDOUT_BEGIN
(
megdnn_arm_conv_quint8_kimpl
,
midout_iv
(
"AlgoUdot8DirectStride1::dispatch_kern"
_hash
))
{
return
deconv
::
stride1_quint8_dot
;
...
...
@@ -54,12 +56,14 @@ ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern(
/* ===================== direct stride 2 algo ===================== */
bool
ConvolutionBackwardDataImpl
::
AlgoUdot8DirectStride2
::
usable
(
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
return
deconv
::
can_stride2_quint8_dot
(
param
);
}
size_t
ConvolutionBackwardDataImpl
::
AlgoUdot8DirectStride2
::
get_workspace
(
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
{
MIDOUT_BEGIN
(
megdnn_arm_conv_quint8_kimpl
,
midout_iv
(
"AlgoUdot8DirectStride2::get_workspace"
_hash
))
{
return
deconv
::
get_workspace_in_bytes_stride2_quint8_dot
(
param
);
...
...
@@ -70,7 +74,7 @@ size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace(
ConvolutionBackwardDataImpl
::
ncb_kern_t
ConvolutionBackwardDataImpl
::
AlgoUdot8DirectStride2
::
dispatch_kern
(
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
)
const
{
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
)
const
{
MIDOUT_BEGIN
(
megdnn_arm_conv_quint8_kimpl
,
midout_iv
(
"AlgoUdot8DirectStride2::dispatch_kern"
_hash
))
{
return
deconv
::
stride2_quint8_dot
;
...
...
dnn/src/arm_common/convolution/quint8/algos.h
浏览文件 @
f7b2bdae
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
...
...
@@ -18,38 +19,42 @@ namespace arm_common {
#if __ARM_FEATURE_DOTPROD
/* ===================== ConvolutionBackwardData ===================== */
class
ConvolutionBackwardDataImpl
::
AlgoUdot8DirectStride1
final
:
public
AlgoBase
{
class
ConvolutionBackwardDataImpl
::
AlgoUdot8DirectStride1
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE1"
;
}
const
char
*
name
()
const
override
{
return
"ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE1"
;
}
bool
usable
(
ConvolutionBackwardDataImpl
*
,
bool
usable
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
override
;
size_t
get_workspace
(
ConvolutionBackwardDataImpl
*
,
size_t
get_workspace
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
override
;
ncb_kern_t
dispatch_kern
(
ConvolutionBackwardDataImpl
*
,
ncb_kern_t
dispatch_kern
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
};
class
ConvolutionBackwardDataImpl
::
AlgoUdot8DirectStride2
final
:
public
AlgoBase
{
class
ConvolutionBackwardDataImpl
::
AlgoUdot8DirectStride2
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE2"
;
}
const
char
*
name
()
const
override
{
return
"ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE2"
;
}
bool
usable
(
ConvolutionBackwardDataImpl
*
,
bool
usable
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
override
;
size_t
get_workspace
(
ConvolutionBackwardDataImpl
*
,
size_t
get_workspace
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
param
)
const
override
;
ncb_kern_t
dispatch_kern
(
ConvolutionBackwardDataImpl
*
,
ncb_kern_t
dispatch_kern
(
fallback
::
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
};
#endif
}
// namespace arm_common
...
...
dnn/src/arm_common/matrix_mul/algos.h
浏览文件 @
f7b2bdae
...
...
@@ -24,7 +24,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
,
AlgoDataType
::
INT8X8X16
,
DEFAULT
)
};
...
...
@@ -37,7 +36,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
{
return
0
;
}
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
,
AlgoDataType
::
QINT8X8X32
,
DEFAULT
)
...
...
@@ -51,7 +49,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
{
return
0
;
}
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
,
AlgoDataType
::
QINT8X8X32
,
MK4
)
...
...
@@ -66,7 +63,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
{
return
0
;
}
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
,
AlgoDataType
::
QINT8X8X32
,
MK4_DOT
)
...
...
@@ -84,7 +80,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
{
return
0
;
}
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
,
AlgoDataType
::
FLOAT32
,
DEFAULT
)
...
...
@@ -98,7 +93,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
{
return
0
;
}
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
1
,
1
,
4
,
AlgoDataType
::
FLOAT32
,
MK4
)
...
...
@@ -113,7 +107,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
{
return
0
;
}
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
,
AlgoDataType
::
FLOAT16
,
DEFAULT
)
...
...
@@ -128,7 +121,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
{
return
0
;
}
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
...
...
dnn/src/arm_common/matrix_mul/opr_impl.cpp
浏览文件 @
f7b2bdae
...
...
@@ -15,13 +15,6 @@
using
namespace
megdnn
;
using
namespace
arm_common
;
namespace
{
uint8_t
arm_common_algo_type_storage
;
}
// anonymous namespace
void
*
const
MatrixMulImpl
::
sm_arm_common_algo_type
=
&
arm_common_algo_type_storage
;
class
MatrixMulImpl
::
AlgoPack
:
NonCopyableObj
{
AlgoInt8x8x16
int8x8x16
;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
...
...
@@ -49,10 +42,10 @@ public:
all_algos
.
emplace_back
(
&
f32_gemv_mk4
);
all_algos
.
emplace_back
(
&
gevm
);
}
SmallVector
<
AlgoBase
*>
all_algos
;
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
all_algos
;
};
SmallVector
<
MatrixMulImpl
::
AlgoBase
*>
MatrixMulImpl
::
algo_pack
()
{
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
MatrixMulImpl
::
algo_pack
()
{
static
AlgoPack
s_algo_pack
;
auto
&&
algos
=
fallback
::
MatrixMulImpl
::
algo_pack
();
algos
.
insert
(
algos
.
begin
(),
s_algo_pack
.
all_algos
.
begin
(),
...
...
dnn/src/arm_common/matrix_mul/opr_impl.h
浏览文件 @
f7b2bdae
...
...
@@ -18,13 +18,18 @@ namespace arm_common {
class
MatrixMulImpl
:
public
fallback
::
MatrixMulImpl
{
public:
using
fallback
::
MatrixMulImpl
::
MatrixMulImpl
;
bool
is_thread_safe
()
const
override
{
return
true
;
}
SmallVector
<
AlgoBase
*>
algo_pack
()
override
;
class
AlgoBase
:
public
fallback
::
MatrixMulImpl
::
AlgoBase
{
public:
AlgoBase
()
:
fallback
::
MatrixMulImpl
::
AlgoBase
()
{
m_handle_type
=
Handle
::
HandleType
::
ARM_COMMON
;
}
};
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
algo_pack
()
override
;
protected:
static
void
*
const
sm_arm_common_algo_type
;
class
AlgoF32Gemv
;
// Arm_common F32 Gemv
class
AlgoF32GemvMK4
;
// Arm_common F32 Gemv NCHW44
class
AlgoInt8x8x32Gemv
;
// Arm_common Int8x8x32 Gemv
...
...
dnn/src/armv7/conv_bias/opr_impl.cpp
浏览文件 @
f7b2bdae
...
...
@@ -32,7 +32,7 @@ public:
SmallVector
<
AlgoBase
*>
all_algos
;
};
SmallVector
<
ConvBiasImpl
::
AlgoBase
*>
ConvBiasImpl
::
algo_pack
()
{
SmallVector
<
fallback
::
ConvBiasImpl
::
AlgoBase
*>
ConvBiasImpl
::
algo_pack
()
{
static
AlgoPack
sl_algo_pack
;
auto
&&
algos
=
arm_common
::
ConvBiasImpl
::
algo_pack
();
//! TODO fused matmul bias is slower than matmul + elemwise in armv7 now,
...
...
dnn/src/armv7/conv_bias/opr_impl.h
浏览文件 @
f7b2bdae
...
...
@@ -18,11 +18,16 @@ namespace armv7 {
class
ConvBiasImpl
:
public
arm_common
::
ConvBiasImpl
{
public:
using
arm_common
::
ConvBiasImpl
::
ConvBiasImpl
;
class
AlgoBase
:
public
arm_common
::
ConvBiasImpl
::
AlgoBase
{
public:
AlgoBase
()
:
arm_common
::
ConvBiasImpl
::
AlgoBase
()
{
m_handle_type
=
Handle
::
HandleType
::
ARMV7
;
}
};
SmallVector
<
AlgoBase
*>
algo_pack
()
override
;
SmallVector
<
fallback
::
ConvBiasImpl
::
AlgoBase
*>
algo_pack
()
override
;
protected:
const
char
*
get_algorithm_set_name
()
const
override
;
private:
...
...
dnn/src/armv7/matrix_mul/algos.h
浏览文件 @
f7b2bdae
...
...
@@ -26,7 +26,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -37,7 +36,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -48,7 +46,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
8
,
4
,
4
,
AlgoDataType
::
FLOAT32
,
MK4
)
};
...
...
@@ -61,7 +58,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
class
MatrixMulImpl
::
AlgoF16MK8_4x8
final
:
public
AlgoBase
{
...
...
@@ -71,7 +67,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
8
,
8
,
2
,
AlgoDataType
::
FLOAT16
,
MK8
)
};
...
...
@@ -121,7 +116,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -133,7 +127,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -144,7 +137,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -156,7 +148,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -168,7 +159,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -180,7 +170,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -192,7 +181,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -203,7 +191,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
8
,
8
,
2
,
AlgoDataType
::
INT16X16X32
,
MK8
)
};
...
...
@@ -216,7 +203,6 @@ public:
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
dnn/src/armv7/matrix_mul/opr_impl.cpp
浏览文件 @
f7b2bdae
...
...
@@ -44,7 +44,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt16x16x32MK8_4x8
int16x16x32_mk8_4x8
;
public:
SmallVector
<
MatrixMulImpl
::
AlgoBase
*>
all_algos
;
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
all_algos
;
AlgoPack
()
{
all_algos
.
emplace_back
(
&
f32_gemv
);
...
...
@@ -73,7 +73,7 @@ public:
}
};
SmallVector
<
MatrixMulImpl
::
AlgoBase
*>
MatrixMulImpl
::
algo_pack
()
{
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
MatrixMulImpl
::
algo_pack
()
{
static
AlgoPack
s_algo_pack
;
auto
algos
=
arm_common
::
MatrixMulImpl
::
algo_pack
();
algos
.
insert
(
algos
.
begin
(),
s_algo_pack
.
all_algos
.
begin
(),
...
...
dnn/src/armv7/matrix_mul/opr_impl.h
浏览文件 @
f7b2bdae
...
...
@@ -18,7 +18,14 @@ namespace armv7 {
class
MatrixMulImpl
:
public
arm_common
::
MatrixMulImpl
{
public:
using
arm_common
::
MatrixMulImpl
::
MatrixMulImpl
;
SmallVector
<
AlgoBase
*>
algo_pack
()
override
;
class
AlgoBase
:
public
arm_common
::
MatrixMulImpl
::
AlgoBase
{
public:
AlgoBase
()
:
arm_common
::
MatrixMulImpl
::
AlgoBase
()
{
m_handle_type
=
Handle
::
HandleType
::
ARMV7
;
}
};
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
algo_pack
()
override
;
private:
class
AlgoF32
;
// Armv7 F32
...
...
dnn/src/cuda/batch_conv_bias/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -26,6 +26,7 @@ protected:
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
{
BatchConvBiasForwardImpl
*
opr
;
TensorLayout
src_layout
,
filter_layout
,
bias_layout
,
z_layout
,
...
...
dnn/src/cuda/batched_matrix_mul/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -28,6 +28,7 @@ protected:
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
{
BatchedMatrixMulForwardImpl
*
opr
;
TensorLayout
layout_a
,
layout_b
,
layout_c
;
...
...
dnn/src/cuda/conv_bias/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -38,6 +38,7 @@ protected:
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
:
public
conv_bias
::
BiasForwardSizeArgs
{
ConvBiasForwardImpl
*
opr
;
...
...
dnn/src/cuda/convolution/backward_data/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -28,6 +28,7 @@ class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm {
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
{
HandleImpl
*
handle
;
CanonizedFilterMeta
filter_meta
;
...
...
dnn/src/cuda/convolution/backward_filter/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -28,6 +28,7 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm {
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
{
HandleImpl
*
handle
;
const
TensorLayout
*
src_layout
,
*
diff_layout
,
*
grad_layout
;
...
...
dnn/src/cuda/convolution3d/backward_data/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -28,6 +28,7 @@ class Convolution3DBackwardDataImpl::AlgoBase: public Algorithm {
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
{
HandleImpl
*
handle
;
CanonizedFilterMeta
filter_meta
;
...
...
dnn/src/cuda/convolution3d/backward_filter/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -22,6 +22,7 @@ class Convolution3DBackwardFilterImpl::AlgoBase: public Algorithm {
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
{
HandleImpl
*
handle
;
const
TensorLayout
*
src_layout
,
*
diff_layout
;
...
...
@@ -128,8 +129,8 @@ class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final: public AlgoBase
const
char
*
name
()
const
override
{
return
"INPLACE_MATMUL"
;
}
bool
is_reproducible
()
const
override
{
return
false
;
bool
is_reproducible
()
const
override
{
return
false
;
}
};
...
...
dnn/src/cuda/convolution3d/forward/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -34,6 +34,7 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm {
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
:
public
convolution3d
::
ForwardSizeArgs
{
Convolution3DForwardImpl
*
opr
;
...
...
@@ -42,11 +43,11 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm {
desc
.
set
(
*
src_layout
,
filter_meta
,
*
dst_layout
,
opr
->
param
());
}
SizeArgs
(
Convolution3DForwardImpl
*
opr
,
const
TensorLayout
&
src
,
const
TensorLayout
&
src
,
const
TensorLayout
&
filter
,
const
TensorLayout
&
dst
);
SizeArgs
(
Convolution3DForwardImpl
*
opr
,
const
TensorLayout
&
src
,
const
TensorLayout
&
src
,
const
CanonizedFilterMeta
&
filter
,
const
TensorLayout
&
dst
);
};
...
...
dnn/src/cuda/deformable_conv/bwd_data/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -26,6 +26,7 @@ protected:
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
{
DeformableConvBackwardDataImpl
*
opr
;
HandleImpl
*
handle
;
...
...
dnn/src/cuda/deformable_conv/bwd_flt/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -26,6 +26,7 @@ protected:
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
{
DeformableConvBackwardFilterImpl
*
opr
;
HandleImpl
*
handle
;
...
...
dnn/src/cuda/deformable_conv/fwd/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -24,6 +24,7 @@ protected:
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
{
DeformableConvForwardImpl
*
opr
;
HandleImpl
*
handle
;
...
...
dnn/src/cuda/local_share/backward_data/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -25,6 +25,7 @@ protected:
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
{
LocalShareBackwardDataImpl
*
opr
;
TensorLayout
filter_layout
,
diff_layout
,
grad_layout
;
...
...
dnn/src/cuda/local_share/backward_filter/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -25,6 +25,7 @@ protected:
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
{
LocalShareBackwardFilterImpl
*
opr
;
TensorLayout
src_layout
,
diff_layout
,
grad_layout
;
...
...
dnn/src/cuda/local_share/forward/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -25,6 +25,7 @@ protected:
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
{
LocalShareForwardImpl
*
opr
;
TensorLayout
src_layout
,
filter_layout
,
dst_layout
;
...
...
dnn/src/cuda/matrix_mul/algos.h
浏览文件 @
f7b2bdae
...
...
@@ -32,13 +32,14 @@ protected:
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
CUDA
;
}
struct
SizeArgs
{
MatrixMulForwardImpl
*
opr
;
TensorLayout
layout_a
,
layout_b
,
layout_c
;
std
::
string
to_string
()
const
;
SizeArgs
(
MatrixMulForwardImpl
*
opr
,
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
);
SizeArgs
(
MatrixMulForwardImpl
*
opr
,
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
);
bool
can_be_treated_as_int8x8x32
()
const
{
return
layout_a
.
dtype
.
enumv
()
==
layout_b
.
dtype
.
enumv
()
&&
...
...
dnn/src/fallback/conv_bias/opr_impl.h
浏览文件 @
f7b2bdae
...
...
@@ -213,6 +213,9 @@ public:
class
AlgoBase
:
public
Algorithm
{
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
FALLBACK
;
}
virtual
~
AlgoBase
()
=
default
;
virtual
bool
usable
(
const
NCBKernSizeParam
&
param
,
...
...
dnn/src/fallback/convolution/algos.h
浏览文件 @
f7b2bdae
...
...
@@ -141,8 +141,6 @@ public:
return
get_kimpl
(
m_algorithm
,
param
);
}
void
*
type
()
const
override
{
return
sm_fallback_conv_algo_type
;
}
//! select matmul to the highest preference
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
...
...
@@ -168,7 +166,6 @@ public:
const
NCBKernSizeParam
&
param
)
const
override
;
ncb_kern_t
dispatch_kern
(
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_fallback_deconv_algo_type
;
}
};
class
ConvolutionBackwardDataImpl
::
AlgoMatrixMul
final
:
public
AlgoBase
{
...
...
@@ -181,7 +178,6 @@ public:
const
NCBKernSizeParam
&
param
)
const
override
;
ncb_kern_t
dispatch_kern
(
ConvolutionBackwardDataImpl
*
,
const
NCBKernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_fallback_deconv_algo_type
;
}
};
}
// namespace fallback
...
...
dnn/src/fallback/convolution/opr_impl.cpp
浏览文件 @
f7b2bdae
...
...
@@ -37,8 +37,6 @@ class NaiveConvolutionBackwardData final
const
char
*
name
()
const
override
{
return
"NCBD"
;
}
};
NaiveConvolutionBackwardData
naive_conv_backward_data
;
uint8_t
fallback_deconv_algo_type_storage
;
uint8_t
fallback_conv_algo_type_storage
;
template
<
typename
T
>
void
incr_ptr
(
T
*&
dst
,
ptrdiff_t
delta
)
{
...
...
@@ -69,9 +67,6 @@ public:
SmallVector
<
AlgoBase
*>
all_algos
;
};
void
*
const
ConvolutionImpl
::
sm_fallback_conv_algo_type
=
&
fallback_conv_algo_type_storage
;
SmallVector
<
ConvolutionImpl
::
AlgoBase
*>
ConvolutionImpl
::
algo_pack
()
{
static
AlgoPack
sl_algo_pack
;
return
sl_algo_pack
.
all_algos
;
...
...
@@ -412,9 +407,6 @@ ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const {
/* ===================== ConvolutionBackwardData ===================== */
void
*
const
ConvolutionBackwardDataImpl
::
sm_fallback_deconv_algo_type
=
&
fallback_deconv_algo_type_storage
;
struct
ConvolutionBackwardDataImpl
::
AlgoPack
{
AlgoDirect
direct
;
AlgoMatrixMul
matmul
;
...
...
@@ -630,7 +622,7 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb(
size_t
ConvolutionBackwardDataImpl
::
ncb_1g_get_workspace
(
Algorithm
*
algo
,
const
NCBKernSizeParam
&
param
)
{
megdnn_assert
(
param
.
filter_meta
.
group
==
1
);
if
(
algo
->
type
()
==
sm_fallback_deconv_algo_type
)
{
if
(
algo
->
handle_type
()
==
Handle
::
HandleType
::
FALLBACK
)
{
return
static_cast
<
AlgoBase
*>
(
algo
)
->
get_workspace
(
this
,
param
);
}
megdnn_assert
(
algo
==
&
naive_conv_backward_data
);
...
...
@@ -642,7 +634,7 @@ ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(
Algorithm
*
algo
,
const
NCBKernSizeParam
&
param
)
{
megdnn_assert
(
param
.
filter_meta
.
group
==
1
);
if
(
algo
->
type
()
==
sm_fallback_deconv_algo_type
)
{
if
(
algo
->
handle_type
()
==
Handle
::
HandleType
::
FALLBACK
)
{
return
static_cast
<
AlgoBase
*>
(
algo
)
->
dispatch_kern
(
this
,
param
);
}
...
...
dnn/src/fallback/convolution/opr_impl.h
浏览文件 @
f7b2bdae
...
...
@@ -177,8 +177,6 @@ public:
}
};
static
void
*
const
sm_fallback_conv_algo_type
;
/**
* \brief Kernel run time id, This information is used for getting the
* work data
...
...
@@ -197,6 +195,9 @@ public:
class
AlgoBase
:
public
Algorithm
{
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
FALLBACK
;
}
virtual
~
AlgoBase
()
=
default
;
virtual
bool
usable
(
const
NCBKernSizeParam
&
param
,
AlgoSelectionStrategy
)
const
=
0
;
...
...
@@ -407,13 +408,14 @@ protected:
const
NCBKernSizeParam
&
param
,
size_t
workspace_limit_in_bytes
,
bool
reproducible
=
false
);
static
void
*
const
sm_fallback_deconv_algo_type
;
class
AlgoBase
:
public
Algorithm
{
protected:
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
FALLBACK
;
}
virtual
bool
usable
(
ConvolutionBackwardDataImpl
*
opr
,
const
NCBKernSizeParam
&
param
)
const
=
0
;
virtual
size_t
get_workspace
(
ConvolutionBackwardDataImpl
*
opr
,
...
...
dnn/src/fallback/matrix_mul/opr_impl.h
浏览文件 @
f7b2bdae
...
...
@@ -103,6 +103,7 @@ public:
}
public:
AlgoBase
()
{
m_handle_type
=
Handle
::
HandleType
::
FALLBACK
;
}
enum
class
AlgoSet
:
uint32_t
{
ALGO_TYPE_GEMM
=
0
,
ALGO_TYPE_GEMV
=
1
,
...
...
dnn/src/rocm/batched_matrix_mul/opr_impl.cpp
浏览文件 @
f7b2bdae
...
...
@@ -6,10 +6,11 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "hcc_detail/hcc_defs_prologue.h"
#include "./opr_impl.h"
#include "hcc_detail/hcc_defs_prologue.h"
#include "src/common/utils.cuh"
#include "src/rocm/handle.h"
...
...
@@ -92,8 +93,8 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
static_cast
<
const
rocblas_half
*>
(
A
.
raw_ptr
),
A
.
layout
.
stride
[
1
],
A
.
layout
.
stride
[
0
],
reinterpret_cast
<
const
rocblas_half
*>
(
zero_half
),
static_cast
<
rocblas_half
*>
(
C
.
raw_ptr
),
C
.
layout
.
stride
[
1
],
C
.
layout
.
stride
[
0
],
batch
));
static_cast
<
rocblas_half
*>
(
C
.
raw_ptr
),
C
.
layout
.
stride
[
1
],
C
.
layout
.
stride
[
0
],
batch
));
};
#endif
...
...
dnn/src/rocm/convolution/backward_data/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -25,6 +25,7 @@ protected:
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
ROCM
;
}
struct
SizeArgs
{
HandleImpl
*
handle
;
CanonizedFilterMeta
filter_meta
;
...
...
dnn/src/rocm/convolution/backward_filter/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -26,6 +26,7 @@ protected:
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
ROCM
;
}
struct
SizeArgs
{
HandleImpl
*
handle
;
const
TensorLayout
*
src_layout
,
*
diff_layout
;
...
...
dnn/src/rocm/convolution/forward/algo.h
浏览文件 @
f7b2bdae
...
...
@@ -32,6 +32,7 @@ protected:
~
AlgoBase
()
=
default
;
public:
AlgoBase
()
:
Algorithm
()
{
m_handle_type
=
Handle
::
HandleType
::
ROCM
;
}
struct
SizeArgs
:
public
convolution
::
ForwardSizeArgs
{
ConvolutionForwardImpl
*
opr
;
...
...
dnn/src/x86/conv_bias/f32/algos.h
浏览文件 @
f7b2bdae
...
...
@@ -47,8 +47,6 @@ public:
return
get_kimpls
(
param
);
}
void
*
type
()
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
DIRECT
};
}
...
...
@@ -84,8 +82,6 @@ public:
return
get_kimpls
(
param
);
}
void
*
type
()
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
DIRECT
};
}
...
...
@@ -103,7 +99,6 @@ public:
}
return
m_name
.
c_str
();
}
void
*
type
()
const
override
;
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT32
);
};
...
...
@@ -119,7 +114,6 @@ public:
}
return
m_name
.
c_str
();
}
void
*
type
()
const
override
;
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT32
);
};
...
...
@@ -161,7 +155,6 @@ public:
};
return
{{
kern
,
{
1
_z
,
1
_z
,
1
_z
}}};
}
void
*
type
()
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
DIRECT
};
...
...
dnn/src/x86/conv_bias/int8/algos.h
浏览文件 @
f7b2bdae
...
...
@@ -32,7 +32,6 @@ public:
const
NCBKernSizeParam
&
param
)
const
override
{
return
get_kimpls
(
param
);
}
void
*
type
()
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
...
...
@@ -57,7 +56,6 @@ public:
const
NCBKernSizeParam
&
param
)
const
override
{
return
get_kimpls
(
param
);
}
void
*
type
()
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
...
...
@@ -82,7 +80,6 @@ public:
const
NCBKernSizeParam
&
param
)
const
override
{
return
get_kimpls
(
param
);
}
void
*
type
()
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
...
...
@@ -107,7 +104,6 @@ public:
const
NCBKernSizeParam
&
param
)
const
override
{
return
get_kimpls
(
param
);
}
void
*
type
()
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
...
...
@@ -148,7 +144,6 @@ public:
};
return
{{
kern
,
{
group
,
n
,
1
_z
}}};
}
void
*
type
()
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
...
...
@@ -179,8 +174,6 @@ public:
//! select matmul to the highest preference
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
void
*
type
()
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
IM2COL
};
}
...
...
dnn/src/x86/conv_bias/opr_impl.cpp
浏览文件 @
f7b2bdae
...
...
@@ -22,54 +22,14 @@
using
namespace
megdnn
;
using
namespace
x86
;
namespace
{
uint8_t
x86_algo_type_storage
;
void
*
x86_algo_type
=
&
x86_algo_type_storage
;
}
// anonymous namespace
#if MEGDNN_X86_WITH_MKL_DNN
void
*
ConvBiasImpl
::
AlgoMkldnnQint8
::
type
()
const
{
return
x86_algo_type
;
}
void
*
ConvBiasImpl
::
AlgoMkldnnMatmulQint8
::
type
()
const
{
return
x86_algo_type
;
}
void
*
ConvBiasImpl
::
AlgoMkldnnConv
::
type
()
const
{
return
x86_algo_type
;
}
#endif
void
*
ConvBiasImpl
::
AlgoDirect
::
type
()
const
{
return
x86_algo_type
;
}
void
*
ConvBiasImpl
::
AlgoDirectStride2
::
type
()
const
{
return
x86_algo_type
;
}
void
*
ConvBiasImpl
::
AlgoDirectAvx2Stride1Int8
::
type
()
const
{
return
x86_algo_type
;
bool
is_fallback_or_naive
(
const
detail
::
Algorithm
*
algo
)
{
return
algo
->
handle_type
()
==
Handle
::
HandleType
::
NAIVE
||
algo
->
handle_type
()
==
Handle
::
HandleType
::
FALLBACK
;
}
void
*
ConvBiasImpl
::
AlgoFP32WinogradF63_8x8
::
type
()
const
{
return
x86_algo_type
;
}
void
*
ConvBiasImpl
::
AlgoFP32WinogradF23_8x8
::
type
()
const
{
return
x86_algo_type
;
}
void
*
ConvBiasImpl
::
AlgoAVX2DirectConvStride2
::
type
()
const
{
return
x86_algo_type
;
}
void
*
ConvBiasImpl
::
AlgoChanWiseAvx2Stride1Qint8
::
type
()
const
{
return
x86_algo_type
;
}
void
*
ConvBiasImpl
::
AlgoChanWiseAvx2Stride2Qint8
::
type
()
const
{
return
x86_algo_type
;
}
}
// anonymous namespace
class
ConvBiasImpl
::
AlgoPack
:
NonCopyableObj
{
AlgoDirect
stride1_direct
;
...
...
@@ -88,8 +48,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public:
AlgoPack
()
{
//! FIXME: preference to use mkldnn algo on VNNI devices
//! But now mkldnn algo preference issue with NCHW->NHWC->NCHW
//! FIXME: preference to use mkldnn algo on VNNI devices
//! But now mkldnn algo preference issue with NCHW->NHWC->NCHW
#if MEGDNN_X86_WITH_MKL_DNN
//! Create the mkldnn algo
all_algos
.
emplace_back
(
&
mkldnn_conv_fp32
);
...
...
@@ -108,7 +68,7 @@ public:
auto
&&
matmul_algos
=
static_cast
<
MatrixMulImpl
*>
(
matmul_opr
)
->
algo_pack
();
for
(
auto
&&
algo
:
matmul_algos
)
{
if
(
algo
->
type
()
==
nullptr
)
if
(
is_fallback_or_naive
(
algo
)
)
continue
;
for
(
uint32_t
tile_size
:
{
8
,
16
,
24
})
{
refhold
.
emplace_back
(
new
AlgoFP32WinogradF63_8x8
(
...
...
@@ -126,7 +86,7 @@ public:
SmallVector
<
AlgoBase
*>
winograd_algos
;
};
SmallVector
<
ConvBiasImpl
::
AlgoBase
*>
ConvBiasImpl
::
algo_pack
()
{
SmallVector
<
fallback
::
ConvBiasImpl
::
AlgoBase
*>
ConvBiasImpl
::
algo_pack
()
{
static
AlgoPack
sl_algo_pack
;
auto
&&
algos
=
fallback
::
ConvBiasImpl
::
algo_pack
();
algos
.
insert
(
algos
.
begin
(),
sl_algo_pack
.
all_algos
.
begin
(),
...
...
@@ -176,8 +136,8 @@ bool ConvBiasImpl::is_matmul_quantized_prefer(
!
chanwise_avx2_stride2_qint8_usable_preferred
(
param
));
}
SmallVector
<
AlgoCategory
>
ConvBiasImpl
::
suggest_algo_category_order
(
const
NCBKernSizeParam
&
param
)
const
{
SmallVector
<
AlgoCategory
>
ConvBiasImpl
::
suggest_algo_category_order
(
const
NCBKernSizeParam
&
param
)
const
{
auto
IC
=
param
.
filter_meta
.
icpg
;
auto
OC
=
param
.
filter_meta
.
ocpg
;
auto
FH
=
param
.
filter_meta
.
spatial
[
0
];
...
...
dnn/src/x86/conv_bias/opr_impl.h
浏览文件 @
f7b2bdae
...
...
@@ -20,10 +20,15 @@ namespace x86 {
class
ConvBiasImpl
:
public
fallback
::
ConvBiasImpl
{
public:
using
fallback
::
ConvBiasImpl
::
ConvBiasImpl
;
using
FallbackConvBiasImpl
=
fallback
::
ConvBiasImpl
;
class
AlgoBase
:
public
fallback
::
ConvBiasImpl
::
AlgoBase
{
public:
AlgoBase
()
:
fallback
::
ConvBiasImpl
::
AlgoBase
()
{
m_handle_type
=
Handle
::
HandleType
::
X86
;
}
};
bool
is_thread_safe
()
const
override
{
return
true
;
}
SmallVector
<
AlgoBase
*>
algo_pack
()
override
;
SmallVector
<
fallback
::
ConvBiasImpl
::
AlgoBase
*>
algo_pack
()
override
;
SmallVector
<
AlgoCategory
>
suggest_algo_category_order
(
const
NCBKernSizeParam
&
param
)
const
override
;
...
...
dnn/src/x86/matrix_mul/algos.h
浏览文件 @
f7b2bdae
...
...
@@ -25,7 +25,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
{
return
0
;
}
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
,
AlgoDataType
::
FLOAT32
,
DEFAULT
)
};
...
...
@@ -38,7 +37,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
{
return
0
;
}
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
ONLY_PACKA
;
}
kern_naked_t
get_kern_naked
(
const
KernSizeParam
&
)
const
override
;
void
pack_A
(
const
KernParam
&
kern_param
,
void
*
out
,
size_t
index
,
...
...
@@ -60,7 +58,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -71,7 +68,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -86,7 +82,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -102,7 +97,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -114,7 +108,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
...
...
@@ -125,7 +118,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
8
,
8
,
4
,
AlgoDataType
::
FLOAT32
,
MK8
)
};
...
...
@@ -138,7 +130,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
#endif
...
...
@@ -151,7 +142,6 @@ public:
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
{
return
0
;
}
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
,
AlgoDataType
::
QINT8X8X32
,
DEFAULT
)
};
...
...
dnn/src/x86/matrix_mul/opr_impl.cpp
浏览文件 @
f7b2bdae
...
...
@@ -16,12 +16,6 @@
using
namespace
megdnn
;
using
namespace
x86
;
namespace
{
uint8_t
x86_algo_type_storage
;
}
// anonymous namespace
void
*
const
MatrixMulImpl
::
sm_x86_algo_type
=
&
x86_algo_type_storage
;
class
MatrixMulImpl
::
AlgoPack
:
NonCopyableObj
{
AlgoF32Blas
f32blas
;
...
...
@@ -62,10 +56,10 @@ public:
all_algos
.
emplace_back
(
&
f32mkl_packa
);
#endif
}
SmallVector
<
AlgoBase
*>
all_algos
;
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
all_algos
;
};
SmallVector
<
MatrixMulImpl
::
AlgoBase
*>
MatrixMulImpl
::
algo_pack
()
{
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
MatrixMulImpl
::
algo_pack
()
{
static
AlgoPack
s_algo_pack
;
auto
&&
algos
=
fallback
::
MatrixMulImpl
::
algo_pack
();
algos
.
insert
(
algos
.
begin
(),
s_algo_pack
.
all_algos
.
begin
(),
...
...
dnn/src/x86/matrix_mul/opr_impl.h
浏览文件 @
f7b2bdae
...
...
@@ -33,13 +33,18 @@ namespace x86 {
class
MatrixMulImpl
:
public
fallback
::
MatrixMulImpl
{
public:
using
fallback
::
MatrixMulImpl
::
MatrixMulImpl
;
class
AlgoBase
:
public
fallback
::
MatrixMulImpl
::
AlgoBase
{
public:
AlgoBase
()
:
fallback
::
MatrixMulImpl
::
AlgoBase
()
{
m_handle_type
=
Handle
::
HandleType
::
X86
;
}
};
bool
is_thread_safe
()
const
override
{
return
true
;
}
SmallVector
<
AlgoBase
*>
algo_pack
()
override
;
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
algo_pack
()
override
;
protected:
static
void
*
const
sm_x86_algo_type
;
class
AlgoF32Blas
;
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
class
AlgoF32MKLPackA
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录