diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index 04f36c2e74cf880ef9c7134767e9b127211e5894..ef011d2ddc83fa9b5a34abc507065c9826d18d85 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -229,6 +229,10 @@ class TestDistBase(unittest.TestCase): required_envs["GLOG_logtostderr"] = "1" required_envs["GLOO_LOG_LEVEL"] = "TRACE" + if os.getenv('NVIDIA_TF32_OVERRIDE', '') is not None: + required_envs['NVIDIA_TF32_OVERRIDE'] = os.getenv( + 'NVIDIA_TF32_OVERRIDE', '') + if eager_mode: required_envs["FLAGS_enable_eager_mode"] = "%d" % 1 else: diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 987af61c489de299ef71aca7896f5cb405c47fce..70b1d0568a011b6c2df8b0c6fb3cf38792ac9f48 100755 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -1468,6 +1468,10 @@ class TestDistBase(unittest.TestCase): "grpc_server=10,request_handler_impl=10,section_worker=10" required_envs["GLOG_logtostderr"] = "1" + if os.getenv('NVIDIA_TF32_OVERRIDE', '') is not None: + required_envs['NVIDIA_TF32_OVERRIDE'] = os.getenv( + 'NVIDIA_TF32_OVERRIDE', '') + required_envs.update(need_envs) return required_envs