diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.cc b/paddle/fluid/operators/cinn/cinn_launch_context.cc index af429e0f01e336c65b0183a08ef4acfb319006c8..0b999ccab016f24431a2511c4d44ad8d055f5e02 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_context.cc @@ -119,12 +119,16 @@ CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph, // collect variables name list to be skipped in GC skip_eager_vars_.reserve(input_var_names.size() + output_var_names.size()); auto add_skip_var_fn = [&outer_varinfo, this](const std::string& var_name) { + // Always consider Input/Output of Graph as skip_gc_vars, because + // InterpreterCore has no eager_deletion_op to deal with it. + + VLOG(4) << "Append a skip_gc_var for InterpreterCore:" << var_name; + skip_gc_vars_.insert(var_name); // if a var exists at the outer_varinfo map, that means it will be // erased by the following eager_deletion_op of current cinn_launch op if (!outer_varinfo.count(var_name)) { skip_eager_vars_.emplace_back(var_name); - skip_gc_vars_.insert(var_name); - VLOG(4) << "Append a skip_gc_var:" << var_name; + VLOG(4) << "Append a skip_gc_var for PE:" << var_name; } }; std::for_each( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py index 1673ff083e7cf4081b300ab7de6e585e7b7d1c21..50ef9f6f13036ac114b2c9e6a7e4be7bb82b24e6 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py @@ -91,7 +91,7 @@ class TestAddGradComp(unittest.TestCase): def test_cinn(self): paddle.disable_static() dy_res = self.train(use_prim=False, use_cinn=False) - comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=True) for i in range(len(dy_res)): np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py index 5dd7417130bc1137b751ea420384d63350c216b0..b037cc73bfd545c9e21862d183d9f2759333d0b5 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py @@ -92,7 +92,7 @@ class TestDivGradComp(unittest.TestCase): def test_cinn(self): paddle.disable_static() dy_res = self.train(use_prim=False, use_cinn=False) - comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=True) for i in range(len(dy_res)): np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py index 95d3c3027fd9d28e4b054806959c8ad8ec391e9a..606b55b5a95c06fc097c03e0fe8964e38faadce6 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py @@ -91,7 +91,7 @@ class TestDivGradComp(unittest.TestCase): def test_cinn(self): paddle.disable_static() dy_res = self.train(use_prim=False, use_cinn=False) - comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=True) for i in range(len(dy_res)): np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py index 8df50c768c2b72e11d5de955b3b88e65183c0aad..8e623100dd09cb86b7aae0562035536944598b9f 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sqrt_grad.py @@ -70,7 +70,7 @@ class TestSqrtGradComp(unittest.TestCase): def test_cinn(self): paddle.disable_static() dy_res = self.train(use_prim=False, use_cinn=False) - comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=True) for i in range(len(dy_res)): np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py index 693bf8b942bab23e9af6b10c5456b8e76936d38b..3245d118760b2ba3596af964428bafc0620badcf 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py @@ -92,7 +92,7 @@ class TestDivGradComp(unittest.TestCase): def test_cinn(self): paddle.disable_static() dy_res = self.train(use_prim=False, use_cinn=False) - comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=True) for i in range(len(dy_res)): np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py index e643cf620a8118fb1c36202ada5543b60b3f0012..d28f84a685b0d0d83d085f1f76edfe003b6bd2fb 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py @@ -70,7 +70,7 @@ class TestTanhGradComp(unittest.TestCase): def test_cinn(self): paddle.disable_static() dy_res = self.train(use_prim=False, use_cinn=False) - comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + comp_st_cinn_res = self.train(use_prim=True, use_cinn=True) for i in range(len(dy_res)): np.testing.assert_allclose(