Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
17c6d399
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
17c6d399
编写于
6月 01, 2021
作者:
C
ceci3
提交者:
GitHub
6月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix syncbn (#32989)
* fix syncbn
上级
b751a805
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
82 addition
and
5 deletion
+82
-5
python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py
...n/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py
+66
-1
python/paddle/nn/layer/norm.py
python/paddle/nn/layer/norm.py
+16
-4
未找到文件。
python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py
浏览文件 @
17c6d399
...
...
@@ -248,7 +248,7 @@ class TestConvertSyncBatchNorm(unittest.TestCase):
isinstance
(
model
[
idx
],
paddle
.
nn
.
SyncBatchNorm
),
True
)
class
TestConvertSyncBatchNormCas
e2
(
unittest
.
TestCase
):
class
TestConvertSyncBatchNormCas
t1
(
unittest
.
TestCase
):
def
test_convert
(
self
):
if
not
core
.
is_compiled_with_cuda
():
return
...
...
@@ -277,5 +277,70 @@ class TestConvertSyncBatchNormCase2(unittest.TestCase):
self
.
assertEqual
(
len
(
compare_model
.
sublayers
()),
len
(
model
.
sublayers
()))
class
TestConvertSyncBatchNormCase2
(
unittest
.
TestCase
):
def
test_convert
(
self
):
if
not
core
.
is_compiled_with_cuda
():
return
with
fluid
.
dygraph
.
guard
(
fluid
.
CUDAPlace
(
0
)):
class
SyBNNet
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
in_ch
=
3
,
out_ch
=
3
,
dirate
=
1
):
super
(
SyBNNet
,
self
).
__init__
()
self
.
bn_s1
=
paddle
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
paddle
.
nn
.
BatchNorm3D
(
out_ch
,
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
paddle
.
regularizer
.
L2Decay
(
0.
))))
self
.
bn_s2
=
paddle
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
paddle
.
nn
.
BatchNorm3D
(
out_ch
,
data_format
=
'NDHWC'
))
def
forward
(
self
,
x
):
x
=
self
.
bn_s1
(
x
)
out
=
paddle
.
sum
(
paddle
.
abs
(
self
.
bn_s2
(
x
)))
return
out
class
BNNet
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
in_ch
=
3
,
out_ch
=
3
,
dirate
=
1
):
super
(
BNNet
,
self
).
__init__
()
self
.
bn_s1
=
paddle
.
nn
.
BatchNorm3D
(
out_ch
,
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
paddle
.
regularizer
.
L2Decay
(
0.
)))
self
.
bn_s2
=
paddle
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
paddle
.
nn
.
BatchNorm3D
(
out_ch
,
data_format
=
'NDHWC'
))
def
forward
(
self
,
x
):
x
=
self
.
bn_s1
(
x
)
out
=
paddle
.
sum
(
paddle
.
abs
(
self
.
bn_s2
(
x
)))
return
out
bn_model
=
BNNet
()
sybn_model
=
SyBNNet
()
np
.
random
.
seed
(
10
)
data
=
np
.
random
.
random
([
3
,
3
,
3
,
3
,
3
]).
astype
(
'float32'
)
x
=
paddle
.
to_tensor
(
data
)
bn_out
=
bn_model
(
x
)
sybn_out
=
sybn_model
(
x
)
self
.
assertTrue
(
np
.
allclose
(
bn_out
.
numpy
(),
sybn_out
.
numpy
()),
"Output has diff.
\n
"
+
"
\n
BN "
+
str
(
bn_out
.
numpy
())
+
"
\n
"
+
"Sync BN "
+
str
(
sybn_out
.
numpy
()))
class
TestDygraphSyncBatchNormDataFormatError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
if
not
core
.
is_compiled_with_cuda
():
return
with
fluid
.
dygraph
.
guard
(
fluid
.
CUDAPlace
(
0
)):
my_sync_batch_norm
=
paddle
.
nn
.
SyncBatchNorm
(
10
,
data_format
=
'CN'
)
data
=
np
.
random
.
random
([
3
,
3
,
3
]).
astype
(
'float32'
)
x
=
paddle
.
to_tensor
(
data
)
self
.
assertRaises
(
ValueError
,
my_sync_batch_norm
,
x
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/nn/layer/norm.py
浏览文件 @
17c6d399
...
...
@@ -1057,7 +1057,18 @@ class SyncBatchNorm(_BatchNormBase):
self
).
__init__
(
num_features
,
momentum
,
epsilon
,
weight_attr
,
bias_attr
,
data_format
,
None
,
name
)
def
_check_data_format
(
self
):
if
self
.
_data_format
in
[
'NCHW'
,
'NCDHW'
,
'NC'
,
'NCL'
]:
self
.
_data_format
=
'NCHW'
elif
self
.
_data_format
in
[
"NHWC"
,
"NDHWC"
,
'NLC'
]:
self
.
_data_format
=
'NHWC'
else
:
raise
ValueError
(
'expected
\'
NCDHW
\'
,
\'
NDHWC
\'
,
\'
NCL
\'
,
\'
NLC
\'
,
\'
NC
\'
,
\'
NCHW
\'
,
\'
NHWC
\'
for data_format'
)
def
forward
(
self
,
x
):
self
.
_check_data_format
()
# create output
# mean and mean_out share the same memory
mean_out
=
self
.
_mean
...
...
@@ -1142,11 +1153,12 @@ class SyncBatchNorm(_BatchNormBase):
"""
layer_output
=
layer
if
isinstance
(
layer
,
_BatchNormBase
):
if
layer
.
_weight_attr
!=
None
and
not
isinstance
(
layer
.
_weight_attr
,
bool
):
if
layer
.
_weight_attr
!=
None
and
not
isinstance
(
layer
.
_weight_attr
,
bool
)
and
layer
.
_weight_attr
.
name
!=
None
:
layer
.
_weight_attr
.
name
=
layer
.
_weight_attr
.
name
+
'_sync'
if
layer
.
_bias_attr
!=
None
and
not
isinstance
(
layer
.
_weight_attr
,
bool
)
:
if
layer
.
_bias_attr
!=
None
and
not
isinstance
(
layer
.
_bias_attr
,
bool
)
and
layer
.
_bias_attr
.
name
!=
None
:
layer
.
_bias_attr
.
name
=
layer
.
_bias_attr
.
name
+
'_sync'
layer_output
=
SyncBatchNorm
(
layer
.
_num_features
,
layer
.
_momentum
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录