From a7b13d385a8b46c7e2000ecf5e5eb31c13e8bf4d Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Tue, 4 Jan 2022 10:54:38 +0800 Subject: [PATCH] Support test_imperative container_sequential and signal_handler with eager_guard (#38614) --- .../test_imperative_container_sequential.py | 15 +++++++-- .../test_imperative_signal_handler.py | 31 ++++++++++++++++--- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_container_sequential.py b/python/paddle/fluid/tests/unittests/test_imperative_container_sequential.py index 972f1b64e1..dcf4e8de5e 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_container_sequential.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_container_sequential.py @@ -17,10 +17,11 @@ from __future__ import print_function import unittest import paddle.fluid as fluid import numpy as np +from paddle.fluid.framework import _test_eager_guard class TestImperativeContainerSequential(unittest.TestCase): - def test_sequential(self): + def func_sequential(self): data = np.random.uniform(-1, 1, [5, 10]).astype('float32') with fluid.dygraph.guard(): data = fluid.dygraph.to_variable(data) @@ -55,7 +56,12 @@ class TestImperativeContainerSequential(unittest.TestCase): loss2 = fluid.layers.reduce_mean(res2) loss2.backward() - def test_sequential_list_params(self): + def test_sequential(self): + with _test_eager_guard(): + self.func_sequential() + self.func_sequential() + + def func_sequential_list_params(self): data = np.random.uniform(-1, 1, [5, 10]).astype('float32') with fluid.dygraph.guard(): data = fluid.dygraph.to_variable(data) @@ -90,6 +96,11 @@ class TestImperativeContainerSequential(unittest.TestCase): loss2 = fluid.layers.reduce_mean(res2) loss2.backward() + def test_sequential_list_params(self): + with _test_eager_guard(): + self.func_sequential_list_params() + self.func_sequential_list_params() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_signal_handler.py b/python/paddle/fluid/tests/unittests/test_imperative_signal_handler.py index b388efc5f3..8aadb155b0 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_signal_handler.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_signal_handler.py @@ -21,6 +21,7 @@ import time import paddle.compat as cpt from paddle.fluid import core +from paddle.fluid.framework import _test_eager_guard def set_child_signal_handler(self, child_pid): @@ -37,8 +38,8 @@ def set_child_signal_handler(self, child_pid): signal.signal(signal.SIGCHLD, __handler__) -class TestDygraphDataLoaderSingalHandler(unittest.TestCase): - def test_child_process_exit_with_error(self): +class DygraphDataLoaderSingalHandler(unittest.TestCase): + def func_child_process_exit_with_error(self): def __test_process__(): core._set_process_signal_handler() sys.exit(1) @@ -65,7 +66,12 @@ class TestDygraphDataLoaderSingalHandler(unittest.TestCase): self.assertIsNotNone(exception) - def test_child_process_killed_by_sigsegv(self): + def test_child_process_exit_with_error(self): + with _test_eager_guard(): + self.func_child_process_exit_with_error() + self.func_child_process_exit_with_error() + + def func_child_process_killed_by_sigsegv(self): def __test_process__(): core._set_process_signal_handler() os.kill(os.getpid(), signal.SIGSEGV) @@ -93,7 +99,12 @@ class TestDygraphDataLoaderSingalHandler(unittest.TestCase): self.assertIsNotNone(exception) - def test_child_process_killed_by_sigbus(self): + def test_child_process_killed_by_sigsegv(self): + with _test_eager_guard(): + self.func_child_process_killed_by_sigsegv() + self.func_child_process_killed_by_sigsegv() + + def func_child_process_killed_by_sigbus(self): def __test_process__(): core._set_process_signal_handler() os.kill(os.getpid(), signal.SIGBUS) @@ -120,7 +131,12 @@ class TestDygraphDataLoaderSingalHandler(unittest.TestCase): self.assertIsNotNone(exception) - def test_child_process_killed_by_sigterm(self): + def test_child_process_killed_by_sigbus(self): + with _test_eager_guard(): + self.func_child_process_killed_by_sigbus() + self.func_child_process_killed_by_sigbus() + + def func_child_process_killed_by_sigterm(self): def __test_process__(): core._set_process_signal_handler() time.sleep(10) @@ -132,6 +148,11 @@ class TestDygraphDataLoaderSingalHandler(unittest.TestCase): set_child_signal_handler(id(self), test_process.pid) time.sleep(1) + def test_child_process_killed_by_sigterm(self): + with _test_eager_guard(): + self.func_child_process_killed_by_sigterm() + self.func_child_process_killed_by_sigterm() + if __name__ == '__main__': unittest.main() -- GitLab