Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d063d577
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d063d577
编写于
8月 11, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(functional): use fma to reduce elemwise but disable subgraph compilation
GitOrigin-RevId: c75a6e1a09b8e727a48e3b5eaabc6926aa046a46
上级
2a063f8e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
38 addition
and
22 deletion
+38
-22
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+18
-2
imperative/python/megengine/functional/math.py
imperative/python/megengine/functional/math.py
+2
-2
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+18
-18
未找到文件。
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
d063d577
...
@@ -242,16 +242,32 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
...
@@ -242,16 +242,32 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
"-"
:
lambda
:
builtin
.
Elemwise
(
mode
=
"negate"
),
"-"
:
lambda
:
builtin
.
Elemwise
(
mode
=
"negate"
),
}
}
ternary_ops
=
{
"fma3"
:
lambda
:
builtin
.
Elemwise
(
mode
=
"FUSE_MUL_ADD3"
),
}
quaternary_ops
=
{
"fma4"
:
lambda
:
builtin
.
Elemwise
(
mode
=
"FUSE_MUL_ADD4"
)}
def
decorator
(
func
):
def
decorator
(
func
):
builder
=
_SubgraphBuilder
(
name
)
builder
=
_SubgraphBuilder
(
name
)
def
apply_expr
(
op
,
*
args
):
def
apply_expr
(
op
,
*
args
,
nr_out
=
None
):
if
isinstance
(
op
,
str
):
if
isinstance
(
op
,
str
):
if
len
(
args
)
==
2
:
if
len
(
args
)
==
2
:
op
=
binary_ops
[
op
]()
op
=
binary_ops
[
op
]()
elif
len
(
args
)
==
1
:
elif
len
(
args
)
==
1
:
op
=
unary_ops
[
op
]()
op
=
unary_ops
[
op
]()
return
builder
.
apply
(
op
,
args
,
1
)[
0
]
elif
len
(
args
)
==
3
:
op
=
ternary_ops
[
op
]()
elif
len
(
args
)
==
4
:
op
=
quaternary_ops
[
op
]()
results
=
builder
.
apply
(
op
,
args
,
1
if
nr_out
is
None
else
nr_out
)
if
nr_out
is
None
:
assert
len
(
results
)
==
1
return
results
[
0
]
else
:
assert
len
(
results
)
==
nr_out
return
results
def
apply_const
(
value
,
dtype
=
dtype
,
device
=
device
):
def
apply_const
(
value
,
dtype
=
dtype
,
device
=
device
):
return
builder
.
apply_const
(
value
,
dtype
,
device
)
return
builder
.
apply_const
(
value
,
dtype
,
device
)
...
...
imperative/python/megengine/functional/math.py
浏览文件 @
d063d577
...
@@ -784,7 +784,7 @@ class _Hashable:
...
@@ -784,7 +784,7 @@ class _Hashable:
def
_get_extentedMatrixMulOp
(
def
_get_extentedMatrixMulOp
(
device
,
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
,
device
,
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
,
):
):
@
subgraph
(
"extentedMatrixMulOp"
,
dtype
,
device
,
2
,
gopt_level
=
3
)
@
subgraph
(
"extentedMatrixMulOp"
,
dtype
,
device
,
2
,
gopt_level
=
2
)
def
extentedMatrixMulOp
(
inputs
,
f
,
c
):
def
extentedMatrixMulOp
(
inputs
,
f
,
c
):
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
inp1
,
inp2
=
inputs
inp1
,
inp2
=
inputs
...
@@ -884,7 +884,7 @@ def _get_extentedMatrixMulOp(
...
@@ -884,7 +884,7 @@ def _get_extentedMatrixMulOp(
def
_get_extentedBatchedMatrixMulOp
(
def
_get_extentedBatchedMatrixMulOp
(
device
,
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
,
device
,
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
,
):
):
@
subgraph
(
"extentedBatchedMatrixMulOp"
,
dtype
,
device
,
2
,
gopt_level
=
3
)
@
subgraph
(
"extentedBatchedMatrixMulOp"
,
dtype
,
device
,
2
,
gopt_level
=
2
)
def
extentedBatchedMatrixMulOp
(
inputs
,
f
,
c
):
def
extentedBatchedMatrixMulOp
(
inputs
,
f
,
c
):
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
inp1
,
inp2
=
inputs
inp1
,
inp2
=
inputs
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
d063d577
...
@@ -1174,7 +1174,7 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
...
@@ -1174,7 +1174,7 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
reduce_size_f
=
f
(
TypeCvt
(
dtype
=
dtype
),
reduce_size
)
reduce_size_f
=
f
(
TypeCvt
(
dtype
=
dtype
),
reduce_size
)
return
(
reduce_shape
,
reduce_size_f
,
channel_x1s
,
channel_x2s
),
(
False
,
False
,
True
,
True
)
return
(
reduce_shape
,
reduce_size_f
,
channel_x1s
,
channel_x2s
),
(
False
,
False
,
True
,
True
)
@
subgraph
(
"SyncBnStage1"
,
dtype
,
device
,
7
,
gopt_level
=
3
)
@
subgraph
(
"SyncBnStage1"
,
dtype
,
device
,
7
)
def
syncbn_stage1
(
inputs
,
f
,
c
):
def
syncbn_stage1
(
inputs
,
f
,
c
):
input
,
reduce_size
,
channel_x1s
,
channel_x2s
,
eps
=
inputs
[
0
:
5
]
input
,
reduce_size
,
channel_x1s
,
channel_x2s
,
eps
=
inputs
[
0
:
5
]
weight
,
bias
=
inputs
[
5
:
7
]
weight
,
bias
=
inputs
[
5
:
7
]
...
@@ -1187,12 +1187,12 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
...
@@ -1187,12 +1187,12 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
inv_var_wt
=
f
(
"*"
,
invsqrt_channel_var
,
weight
)
inv_var_wt
=
f
(
"*"
,
invsqrt_channel_var
,
weight
)
neg_channel_mean
=
f
(
"-"
,
channel_mean
)
neg_channel_mean
=
f
(
"-"
,
channel_mean
)
outvar
=
\
outvar
=
\
f
(
"
+"
,
f
(
"*"
,
input
,
inv_var_wt
)
,
f
(
"
fma3"
,
input
,
inv_var_wt
,
f
(
"+"
,
f
(
"*"
,
neg_channel_mean
,
inv_var_wt
),
f
(
"+"
,
f
(
"*"
,
neg_channel_mean
,
inv_var_wt
),
bias
))
bias
))
return
(
outvar
,
channel_mean
,
channel_var
,
inv_var_wt
),
(
True
,
False
,
False
,
False
)
return
(
outvar
,
channel_mean
,
channel_var
,
inv_var_wt
),
(
True
,
False
,
False
,
False
)
@
subgraph
(
"SyncBnStage1Inference"
,
dtype
,
device
,
6
,
gopt_level
=
3
)
@
subgraph
(
"SyncBnStage1Inference"
,
dtype
,
device
,
6
)
def
syncbn_stage1_inference
(
inputs
,
f
,
c
):
def
syncbn_stage1_inference
(
inputs
,
f
,
c
):
input
,
channel_mean
,
channel_var
,
eps
=
inputs
[
0
:
4
]
input
,
channel_mean
,
channel_var
,
eps
=
inputs
[
0
:
4
]
weight
,
bias
=
inputs
[
4
:
6
]
weight
,
bias
=
inputs
[
4
:
6
]
...
@@ -1205,36 +1205,36 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
...
@@ -1205,36 +1205,36 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
bias
))
bias
))
return
(
outvar
,),
(
True
,)
return
(
outvar
,),
(
True
,)
@
subgraph
(
"SyncBnStage2"
,
dtype
,
device
,
7
,
gopt_level
=
3
)
@
subgraph
(
"SyncBnStage2"
,
dtype
,
device
,
7
)
def
syncbn_stage2
(
inputs
,
f
,
c
):
def
syncbn_stage2
(
inputs
,
f
,
c
):
running_mean
,
running_var
,
momentum
=
inputs
[
0
:
3
]
running_mean
,
running_var
,
momentum
=
inputs
[
0
:
3
]
reduce_size
,
channel_x1s
,
channel_x2s
,
channel_mean
=
inputs
[
3
:
7
]
reduce_size
,
channel_x1s
,
channel_x2s
,
channel_mean
=
inputs
[
3
:
7
]
running_mean
=
f
(
"*"
,
running_mean
,
momentum
)
c1_minus_momentum
=
f
(
"-"
,
c
(
1
),
momentum
)
running_mean
=
\
reduce_size_minus_c1
=
f
(
"-"
,
reduce_size
,
c
(
1
))
f
(
"+"
,
running_mean
,
running_mean
=
f
(
"fma4"
,
f
(
"*"
,
f
(
"-"
,
c
(
1
),
momentum
),
running_mean
,
momentum
,
channel_mean
))
c1_minus_momentum
,
channel_mean
,
)
channel_variance_unbiased
=
\
channel_variance_unbiased
=
\
f
(
"+"
,
f
(
"/"
,
f
(
"**"
,
channel_x1s
,
c
(
2
)),
f
(
"+"
,
f
(
"/"
,
f
(
"**"
,
channel_x1s
,
c
(
2
)),
f
(
"*"
,
f
(
"-"
,
reduce_size
),
f
(
"*"
,
f
(
"-"
,
reduce_size
),
f
(
"-"
,
reduce_size
,
c
(
1
))
)),
reduce_size_minus_c1
)),
f
(
"/"
,
channel_x2s
,
f
(
"/"
,
channel_x2s
,
f
(
"-"
,
reduce_size
,
c
(
1
))))
reduce_size_minus_c1
))
running_var
=
f
(
"*"
,
running_var
,
momentum
)
running_var
=
f
(
"fma4"
,
running_var
=
\
running_var
,
momentum
,
f
(
"+"
,
running_var
,
c1_minus_momentum
,
channel_variance_unbiased
f
(
"*"
,
f
(
"-"
,
c
(
1
),
momentum
),
)
channel_variance_unbiased
))
return
(
running_mean
,
running_var
),
(
True
,
True
)
return
(
running_mean
,
running_var
),
(
True
,
True
)
@
subgraph
(
"SyncBnConcatStats"
,
dtype
,
device
,
3
,
gopt_level
=
3
)
@
subgraph
(
"SyncBnConcatStats"
,
dtype
,
device
,
3
)
def
syncbn_concat_stats
(
inputs
,
f
,
c
):
def
syncbn_concat_stats
(
inputs
,
f
,
c
):
reduce_size
,
channel_x1s
,
channel_x2s
=
inputs
[
0
:
3
]
reduce_size
,
channel_x1s
,
channel_x2s
=
inputs
[
0
:
3
]
reduce_size
=
f
(
builtin
.
Broadcast
(),
reduce_size
,
c
([
1
]
*
ndim
,
dtype
=
"int32"
))
reduce_size
=
f
(
builtin
.
Broadcast
(),
reduce_size
,
c
([
1
]
*
ndim
,
dtype
=
"int32"
))
stats
=
f
(
builtin
.
Concat
(
axis
=
1
,
comp_node
=
device
),
reduce_size
,
channel_x1s
,
channel_x2s
)
stats
=
f
(
builtin
.
Concat
(
axis
=
1
,
comp_node
=
device
),
reduce_size
,
channel_x1s
,
channel_x2s
)
return
(
stats
,),
(
True
,)
return
(
stats
,),
(
True
,)
@
subgraph
(
"SyncBnSplitStats"
,
dtype
,
device
,
1
,
gopt_level
=
3
)
@
subgraph
(
"SyncBnSplitStats"
,
dtype
,
device
,
1
)
def
syncbn_split_stats
(
inputs
,
f
,
c
):
def
syncbn_split_stats
(
inputs
,
f
,
c
):
stats
=
inputs
[
0
]
stats
=
inputs
[
0
]
c_1
=
c
(
1
,
dtype
=
"int32"
)
c_1
=
c
(
1
,
dtype
=
"int32"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录