未验证 提交 218c0129 编写于 作者: Q qipengh 提交者: GitHub

[MLU]fix unittest of sync_bn (#46797)

上级 d8b4ca92
......@@ -47,6 +47,7 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase):
self.global_ring_id = 0
self.dtype = np.float32
self.bn_dtype = np.float32
self.N = 8
self.C = 16
self.H = 32
......@@ -77,6 +78,8 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase):
param_attr=fluid.ParamAttr(name='conv2d_weight'),
bias_attr=False,
use_cudnn=use_cudnn)
if self.bn_dtype == np.float16:
conv = fluid.layers.cast(conv, 'float16')
bn = fluid.layers.batch_norm(
conv,
param_attr=fluid.ParamAttr(name='bn_scale'),
......@@ -85,8 +88,8 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase):
moving_variance_name='bn_moving_variance',
data_layout=layout,
is_test=only_forward)
# if self.dtype == np.float16:
# bn = fluid.layers.cast(bn, 'float32')
if self.bn_dtype == np.float16:
bn = fluid.layers.cast(bn, 'float32')
sigmoid = fluid.layers.sigmoid(bn)
out = fluid.layers.reduce_sum(sigmoid)
# if not sync_bn:
......
......@@ -126,22 +126,22 @@ class TestSyncBatchNormRunnerBase(object):
self._compare(args, place, layout, True)
# Test FP16 - @TODO
# self.dtype = np.float16
# self.atol = 1e-2
# # Test training
# for place in places:
# for layout in ["NCHW", "NHWC"]:
# self._compare(args, place, layout, False)
# # Test inference
# for place in places:
# for layout in ["NCHW", "NHWC"]:
# self._compare(args, place, layout, True)
# sys.stdout.buffer.write(
# pickle.dumps(
# 'training, inference, fp32, fp16, NCHW, NHWC all passed'))
self.bn_dtype = np.float16
self.atol = 3e-3
# Test training
for place in places:
for layout in ["NCHW", "NHWC"]:
self._compare(args, place, layout, False)
# Test inference
for place in places:
for layout in ["NCHW", "NHWC"]:
self._compare(args, place, layout, True)
sys.stdout.buffer.write(
pickle.dumps(
'training, inference, fp32, fp16, NCHW, NHWC all passed'))
def _compare(self, args, place, layout, only_forward):
scope = core.Scope()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册