diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index c9081ddd8a969a0a0c00061b3c5855c09e09eac8..0b45c189dd714adedc1fb1600e2b350c3dedb62b 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -36,6 +36,15 @@ namespace paddle { namespace imperative { +struct HashPair { + template + size_t operator()(const std::pair &p) const noexcept { + auto hash1 = std::hash{}(p.first); + auto hash2 = std::hash{}(p.second); + return hash1 ^ hash2; + } +}; + /** * This function prunes the graph to get the ops between `output_targets` * and `input_target_grads`. @@ -152,8 +161,10 @@ static void GetGraphInfoBetweenTargets( target_vars = *input_target_grads; std::queue> op_queue; + std::unordered_set, HashPair> op_base_visited; for (auto &endpoint_op : endpoint_ops) { op_queue.emplace(endpoint_op, nullptr); + op_base_visited.emplace(endpoint_op, nullptr); } while (!op_queue.empty()) { @@ -207,6 +218,7 @@ static void GetGraphInfoBetweenTargets( if (pending_op) { VLOG(10) << "Pending op of " << op->Type() << " is " << pending_op->Type(); + pending_ops[op].insert(pending_op); ++op_deps[pending_op]; } else { @@ -216,7 +228,10 @@ static void GetGraphInfoBetweenTargets( auto iter = preceding_ops.find(op); if (iter != preceding_ops.end()) { for (auto &preceding_op : iter->second) { - op_queue.emplace(preceding_op, op); + if (op_base_visited.count(std::make_pair(preceding_op, op)) == 0) { + op_queue.emplace(preceding_op, op); + op_base_visited.emplace(preceding_op, op); + } } } } diff --git a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py index 5f1d020bff89c12f81f413d2eb803771ed782598..d7d0af009034bcd26cdb3e2f7dedb81a90233b71 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py @@ -13,6 +13,7 @@ # limitations under the License. import paddle.fluid as fluid +import paddle from paddle.fluid.wrapped_decorator import wrap_decorator import unittest from unittest import TestCase @@ -295,5 +296,48 @@ class TestDygraphDoubleGradSortGradient(TestDygraphDoubleGrad): self.shape = [5, 10] +class TestDygraphDoubleGradVisitedUniq(TestCase): + def test_compare(self): + value = np.random.uniform(-0.5, 0.5, 100).reshape(10, 2, + 5).astype("float32") + + def model_f(input): + linear = fluid.dygraph.Linear(5, 3, bias_attr=False) + for i in range(10): + if i == 0: + out = linear(input) + else: + out = out + linear(input) + return out + + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + with fluid.dygraph.guard(): + fluid.default_startup_program().random_seed = 123 + fluid.default_main_program().random_seed = 123 + a = fluid.dygraph.to_variable(value) + a.stop_gradient = False + + out = model_f(a) + + dx=fluid.dygraph.grad(outputs=[out],inputs=[a],create_graph=False,retain_graph=False, \ + only_inputs=True,allow_unused=False, backward_strategy=backward_strategy) + + grad_1 = dx[0].numpy() + + with fluid.dygraph.guard(): + fluid.default_startup_program().random_seed = 123 + fluid.default_main_program().random_seed = 123 + a = fluid.dygraph.to_variable(value) + a.stop_gradient = False + + out = model_f(a) + out.backward(backward_strategy) + + grad_2 = a.gradient() + + self.assertTrue(np.array_equal(grad_1, grad_2)) + + if __name__ == '__main__': unittest.main()