未验证 提交 8768ffb7 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix double grad hang bug (#34023)

上级 3508bd28
...@@ -73,6 +73,7 @@ static void GetGraphInfoBetweenTargets( ...@@ -73,6 +73,7 @@ static void GetGraphInfoBetweenTargets(
std::unordered_map<OpBase *, size_t> *op_deps_ptr, std::unordered_map<OpBase *, size_t> *op_deps_ptr,
std::unordered_set<VariableWrapper *> *related_grad_vars_ptr, std::unordered_set<VariableWrapper *> *related_grad_vars_ptr,
const std::unordered_set<VariableWrapper *> &no_grad_var_grad) { const std::unordered_set<VariableWrapper *> &no_grad_var_grad) {
VLOG(10) << "prune graph starts";
/** /**
* Step 1. Find the candidate startup grad ops, prepared for following BFS. * Step 1. Find the candidate startup grad ops, prepared for following BFS.
*/ */
...@@ -117,6 +118,8 @@ static void GetGraphInfoBetweenTargets( ...@@ -117,6 +118,8 @@ static void GetGraphInfoBetweenTargets(
auto *op = op_node_pair.first; auto *op = op_node_pair.first;
auto *node = op_node_pair.second; auto *node = op_node_pair.second;
VLOG(10) << "Visit node " << node << " , visit op " << op->Type();
for (auto &output_pair : op->GetOutsMap()) { for (auto &output_pair : op->GetOutsMap()) {
if (!output_pair.second.IsGrad()) { if (!output_pair.second.IsGrad()) {
VLOG(10) << "WARNING: " << op->Type() << " outputs a forward var"; VLOG(10) << "WARNING: " << op->Type() << " outputs a forward var";
...@@ -135,6 +138,7 @@ static void GetGraphInfoBetweenTargets( ...@@ -135,6 +138,7 @@ static void GetGraphInfoBetweenTargets(
for (auto &pending_node : node->GradPendingNodes()) { for (auto &pending_node : node->GradPendingNodes()) {
if (visited.count(pending_node.get()) == 0) { if (visited.count(pending_node.get()) == 0) {
visited.insert(pending_node.get());
for (auto &pending_op : *pending_node) { for (auto &pending_op : *pending_node) {
preceding_ops[&pending_op].insert(op); preceding_ops[&pending_op].insert(op);
q.emplace(&pending_op, pending_node.get()); q.emplace(&pending_op, pending_node.get());
...@@ -143,6 +147,8 @@ static void GetGraphInfoBetweenTargets( ...@@ -143,6 +147,8 @@ static void GetGraphInfoBetweenTargets(
} }
} }
VLOG(10) << "Found endpoint op ends";
/** /**
* Step 3. Based on the found input_target_grads, BFS the graph in reverse * Step 3. Based on the found input_target_grads, BFS the graph in reverse
* order. `target_vars` would record all grad vars in the graph, and * order. `target_vars` would record all grad vars in the graph, and
...@@ -246,6 +252,8 @@ static void GetGraphInfoBetweenTargets( ...@@ -246,6 +252,8 @@ static void GetGraphInfoBetweenTargets(
} }
} }
VLOG(10) << "Found startup op ends";
/** /**
* Step 4. Prune output_targets which is not the input of startup_ops * Step 4. Prune output_targets which is not the input of startup_ops
*/ */
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
from paddle.fluid.wrapped_decorator import wrap_decorator from paddle.fluid.wrapped_decorator import wrap_decorator
from paddle.vision.models import resnet50, resnet101
import unittest import unittest
from unittest import TestCase from unittest import TestCase
import numpy as np import numpy as np
...@@ -228,8 +229,6 @@ class TestDygraphDoubleGrad(TestCase): ...@@ -228,8 +229,6 @@ class TestDygraphDoubleGrad(TestCase):
x_grad_expected = (i + 2) * (2.0 / float(numel) * ( x_grad_expected = (i + 2) * (2.0 / float(numel) * (
x_np + dx_expected * x_np + dx_expected *
(x_np > 0) * 2 / float(numel))).astype('float32') (x_np > 0) * 2 / float(numel))).astype('float32')
print(x_grad_actual)
print(x_grad_expected)
self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) self.assertTrue(np.allclose(x_grad_actual, x_grad_expected))
@dygraph_guard @dygraph_guard
...@@ -369,5 +368,38 @@ class TestRaiseNoDoubleGradOp(TestCase): ...@@ -369,5 +368,38 @@ class TestRaiseNoDoubleGradOp(TestCase):
self.assertRaises(RuntimeError, self.raise_no_grad_op) self.assertRaises(RuntimeError, self.raise_no_grad_op)
class TestDoubleGradResNetBase(TestCase):
@dygraph_guard
def check_resnet(self):
data = np.random.rand(1, 3, 224, 224).astype(np.float32)
data = paddle.to_tensor(data)
data.stop_gradient = False
out = self.model(data)
preds = paddle.argmax(out, axis=1)
label_onehot = paddle.nn.functional.one_hot(
paddle.to_tensor(preds), num_classes=out.shape[1])
target = paddle.sum(out * label_onehot, axis=1)
g = paddle.grad(outputs=target, inputs=out)[0]
g_numpy = g.numpy()
self.assertEqual(list(g_numpy.shape), list(out.shape))
class TestDoubleGradResNet50(TestDoubleGradResNetBase):
def setUp(self):
self.model = resnet50(pretrained=False)
def test_main(self):
self.check_resnet()
class TestDoubleGradResNet101(TestDoubleGradResNetBase):
def setUp(self):
self.model = resnet101(pretrained=False)
def test_main(self):
self.check_resnet()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册