diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index 62d449ccd2ea8c873629a6dade5fce2fac167aed..d5402699553c7fa8040cb5aa351505900d00a4a5 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -53,6 +53,10 @@ void BasicEngine::Init( platform::errors::AlreadyExists( "Accumulators are not empty before preparing it for " "backward network execution.")); + PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true, + platform::errors::AlreadyExists( + "Accumulators with grad_node as the key are not empty " + "before preparing it for backward network execution.")); for (size_t i = 0; i < tensors.size(); ++i) { auto var = tensors[i]; @@ -73,7 +77,6 @@ void BasicEngine::Init( VLOG(5) << "Clear the auto-grad graph from grad var " << var->Name() << " because of retain_graph=False when calling backward"; var->GradVarBase()->SetGraphIsFreed(true); - var->GradVarBase()->ClearGradNode(); } if (init_node == nullptr || var->OverridedStopGradient()) { @@ -108,7 +111,9 @@ void BasicEngine::Init( } VariableWrapper* init_grad_var = var->GradVarBase()->SharedVar().get(); - auto& accumulator = accumulators_[init_grad_var]; + auto& accumulator = + accumulators_with_grad_node_[init_grad_var->GetGradNode()] + [init_grad_var]; if (!accumulator) { if (FLAGS_sort_sum_gradient) { accumulator.reset(new SortedGradientAccumulator(init_grad_var)); @@ -116,6 +121,8 @@ void BasicEngine::Init( accumulator.reset(new EagerGradientAccumulator(init_grad_var)); } } + accumulator->IncreaseRefCnt(); + accumulator->IncreaseCurCnt(); init_nodes_.push_back(init_node); } @@ -253,10 +260,6 @@ void BasicEngine::PrepareDeps() { node_deps_.empty(), true, platform::errors::AlreadyExists("Op deps are not empty before preparing " "it for backward network execution.")); - PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true, - platform::errors::AlreadyExists( - "Accumulators with grad_node as the key are not empty " - "before preparing it for backward network execution.")); std::queue q; std::unordered_set visited; diff --git a/python/paddle/fluid/tests/unittests/test_inplace.py b/python/paddle/fluid/tests/unittests/test_inplace.py index 3d158763527e715885d99cb8cdb15920aecf2ce4..98e2d2367fd5edcd7207bb1bd279c0777cce5688 100644 --- a/python/paddle/fluid/tests/unittests/test_inplace.py +++ b/python/paddle/fluid/tests/unittests/test_inplace.py @@ -409,5 +409,30 @@ class TestDygraphInplaceSubtract(TestDygraphInplaceAdd): return var.subtract_(self.input_var_2) +class TestLossIsInplaceVar(unittest.TestCase): + def test_loss_is_inplace_var(self): + with paddle.fluid.dygraph.guard(): + var_a = paddle.ones((2, 2)) + var_a.stop_gradient = False + + var_b = var_a * 2 + loss = var_b.tanh_() + + loss.backward() + inplace_grad_var_a = var_a.grad.numpy() + + with paddle.fluid.dygraph.guard(): + var_a = paddle.ones((2, 2)) + var_a.stop_gradient = False + + var_b = var_a * 2 + loss = var_b.tanh() + + loss.backward() + grad_var_a = var_a.grad.numpy() + + self.assertTrue(np.array_equal(inplace_grad_var_a, grad_var_a)) + + if __name__ == '__main__': unittest.main()