From 8768ffb7e20cfa31d0bb7ba1f6234391b54c7404 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Fri, 9 Jul 2021 10:02:51 +0800 Subject: [PATCH] fix double grad hang bug (#34023) --- .../fluid/imperative/partial_grad_engine.cc | 8 +++++ .../unittests/test_imperative_double_grad.py | 36 +++++++++++++++++-- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index d905b135082..84ba60fef80 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -73,6 +73,7 @@ static void GetGraphInfoBetweenTargets( std::unordered_map *op_deps_ptr, std::unordered_set *related_grad_vars_ptr, const std::unordered_set &no_grad_var_grad) { + VLOG(10) << "prune graph starts"; /** * Step 1. Find the candidate startup grad ops, prepared for following BFS. */ @@ -117,6 +118,8 @@ static void GetGraphInfoBetweenTargets( auto *op = op_node_pair.first; auto *node = op_node_pair.second; + VLOG(10) << "Visit node " << node << " , visit op " << op->Type(); + for (auto &output_pair : op->GetOutsMap()) { if (!output_pair.second.IsGrad()) { VLOG(10) << "WARNING: " << op->Type() << " outputs a forward var"; @@ -135,6 +138,7 @@ static void GetGraphInfoBetweenTargets( for (auto &pending_node : node->GradPendingNodes()) { if (visited.count(pending_node.get()) == 0) { + visited.insert(pending_node.get()); for (auto &pending_op : *pending_node) { preceding_ops[&pending_op].insert(op); q.emplace(&pending_op, pending_node.get()); @@ -143,6 +147,8 @@ static void GetGraphInfoBetweenTargets( } } + VLOG(10) << "Found endpoint op ends"; + /** * Step 3. Based on the found input_target_grads, BFS the graph in reverse * order. `target_vars` would record all grad vars in the graph, and @@ -246,6 +252,8 @@ static void GetGraphInfoBetweenTargets( } } + VLOG(10) << "Found startup op ends"; + /** * Step 4. Prune output_targets which is not the input of startup_ops */ diff --git a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py index e41960f6b47..cd4ba5b0542 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py @@ -15,6 +15,7 @@ import paddle.fluid as fluid import paddle from paddle.fluid.wrapped_decorator import wrap_decorator +from paddle.vision.models import resnet50, resnet101 import unittest from unittest import TestCase import numpy as np @@ -228,8 +229,6 @@ class TestDygraphDoubleGrad(TestCase): x_grad_expected = (i + 2) * (2.0 / float(numel) * ( x_np + dx_expected * (x_np > 0) * 2 / float(numel))).astype('float32') - print(x_grad_actual) - print(x_grad_expected) self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) @dygraph_guard @@ -369,5 +368,38 @@ class TestRaiseNoDoubleGradOp(TestCase): self.assertRaises(RuntimeError, self.raise_no_grad_op) +class TestDoubleGradResNetBase(TestCase): + @dygraph_guard + def check_resnet(self): + data = np.random.rand(1, 3, 224, 224).astype(np.float32) + data = paddle.to_tensor(data) + data.stop_gradient = False + out = self.model(data) + preds = paddle.argmax(out, axis=1) + label_onehot = paddle.nn.functional.one_hot( + paddle.to_tensor(preds), num_classes=out.shape[1]) + target = paddle.sum(out * label_onehot, axis=1) + + g = paddle.grad(outputs=target, inputs=out)[0] + g_numpy = g.numpy() + self.assertEqual(list(g_numpy.shape), list(out.shape)) + + +class TestDoubleGradResNet50(TestDoubleGradResNetBase): + def setUp(self): + self.model = resnet50(pretrained=False) + + def test_main(self): + self.check_resnet() + + +class TestDoubleGradResNet101(TestDoubleGradResNetBase): + def setUp(self): + self.model = resnet101(pretrained=False) + + def test_main(self): + self.check_resnet() + + if __name__ == '__main__': unittest.main() -- GitLab