未验证 提交 b9cfa974 编写于 作者: P pangyoki 提交者: GitHub

Cherry-pick PR38014, fix dygraph_grad_maker to support set_value (#38014) (#38521)

Cherry-pick PR #38014
上级 624a2b9c
...@@ -173,27 +173,11 @@ void BasicEngine::PrepareGradAccumulators( ...@@ -173,27 +173,11 @@ void BasicEngine::PrepareGradAccumulators(
for (const auto& var : pair.second) { for (const auto& var : pair.second) {
if (!var) continue; if (!var) continue;
if (!var->HasGradNode()) { bool find_grad_node_of_var = false;
auto& accumulator = accumulators_[var.get()]; if (var->HasGradNode()) {
if (!accumulator) {
if (FLAGS_sort_sum_gradient) {
accumulator.reset(new SortedGradientAccumulator(var.get()));
} else {
accumulator.reset(new EagerGradientAccumulator(var.get()));
}
}
accumulator->IncreaseRefCnt();
VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "("
<< var.get()
<< ") that don't have grad node with reference count "
<< accumulator->RefCnt();
} else {
// Because Inplace op overwrites the grad_node of the input grad_var. So // Because Inplace op overwrites the grad_node of the input grad_var. So
// only the information of grad_pending_node can be used to find the // only the information of grad_pending_node can be used to find the
// grad_node of grad_var. // grad_node of grad_var.
bool find_grad_node_of_var = false;
for (auto& grad_pending_node : grad_pending_nodes) { for (auto& grad_pending_node : grad_pending_nodes) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
grad_pending_node, grad_pending_node,
...@@ -245,11 +229,33 @@ void BasicEngine::PrepareGradAccumulators( ...@@ -245,11 +229,33 @@ void BasicEngine::PrepareGradAccumulators(
break; break;
} }
} }
PADDLE_ENFORCE_EQ( if (!find_grad_node_of_var) {
find_grad_node_of_var, true, // Special case: `set_value` is inplace op, and it can change
platform::errors::NotFound( // the var with `stop_gradient=True` to the var with
"No grad node corresponding to grad Tensor (%s) was found.", // `stop_gradient=False `.
var->Name())); // This inplace var has grad_node (the inplace op), but it
// isn't the input of grad_pending_op.
VLOG(6) << "No grad node corresponding to grad Tensor ("
<< var->Name() << ") was found.";
}
}
if (!var->HasGradNode() || !find_grad_node_of_var) {
auto& accumulator = accumulators_[var.get()];
if (!accumulator) {
if (FLAGS_sort_sum_gradient) {
accumulator.reset(new SortedGradientAccumulator(var.get()));
} else {
accumulator.reset(new EagerGradientAccumulator(var.get()));
}
}
accumulator->IncreaseRefCnt();
VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "("
<< var.get()
<< ") that don't have grad node with reference count "
<< accumulator->RefCnt();
} }
} }
} }
...@@ -435,16 +441,8 @@ void BasicEngine::Execute() { ...@@ -435,16 +441,8 @@ void BasicEngine::Execute() {
std::unordered_map<VariableWrapper*, std::unordered_map<VariableWrapper*,
std::unique_ptr<GradientAccumulator>>::iterator std::unique_ptr<GradientAccumulator>>::iterator
iter; iter;
if (!var->HasGradNode()) {
VLOG(10) << "Find gradient of var (" << var->Name()
<< ") with no grad_node.";
iter = accumulators_.find(var.get());
PADDLE_ENFORCE_EQ(
iter != accumulators_.end(), true,
platform::errors::NotFound(
"Cannot find gradient of variable %s", var->Name()));
} else {
bool flag_find_grad = false; bool flag_find_grad = false;
if (var->HasGradNode()) {
VLOG(10) << "Find gradient of var (" << var->Name() VLOG(10) << "Find gradient of var (" << var->Name()
<< ") with grad_node."; << ") with grad_node.";
for (auto& grad_pending_node : for (auto& grad_pending_node :
...@@ -459,8 +457,16 @@ void BasicEngine::Execute() { ...@@ -459,8 +457,16 @@ void BasicEngine::Execute() {
} }
} }
} }
if (!flag_find_grad) {
VLOG(6) << "Cannot find gradient of variable " << var->Name();
}
}
if (!var->HasGradNode() || !flag_find_grad) {
VLOG(10) << "Find gradient of var (" << var->Name()
<< ") with no grad_node.";
iter = accumulators_.find(var.get());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
flag_find_grad, true, iter != accumulators_.end(), true,
platform::errors::NotFound( platform::errors::NotFound(
"Cannot find gradient of variable %s", var->Name())); "Cannot find gradient of variable %s", var->Name()));
} }
......
...@@ -269,8 +269,14 @@ class TracedGradOp { ...@@ -269,8 +269,14 @@ class TracedGradOp {
for (auto& var : vars) { for (auto& var : vars) {
if (var && !var->OverridedStopGradient() && var->GradNode()) { if (var && !var->OverridedStopGradient() && var->GradNode()) {
if (map_dirty_grad_node_.find(var) != map_dirty_grad_node_.end()) { if (map_dirty_grad_node_.find(var) != map_dirty_grad_node_.end()) {
// Because inplace var isn't a leaf var, it should have
// dirty_grad_node.
node_->InsertGradPendingNode(map_dirty_grad_node_[var]); node_->InsertGradPendingNode(map_dirty_grad_node_[var]);
} else { } else if (node_ != var->GradNode()) {
// For non-inplace var.
// Special case: `set_value` is inplace op, and it can change
// the var with `stop_gradient=True` to the var with
// `stop_gradient=False`.
node_->InsertGradPendingNode(var->GradNode()); node_->InsertGradPendingNode(var->GradNode());
} }
} }
......
...@@ -1322,5 +1322,41 @@ class TestGradientTruncated(unittest.TestCase): ...@@ -1322,5 +1322,41 @@ class TestGradientTruncated(unittest.TestCase):
array = array[0] array = array[0]
class TestSetValueInplaceLeafVar(unittest.TestCase):
def test_inplace_var_become_leaf_var(self):
paddle.disable_static()
a_grad_1, b_grad_1, a_grad_2, b_grad_2 = 0, 1, 2, 3
with paddle.fluid.dygraph.guard():
paddle.seed(100)
a = paddle.rand(shape=[1, 4])
b = paddle.rand(shape=[1, 4])
a.stop_gradient = False
b.stop_gradient = False
c = a / b
c.sum().backward()
a_grad_1 = a.grad.numpy()
b_grad_1 = b.grad.numpy()
with paddle.fluid.dygraph.guard():
paddle.seed(100)
a = paddle.rand(shape=[1, 4])
b = paddle.rand(shape=[1, 4])
a.stop_gradient = False
b.stop_gradient = False
c = a / b
d = paddle.zeros((4, 4))
self.assertTrue(d.stop_gradient)
d[0, :] = c
self.assertFalse(d.stop_gradient)
d[0, :].sum().backward()
a_grad_2 = a.grad.numpy()
b_grad_2 = b.grad.numpy()
self.assertTrue(np.array_equal(a_grad_1, a_grad_2))
self.assertTrue(np.array_equal(b_grad_1, b_grad_2))
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册