Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e6caa9ff
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e6caa9ff
编写于
6月 01, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(opr): add bn backward for inference mode
GitOrigin-RevId: bb643cb62fbba90ca8846a3550f88bf6763ddd58
上级
c90fa087
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
57 addition
and
35 deletion
+57
-35
imperative/python/megengine/module/batchnorm.py
imperative/python/megengine/module/batchnorm.py
+0
-10
imperative/python/test/integration/test_bn.py
imperative/python/test/integration/test_bn.py
+29
-17
src/opr/impl/dnn/batch_norm.cpp
src/opr/impl/dnn/batch_norm.cpp
+28
-8
未找到文件。
imperative/python/megengine/module/batchnorm.py
浏览文件 @
e6caa9ff
...
...
@@ -100,16 +100,6 @@ class _BatchNorm(Module):
if
_bias
is
not
None
:
_bias
=
_bias
.
detach
()
# Need to expand to elementwise operations here
# see MGB_IMPL_OPR_GRAD(BatchNormForward) in src/opr/impl/dnn/batch_norm.cpp
scale
=
(
self
.
running_var
+
self
.
eps
)
**
(
-
0.5
)
if
_weight
is
not
None
:
scale
*=
_weight
bias
=
-
self
.
running_mean
*
scale
if
_bias
is
not
None
:
bias
+=
_bias
return
inp
*
scale
+
bias
if
self
.
training
and
self
.
track_running_stats
:
exponential_average_factor
=
self
.
momentum
else
:
...
...
imperative/python/test/integration/test_bn.py
浏览文件 @
e6caa9ff
...
...
@@ -19,9 +19,13 @@ from megengine.jit import trace
from
megengine.module
import
BatchNorm2d
,
Conv2d
,
Module
,
Sequential
,
SyncBatchNorm
def
run_frozen_bn
(
BNModule
,
use_trace
=
False
,
use_symbolic
=
False
):
def
run_frozen_bn
(
BNModule
,
is_training
,
use_trace
,
use_symbolic
):
nchannel
=
3
m
=
BNModule
(
nchannel
,
freeze
=
True
)
if
is_training
:
m
.
train
()
else
:
m
.
eval
()
var
=
4.0
bias
=
1.0
shape
=
(
1
,
nchannel
,
1
,
1
)
...
...
@@ -51,30 +55,33 @@ def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False):
train_fn
=
trace
(
train_fn
,
symbolic
=
use_symbolic
)
for
_
in
range
(
3
):
loss
=
train_fn
(
megengine
.
Tensor
(
data
))
np
.
testing
.
assert_equal
(
m
.
running_var
.
numpy
(),
saved_var
)
np
.
testing
.
assert_equal
(
m
.
running_mean
.
numpy
(),
saved_mean
)
loss
=
train_fn
(
megengine
.
tensor
(
data
))
if
not
is_training
:
np
.
testing
.
assert_equal
(
m
.
running_var
.
numpy
(),
saved_var
)
np
.
testing
.
assert_equal
(
m
.
running_mean
.
numpy
(),
saved_mean
)
np
.
testing
.
assert_almost_equal
(
loss
.
numpy
(),
((
data
-
bias
)
/
np
.
sqrt
(
var
)).
mean
(),
5
)
np
.
testing
.
assert_equal
(
m
.
weight
.
numpy
(),
saved_wt
)
np
.
testing
.
assert_equal
(
m
.
bias
.
numpy
(),
saved_bias
)
np
.
testing
.
assert_almost_equal
(
loss
.
numpy
(),
((
data
-
bias
)
/
np
.
sqrt
(
var
)).
mean
(),
5
)
def
test_frozen_bn
():
run_frozen_bn
(
BatchNorm2d
)
run_frozen_bn
(
BatchNorm2d
,
True
,
False
)
run_frozen_bn
(
BatchNorm2d
,
True
,
True
)
@
pytest
.
mark
.
parametrize
(
"is_training"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_trace"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_symbolic"
,
[
False
,
True
])
def
test_frozen_bn
(
is_training
,
use_trace
,
use_symbolic
):
run_frozen_bn
(
BatchNorm2d
,
is_training
,
use_trace
,
use_symbolic
)
@
pytest
.
mark
.
require_ngpu
(
2
)
@
pytest
.
mark
.
isolated_distributed
def
test_frozen_synced_bn
():
@
pytest
.
mark
.
parametrize
(
"is_training"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_trace"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"use_symbolic"
,
[
False
,
True
])
def
test_frozen_synced_bn
(
is_training
,
use_trace
,
use_symbolic
):
@
dist
.
launcher
(
n_gpus
=
2
)
def
worker
():
run_frozen_bn
(
SyncBatchNorm
)
run_frozen_bn
(
SyncBatchNorm
,
True
,
False
)
run_frozen_bn
(
SyncBatchNorm
,
True
,
True
)
run_frozen_bn
(
SyncBatchNorm
,
is_training
,
use_trace
,
use_symbolic
)
worker
()
...
...
@@ -190,8 +197,13 @@ def test_trace_several_syncbn(trace_mode):
# https://github.com/MegEngine/MegEngine/issues/145
def
test_frozen_bn_no_affine
():
@
pytest
.
mark
.
parametrize
(
"is_training"
,
[
False
,
True
])
def
test_frozen_bn_no_affine
(
is_training
):
nchannel
=
3
m
=
BatchNorm2d
(
nchannel
,
freeze
=
True
,
affine
=
False
)
data
=
tensor
(
np
.
random
.
random
((
6
,
nchannel
,
2
,
2
)).
astype
(
"float32"
))
if
is_training
:
m
.
train
()
else
:
m
.
eval
()
data
=
megengine
.
tensor
(
np
.
random
.
random
((
6
,
nchannel
,
2
,
2
)).
astype
(
"float32"
))
m
(
data
).
numpy
()
src/opr/impl/dnn/batch_norm.cpp
浏览文件 @
e6caa9ff
...
...
@@ -12,6 +12,8 @@
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/io.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/tensor_manip.h"
#include "../internal/megdnn_opr_wrapper.inl"
...
...
@@ -243,16 +245,34 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() {
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
BatchNormForward
)
{
mgb_assert
(
opr
.
param
().
fwd_mode
==
BatchNorm
::
Param
::
FwdMode
::
TRAINING
,
"batch norm could only take grad in training mode"
);
mgb_assert
(
wrt_idx
<
5
,
"wrt_idx %zu is out of range"
,
wrt_idx
);
VarNodeArray
ret
(
opr
.
input
().
size
(),
nullptr
);
SymbolVarArray
grad
=
BatchNormBackward
::
make
(
opr
.
input
(
0
),
out_grad
[
4
],
opr
.
output
(
2
),
opr
.
output
(
3
),
opr
.
input
(
1
),
opr
.
param
());
for
(
size_t
i
=
0
;
i
<
3
;
++
i
)
{
ret
[
i
]
=
grad
[(
i
+
2
)
%
3
].
node
();
SymbolVarArray
grad
;
switch
(
opr
.
param
().
fwd_mode
)
{
case
BatchNorm
::
Param
::
FwdMode
::
TRAINING
:
grad
=
BatchNormBackward
::
make
(
opr
.
input
(
0
),
out_grad
[
4
],
opr
.
output
(
2
),
opr
.
output
(
3
),
opr
.
input
(
1
),
opr
.
param
());
for
(
size_t
i
=
0
;
i
<
3
;
++
i
)
{
ret
[
i
]
=
grad
[(
i
+
2
)
%
3
].
node
();
}
return
ret
;
case
BatchNorm
::
Param
::
FwdMode
::
INFERENCE
:
auto
sqrt_var
=
PowC
::
make
((
SymbolVar
{
opr
.
input
(
4
)}
+
static_cast
<
dt_float32
>
(
opr
.
param
().
epsilon
)),
0.5
,
opr
.
config
());
auto
d_bn_scale_unreduced
=
SymbolVar
{
out_grad
[
4
]}
*
(
SymbolVar
{
opr
.
input
(
0
)}
-
SymbolVar
{
opr
.
input
(
3
)})
/
sqrt_var
;
auto
d_bn_scale
=
Reduce
::
make
(
d_bn_scale_unreduced
,
Reduce
::
Param
::
Mode
::
SUM
,
GetVarShape
::
make
(
opr
.
input
(
1
)));
auto
d_bn_bias
=
Reduce
::
make
(
out_grad
[
4
],
Reduce
::
Param
::
Mode
::
SUM
,
GetVarShape
::
make
(
opr
.
input
(
2
)));
auto
dx
=
SymbolVar
{
out_grad
[
4
]}
*
SymbolVar
{
opr
.
input
(
1
)}
/
sqrt_var
;
ret
[
0
]
=
dx
.
node
();
ret
[
1
]
=
d_bn_scale
.
node
();
ret
[
2
]
=
d_bn_bias
.
node
();
return
ret
;
}
return
ret
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录