diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index d5402699553c7fa8040cb5aa351505900d00a4a5..3437f782a548668644025843663b76fd4064565d 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -173,27 +173,11 @@ void BasicEngine::PrepareGradAccumulators( for (const auto& var : pair.second) { if (!var) continue; - if (!var->HasGradNode()) { - auto& accumulator = accumulators_[var.get()]; - if (!accumulator) { - if (FLAGS_sort_sum_gradient) { - accumulator.reset(new SortedGradientAccumulator(var.get())); - } else { - accumulator.reset(new EagerGradientAccumulator(var.get())); - } - } - - accumulator->IncreaseRefCnt(); - - VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "(" - << var.get() - << ") that don't have grad node with reference count " - << accumulator->RefCnt(); - } else { + bool find_grad_node_of_var = false; + if (var->HasGradNode()) { // Because Inplace op overwrites the grad_node of the input grad_var. So // only the information of grad_pending_node can be used to find the // grad_node of grad_var. - bool find_grad_node_of_var = false; for (auto& grad_pending_node : grad_pending_nodes) { PADDLE_ENFORCE_NOT_NULL( grad_pending_node, @@ -245,11 +229,33 @@ void BasicEngine::PrepareGradAccumulators( break; } } - PADDLE_ENFORCE_EQ( - find_grad_node_of_var, true, - platform::errors::NotFound( - "No grad node corresponding to grad Tensor (%s) was found.", - var->Name())); + if (!find_grad_node_of_var) { + // Special case: `set_value` is inplace op, and it can change + // the var with `stop_gradient=True` to the var with + // `stop_gradient=False `. + // This inplace var has grad_node (the inplace op), but it + // isn't the input of grad_pending_op. + VLOG(6) << "No grad node corresponding to grad Tensor (" + << var->Name() << ") was found."; + } + } + + if (!var->HasGradNode() || !find_grad_node_of_var) { + auto& accumulator = accumulators_[var.get()]; + if (!accumulator) { + if (FLAGS_sort_sum_gradient) { + accumulator.reset(new SortedGradientAccumulator(var.get())); + } else { + accumulator.reset(new EagerGradientAccumulator(var.get())); + } + } + + accumulator->IncreaseRefCnt(); + + VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "(" + << var.get() + << ") that don't have grad node with reference count " + << accumulator->RefCnt(); } } } @@ -435,16 +441,8 @@ void BasicEngine::Execute() { std::unordered_map>::iterator iter; - if (!var->HasGradNode()) { - VLOG(10) << "Find gradient of var (" << var->Name() - << ") with no grad_node."; - iter = accumulators_.find(var.get()); - PADDLE_ENFORCE_EQ( - iter != accumulators_.end(), true, - platform::errors::NotFound( - "Cannot find gradient of variable %s", var->Name())); - } else { - bool flag_find_grad = false; + bool flag_find_grad = false; + if (var->HasGradNode()) { VLOG(10) << "Find gradient of var (" << var->Name() << ") with grad_node."; for (auto& grad_pending_node : @@ -459,8 +457,16 @@ void BasicEngine::Execute() { } } } + if (!flag_find_grad) { + VLOG(6) << "Cannot find gradient of variable " << var->Name(); + } + } + if (!var->HasGradNode() || !flag_find_grad) { + VLOG(10) << "Find gradient of var (" << var->Name() + << ") with no grad_node."; + iter = accumulators_.find(var.get()); PADDLE_ENFORCE_EQ( - flag_find_grad, true, + iter != accumulators_.end(), true, platform::errors::NotFound( "Cannot find gradient of variable %s", var->Name())); } diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index f1eb8aa62c9271b194d5159883392372f4cbd4f3..7a567c85854729ca36bdaa93b7cfb316c9b887b1 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -269,8 +269,14 @@ class TracedGradOp { for (auto& var : vars) { if (var && !var->OverridedStopGradient() && var->GradNode()) { if (map_dirty_grad_node_.find(var) != map_dirty_grad_node_.end()) { + // Because inplace var isn't a leaf var, it should have + // dirty_grad_node. node_->InsertGradPendingNode(map_dirty_grad_node_[var]); - } else { + } else if (node_ != var->GradNode()) { + // For non-inplace var. + // Special case: `set_value` is inplace op, and it can change + // the var with `stop_gradient=True` to the var with + // `stop_gradient=False`. node_->InsertGradPendingNode(var->GradNode()); } } diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index 057d1b590a0d15b7a277ea690129eab1d2cb5c0c..fd277757d73a9a8413bad805f57d3af8a7747254 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -1322,5 +1322,41 @@ class TestGradientTruncated(unittest.TestCase): array = array[0] +class TestSetValueInplaceLeafVar(unittest.TestCase): + def test_inplace_var_become_leaf_var(self): + paddle.disable_static() + + a_grad_1, b_grad_1, a_grad_2, b_grad_2 = 0, 1, 2, 3 + with paddle.fluid.dygraph.guard(): + paddle.seed(100) + a = paddle.rand(shape=[1, 4]) + b = paddle.rand(shape=[1, 4]) + a.stop_gradient = False + b.stop_gradient = False + c = a / b + c.sum().backward() + a_grad_1 = a.grad.numpy() + b_grad_1 = b.grad.numpy() + + with paddle.fluid.dygraph.guard(): + paddle.seed(100) + a = paddle.rand(shape=[1, 4]) + b = paddle.rand(shape=[1, 4]) + a.stop_gradient = False + b.stop_gradient = False + c = a / b + d = paddle.zeros((4, 4)) + self.assertTrue(d.stop_gradient) + d[0, :] = c + self.assertFalse(d.stop_gradient) + d[0, :].sum().backward() + a_grad_2 = a.grad.numpy() + b_grad_2 = b.grad.numpy() + + self.assertTrue(np.array_equal(a_grad_1, a_grad_2)) + self.assertTrue(np.array_equal(b_grad_1, b_grad_2)) + paddle.enable_static() + + if __name__ == '__main__': unittest.main()