diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc index 399ad4a3ca52317c7fbaab2542a5d8ccdd4d1330..223d944c83a853dc90b153c47bbb764ed529c83b 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc @@ -36,6 +36,29 @@ void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const { std::unordered_map target_ops = DeviceIdToRecurrentAndRecurrentGradOp(*graph); + if (graph->IsConstructedByPartialProgram()) { + PADDLE_ENFORCE_LE(target_ops.size(), + 1, + platform::errors::InvalidArgument( + "Unsupported multi devices if graph is constructed " + "with partial program.")); + size_t scope_idx = 0; + auto &recur_ops = target_ops[scope_idx].first; + auto &recur_grad_ops = target_ops[scope_idx].second; + + auto all_ops = graph->OriginProgram().Block(0).AllOps(); + if (recur_ops.empty()) { + operators::AppendOpVariantByOpName( + all_ops, std::string("recurrent"), &recur_ops); + } else if (recur_grad_ops.empty()) { + operators::AppendOpVariantByOpName( + all_ops, std::string("recurrent_grad"), &recur_grad_ops); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "One of recur_ops or recur_grad_ops should be empty.")); + } + } + for (auto &entry : target_ops) { // Prepare safe eager deletion on different devices because the garbage // collection may be different across devices diff --git a/paddle/fluid/operators/controlflow/op_variant.cc b/paddle/fluid/operators/controlflow/op_variant.cc index 48b7a4341067284a4c48e2c9e89c26e077d49570..8d43a21e66437fcdca9a0f8f9226f6814f493095 100644 --- a/paddle/fluid/operators/controlflow/op_variant.cc +++ b/paddle/fluid/operators/controlflow/op_variant.cc @@ -81,5 +81,22 @@ void AppendOpVariantByOpName(const std::vector &op_descs, } } +void AppendOpVariantByOpName( + const std::vector &op_descs, + const std::string &candidate_op_name, + std::unordered_set *result_ops) { + PADDLE_ENFORCE_NOT_NULL( + result_ops, + platform::errors::Unavailable("result_ops should not be a null_ptr.")); + for (auto *op_desc : op_descs) { + PADDLE_ENFORCE_NOT_NULL( + op_desc, + platform::errors::Unavailable("op_desc should not be a null_ptr.")); + if (op_desc->Type() == candidate_op_name) { + result_ops->emplace(op_desc); + } + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/controlflow/op_variant.h b/paddle/fluid/operators/controlflow/op_variant.h index 738e7a4acc7eb003d5f6647f5ef39b79546506fc..ad7cc6b741eb9ed01cd3bcf07d9da79b412efb25 100644 --- a/paddle/fluid/operators/controlflow/op_variant.h +++ b/paddle/fluid/operators/controlflow/op_variant.h @@ -78,5 +78,10 @@ void AppendOpVariantByOpName(const std::vector &op_descs, const std::string &candidate_op_name, std::vector *result_ops); +void AppendOpVariantByOpName( + const std::vector &op_descs, + const std::string &candidate_op_name, + std::unordered_set *result_ops); + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_recurrent_op.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_recurrent_op.py index cea7da9f72c87baa92978d194a63000c5b200365..018ab42c50a55602da8c948f9758b4889f22a09b 100644 --- a/python/paddle/fluid/tests/unittests/test_eager_deletion_recurrent_op.py +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_recurrent_op.py @@ -690,5 +690,45 @@ class EagerDeletionFarwardOnlyRnnAndBackwardRnnTest( np.testing.assert_allclose(pd_output, py_output, rtol=0.01) +class RecurrentNet(paddle.nn.Layer): + + def __init__(self): + super(RecurrentNet, self).__init__() + self.cell = paddle.nn.SimpleRNNCell(16, 32) + self.rnn = paddle.nn.RNN(self.cell) + + def forward(self, inputs, prev_h): + outputs, final_states = self.rnn(inputs, prev_h) + return outputs, final_states + + +class TestDy2StRecurrentOpBackward(unittest.TestCase): + + def setUp(self): + paddle.disable_static() + paddle.seed(100) + + def tearDown(self): + paddle.enable_static() + + def test_recurrent_backward(self): + net = RecurrentNet() + inputs = paddle.rand((4, 23, 16)) + inputs.stop_gradient = False + prev_h = paddle.randn((4, 32)) + prev_h.stop_gradient = False + + outputs, final_states = net(inputs, prev_h) + outputs.backward() + dy_grad = inputs.gradient() + inputs.clear_gradient() + + net = paddle.jit.to_static(net) + outputs, final_states = net(inputs, prev_h) + outputs.backward() + st_grad = inputs.gradient() + np.testing.assert_allclose(dy_grad, st_grad) + + if __name__ == '__main__': unittest.main()