Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
27a9326c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
27a9326c
编写于
1月 22, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/module): fix frozen batch norm
GitOrigin-RevId: 143d468a37694522591971a191a549e3f1dd2d05
上级
c3ba0280
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
98 addition
and
26 deletion
+98
-26
imperative/python/megengine/module/batchnorm.py
imperative/python/megengine/module/batchnorm.py
+41
-14
imperative/python/test/integration/test_bn.py
imperative/python/test/integration/test_bn.py
+57
-12
未找到文件。
imperative/python/megengine/module/batchnorm.py
浏览文件 @
27a9326c
...
@@ -35,6 +35,10 @@ class _BatchNorm(Module):
...
@@ -35,6 +35,10 @@ class _BatchNorm(Module):
self
.
track_running_stats
=
track_running_stats
self
.
track_running_stats
=
track_running_stats
self
.
_track_running_stats_saved
=
track_running_stats
self
.
_track_running_stats_saved
=
track_running_stats
self
.
freeze
=
freeze
self
.
freeze
=
freeze
if
self
.
freeze
:
assert
(
self
.
_track_running_stats_saved
),
"track_running_stats must be initilized to True if freeze is True"
tshape
=
(
1
,
self
.
num_features
,
1
,
1
)
tshape
=
(
1
,
self
.
num_features
,
1
,
1
)
if
self
.
affine
:
if
self
.
affine
:
self
.
weight
=
Parameter
(
np
.
ones
(
tshape
,
dtype
=
np
.
float32
))
self
.
weight
=
Parameter
(
np
.
ones
(
tshape
,
dtype
=
np
.
float32
))
...
@@ -84,10 +88,24 @@ class _BatchNorm(Module):
...
@@ -84,10 +88,24 @@ class _BatchNorm(Module):
inp
=
inp
.
reshape
(
new_shape
)
inp
=
inp
.
reshape
(
new_shape
)
if
self
.
freeze
and
self
.
training
and
self
.
_track_running_stats_saved
:
_weight
=
self
.
weight
scale
=
self
.
weight
*
(
self
.
running_var
+
self
.
eps
)
**
(
-
0.5
)
_bias
=
self
.
bias
bias
=
self
.
bias
-
self
.
running_mean
*
scale
return
inp
*
scale
.
detach
()
+
bias
.
detach
()
if
self
.
freeze
:
if
_weight
is
not
None
:
_weight
=
_weight
.
detach
()
if
_bias
is
not
None
:
_bias
=
_bias
.
detach
()
# Need to expand to elementwise operations here
# see MGB_IMPL_OPR_GRAD(BatchNormForward) in src/opr/impl/dnn/batch_norm.cpp
scale
=
(
self
.
running_var
+
self
.
eps
)
**
(
-
0.5
)
if
_weight
is
not
None
:
scale
*=
_weight
bias
=
-
self
.
running_mean
*
scale
if
_bias
is
not
None
:
bias
+=
_bias
return
inp
*
scale
+
bias
if
self
.
training
and
self
.
track_running_stats
:
if
self
.
training
and
self
.
track_running_stats
:
exponential_average_factor
=
self
.
momentum
exponential_average_factor
=
self
.
momentum
...
@@ -98,8 +116,8 @@ class _BatchNorm(Module):
...
@@ -98,8 +116,8 @@ class _BatchNorm(Module):
inp
,
inp
,
self
.
running_mean
if
self
.
track_running_stats
else
None
,
self
.
running_mean
if
self
.
track_running_stats
else
None
,
self
.
running_var
if
self
.
track_running_stats
else
None
,
self
.
running_var
if
self
.
track_running_stats
else
None
,
self
.
weight
,
_
weight
,
self
.
bias
,
_
bias
,
training
=
self
.
training
training
=
self
.
training
or
((
self
.
running_mean
is
None
)
and
(
self
.
running_var
is
None
)),
or
((
self
.
running_mean
is
None
)
and
(
self
.
running_var
is
None
)),
momentum
=
exponential_average_factor
,
momentum
=
exponential_average_factor
,
...
@@ -121,7 +139,7 @@ class _BatchNorm(Module):
...
@@ -121,7 +139,7 @@ class _BatchNorm(Module):
class
SyncBatchNorm
(
_BatchNorm
):
class
SyncBatchNorm
(
_BatchNorm
):
r
"""
r
"""
Applies Synchroniz
ation Batch Normalization
.
Applies Synchroniz
ed Batch Normalization for distributed training
.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -169,15 +187,25 @@ class SyncBatchNorm(_BatchNorm):
...
@@ -169,15 +187,25 @@ class SyncBatchNorm(_BatchNorm):
else
:
else
:
exponential_average_factor
=
0.0
# useless
exponential_average_factor
=
0.0
# useless
_weight
=
self
.
weight
_bias
=
self
.
bias
if
self
.
freeze
:
if
_weight
is
not
None
:
_weight
=
_weight
.
detach
()
if
_bias
is
not
None
:
_bias
=
_bias
.
detach
()
output
=
sync_batch_norm
(
output
=
sync_batch_norm
(
inp
,
inp
,
self
.
running_mean
,
self
.
running_mean
,
self
.
running_var
,
self
.
running_var
,
self
.
weight
,
_weight
,
self
.
bias
,
_bias
,
self
.
training
or
not
self
.
track_running_stats
,
training
=
(
self
.
training
and
not
self
.
freeze
)
exponential_average_factor
,
or
((
self
.
running_mean
is
None
)
and
(
self
.
running_var
is
None
)),
self
.
eps
,
momentum
=
exponential_average_factor
,
eps
=
self
.
eps
,
group
=
self
.
group
,
group
=
self
.
group
,
)
)
...
@@ -257,8 +285,7 @@ class BatchNorm2d(_BatchNorm):
...
@@ -257,8 +285,7 @@ class BatchNorm2d(_BatchNorm):
:param freeze: when set to True, this module does not update the
:param freeze: when set to True, this module does not update the
running mean and variance, and uses the running mean and variance instead of
running mean and variance, and uses the running mean and variance instead of
the batch mean and batch variance to normalize the input. The parameter takes effect
the batch mean and batch variance to normalize the input. The parameter takes effect
only when the module is initilized with track_running_stats as True and
only when the module is initilized with track_running_stats as True.
the module is in training mode.
Default: False
Default: False
Examples:
Examples:
...
...
imperative/python/test/integration/test_bn.py
浏览文件 @
27a9326c
...
@@ -11,15 +11,23 @@ import pytest
...
@@ -11,15 +11,23 @@ import pytest
import
megengine
import
megengine
import
megengine.autodiff
as
ad
import
megengine.autodiff
as
ad
import
megengine.distributed
as
dist
import
megengine.functional
as
F
import
megengine.optimizer
as
optimizer
import
megengine.optimizer
as
optimizer
from
megengine
import
Parameter
,
tensor
from
megengine
import
Parameter
,
tensor
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.jit
import
trace
from
megengine.jit
import
trace
from
megengine.module
import
BatchNorm2d
,
Module
from
megengine.module
import
BatchNorm2d
,
Module
,
SyncBatchNorm
def
test_frozen_bn
(
):
def
run_frozen_bn
(
BNModule
,
use_trace
=
False
,
use_symbolic
=
False
):
nchannel
=
3
nchannel
=
3
m
=
BatchNorm2d
(
nchannel
,
freeze
=
True
)
m
=
BNModule
(
nchannel
,
freeze
=
True
)
var
=
4.0
bias
=
1.0
shape
=
(
1
,
nchannel
,
1
,
1
)
m
.
running_var
[...]
=
var
*
F
.
ones
(
shape
)
m
.
running_mean
[...]
=
bias
*
F
.
ones
(
shape
)
saved_var
=
m
.
running_var
.
numpy
()
saved_var
=
m
.
running_var
.
numpy
()
saved_mean
=
m
.
running_mean
.
numpy
()
saved_mean
=
m
.
running_mean
.
numpy
()
...
@@ -31,16 +39,45 @@ def test_frozen_bn():
...
@@ -31,16 +39,45 @@ def test_frozen_bn():
optim
.
clear_grad
()
optim
.
clear_grad
()
data
=
np
.
random
.
random
((
6
,
nchannel
,
2
,
2
)).
astype
(
"float32"
)
data
=
np
.
random
.
random
((
6
,
nchannel
,
2
,
2
)).
astype
(
"float32"
)
with
gm
:
loss
=
m
(
data
).
mean
()
gm
.
backward
(
loss
)
optim
.
step
()
np
.
testing
.
assert_equal
(
m
.
running_var
.
numpy
(),
saved_var
)
def
train_fn
(
d
):
np
.
testing
.
assert_equal
(
m
.
running_mean
.
numpy
(),
saved_mean
)
for
_
in
range
(
3
):
np
.
testing
.
assert_equal
(
m
.
weight
.
numpy
(),
saved_wt
)
with
gm
:
np
.
testing
.
assert_equal
(
m
.
bias
.
numpy
(),
saved_bias
)
loss
=
m
(
d
).
mean
()
np
.
testing
.
assert_almost_equal
(
loss
.
numpy
(),
data
.
mean
(),
5
)
gm
.
backward
(
loss
)
optim
.
step
()
return
loss
if
use_trace
:
train_fn
=
trace
(
train_fn
,
symbolic
=
use_symbolic
)
for
_
in
range
(
3
):
loss
=
train_fn
(
megengine
.
Tensor
(
data
))
np
.
testing
.
assert_equal
(
m
.
running_var
.
numpy
(),
saved_var
)
np
.
testing
.
assert_equal
(
m
.
running_mean
.
numpy
(),
saved_mean
)
np
.
testing
.
assert_equal
(
m
.
weight
.
numpy
(),
saved_wt
)
np
.
testing
.
assert_equal
(
m
.
bias
.
numpy
(),
saved_bias
)
np
.
testing
.
assert_almost_equal
(
loss
.
numpy
(),
((
data
-
bias
)
/
np
.
sqrt
(
var
)).
mean
(),
5
)
def
test_frozen_bn
():
run_frozen_bn
(
BatchNorm2d
)
run_frozen_bn
(
BatchNorm2d
,
True
,
False
)
run_frozen_bn
(
BatchNorm2d
,
True
,
True
)
@
pytest
.
mark
.
skipif
(
get_device_count_by_fork
(
"gpu"
)
<
2
,
reason
=
"need more gpu device"
)
@
pytest
.
mark
.
isolated_distributed
def
test_frozen_synced_bn
():
@
dist
.
launcher
(
n_gpus
=
2
)
def
worker
():
run_frozen_bn
(
SyncBatchNorm
)
run_frozen_bn
(
SyncBatchNorm
,
True
,
False
)
run_frozen_bn
(
SyncBatchNorm
,
True
,
True
)
worker
()
def
test_bn_no_track_stat
():
def
test_bn_no_track_stat
():
...
@@ -112,3 +149,11 @@ def test_trace_bn_forward_twice():
...
@@ -112,3 +149,11 @@ def test_trace_bn_forward_twice():
x
=
np
.
ones
((
1
,
1
,
32
,
32
),
dtype
=
np
.
float32
)
x
=
np
.
ones
((
1
,
1
,
32
,
32
),
dtype
=
np
.
float32
)
y
=
train_bn
(
x
,
net
=
Simple
())
y
=
train_bn
(
x
,
net
=
Simple
())
np
.
testing
.
assert_equal
(
y
.
numpy
(),
0
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
0
)
# https://github.com/MegEngine/MegEngine/issues/145
def
test_frozen_bn_no_affine
():
nchannel
=
3
m
=
BatchNorm2d
(
nchannel
,
freeze
=
True
,
affine
=
False
)
data
=
megengine
.
Tensor
(
np
.
random
.
random
((
6
,
nchannel
,
2
,
2
)).
astype
(
"float32"
))
m
(
data
).
numpy
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录