未验证 提交 ee1e1642 编写于 作者: P pangyoki 提交者: GitHub

fix inplace bug when the first grad_var(loss_grad) is inplace var (#37420)

* fix inplace bug

* fix custom grad input error

* add unittest

* fix inplace bug
上级 79800978
......@@ -53,6 +53,10 @@ void BasicEngine::Init(
platform::errors::AlreadyExists(
"Accumulators are not empty before preparing it for "
"backward network execution."));
PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true,
platform::errors::AlreadyExists(
"Accumulators with grad_node as the key are not empty "
"before preparing it for backward network execution."));
for (size_t i = 0; i < tensors.size(); ++i) {
auto var = tensors[i];
......@@ -73,7 +77,6 @@ void BasicEngine::Init(
VLOG(5) << "Clear the auto-grad graph from grad var " << var->Name()
<< " because of retain_graph=False when calling backward";
var->GradVarBase()->SetGraphIsFreed(true);
var->GradVarBase()->ClearGradNode();
}
if (init_node == nullptr || var->OverridedStopGradient()) {
......@@ -108,7 +111,9 @@ void BasicEngine::Init(
}
VariableWrapper* init_grad_var = var->GradVarBase()->SharedVar().get();
auto& accumulator = accumulators_[init_grad_var];
auto& accumulator =
accumulators_with_grad_node_[init_grad_var->GetGradNode()]
[init_grad_var];
if (!accumulator) {
if (FLAGS_sort_sum_gradient) {
accumulator.reset(new SortedGradientAccumulator(init_grad_var));
......@@ -116,6 +121,8 @@ void BasicEngine::Init(
accumulator.reset(new EagerGradientAccumulator(init_grad_var));
}
}
accumulator->IncreaseRefCnt();
accumulator->IncreaseCurCnt();
init_nodes_.push_back(init_node);
}
......@@ -253,10 +260,6 @@ void BasicEngine::PrepareDeps() {
node_deps_.empty(), true,
platform::errors::AlreadyExists("Op deps are not empty before preparing "
"it for backward network execution."));
PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true,
platform::errors::AlreadyExists(
"Accumulators with grad_node as the key are not empty "
"before preparing it for backward network execution."));
std::queue<GradOpNode*> q;
std::unordered_set<GradOpNode*> visited;
......
......@@ -409,5 +409,30 @@ class TestDygraphInplaceSubtract(TestDygraphInplaceAdd):
return var.subtract_(self.input_var_2)
class TestLossIsInplaceVar(unittest.TestCase):
def test_loss_is_inplace_var(self):
with paddle.fluid.dygraph.guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False
var_b = var_a * 2
loss = var_b.tanh_()
loss.backward()
inplace_grad_var_a = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False
var_b = var_a * 2
loss = var_b.tanh()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(np.array_equal(inplace_grad_var_a, grad_var_a))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册