diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cycle_gan.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cycle_gan.py index 3578dc58f1856b966aabf525163f622889d1e2f0..d6840ed62810e13648b869340a41691aa0e89101 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cycle_gan.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cycle_gan.py @@ -32,6 +32,10 @@ import unittest import numpy as np from PIL import Image, ImageOps +import os +# Use GPU:0 to elimate the influence of other tasks. +os.environ["CUDA_VISIBLE_DEVICES"] = "1" + import paddle import paddle.fluid as fluid from paddle.fluid.dygraph import to_variable, declarative, ProgramTranslator @@ -327,6 +331,11 @@ class conv2d(fluid.dygraph.Layer): initializer=fluid.initializer.NormalInitializer( loc=0.0, scale=stddev)), bias_attr=con_bias_attr) + # Note(Aurelius84): The calculation of GPU kernel in BN is non-deterministic, + # failure rate is 1/100 in Dev but seems incremental in CE platform. + # If on GPU, we disable BN temporarily. + if fluid.is_compiled_with_cuda(): + norm = False if norm: self.bn = BatchNorm( use_global_stats=True, # set True to use deterministic algorithm @@ -383,6 +392,8 @@ class DeConv2D(fluid.dygraph.Layer): initializer=fluid.initializer.NormalInitializer( loc=0.0, scale=stddev)), bias_attr=de_bias_attr) + if fluid.is_compiled_with_cuda(): + norm = False if norm: self.bn = BatchNorm( use_global_stats=True, # set True to use deterministic algorithm @@ -606,8 +617,16 @@ class TestCycleGANModel(unittest.TestCase): def test_train(self): st_out = self.train(to_static=True) dy_out = self.train(to_static=False) + + assert_func = np.allclose + # Note(Aurelius84): Because we disable BN on GPU, + # but here we enhance the check on CPU by `np.array_equal` + # which means the dy_out and st_out shall be exactly same. + if not fluid.is_compiled_with_cuda(): + assert_func = np.array_equal + self.assertTrue( - np.allclose(dy_out, st_out), + assert_func(dy_out, st_out), msg="dy_out:\n {}\n st_out:\n{}".format(dy_out, st_out))