提交 a270dbb7 编写于 作者: F fengjiayi

Add support for rnn_op

上级 c61e82bc
......@@ -20,6 +20,7 @@
#include "paddle/framework/block_desc.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/program_desc.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
......@@ -254,7 +255,7 @@ static bool AllGradInSet(const std::vector<std::string>& names,
return true;
}
std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs(
std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
const std::unique_ptr<OpDescBind>& op_desc,
std::unordered_set<std::string>& no_grad_vars) {
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
......@@ -295,20 +296,35 @@ std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs(
for (auto& p : pending_fill_zeros_ops) {
grad_op_descs.push_back(std::move(p));
}
// TODO(fengjiayi): RNN op
return grad_op_descs;
}
void AppendBackwardOpDescs(BlockDescBind& block_desc,
std::unordered_set<std::string>& no_grad_vars) {
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind& program_desc, int block_idx,
std::unordered_set<std::string>& no_grad_vars) {
BlockDescBind* cur_block = program_desc.Block(block_idx);
std::deque<std::unique_ptr<OpDescBind>>& op_descs = cur_block->ops_;
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
size_t grad_desc_idx = 0;
std::deque<std::unique_ptr<OpDescBind>>& block_op_descs = block_desc.ops_;
std::vector<std::unique_ptr<OpDescBind>> backward_descs;
for (auto it = block_op_descs.rbegin(); it != block_op_descs.rend(); ++it) {
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
std::vector<std::unique_ptr<OpDescBind>> op_grads =
MakeGradOpDescs(*it, no_grad_vars);
MakeOpGrad(*it, no_grad_vars);
if ((*it)->Type() == "recurrent") {
PADDLE_ENFORCE_EQ(
op_grads.size(), size_t(1),
"rnn_op's gradient process should contain only one op.");
int step_block_idx = (*it)->GetBlockAttr("stop_block");
auto backward_block_op_descs =
MakeBlockBackward(program_desc, step_block_idx, no_grad_vars);
BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block);
for (auto& ptr : backward_block_op_descs) {
backward_block->ops_.push_back(std::move(ptr));
}
op_grads[0]->SetBlockAttr("step_block", *backward_block);
}
for (const auto& desc : op_grads) {
for (const std::string& out_name : desc->OutputArgumentNames()) {
dup_out_ops[out_name].emplace_back(grad_desc_idx);
......@@ -345,11 +361,24 @@ void AppendBackwardOpDescs(BlockDescBind& block_desc,
backward_descs.insert(backward_descs.begin() + p.first + 1,
std::move(p.second));
}
// Append backward_descs to BlockDescBind::ops_
for (std::unique_ptr<OpDescBind>& ptr : backward_descs) {
block_op_descs.push_back(std::move(ptr));
return backward_descs;
}
void AppendBackward(ProgramDescBind& program_desc,
const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_var_names;
no_grad_var_names.reserve(no_grad_vars.size() + 1);
no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
for (auto& name : no_grad_vars) {
no_grad_var_names.insert(GradVarName(name));
}
const int root_block_idx = 0;
auto backward_op_descs =
MakeBlockBackward(program_desc, root_block_idx, no_grad_var_names);
auto& forw_op_descs = program_desc.Block(root_block_idx)->ops_;
for (auto& ptr : backward_op_descs) {
forw_op_descs.push_back(std::move(ptr));
}
return;
}
} // namespace framework
......
......@@ -32,8 +32,13 @@ class ProgramDescBind;
class BlockDescBind {
public:
friend void AppendBackwardOpDescs(
BlockDescBind &block_desc, std::unordered_set<std::string> &no_grad_vars);
friend std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind &program_desc, int block_idx,
std::unordered_set<std::string> &no_grad_vars);
friend void AppendBackward(
ProgramDescBind &program_desc,
const std::unordered_set<std::string> &no_grad_vars);
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册