Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a226f02e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a226f02e
编写于
9月 02, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/imperative): fix syncbn in symbolic mode
GitOrigin-RevId: a9794318a7e28aa262d6047b90e39f2904c35d8c
上级
34333593
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
20 addition
and
20 deletion
+20
-20
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+20
-15
imperative/python/test/unit/module/test_batchnorm.py
imperative/python/test/unit/module/test_batchnorm.py
+0
-5
未找到文件。
imperative/python/megengine/functional/nn.py
浏览文件 @
a226f02e
...
@@ -22,7 +22,7 @@ from .debug_param import get_conv_execution_strategy
...
@@ -22,7 +22,7 @@ from .debug_param import get_conv_execution_strategy
from
.distributed
import
all_reduce_sum
from
.distributed
import
all_reduce_sum
from
.elemwise
import
exp
,
floor
,
log
,
log1p
,
maximum
,
minimum
,
relu
from
.elemwise
import
exp
,
floor
,
log
,
log1p
,
maximum
,
minimum
,
relu
from
.math
import
argsort
,
max
,
sum
from
.math
import
argsort
,
max
,
sum
from
.tensor
import
add_axis
,
broadcast
,
concat
,
full
,
remove_axis
,
reshape
from
.tensor
import
add_axis
,
broadcast
,
concat
,
remove_axis
,
reshape
from
.types
import
_pair
,
_pair_nonzero
from
.types
import
_pair
,
_pair_nonzero
__all__
=
[
__all__
=
[
...
@@ -692,7 +692,7 @@ def batch_norm2d(
...
@@ -692,7 +692,7 @@ def batch_norm2d(
def
sync_batch_norm
(
def
sync_batch_norm
(
inp
ut
:
Tensor
,
inp
:
Tensor
,
running_mean
:
Tensor
,
running_mean
:
Tensor
,
running_var
:
Tensor
,
running_var
:
Tensor
,
weight
:
Optional
[
Tensor
]
=
None
,
weight
:
Optional
[
Tensor
]
=
None
,
...
@@ -723,25 +723,30 @@ def sync_batch_norm(
...
@@ -723,25 +723,30 @@ def sync_batch_norm(
Default: 1e-5.
Default: 1e-5.
"""
"""
assert
eps_mode
in
{
"MAX"
,
"ADDITIVE"
},
"unknown eps_mode: {}"
.
format
(
eps_mode
)
assert
eps_mode
in
{
"MAX"
,
"ADDITIVE"
},
"unknown eps_mode: {}"
.
format
(
eps_mode
)
_channels
=
input
.
shape
[
1
]
_channels
=
inp
.
shape
[
1
]
_ndim
=
input
.
ndim
_ndim
=
inp
.
ndim
_device
=
inp
.
device
_dtype
=
inp
.
dtype
_param_shape
=
(
1
,
_channels
)
+
(
1
,)
*
(
_ndim
-
2
)
_param_shape
=
(
1
,
_channels
)
+
(
1
,)
*
(
_ndim
-
2
)
_reduce_axis
=
[
0
]
+
[
i
for
i
in
range
(
2
,
_ndim
)]
if
training
:
if
training
:
def
_sum_on_channel
(
inp
ut
):
def
_sum_on_channel
(
inp
):
return
apply
(
builtin
.
Reduce
(
mode
=
"SUM"
),
input
,
Tensor
(
_param_shape
))[
0
]
return
inp
.
sum
(
axis
=
_reduce_axis
,
keepdims
=
True
)
reduce_size
=
inp
ut
.
shape
[
0
]
reduce_size
=
inp
.
shape
[
0
]
for
i
in
range
(
2
,
_ndim
):
for
i
in
range
(
2
,
_ndim
):
reduce_size
=
reduce_size
*
inp
ut
.
shape
[
i
]
reduce_size
=
reduce_size
*
inp
.
shape
[
i
]
channel_x1s
=
_sum_on_channel
(
inp
ut
)
channel_x1s
=
_sum_on_channel
(
inp
)
channel_x2s
=
_sum_on_channel
(
inp
ut
**
2
)
channel_x2s
=
_sum_on_channel
(
inp
**
2
)
if
is_distributed
():
if
is_distributed
():
# reduce all nodes' data to calculate mean and variance
# reduce all nodes' data to calculate mean and variance
reduce_size
=
full
([
1
for
_
in
range
(
_ndim
)],
reduce_size
)
reduce_size
=
broadcast
(
Tensor
(
reduce_size
,
dtype
=
_dtype
),
[
1
]
*
_ndim
)
stat
=
concat
([
reduce_size
,
channel_x1s
,
channel_x2s
],
axis
=
1
)
stat
=
concat
(
[
reduce_size
.
astype
(
_dtype
),
channel_x1s
,
channel_x2s
],
axis
=
1
)
stat
=
all_reduce_sum
(
stat
,
group
)
stat
=
all_reduce_sum
(
stat
,
group
)
reduce_size
=
stat
[:,
:
1
].
reshape
(
1
)
reduce_size
=
stat
[:,
:
1
].
reshape
(
1
)
channel_x1s
=
stat
[:,
1
:
1
+
_channels
]
channel_x1s
=
stat
[:,
1
:
1
+
_channels
]
...
@@ -775,11 +780,11 @@ def sync_batch_norm(
...
@@ -775,11 +780,11 @@ def sync_batch_norm(
inv_var_wt
=
invsqrt_channel_variance
*
weight
inv_var_wt
=
invsqrt_channel_variance
*
weight
neg_channel_mean
=
-
channel_mean
neg_channel_mean
=
-
channel_mean
if
bias
is
not
None
:
if
bias
is
not
None
:
outvar
=
inp
ut
*
inv_var_wt
+
(
neg_channel_mean
*
inv_var_wt
+
bias
)
outvar
=
inp
*
inv_var_wt
+
(
neg_channel_mean
*
inv_var_wt
+
bias
)
else
:
else
:
outvar
=
inp
ut
*
inv_var_wt
+
neg_channel_mean
*
inv_var_wt
outvar
=
inp
*
inv_var_wt
+
neg_channel_mean
*
inv_var_wt
else
:
else
:
outvar
=
inp
ut
*
invsqrt_channel_variance
+
(
outvar
=
inp
*
invsqrt_channel_variance
+
(
-
channel_mean
*
invsqrt_channel_variance
-
channel_mean
*
invsqrt_channel_variance
)
)
if
bias
is
not
None
:
if
bias
is
not
None
:
...
...
imperative/python/test/unit/module/test_batchnorm.py
浏览文件 @
a226f02e
...
@@ -27,7 +27,6 @@ from megengine.test import assertTensorClose
...
@@ -27,7 +27,6 @@ from megengine.test import assertTensorClose
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"do not imp GPU mode at Windows now"
platform
.
system
()
==
"Windows"
,
reason
=
"do not imp GPU mode at Windows now"
)
)
@
pytest
.
mark
.
skipif
(
use_tensor_shape
(),
reason
=
"syncbn doesnot support symbolic shape"
)
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_syncbn
():
def
test_syncbn
():
nr_chan
=
8
nr_chan
=
8
...
@@ -154,7 +153,6 @@ def test_batchnorm():
...
@@ -154,7 +153,6 @@ def test_batchnorm():
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"do not imp GPU mode at Windows now"
platform
.
system
()
==
"Windows"
,
reason
=
"do not imp GPU mode at Windows now"
)
)
@
pytest
.
mark
.
skipif
(
use_tensor_shape
(),
reason
=
"syncbn doesnot support symbolic shape"
)
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_syncbn1d
():
def
test_syncbn1d
():
nr_chan
=
8
nr_chan
=
8
...
@@ -257,7 +255,6 @@ def test_batchnorm2d():
...
@@ -257,7 +255,6 @@ def test_batchnorm2d():
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"do not imp GPU mode at Windows now"
platform
.
system
()
==
"Windows"
,
reason
=
"do not imp GPU mode at Windows now"
)
)
@
pytest
.
mark
.
skipif
(
use_tensor_shape
(),
reason
=
"syncbn doesnot support symbolic shape"
)
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_syncbn2d
():
def
test_syncbn2d
():
nr_chan
=
8
nr_chan
=
8
...
@@ -336,7 +333,6 @@ def test_batchnorm_no_stats():
...
@@ -336,7 +333,6 @@ def test_batchnorm_no_stats():
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"do not imp GPU mode at Windows now"
platform
.
system
()
==
"Windows"
,
reason
=
"do not imp GPU mode at Windows now"
)
)
@
pytest
.
mark
.
skipif
(
use_tensor_shape
(),
reason
=
"syncbn doesnot support symbolic shape"
)
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_syncbn_no_stats
():
def
test_syncbn_no_stats
():
nr_chan
=
8
nr_chan
=
8
...
@@ -393,7 +389,6 @@ def test_batchnorm2d_no_stats():
...
@@ -393,7 +389,6 @@ def test_batchnorm2d_no_stats():
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"do not imp GPU mode at Windows now"
platform
.
system
()
==
"Windows"
,
reason
=
"do not imp GPU mode at Windows now"
)
)
@
pytest
.
mark
.
skipif
(
use_tensor_shape
(),
reason
=
"syncbn doesnot support symbolic shape"
)
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
isolated_distributed
def
test_syncbn2d_no_stats
():
def
test_syncbn2d_no_stats
():
nr_chan
=
8
nr_chan
=
8
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录