提交 ec783d6b 编写于 作者: Y Yu Yang 提交者: GitHub

Feature/backward return map (#4806)

* Final step of backward, return a map from param_name to grad

* Complete the final step of backward

Return the param_name to grad_info
上级 d7383c6d
...@@ -273,18 +273,30 @@ static bool AllGradInSet(const std::vector<std::string>& names, ...@@ -273,18 +273,30 @@ static bool AllGradInSet(const std::vector<std::string>& names,
return true; return true;
} }
static void CreateGradVarInBlock(BlockDescBind* block_desc, static void CreateGradVarInBlock(
size_t grad_op_start_index) { std::unordered_map<std::string, GradVarInfo>* grad_var_record,
BlockDescBind* block_desc, size_t grad_op_start_index,
const std::unordered_map<std::string, std::string>& param_name_map) {
auto ops = block_desc->AllOps(); auto ops = block_desc->AllOps();
for (size_t op_index = grad_op_start_index; op_index < ops.size(); for (size_t op_index = grad_op_start_index; op_index < ops.size();
++op_index) { ++op_index) {
for (const auto& output : ops[op_index]->Outputs()) { ForEachVarName(ops[op_index]->Outputs(),
for (const auto& real_output : output.second) { [&](const std::string& grad_var_name) {
if (!block_desc->HasVar(real_output)) { if (block_desc->HasVar(grad_var_name)) {
block_desc->NewVar(real_output); return false;
}
} }
block_desc->NewVar(grad_var_name);
auto it = param_name_map.find(grad_var_name);
if (it == param_name_map.end()) {
return false;
} }
auto param_var_name = it->second;
auto& grad_record = (*grad_var_record)[param_var_name];
grad_record.name_ = grad_var_name;
grad_record.block_idx_ = block_desc->ID();
grad_record.op_idx_ = static_cast<int>(op_index);
return false; /* not break */
});
} }
} }
...@@ -400,7 +412,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -400,7 +412,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
return backward_descs; return backward_descs;
} }
void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, std::unordered_map<std::string /*fwd_var_name*/, GradVarInfo /*grad_var_info*/>
AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars) { const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_var_names; std::unordered_set<std::string> no_grad_var_names;
no_grad_var_names.reserve(no_grad_vars.size() + 1); no_grad_var_names.reserve(no_grad_vars.size() + 1);
...@@ -423,20 +436,28 @@ void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, ...@@ -423,20 +436,28 @@ void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
all_ops.push_back(std::move(fill_one_op)); all_ops.push_back(std::move(fill_one_op));
size_t forward_op_num = all_ops.size(); size_t forward_op_num = all_ops.size();
size_t forward_block_num = program_desc.Size(); size_t forward_block_num = program_desc.Size();
// Insert backward operators
std::unordered_map<std::string, std::string> grad_to_var; std::unordered_map<std::string, std::string> grad_to_var;
auto backward_op_descs = MakeBlockBackward(program_desc, root_block_idx, auto backward_op_descs = MakeBlockBackward(program_desc, root_block_idx,
&no_grad_var_names, &grad_to_var); &no_grad_var_names, &grad_to_var);
std::unordered_map<std::string, GradVarInfo> retv;
// Create Variable
for (auto& ptr : backward_op_descs) { for (auto& ptr : backward_op_descs) {
all_ops.push_back(std::move(ptr)); all_ops.push_back(std::move(ptr));
} }
root_block->NewVar(fill_one_op_out); root_block->NewVar(fill_one_op_out);
// create grad_var for all blocks in this program // create grad_var for all blocks in this program
CreateGradVarInBlock(root_block, forward_op_num); CreateGradVarInBlock(&retv, root_block, forward_op_num, grad_to_var);
for (size_t block_index = forward_block_num; for (size_t block_index = forward_block_num;
block_index < program_desc.Size(); ++block_index) { block_index < program_desc.Size(); ++block_index) {
CreateGradVarInBlock(program_desc.Block(block_index), 0); CreateGradVarInBlock(&retv, program_desc.Block(block_index), 0,
grad_to_var);
} }
return retv;
} }
} // namespace framework } // namespace framework
......
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
#pragma once #pragma once
#include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/program_desc.h" #include "paddle/framework/program_desc.h"
...@@ -27,9 +30,16 @@ extern std::unique_ptr<OperatorBase> Backward( ...@@ -27,9 +30,16 @@ extern std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars); const std::unordered_set<std::string>& no_grad_vars);
struct GradVarInfo {
std::string name_;
int block_idx_;
int op_idx_;
};
// TODO(jiayi): Add target as parameter and generate backward op // TODO(jiayi): Add target as parameter and generate backward op
// according to target. // according to target.
void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, std::unordered_map<std::string /*fwd_var_name*/, GradVarInfo /*grad_var_info*/>
AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars); const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework } // namespace framework
......
...@@ -33,15 +33,6 @@ class ProgramDescBind; ...@@ -33,15 +33,6 @@ class ProgramDescBind;
class BlockDescBind { class BlockDescBind {
public: public:
friend std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind &program_desc, int block_idx,
std::unordered_set<std::string> *no_grad_vars,
std::unordered_map<std::string, std::string> *grad_to_var);
friend void AppendBackward(
ProgramDescBind &program_desc, const VarDescBind &target,
const std::unordered_set<std::string> &no_grad_vars);
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {} : prog_(prog), desc_(desc), need_update_(false) {}
...@@ -69,7 +60,9 @@ class BlockDescBind { ...@@ -69,7 +60,9 @@ class BlockDescBind {
BlockDesc *Proto(); BlockDesc *Proto();
private: // FIXME(yuyang18): backward will access private data of BlockDesc.
// Mark it public temporary. We can fix it later.
public:
ProgramDescBind *prog_; // not_own ProgramDescBind *prog_; // not_own
BlockDesc *desc_; // not_own BlockDesc *desc_; // not_own
bool need_update_; bool need_update_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册