提交 a270dbb7 编写于 作者: F fengjiayi

Add support for rnn_op

上级 c61e82bc
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/framework/block_desc.h" #include "paddle/framework/block_desc.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/program_desc.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h" #include "paddle/operators/recurrent_op.h"
...@@ -254,7 +255,7 @@ static bool AllGradInSet(const std::vector<std::string>& names, ...@@ -254,7 +255,7 @@ static bool AllGradInSet(const std::vector<std::string>& names,
return true; return true;
} }
std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs( std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
const std::unique_ptr<OpDescBind>& op_desc, const std::unique_ptr<OpDescBind>& op_desc,
std::unordered_set<std::string>& no_grad_vars) { std::unordered_set<std::string>& no_grad_vars) {
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs; std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
...@@ -295,20 +296,35 @@ std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs( ...@@ -295,20 +296,35 @@ std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs(
for (auto& p : pending_fill_zeros_ops) { for (auto& p : pending_fill_zeros_ops) {
grad_op_descs.push_back(std::move(p)); grad_op_descs.push_back(std::move(p));
} }
// TODO(fengjiayi): RNN op
return grad_op_descs; return grad_op_descs;
} }
void AppendBackwardOpDescs(BlockDescBind& block_desc, std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std::unordered_set<std::string>& no_grad_vars) { ProgramDescBind& program_desc, int block_idx,
std::unordered_set<std::string>& no_grad_vars) {
BlockDescBind* cur_block = program_desc.Block(block_idx);
std::deque<std::unique_ptr<OpDescBind>>& op_descs = cur_block->ops_;
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops; std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
size_t grad_desc_idx = 0; size_t grad_desc_idx = 0;
std::deque<std::unique_ptr<OpDescBind>>& block_op_descs = block_desc.ops_;
std::vector<std::unique_ptr<OpDescBind>> backward_descs; std::vector<std::unique_ptr<OpDescBind>> backward_descs;
for (auto it = block_op_descs.rbegin(); it != block_op_descs.rend(); ++it) { for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
std::vector<std::unique_ptr<OpDescBind>> op_grads = std::vector<std::unique_ptr<OpDescBind>> op_grads =
MakeGradOpDescs(*it, no_grad_vars); MakeOpGrad(*it, no_grad_vars);
if ((*it)->Type() == "recurrent") {
PADDLE_ENFORCE_EQ(
op_grads.size(), size_t(1),
"rnn_op's gradient process should contain only one op.");
int step_block_idx = (*it)->GetBlockAttr("stop_block");
auto backward_block_op_descs =
MakeBlockBackward(program_desc, step_block_idx, no_grad_vars);
BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block);
for (auto& ptr : backward_block_op_descs) {
backward_block->ops_.push_back(std::move(ptr));
}
op_grads[0]->SetBlockAttr("step_block", *backward_block);
}
for (const auto& desc : op_grads) { for (const auto& desc : op_grads) {
for (const std::string& out_name : desc->OutputArgumentNames()) { for (const std::string& out_name : desc->OutputArgumentNames()) {
dup_out_ops[out_name].emplace_back(grad_desc_idx); dup_out_ops[out_name].emplace_back(grad_desc_idx);
...@@ -345,11 +361,24 @@ void AppendBackwardOpDescs(BlockDescBind& block_desc, ...@@ -345,11 +361,24 @@ void AppendBackwardOpDescs(BlockDescBind& block_desc,
backward_descs.insert(backward_descs.begin() + p.first + 1, backward_descs.insert(backward_descs.begin() + p.first + 1,
std::move(p.second)); std::move(p.second));
} }
// Append backward_descs to BlockDescBind::ops_ return backward_descs;
for (std::unique_ptr<OpDescBind>& ptr : backward_descs) { }
block_op_descs.push_back(std::move(ptr));
void AppendBackward(ProgramDescBind& program_desc,
const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_var_names;
no_grad_var_names.reserve(no_grad_vars.size() + 1);
no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
for (auto& name : no_grad_vars) {
no_grad_var_names.insert(GradVarName(name));
}
const int root_block_idx = 0;
auto backward_op_descs =
MakeBlockBackward(program_desc, root_block_idx, no_grad_var_names);
auto& forw_op_descs = program_desc.Block(root_block_idx)->ops_;
for (auto& ptr : backward_op_descs) {
forw_op_descs.push_back(std::move(ptr));
} }
return;
} }
} // namespace framework } // namespace framework
......
...@@ -32,8 +32,13 @@ class ProgramDescBind; ...@@ -32,8 +32,13 @@ class ProgramDescBind;
class BlockDescBind { class BlockDescBind {
public: public:
friend void AppendBackwardOpDescs( friend std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
BlockDescBind &block_desc, std::unordered_set<std::string> &no_grad_vars); ProgramDescBind &program_desc, int block_idx,
std::unordered_set<std::string> &no_grad_vars);
friend void AppendBackward(
ProgramDescBind &program_desc,
const std::unordered_set<std::string> &no_grad_vars);
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {} : prog_(prog), desc_(desc), need_update_(false) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册