未验证 提交 d31d597f 编写于 作者: P pangyoki 提交者: GitHub

Cherry-pick PR 37420, fix inplace bug when the first grad_var(loss_grad) is...

Cherry-pick PR 37420, fix inplace bug when the first grad_var(loss_grad) is inplace var (#37420) (#37488)

fix inplace bug,Cherry pick PR #37420
上级 bed652d6
...@@ -53,6 +53,10 @@ void BasicEngine::Init( ...@@ -53,6 +53,10 @@ void BasicEngine::Init(
platform::errors::AlreadyExists( platform::errors::AlreadyExists(
"Accumulators are not empty before preparing it for " "Accumulators are not empty before preparing it for "
"backward network execution.")); "backward network execution."));
PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true,
platform::errors::AlreadyExists(
"Accumulators with grad_node as the key are not empty "
"before preparing it for backward network execution."));
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
auto var = tensors[i]; auto var = tensors[i];
...@@ -73,7 +77,6 @@ void BasicEngine::Init( ...@@ -73,7 +77,6 @@ void BasicEngine::Init(
VLOG(5) << "Clear the auto-grad graph from grad var " << var->Name() VLOG(5) << "Clear the auto-grad graph from grad var " << var->Name()
<< " because of retain_graph=False when calling backward"; << " because of retain_graph=False when calling backward";
var->GradVarBase()->SetGraphIsFreed(true); var->GradVarBase()->SetGraphIsFreed(true);
var->GradVarBase()->ClearGradNode();
} }
if (init_node == nullptr || var->OverridedStopGradient()) { if (init_node == nullptr || var->OverridedStopGradient()) {
...@@ -108,7 +111,9 @@ void BasicEngine::Init( ...@@ -108,7 +111,9 @@ void BasicEngine::Init(
} }
VariableWrapper* init_grad_var = var->GradVarBase()->SharedVar().get(); VariableWrapper* init_grad_var = var->GradVarBase()->SharedVar().get();
auto& accumulator = accumulators_[init_grad_var]; auto& accumulator =
accumulators_with_grad_node_[init_grad_var->GetGradNode()]
[init_grad_var];
if (!accumulator) { if (!accumulator) {
if (FLAGS_sort_sum_gradient) { if (FLAGS_sort_sum_gradient) {
accumulator.reset(new SortedGradientAccumulator(init_grad_var)); accumulator.reset(new SortedGradientAccumulator(init_grad_var));
...@@ -116,6 +121,8 @@ void BasicEngine::Init( ...@@ -116,6 +121,8 @@ void BasicEngine::Init(
accumulator.reset(new EagerGradientAccumulator(init_grad_var)); accumulator.reset(new EagerGradientAccumulator(init_grad_var));
} }
} }
accumulator->IncreaseRefCnt();
accumulator->IncreaseCurCnt();
init_nodes_.push_back(init_node); init_nodes_.push_back(init_node);
} }
...@@ -253,10 +260,6 @@ void BasicEngine::PrepareDeps() { ...@@ -253,10 +260,6 @@ void BasicEngine::PrepareDeps() {
node_deps_.empty(), true, node_deps_.empty(), true,
platform::errors::AlreadyExists("Op deps are not empty before preparing " platform::errors::AlreadyExists("Op deps are not empty before preparing "
"it for backward network execution.")); "it for backward network execution."));
PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true,
platform::errors::AlreadyExists(
"Accumulators with grad_node as the key are not empty "
"before preparing it for backward network execution."));
std::queue<GradOpNode*> q; std::queue<GradOpNode*> q;
std::unordered_set<GradOpNode*> visited; std::unordered_set<GradOpNode*> visited;
......
...@@ -409,5 +409,30 @@ class TestDygraphInplaceSubtract(TestDygraphInplaceAdd): ...@@ -409,5 +409,30 @@ class TestDygraphInplaceSubtract(TestDygraphInplaceAdd):
return var.subtract_(self.input_var_2) return var.subtract_(self.input_var_2)
class TestLossIsInplaceVar(unittest.TestCase):
def test_loss_is_inplace_var(self):
with paddle.fluid.dygraph.guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False
var_b = var_a * 2
loss = var_b.tanh_()
loss.backward()
inplace_grad_var_a = var_a.grad.numpy()
with paddle.fluid.dygraph.guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False
var_b = var_a * 2
loss = var_b.tanh()
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(np.array_equal(inplace_grad_var_a, grad_var_a))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册