Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
218c0129
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
218c0129
编写于
10月 09, 2022
作者:
Q
qipengh
提交者:
GitHub
10月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MLU]fix unittest of sync_bn (#46797)
上级
d8b4ca92
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
21 addition
and
18 deletion
+21
-18
python/paddle/fluid/tests/unittests/mlu/sync_batch_norm_op_mlu.py
...addle/fluid/tests/unittests/mlu/sync_batch_norm_op_mlu.py
+5
-2
python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py
...luid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py
+16
-16
未找到文件。
python/paddle/fluid/tests/unittests/mlu/sync_batch_norm_op_mlu.py
浏览文件 @
218c0129
...
...
@@ -47,6 +47,7 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase):
self
.
global_ring_id
=
0
self
.
dtype
=
np
.
float32
self
.
bn_dtype
=
np
.
float32
self
.
N
=
8
self
.
C
=
16
self
.
H
=
32
...
...
@@ -77,6 +78,8 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase):
param_attr
=
fluid
.
ParamAttr
(
name
=
'conv2d_weight'
),
bias_attr
=
False
,
use_cudnn
=
use_cudnn
)
if
self
.
bn_dtype
==
np
.
float16
:
conv
=
fluid
.
layers
.
cast
(
conv
,
'float16'
)
bn
=
fluid
.
layers
.
batch_norm
(
conv
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'bn_scale'
),
...
...
@@ -85,8 +88,8 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase):
moving_variance_name
=
'bn_moving_variance'
,
data_layout
=
layout
,
is_test
=
only_forward
)
# if self.
dtype == np.float16:
#
bn = fluid.layers.cast(bn, 'float32')
if
self
.
bn_
dtype
==
np
.
float16
:
bn
=
fluid
.
layers
.
cast
(
bn
,
'float32'
)
sigmoid
=
fluid
.
layers
.
sigmoid
(
bn
)
out
=
fluid
.
layers
.
reduce_sum
(
sigmoid
)
# if not sync_bn:
...
...
python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py
浏览文件 @
218c0129
...
...
@@ -126,22 +126,22 @@ class TestSyncBatchNormRunnerBase(object):
self
.
_compare
(
args
,
place
,
layout
,
True
)
# Test FP16 - @TODO
# self.
dtype = np.float16
# self.atol = 1e-2
#
#
Test training
#
for place in places:
#
for layout in ["NCHW", "NHWC"]:
#
self._compare(args, place, layout, False)
#
#
Test inference
#
for place in places:
#
for layout in ["NCHW", "NHWC"]:
#
self._compare(args, place, layout, True)
#
sys.stdout.buffer.write(
#
pickle.dumps(
#
'training, inference, fp32, fp16, NCHW, NHWC all passed'))
self
.
bn_
dtype
=
np
.
float16
self
.
atol
=
3e-3
# Test training
for
place
in
places
:
for
layout
in
[
"NCHW"
,
"NHWC"
]:
self
.
_compare
(
args
,
place
,
layout
,
False
)
# Test inference
for
place
in
places
:
for
layout
in
[
"NCHW"
,
"NHWC"
]:
self
.
_compare
(
args
,
place
,
layout
,
True
)
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
'training, inference, fp32, fp16, NCHW, NHWC all passed'
))
def
_compare
(
self
,
args
,
place
,
layout
,
only_forward
):
scope
=
core
.
Scope
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录