From 113c8b93bea9657c1515677591a9dccf711d9477 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Mon, 27 Dec 2021 19:29:23 +0800 Subject: [PATCH] fix accumulator bug when multiple inplace OPs are executed continuously (#38406) * fix accumulator bug * fix unittest --- paddle/fluid/imperative/basic_engine.cc | 15 ++++++++------- .../paddle/fluid/tests/unittests/test_inplace.py | 13 +++++++++++++ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index ee1c4d1be51..9d377926536 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -174,7 +174,7 @@ void BasicEngine::PrepareGradAccumulators( if (!var) continue; bool find_grad_node_of_var = false; - if (var->HasGradNode()) { + if (grad_pending_nodes.size()) { // Because Inplace op overwrites the grad_node of the input grad_var. So // only the information of grad_pending_node can be used to find the // grad_node of grad_var. @@ -240,7 +240,7 @@ void BasicEngine::PrepareGradAccumulators( } } - if (!var->HasGradNode() || !find_grad_node_of_var) { + if (!grad_pending_nodes.size() || !find_grad_node_of_var) { auto& accumulator = accumulators_[var.get()]; if (!accumulator) { if (FLAGS_sort_sum_gradient) { @@ -438,15 +438,15 @@ void BasicEngine::Execute() { continue; } + const auto& grad_pending_nodes = shared_cur_node->GradPendingNodes(); std::unordered_map>::iterator iter; bool flag_find_grad = false; - if (var->HasGradNode()) { + if (grad_pending_nodes.size()) { VLOG(10) << "Find gradient of var (" << var->Name() << ") with grad_node."; - for (auto& grad_pending_node : - shared_cur_node->GradPendingNodes()) { + for (auto& grad_pending_node : grad_pending_nodes) { const auto& iter_grad_node = accumulators_with_grad_node_.find(grad_pending_node); if (iter_grad_node != accumulators_with_grad_node_.end()) { @@ -458,10 +458,11 @@ void BasicEngine::Execute() { } } if (!flag_find_grad) { - VLOG(6) << "Cannot find gradient of variable " << var->Name(); + VLOG(6) << "Cannot find gradient of variable " << var->Name() + << " in accumulators_with_grad_node_"; } } - if (!var->HasGradNode() || !flag_find_grad) { + if (!grad_pending_nodes.size() || !flag_find_grad) { VLOG(10) << "Find gradient of var (" << var->Name() << ") with no grad_node."; iter = accumulators_.find(var.get()); diff --git a/python/paddle/fluid/tests/unittests/test_inplace.py b/python/paddle/fluid/tests/unittests/test_inplace.py index 98e2d2367fd..316db187535 100644 --- a/python/paddle/fluid/tests/unittests/test_inplace.py +++ b/python/paddle/fluid/tests/unittests/test_inplace.py @@ -434,5 +434,18 @@ class TestLossIsInplaceVar(unittest.TestCase): self.assertTrue(np.array_equal(inplace_grad_var_a, grad_var_a)) +class TestContinuouslyInplace(unittest.TestCase): + def test_continuously_inplace(self): + a = paddle.rand([2, 3]) + a.stop_gradient = False + b = a * 2 + + b.reshape_([-1]) + b.reshape_([2, 3]) + b.reshape_([-1]) + + b.backward() + + if __name__ == '__main__': unittest.main() -- GitLab