未验证 提交 9f231147 编写于 作者: A Aurelius84 提交者: GitHub

[PrimCinn]Fix some vars are wrongly gc in CINN+InterpreterCore (#50116)

* [PrimCinn]Fix some vars are wrongly gc in CINN+InterpreterCore

* fix baseline unittest config

* fix code style
上级 057ba778
......@@ -119,12 +119,16 @@ CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph,
// collect variables name list to be skipped in GC
skip_eager_vars_.reserve(input_var_names.size() + output_var_names.size());
auto add_skip_var_fn = [&outer_varinfo, this](const std::string& var_name) {
// Always consider Input/Output of Graph as skip_gc_vars, because
// InterpreterCore has no eager_deletion_op to deal with it.
VLOG(4) << "Append a skip_gc_var for InterpreterCore:" << var_name;
skip_gc_vars_.insert(var_name);
// if a var exists at the outer_varinfo map, that means it will be
// erased by the following eager_deletion_op of current cinn_launch op
if (!outer_varinfo.count(var_name)) {
skip_eager_vars_.emplace_back(var_name);
skip_gc_vars_.insert(var_name);
VLOG(4) << "Append a skip_gc_var:" << var_name;
VLOG(4) << "Append a skip_gc_var for PE:" << var_name;
}
};
std::for_each(
......
......@@ -91,7 +91,7 @@ class TestAddGradComp(unittest.TestCase):
def test_cinn(self):
paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=True)
for i in range(len(dy_res)):
np.testing.assert_allclose(
......
......@@ -92,7 +92,7 @@ class TestDivGradComp(unittest.TestCase):
def test_cinn(self):
paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=True)
for i in range(len(dy_res)):
np.testing.assert_allclose(
......
......@@ -91,7 +91,7 @@ class TestDivGradComp(unittest.TestCase):
def test_cinn(self):
paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=True)
for i in range(len(dy_res)):
np.testing.assert_allclose(
......
......@@ -70,7 +70,7 @@ class TestSqrtGradComp(unittest.TestCase):
def test_cinn(self):
paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=True)
for i in range(len(dy_res)):
np.testing.assert_allclose(
......
......@@ -92,7 +92,7 @@ class TestDivGradComp(unittest.TestCase):
def test_cinn(self):
paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=True)
for i in range(len(dy_res)):
np.testing.assert_allclose(
......
......@@ -70,7 +70,7 @@ class TestTanhGradComp(unittest.TestCase):
def test_cinn(self):
paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=True)
for i in range(len(dy_res)):
np.testing.assert_allclose(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册