提交 ec783d6b 编写于 作者: Y Yu Yang 提交者: GitHub

Feature/backward return map (#4806)

* Final step of backward, return a map from param_name to grad

* Complete the final step of backward

Return the param_name to grad_info
上级 d7383c6d
......@@ -273,18 +273,30 @@ static bool AllGradInSet(const std::vector<std::string>& names,
return true;
}
static void CreateGradVarInBlock(BlockDescBind* block_desc,
size_t grad_op_start_index) {
static void CreateGradVarInBlock(
std::unordered_map<std::string, GradVarInfo>* grad_var_record,
BlockDescBind* block_desc, size_t grad_op_start_index,
const std::unordered_map<std::string, std::string>& param_name_map) {
auto ops = block_desc->AllOps();
for (size_t op_index = grad_op_start_index; op_index < ops.size();
++op_index) {
for (const auto& output : ops[op_index]->Outputs()) {
for (const auto& real_output : output.second) {
if (!block_desc->HasVar(real_output)) {
block_desc->NewVar(real_output);
}
}
}
ForEachVarName(ops[op_index]->Outputs(),
[&](const std::string& grad_var_name) {
if (block_desc->HasVar(grad_var_name)) {
return false;
}
block_desc->NewVar(grad_var_name);
auto it = param_name_map.find(grad_var_name);
if (it == param_name_map.end()) {
return false;
}
auto param_var_name = it->second;
auto& grad_record = (*grad_var_record)[param_var_name];
grad_record.name_ = grad_var_name;
grad_record.block_idx_ = block_desc->ID();
grad_record.op_idx_ = static_cast<int>(op_index);
return false; /* not break */
});
}
}
......@@ -400,8 +412,9 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
return backward_descs;
}
void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_map<std::string /*fwd_var_name*/, GradVarInfo /*grad_var_info*/>
AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_var_names;
no_grad_var_names.reserve(no_grad_vars.size() + 1);
no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
......@@ -423,20 +436,28 @@ void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
all_ops.push_back(std::move(fill_one_op));
size_t forward_op_num = all_ops.size();
size_t forward_block_num = program_desc.Size();
// Insert backward operators
std::unordered_map<std::string, std::string> grad_to_var;
auto backward_op_descs = MakeBlockBackward(program_desc, root_block_idx,
&no_grad_var_names, &grad_to_var);
std::unordered_map<std::string, GradVarInfo> retv;
// Create Variable
for (auto& ptr : backward_op_descs) {
all_ops.push_back(std::move(ptr));
}
root_block->NewVar(fill_one_op_out);
// create grad_var for all blocks in this program
CreateGradVarInBlock(root_block, forward_op_num);
CreateGradVarInBlock(&retv, root_block, forward_op_num, grad_to_var);
for (size_t block_index = forward_block_num;
block_index < program_desc.Size(); ++block_index) {
CreateGradVarInBlock(program_desc.Block(block_index), 0);
CreateGradVarInBlock(&retv, program_desc.Block(block_index), 0,
grad_to_var);
}
return retv;
}
} // namespace framework
......
......@@ -14,7 +14,10 @@
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/framework/operator.h"
#include "paddle/framework/program_desc.h"
......@@ -27,10 +30,17 @@ extern std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars);
struct GradVarInfo {
std::string name_;
int block_idx_;
int op_idx_;
};
// TODO(jiayi): Add target as parameter and generate backward op
// according to target.
void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars);
std::unordered_map<std::string /*fwd_var_name*/, GradVarInfo /*grad_var_info*/>
AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework
} // namespace paddle
......@@ -33,15 +33,6 @@ class ProgramDescBind;
class BlockDescBind {
public:
friend std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind &program_desc, int block_idx,
std::unordered_set<std::string> *no_grad_vars,
std::unordered_map<std::string, std::string> *grad_to_var);
friend void AppendBackward(
ProgramDescBind &program_desc, const VarDescBind &target,
const std::unordered_set<std::string> &no_grad_vars);
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {}
......@@ -69,7 +60,9 @@ class BlockDescBind {
BlockDesc *Proto();
private:
// FIXME(yuyang18): backward will access private data of BlockDesc.
// Mark it public temporary. We can fix it later.
public:
ProgramDescBind *prog_; // not_own
BlockDesc *desc_; // not_own
bool need_update_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册