From 218c01299c646f6cc23d002f234efb2a8afe1e50 Mon Sep 17 00:00:00 2001 From: qipengh Date: Sun, 9 Oct 2022 16:29:13 +0800 Subject: [PATCH] [MLU]fix unittest of sync_bn (#46797) --- .../unittests/mlu/sync_batch_norm_op_mlu.py | 7 ++-- .../mlu/test_sync_batch_norm_base_mlu.py | 32 +++++++++---------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/mlu/sync_batch_norm_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/sync_batch_norm_op_mlu.py index 29c0a961e00..fbec31a16fa 100644 --- a/python/paddle/fluid/tests/unittests/mlu/sync_batch_norm_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/sync_batch_norm_op_mlu.py @@ -47,6 +47,7 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase): self.global_ring_id = 0 self.dtype = np.float32 + self.bn_dtype = np.float32 self.N = 8 self.C = 16 self.H = 32 @@ -77,6 +78,8 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase): param_attr=fluid.ParamAttr(name='conv2d_weight'), bias_attr=False, use_cudnn=use_cudnn) + if self.bn_dtype == np.float16: + conv = fluid.layers.cast(conv, 'float16') bn = fluid.layers.batch_norm( conv, param_attr=fluid.ParamAttr(name='bn_scale'), @@ -85,8 +88,8 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase): moving_variance_name='bn_moving_variance', data_layout=layout, is_test=only_forward) - # if self.dtype == np.float16: - # bn = fluid.layers.cast(bn, 'float32') + if self.bn_dtype == np.float16: + bn = fluid.layers.cast(bn, 'float32') sigmoid = fluid.layers.sigmoid(bn) out = fluid.layers.reduce_sum(sigmoid) # if not sync_bn: diff --git a/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py index 2b66996cebe..720ee9f47d8 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py @@ -126,22 +126,22 @@ class TestSyncBatchNormRunnerBase(object): self._compare(args, place, layout, True) # Test FP16 - @TODO - # self.dtype = np.float16 - # self.atol = 1e-2 - - # # Test training - # for place in places: - # for layout in ["NCHW", "NHWC"]: - # self._compare(args, place, layout, False) - - # # Test inference - # for place in places: - # for layout in ["NCHW", "NHWC"]: - # self._compare(args, place, layout, True) - - # sys.stdout.buffer.write( - # pickle.dumps( - # 'training, inference, fp32, fp16, NCHW, NHWC all passed')) + self.bn_dtype = np.float16 + self.atol = 3e-3 + + # Test training + for place in places: + for layout in ["NCHW", "NHWC"]: + self._compare(args, place, layout, False) + + # Test inference + for place in places: + for layout in ["NCHW", "NHWC"]: + self._compare(args, place, layout, True) + + sys.stdout.buffer.write( + pickle.dumps( + 'training, inference, fp32, fp16, NCHW, NHWC all passed')) def _compare(self, args, place, layout, only_forward): scope = core.Scope() -- GitLab