diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 03985260241689a099ae9ebc136bd04831a44167..68304c9fc8b8fa13cb1f99b82517abc87c71496c 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -38,7 +38,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) cc_library(backward SRCS backward.cc DEPS net_op) -cc_test(backward_test SRCS backward_test.cc DEPS backward) +cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) if(WITH_PYTHON) cc_library(paddle_pybind SHARED diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 83b7e4cdac9bc79ebf687cf199f6d2bc8d1695cf..c226e4e3d2a58d1a647e204c4cd26f4eb6bcd968 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -17,6 +17,7 @@ #include #include "paddle/framework/op_registry.h" #include "paddle/operators/net_op.h" +#include "paddle/operators/recurrent_op.h" namespace paddle { namespace framework { @@ -178,6 +179,22 @@ std::shared_ptr BackwardRecursive( return false; }); + // process recurrent gradient op as a special operator. + if (forwardOp.Type() == "recurrent_op") { + // NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or + // this will result in infinite loop. + const auto& rnnop = + *static_cast(&forwardOp); + auto rnn_grad_op = + static_cast(grad_op.get()); + const auto& stepnet_op = + *static_cast(&rnnop.stepnet()); + // create stepnet's gradient op + auto grad_stepnet = BackwardRecursive(stepnet_op, no_grad_names, uniq_id); + rnn_grad_op->set_stepnet( + std::static_pointer_cast(grad_stepnet)); + } + if (net->ops_.empty()) { // Current no aux op is added to network return grad_op; } diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index caca644c96c3f8c741bac4a3b5a6239d2a4555c7..171a0bd2ae80245939a9237f51d195e8bc9178fc 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -127,7 +127,7 @@ class RecurrentOp final : public framework::OperatorBase { } void set_stepnet(std::shared_ptr net) { stepnet_ = net; } - const NetOp* stepnet() const { return stepnet_.get(); } + const NetOp& stepnet() const { return *stepnet_; } static const rnn::ArgumentName kArgName; @@ -158,7 +158,7 @@ class RecurrentGradientOp final : public framework::OperatorBase { static const rnn::ArgumentName kArgName; void set_stepnet(const std::shared_ptr& net) { stepnet_ = net; } - const NetOp* stepnet() const { return stepnet_.get(); } + const NetOp& stepnet() const { return *stepnet_; } private: RecurrentGradientAlgorithm alg_;