提交 9eaef753 编写于 作者: Y Yan Chunwei 提交者: GitHub

RNN backward create (#3490)

* insert rnn's backward into Backward()

* add device_context into backward_test
上级 d08550fd
...@@ -38,7 +38,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD ...@@ -38,7 +38,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
cc_library(backward SRCS backward.cc DEPS net_op) cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
if(WITH_PYTHON) if(WITH_PYTHON)
cc_library(paddle_pybind SHARED cc_library(paddle_pybind SHARED
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <list> #include <list>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -178,6 +179,22 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -178,6 +179,22 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
return false; return false;
}); });
// process recurrent gradient op as a special operator.
if (forwardOp.Type() == "recurrent_op") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or
// this will result in infinite loop.
const auto& rnnop =
*static_cast<const operators::RecurrentOp*>(&forwardOp);
auto rnn_grad_op =
static_cast<operators::RecurrentGradientOp*>(grad_op.get());
const auto& stepnet_op =
*static_cast<const OperatorBase*>(&rnnop.stepnet());
// create stepnet's gradient op
auto grad_stepnet = BackwardRecursive(stepnet_op, no_grad_names, uniq_id);
rnn_grad_op->set_stepnet(
std::static_pointer_cast<operators::NetOp>(grad_stepnet));
}
if (net->ops_.empty()) { // Current no aux op is added to network if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op; return grad_op;
} }
......
...@@ -127,7 +127,7 @@ class RecurrentOp final : public framework::OperatorBase { ...@@ -127,7 +127,7 @@ class RecurrentOp final : public framework::OperatorBase {
} }
void set_stepnet(std::shared_ptr<NetOp> net) { stepnet_ = net; } void set_stepnet(std::shared_ptr<NetOp> net) { stepnet_ = net; }
const NetOp* stepnet() const { return stepnet_.get(); } const NetOp& stepnet() const { return *stepnet_; }
static const rnn::ArgumentName kArgName; static const rnn::ArgumentName kArgName;
...@@ -158,7 +158,7 @@ class RecurrentGradientOp final : public framework::OperatorBase { ...@@ -158,7 +158,7 @@ class RecurrentGradientOp final : public framework::OperatorBase {
static const rnn::ArgumentName kArgName; static const rnn::ArgumentName kArgName;
void set_stepnet(const std::shared_ptr<NetOp>& net) { stepnet_ = net; } void set_stepnet(const std::shared_ptr<NetOp>& net) { stepnet_ = net; }
const NetOp* stepnet() const { return stepnet_.get(); } const NetOp& stepnet() const { return *stepnet_; }
private: private:
RecurrentGradientAlgorithm alg_; RecurrentGradientAlgorithm alg_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册