提交 9eaef753 编写于 作者: Y Yan Chunwei 提交者: GitHub

RNN backward create (#3490)

* insert rnn's backward into Backward()

* add device_context into backward_test
上级 d08550fd
......@@ -38,7 +38,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
if(WITH_PYTHON)
cc_library(paddle_pybind SHARED
......
......@@ -17,6 +17,7 @@
#include <list>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
namespace paddle {
namespace framework {
......@@ -178,6 +179,22 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
return false;
});
// process recurrent gradient op as a special operator.
if (forwardOp.Type() == "recurrent_op") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or
// this will result in infinite loop.
const auto& rnnop =
*static_cast<const operators::RecurrentOp*>(&forwardOp);
auto rnn_grad_op =
static_cast<operators::RecurrentGradientOp*>(grad_op.get());
const auto& stepnet_op =
*static_cast<const OperatorBase*>(&rnnop.stepnet());
// create stepnet's gradient op
auto grad_stepnet = BackwardRecursive(stepnet_op, no_grad_names, uniq_id);
rnn_grad_op->set_stepnet(
std::static_pointer_cast<operators::NetOp>(grad_stepnet));
}
if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op;
}
......
......@@ -127,7 +127,7 @@ class RecurrentOp final : public framework::OperatorBase {
}
void set_stepnet(std::shared_ptr<NetOp> net) { stepnet_ = net; }
const NetOp* stepnet() const { return stepnet_.get(); }
const NetOp& stepnet() const { return *stepnet_; }
static const rnn::ArgumentName kArgName;
......@@ -158,7 +158,7 @@ class RecurrentGradientOp final : public framework::OperatorBase {
static const rnn::ArgumentName kArgName;
void set_stepnet(const std::shared_ptr<NetOp>& net) { stepnet_ = net; }
const NetOp* stepnet() const { return stepnet_.get(); }
const NetOp& stepnet() const { return *stepnet_; }
private:
RecurrentGradientAlgorithm alg_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册