未验证 提交 218c0129 编写于 作者: Q qipengh 提交者: GitHub

[MLU]fix unittest of sync_bn (#46797)

上级 d8b4ca92
...@@ -47,6 +47,7 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase): ...@@ -47,6 +47,7 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase):
self.global_ring_id = 0 self.global_ring_id = 0
self.dtype = np.float32 self.dtype = np.float32
self.bn_dtype = np.float32
self.N = 8 self.N = 8
self.C = 16 self.C = 16
self.H = 32 self.H = 32
...@@ -77,6 +78,8 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase): ...@@ -77,6 +78,8 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase):
param_attr=fluid.ParamAttr(name='conv2d_weight'), param_attr=fluid.ParamAttr(name='conv2d_weight'),
bias_attr=False, bias_attr=False,
use_cudnn=use_cudnn) use_cudnn=use_cudnn)
if self.bn_dtype == np.float16:
conv = fluid.layers.cast(conv, 'float16')
bn = fluid.layers.batch_norm( bn = fluid.layers.batch_norm(
conv, conv,
param_attr=fluid.ParamAttr(name='bn_scale'), param_attr=fluid.ParamAttr(name='bn_scale'),
...@@ -85,8 +88,8 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase): ...@@ -85,8 +88,8 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase):
moving_variance_name='bn_moving_variance', moving_variance_name='bn_moving_variance',
data_layout=layout, data_layout=layout,
is_test=only_forward) is_test=only_forward)
# if self.dtype == np.float16: if self.bn_dtype == np.float16:
# bn = fluid.layers.cast(bn, 'float32') bn = fluid.layers.cast(bn, 'float32')
sigmoid = fluid.layers.sigmoid(bn) sigmoid = fluid.layers.sigmoid(bn)
out = fluid.layers.reduce_sum(sigmoid) out = fluid.layers.reduce_sum(sigmoid)
# if not sync_bn: # if not sync_bn:
......
...@@ -126,22 +126,22 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -126,22 +126,22 @@ class TestSyncBatchNormRunnerBase(object):
self._compare(args, place, layout, True) self._compare(args, place, layout, True)
# Test FP16 - @TODO # Test FP16 - @TODO
# self.dtype = np.float16 self.bn_dtype = np.float16
# self.atol = 1e-2 self.atol = 3e-3
# # Test training # Test training
# for place in places: for place in places:
# for layout in ["NCHW", "NHWC"]: for layout in ["NCHW", "NHWC"]:
# self._compare(args, place, layout, False) self._compare(args, place, layout, False)
# # Test inference # Test inference
# for place in places: for place in places:
# for layout in ["NCHW", "NHWC"]: for layout in ["NCHW", "NHWC"]:
# self._compare(args, place, layout, True) self._compare(args, place, layout, True)
# sys.stdout.buffer.write( sys.stdout.buffer.write(
# pickle.dumps( pickle.dumps(
# 'training, inference, fp32, fp16, NCHW, NHWC all passed')) 'training, inference, fp32, fp16, NCHW, NHWC all passed'))
def _compare(self, args, place, layout, only_forward): def _compare(self, args, place, layout, only_forward):
scope = core.Scope() scope = core.Scope()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册