未验证 提交 8768ffb7 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix double grad hang bug (#34023)

上级 3508bd28
......@@ -73,6 +73,7 @@ static void GetGraphInfoBetweenTargets(
std::unordered_map<OpBase *, size_t> *op_deps_ptr,
std::unordered_set<VariableWrapper *> *related_grad_vars_ptr,
const std::unordered_set<VariableWrapper *> &no_grad_var_grad) {
VLOG(10) << "prune graph starts";
/**
* Step 1. Find the candidate startup grad ops, prepared for following BFS.
*/
......@@ -117,6 +118,8 @@ static void GetGraphInfoBetweenTargets(
auto *op = op_node_pair.first;
auto *node = op_node_pair.second;
VLOG(10) << "Visit node " << node << " , visit op " << op->Type();
for (auto &output_pair : op->GetOutsMap()) {
if (!output_pair.second.IsGrad()) {
VLOG(10) << "WARNING: " << op->Type() << " outputs a forward var";
......@@ -135,6 +138,7 @@ static void GetGraphInfoBetweenTargets(
for (auto &pending_node : node->GradPendingNodes()) {
if (visited.count(pending_node.get()) == 0) {
visited.insert(pending_node.get());
for (auto &pending_op : *pending_node) {
preceding_ops[&pending_op].insert(op);
q.emplace(&pending_op, pending_node.get());
......@@ -143,6 +147,8 @@ static void GetGraphInfoBetweenTargets(
}
}
VLOG(10) << "Found endpoint op ends";
/**
* Step 3. Based on the found input_target_grads, BFS the graph in reverse
* order. `target_vars` would record all grad vars in the graph, and
......@@ -246,6 +252,8 @@ static void GetGraphInfoBetweenTargets(
}
}
VLOG(10) << "Found startup op ends";
/**
* Step 4. Prune output_targets which is not the input of startup_ops
*/
......
......@@ -15,6 +15,7 @@
import paddle.fluid as fluid
import paddle
from paddle.fluid.wrapped_decorator import wrap_decorator
from paddle.vision.models import resnet50, resnet101
import unittest
from unittest import TestCase
import numpy as np
......@@ -228,8 +229,6 @@ class TestDygraphDoubleGrad(TestCase):
x_grad_expected = (i + 2) * (2.0 / float(numel) * (
x_np + dx_expected *
(x_np > 0) * 2 / float(numel))).astype('float32')
print(x_grad_actual)
print(x_grad_expected)
self.assertTrue(np.allclose(x_grad_actual, x_grad_expected))
@dygraph_guard
......@@ -369,5 +368,38 @@ class TestRaiseNoDoubleGradOp(TestCase):
self.assertRaises(RuntimeError, self.raise_no_grad_op)
class TestDoubleGradResNetBase(TestCase):
@dygraph_guard
def check_resnet(self):
data = np.random.rand(1, 3, 224, 224).astype(np.float32)
data = paddle.to_tensor(data)
data.stop_gradient = False
out = self.model(data)
preds = paddle.argmax(out, axis=1)
label_onehot = paddle.nn.functional.one_hot(
paddle.to_tensor(preds), num_classes=out.shape[1])
target = paddle.sum(out * label_onehot, axis=1)
g = paddle.grad(outputs=target, inputs=out)[0]
g_numpy = g.numpy()
self.assertEqual(list(g_numpy.shape), list(out.shape))
class TestDoubleGradResNet50(TestDoubleGradResNetBase):
def setUp(self):
self.model = resnet50(pretrained=False)
def test_main(self):
self.check_resnet()
class TestDoubleGradResNet101(TestDoubleGradResNetBase):
def setUp(self):
self.model = resnet101(pretrained=False)
def test_main(self):
self.check_resnet()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册