From 9413219024910b72c3a37f9bf93696f6fe5ada1d Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Wed, 19 Oct 2022 11:28:33 +0800 Subject: [PATCH] [Dy2St]Fix recurrent op eager deletion pass error in dy2st (#47105) * Fix recurrent op eager deletion pass error in dy2st * Polish code * Refine error message --- .../recurrent_op_eager_deletion_pass.cc | 23 +++++++++++ .../fluid/operators/controlflow/op_variant.cc | 17 ++++++++ .../fluid/operators/controlflow/op_variant.h | 5 +++ .../test_eager_deletion_recurrent_op.py | 40 +++++++++++++++++++ 4 files changed, 85 insertions(+) diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc index 399ad4a3ca5..223d944c83a 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc @@ -36,6 +36,29 @@ void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const { std::unordered_map target_ops = DeviceIdToRecurrentAndRecurrentGradOp(*graph); + if (graph->IsConstructedByPartialProgram()) { + PADDLE_ENFORCE_LE(target_ops.size(), + 1, + platform::errors::InvalidArgument( + "Unsupported multi devices if graph is constructed " + "with partial program.")); + size_t scope_idx = 0; + auto &recur_ops = target_ops[scope_idx].first; + auto &recur_grad_ops = target_ops[scope_idx].second; + + auto all_ops = graph->OriginProgram().Block(0).AllOps(); + if (recur_ops.empty()) { + operators::AppendOpVariantByOpName( + all_ops, std::string("recurrent"), &recur_ops); + } else if (recur_grad_ops.empty()) { + operators::AppendOpVariantByOpName( + all_ops, std::string("recurrent_grad"), &recur_grad_ops); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "One of recur_ops or recur_grad_ops should be empty.")); + } + } + for (auto &entry : target_ops) { // Prepare safe eager deletion on different devices because the garbage // collection may be different across devices diff --git a/paddle/fluid/operators/controlflow/op_variant.cc b/paddle/fluid/operators/controlflow/op_variant.cc index 48b7a434106..8d43a21e664 100644 --- a/paddle/fluid/operators/controlflow/op_variant.cc +++ b/paddle/fluid/operators/controlflow/op_variant.cc @@ -81,5 +81,22 @@ void AppendOpVariantByOpName(const std::vector &op_descs, } } +void AppendOpVariantByOpName( + const std::vector &op_descs, + const std::string &candidate_op_name, + std::unordered_set *result_ops) { + PADDLE_ENFORCE_NOT_NULL( + result_ops, + platform::errors::Unavailable("result_ops should not be a null_ptr.")); + for (auto *op_desc : op_descs) { + PADDLE_ENFORCE_NOT_NULL( + op_desc, + platform::errors::Unavailable("op_desc should not be a null_ptr.")); + if (op_desc->Type() == candidate_op_name) { + result_ops->emplace(op_desc); + } + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/controlflow/op_variant.h b/paddle/fluid/operators/controlflow/op_variant.h index 738e7a4acc7..ad7cc6b741e 100644 --- a/paddle/fluid/operators/controlflow/op_variant.h +++ b/paddle/fluid/operators/controlflow/op_variant.h @@ -78,5 +78,10 @@ void AppendOpVariantByOpName(const std::vector &op_descs, const std::string &candidate_op_name, std::vector *result_ops); +void AppendOpVariantByOpName( + const std::vector &op_descs, + const std::string &candidate_op_name, + std::unordered_set *result_ops); + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_recurrent_op.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_recurrent_op.py index cea7da9f72c..018ab42c50a 100644 --- a/python/paddle/fluid/tests/unittests/test_eager_deletion_recurrent_op.py +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_recurrent_op.py @@ -690,5 +690,45 @@ class EagerDeletionFarwardOnlyRnnAndBackwardRnnTest( np.testing.assert_allclose(pd_output, py_output, rtol=0.01) +class RecurrentNet(paddle.nn.Layer): + + def __init__(self): + super(RecurrentNet, self).__init__() + self.cell = paddle.nn.SimpleRNNCell(16, 32) + self.rnn = paddle.nn.RNN(self.cell) + + def forward(self, inputs, prev_h): + outputs, final_states = self.rnn(inputs, prev_h) + return outputs, final_states + + +class TestDy2StRecurrentOpBackward(unittest.TestCase): + + def setUp(self): + paddle.disable_static() + paddle.seed(100) + + def tearDown(self): + paddle.enable_static() + + def test_recurrent_backward(self): + net = RecurrentNet() + inputs = paddle.rand((4, 23, 16)) + inputs.stop_gradient = False + prev_h = paddle.randn((4, 32)) + prev_h.stop_gradient = False + + outputs, final_states = net(inputs, prev_h) + outputs.backward() + dy_grad = inputs.gradient() + inputs.clear_gradient() + + net = paddle.jit.to_static(net) + outputs, final_states = net(inputs, prev_h) + outputs.backward() + st_grad = inputs.gradient() + np.testing.assert_allclose(dy_grad, st_grad) + + if __name__ == '__main__': unittest.main() -- GitLab