Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a6bc250d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a6bc250d
编写于
5月 22, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/common): add matmul impl for naive with matrix format mk4_dot
GitOrigin-RevId: 7c6fbdfa973413b1b8d2f2fb0d18f8bb5ee7f243
上级
bb872965
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
56 addition
and
17 deletion
+56
-17
dnn/scripts/opr_param_defs.py
dnn/scripts/opr_param_defs.py
+4
-1
dnn/src/common/matrix_mul.cpp
dnn/src/common/matrix_mul.cpp
+2
-0
dnn/src/naive/matrix_mul/matrix_mul_helper.h
dnn/src/naive/matrix_mul/matrix_mul_helper.h
+29
-0
dnn/src/naive/matrix_mul/opr_impl.cpp
dnn/src/naive/matrix_mul/opr_impl.cpp
+21
-16
未找到文件。
dnn/scripts/opr_param_defs.py
浏览文件 @
a6bc250d
...
...
@@ -433,7 +433,10 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'
),
Doc
(
'MK8'
,
'Split 8 from M and K, better for neon compute:'
'(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the '
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'
))
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'
),
Doc
(
'MK4_DOT'
,
'Split 4 from M and K, better for neon dotprod:'
'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the '
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))'
))
)
(
pdef
(
'Winograd'
,
'winograd param used in convbias'
).
...
...
dnn/src/common/matrix_mul.cpp
浏览文件 @
a6bc250d
...
...
@@ -186,6 +186,8 @@ size_t MatrixMulForward::pack_size(const Param::Format format) {
return
1
;
case
Param
::
Format
::
MK4
:
return
4
;
case
Param
::
Format
::
MK4_DOT
:
return
4
;
case
Param
::
Format
::
MK8
:
return
8
;
default:
...
...
dnn/src/naive/matrix_mul/matrix_mul_helper.h
浏览文件 @
a6bc250d
...
...
@@ -82,6 +82,35 @@ void run_matrix_mul_mk4_tpl(const itype* A, const itype* B, otype* C, size_t M,
}
}
template
<
typename
itype
,
typename
otype
,
bool
transA
,
bool
transB
,
typename
comp_type
=
otype
>
void
run_matrix_mul_mk4_dot_tpl
(
const
itype
*
A
,
const
itype
*
B
,
otype
*
C
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
LDA
,
size_t
LDB
,
size_t
LDC
,
const
DType
&
A_type
,
const
DType
&
B_type
)
{
Getter
<
itype
,
comp_type
>
getterA
(
A_type
),
getterB
(
B_type
);
for
(
size_t
m
=
0
;
m
<
M
;
++
m
)
{
for
(
size_t
n
=
0
;
n
<
N
;
++
n
)
{
comp_type
res
[
4
]
=
{
comp_type
(
0
)};
for
(
size_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
size_t
i
=
0
;
i
<
4
;
i
++
)
{
comp_type
av
,
bv
;
for
(
size_t
j
=
0
;
j
<
4
;
j
++
)
{
av
=
transA
?
getterA
(
A
[
k
*
LDA
+
m
*
16
+
4
*
i
+
j
])
:
getterA
(
A
[
m
*
LDA
+
k
*
16
+
4
*
i
+
j
]),
bv
=
transB
?
getterB
(
B
[
n
*
LDB
+
k
*
4
+
j
])
:
getterB
(
B
[
k
*
LDB
+
n
*
4
+
j
]);
res
[
i
]
+=
av
*
bv
;
}
}
}
for
(
size_t
i
=
0
;
i
<
4
;
i
++
)
{
C
[
m
*
LDC
+
n
*
4
+
i
]
=
res
[
i
];
}
}
}
}
template
<
typename
itype
,
typename
otype
,
bool
transA
,
bool
transB
,
typename
comp_type
=
otype
>
void
run_matrix_mul_mk8_tpl
(
const
itype
*
A
,
const
itype
*
B
,
otype
*
C
,
size_t
M
,
...
...
dnn/src/naive/matrix_mul/opr_impl.cpp
浏览文件 @
a6bc250d
...
...
@@ -49,6 +49,11 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B,
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK4_DOT) { \
return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \
A.layout.dtype, B.layout.dtype); \
} else if (param.format == param::MatrixMul::Format::MK8) { \
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \
A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录