From 9f23114793dfb44445fb39df63f1dc92bdff9c53 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 1 Feb 2023 09:22:54 +0800 Subject: [PATCH] [PrimCinn]Fix some vars are wrongly gc in CINN+InterpreterCore (#50116) * [PrimCinn]Fix some vars are wrongly gc in CINN+InterpreterCore * fix baseline unittest config * fix code style --- paddle/fluid/operators/cinn/cinn_launch_context.cc | 8 ++++++-- .../unittests/prim/prim/vjp/static/test_comp_add_grad.py | 2 +- .../prim/prim/vjp/static/test_comp_add_tanh_grad.py | 2 +- .../unittests/prim/prim/vjp/static/test_comp_div_grad.py | 2 +- .../unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py | 2 +- .../unittests/prim/prim/vjp/static/test_comp_sub_grad.py | 2 +- .../unittests/prim/prim/vjp/static/test_comp_tanh_grad.py | 2 +- 7 files changed, 12 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.cc b/paddle/fluid/operators/cinn/cinn_launch_context.cc index af429e0f01..0b999ccab0 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_context.cc @@ -119,12 +119,16 @@ CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph, // collect variables name list to be skipped in GC skip_eager_vars_.reserve(input_var_names.size() + output_var_names.size()); auto add_skip_var_fn = [&outer_varinfo, this](const std::string& var_name) { + // Always consider Input/Output of Graph as skip_gc_vars, because + // InterpreterCore has no eager_deletion_op to deal with it. + + VLOG(4) << "Append a skip_gc_var for InterpreterCore:" << var_name; + skip_gc_vars_.insert(var_name); // if a var exists at the outer_varinfo map, that means it will be // erased by the following eager_deletion_op of current cinn_launch op if (!outer_varinfo.count(var_name)) { skip_eager_vars_.emplace_back(var_name); - skip_gc_vars_.insert(var_name); - VLOG(4) << "Append a skip_gc_var:" << var_name; + VLOG(4) << "Append a skip_gc_var for PE:" << var_name; } }; std::for_each( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py index 1673ff083e..50ef9f6f13 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py @@ -91,7 +91,7 @@ class TestAddGradComp(unittest.TestCase): def test_cinn(self): paddle.disable_static() dy_res = self.train(use_prim=False, use_cinn=False) - comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=True) for i in range(len(dy_res)): np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py index 5dd7417130..b037cc73bf 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py @@ -92,7 +92,7 @@ class TestDivGradComp(unittest.TestCase): def test_cinn(self): paddle.disable_static() dy_res = self.train(use_prim=False, use_cinn=False) - comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=True) for i in range(len(dy_res)): np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py index 95d3c3027f..606b55b5a9 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py @@ -91,7 +91,7 @@ class TestDivGradComp(unittest.TestCase): def test_cinn(self): paddle.disable_static() dy_res = self.train(use_prim=False, use_cinn=False) - comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=True) for i in range(len(dy_res)): np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py index 8df50c768c..8e623100dd 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py @@ -70,7 +70,7 @@ class TestSqrtGradComp(unittest.TestCase): def test_cinn(self): paddle.disable_static() dy_res = self.train(use_prim=False, use_cinn=False) - comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=True) for i in range(len(dy_res)): np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py index 693bf8b942..3245d11876 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py @@ -92,7 +92,7 @@ class TestDivGradComp(unittest.TestCase): def test_cinn(self): paddle.disable_static() dy_res = self.train(use_prim=False, use_cinn=False) - comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=True) for i in range(len(dy_res)): np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py index e643cf620a..d28f84a685 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py @@ -70,7 +70,7 @@ class TestTanhGradComp(unittest.TestCase): def test_cinn(self): paddle.disable_static() dy_res = self.train(use_prim=False, use_cinn=False) - comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=True) for i in range(len(dy_res)): np.testing.assert_allclose( -- GitLab