提交 2c46666e 编写于 作者: F fengjiayi 提交者: GitHub

Add grad_name_map to record correspondences between vars and grad_vars (#4794)

* Add grad_name_map

* Fix bug

* Fix bug

* Follow comments
上级 9165235a
...@@ -28,15 +28,15 @@ namespace paddle { ...@@ -28,15 +28,15 @@ namespace paddle {
namespace framework { namespace framework {
static inline std::unique_ptr<OperatorBase> CreateGradOp( static inline std::unique_ptr<OperatorBase> CreateGradOp(
const OperatorBase& op, const OperatorBase& op, const std::unordered_set<std::string>& no_grad_set,
const std::unordered_set<std::string>& no_grad_set) { std::unordered_map<std::string, std::string>* grad_to_var) {
OpDescBind op_desc; OpDescBind op_desc;
op_desc.SetInputMap(op.Inputs()); op_desc.SetInputMap(op.Inputs());
op_desc.SetOutputMap(op.Outputs()); op_desc.SetOutputMap(op.Outputs());
op_desc.SetType(op.Type()); op_desc.SetType(op.Type());
op_desc.SetAttrMap(op.Attrs()); op_desc.SetAttrMap(op.Attrs());
auto& info = OpInfoMap::Instance().Get(op.Type()); auto& info = OpInfoMap::Instance().Get(op.Type());
auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set); auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set, grad_to_var);
std::vector<std::unique_ptr<OperatorBase>> grad_ops; std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size()); grad_ops.reserve(grad_descs.size());
std::transform(grad_descs.begin(), grad_descs.end(), std::transform(grad_descs.begin(), grad_descs.end(),
...@@ -99,7 +99,9 @@ static std::unique_ptr<OperatorBase> NOP() { ...@@ -99,7 +99,9 @@ static std::unique_ptr<OperatorBase> NOP() {
// See Backward.h for details // See Backward.h for details
static std::unique_ptr<OperatorBase> BackwardRecursive( static std::unique_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { std::unordered_set<std::string>& no_grad_names,
std::unordered_map<std::string, std::string>* grad_to_var,
size_t& uniq_id) {
// If all input gradients of forwarding operator do not need to calculate, // If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take // just return an NOP. Not return null ptr because NOP does not take
// too much time for calculation, but it is useful for simplifying logic. // too much time for calculation, but it is useful for simplifying logic.
...@@ -137,7 +139,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -137,7 +139,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it, ++local_op_id) { ++it, ++local_op_id) {
auto& fwd = *it; auto& fwd = *it;
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); auto bwd = BackwardRecursive(*fwd, no_grad_names, grad_to_var, uniq_id);
ForEachVarName(bwd->Outputs(), ForEachVarName(bwd->Outputs(),
[&dup_output_ops, local_op_id](const std::string& out) { [&dup_output_ops, local_op_id](const std::string& out) {
dup_output_ops[out].emplace_back(local_op_id); dup_output_ops[out].emplace_back(local_op_id);
...@@ -189,7 +191,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -189,7 +191,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
} }
} else { } else {
std::unique_ptr<OperatorBase> grad_op( std::unique_ptr<OperatorBase> grad_op(
CreateGradOp(forwardOp, no_grad_names)); CreateGradOp(forwardOp, no_grad_names, grad_to_var));
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
const std::string& grad_input) { const std::string& grad_input) {
...@@ -228,7 +230,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -228,7 +230,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
*static_cast<const OperatorBase*>(&rnnop.stepnet()); *static_cast<const OperatorBase*>(&rnnop.stepnet());
// create stepnet's gradient op // create stepnet's gradient op
rnn_grad_op->set_stepnet( rnn_grad_op->set_stepnet(
BackwardRecursive(stepnet_op, no_grad_names, uniq_id)); BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id));
} }
if (net->ops_.empty()) { // Current no aux op is added to network if (net->ops_.empty()) { // Current no aux op is added to network
...@@ -255,7 +257,8 @@ std::unique_ptr<OperatorBase> Backward( ...@@ -255,7 +257,8 @@ std::unique_ptr<OperatorBase> Backward(
no_grad_names.insert(name + kGradVarSuffix); no_grad_names.insert(name + kGradVarSuffix);
} }
size_t uid = 0; size_t uid = 0;
return BackwardRecursive(forwardOp, no_grad_names, uid); std::unordered_map<std::string, std::string> grad_to_var;
return BackwardRecursive(forwardOp, no_grad_names, &grad_to_var, uid);
} }
// ==================================== // // ==================================== //
...@@ -272,30 +275,31 @@ static bool AllGradInSet(const std::vector<std::string>& names, ...@@ -272,30 +275,31 @@ static bool AllGradInSet(const std::vector<std::string>& names,
std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
const std::unique_ptr<OpDescBind>& op_desc, const std::unique_ptr<OpDescBind>& op_desc,
std::unordered_set<std::string>& no_grad_vars) { std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) {
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs; std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
// All input gradients of forwarding operator do not need to calculate. // All input gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& inputs = op_desc->InputArgumentNames(); const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
if (AllGradInSet(inputs, no_grad_vars)) { if (AllGradInSet(inputs, *no_grad_vars)) {
return grad_op_descs; // empty vector return grad_op_descs; // empty vector
} }
// All output gradients of forwarding operator do not need to calculate. // All output gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& outputs = op_desc->OutputArgumentNames(); const std::vector<std::string>& outputs = op_desc->OutputArgumentNames();
if (AllGradInSet(outputs, no_grad_vars)) { if (AllGradInSet(outputs, *no_grad_vars)) {
for (const std::string& name : inputs) { for (const std::string& name : inputs) {
no_grad_vars.insert(GradVarName(name)); no_grad_vars->insert(GradVarName(name));
} }
return grad_op_descs; // empty vector return grad_op_descs; // empty vector
} }
grad_op_descs = OpInfoMap::Instance() grad_op_descs = OpInfoMap::Instance()
.Get(op_desc->Type()) .Get(op_desc->Type())
.GradOpMaker()(*op_desc, no_grad_vars); .GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var);
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops; std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) { for (auto& desc : grad_op_descs) {
for (const std::string& in_name : desc->InputArgumentNames()) { for (const std::string& in_name : desc->InputArgumentNames()) {
if (no_grad_vars.count(in_name)) { if (no_grad_vars->count(in_name)) {
std::string prefix = in_name.substr( std::string prefix = in_name.substr(
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); 0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
std::string new_name = prefix + kZeroVarSuffix; std::string new_name = prefix + kZeroVarSuffix;
...@@ -315,7 +319,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( ...@@ -315,7 +319,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind& program_desc, int block_idx, ProgramDescBind& program_desc, int block_idx,
std::unordered_set<std::string>& no_grad_vars) { std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) {
BlockDescBind* cur_block = program_desc.Block(block_idx); BlockDescBind* cur_block = program_desc.Block(block_idx);
std::deque<std::unique_ptr<OpDescBind>>& op_descs = cur_block->ops_; std::deque<std::unique_ptr<OpDescBind>>& op_descs = cur_block->ops_;
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops; std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
...@@ -323,15 +328,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -323,15 +328,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std::vector<std::unique_ptr<OpDescBind>> backward_descs; std::vector<std::unique_ptr<OpDescBind>> backward_descs;
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) { for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
std::vector<std::unique_ptr<OpDescBind>> op_grads = std::vector<std::unique_ptr<OpDescBind>> op_grads =
MakeOpGrad(*it, no_grad_vars); MakeOpGrad(*it, no_grad_vars, grad_to_var);
if ((*it)->Type() == "recurrent") { if ((*it)->Type() == "recurrent") {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_grads.size(), size_t(1), op_grads.size(), size_t(1),
"rnn_op's gradient process should contain only one op."); "rnn_op's gradient process should contain only one op.");
int step_block_idx = (*it)->GetBlockAttr("stop_block"); int step_block_idx = (*it)->GetBlockAttr("stop_block");
auto backward_block_op_descs = auto backward_block_op_descs = MakeBlockBackward(
MakeBlockBackward(program_desc, step_block_idx, no_grad_vars); program_desc, step_block_idx, no_grad_vars, grad_to_var);
BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block); BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block);
for (auto& ptr : backward_block_op_descs) { for (auto& ptr : backward_block_op_descs) {
backward_block->ops_.push_back(std::move(ptr)); backward_block->ops_.push_back(std::move(ptr));
...@@ -387,8 +392,9 @@ void AppendBackward(ProgramDescBind& program_desc, ...@@ -387,8 +392,9 @@ void AppendBackward(ProgramDescBind& program_desc,
no_grad_var_names.insert(GradVarName(name)); no_grad_var_names.insert(GradVarName(name));
} }
const int root_block_idx = 0; const int root_block_idx = 0;
auto backward_op_descs = std::unordered_map<std::string, std::string> grad_to_var;
MakeBlockBackward(program_desc, root_block_idx, no_grad_var_names); auto backward_op_descs = MakeBlockBackward(program_desc, root_block_idx,
&no_grad_var_names, &grad_to_var);
auto& forw_op_descs = program_desc.Block(root_block_idx)->ops_; auto& forw_op_descs = program_desc.Block(root_block_idx)->ops_;
for (auto& ptr : backward_op_descs) { for (auto& ptr : backward_op_descs) {
forw_op_descs.push_back(std::move(ptr)); forw_op_descs.push_back(std::move(ptr));
......
...@@ -35,7 +35,8 @@ class BlockDescBind { ...@@ -35,7 +35,8 @@ class BlockDescBind {
public: public:
friend std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( friend std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind &program_desc, int block_idx, ProgramDescBind &program_desc, int block_idx,
std::unordered_set<std::string> &no_grad_vars); std::unordered_set<std::string> *no_grad_vars,
std::unordered_map<std::string, std::string> *grad_to_var);
friend void AppendBackward( friend void AppendBackward(
ProgramDescBind &program_desc, ProgramDescBind &program_desc,
......
...@@ -99,8 +99,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> { ...@@ -99,8 +99,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->grad_op_maker_ = []( info->grad_op_maker_ = [](
const OpDescBind& fwd_op, const OpDescBind& fwd_op,
const std::unordered_set<std::string>& no_grad_set) { const std::unordered_set<std::string>& no_grad_set,
T maker(fwd_op, no_grad_set); std::unordered_map<std::string, std::string>* grad_to_var) {
T maker(fwd_op, no_grad_set, grad_to_var);
return maker(); return maker();
}; };
} }
......
...@@ -25,8 +25,9 @@ class GradOpDescMakerBase { ...@@ -25,8 +25,9 @@ class GradOpDescMakerBase {
public: public:
explicit GradOpDescMakerBase( explicit GradOpDescMakerBase(
const OpDescBind& fwd_op, const OpDescBind& fwd_op,
const std::unordered_set<std::string>& no_grad_set) const std::unordered_set<std::string>& no_grad_set,
: fwd_op_(fwd_op), no_grad_set_(no_grad_set) {} std::unordered_map<std::string, std::string>* grad_to_var)
: fwd_op_(fwd_op), no_grad_set_(no_grad_set), grad_to_var_(grad_to_var) {}
virtual ~GradOpDescMakerBase() = default; virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0; virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
...@@ -37,11 +38,16 @@ class GradOpDescMakerBase { ...@@ -37,11 +38,16 @@ class GradOpDescMakerBase {
std::vector<std::string> ret_val; std::vector<std::string> ret_val;
auto var_names = this->Input(name); auto var_names = this->Input(name);
ret_val.reserve(var_names.size()); ret_val.reserve(var_names.size());
std::transform( std::transform(var_names.begin(), var_names.end(),
var_names.begin(), var_names.end(), std::back_inserter(ret_val), std::back_inserter(ret_val),
[this](const std::string& fwd_var_name) -> std::string { [this](const std::string& fwd_var_name) -> std::string {
auto g_name = GradVarName(fwd_var_name); auto g_name = GradVarName(fwd_var_name);
return no_grad_set_.count(g_name) == 0 ? g_name : kEmptyVarName; if (no_grad_set_.count(g_name)) {
return kEmptyVarName;
} else {
(*this->grad_to_var_)[g_name] = fwd_var_name;
return g_name;
}
}); });
if (!drop_empty_grad) { if (!drop_empty_grad) {
return ret_val; return ret_val;
...@@ -95,6 +101,7 @@ class GradOpDescMakerBase { ...@@ -95,6 +101,7 @@ class GradOpDescMakerBase {
private: private:
const OpDescBind& fwd_op_; const OpDescBind& fwd_op_;
const std::unordered_set<std::string>& no_grad_set_; const std::unordered_set<std::string>& no_grad_set_;
std::unordered_map<std::string, std::string>* grad_to_var_;
}; };
class SingleGradOpDescMaker : public GradOpDescMakerBase { class SingleGradOpDescMaker : public GradOpDescMakerBase {
......
...@@ -37,7 +37,8 @@ using OpCreator = std::function<OperatorBase*( ...@@ -37,7 +37,8 @@ using OpCreator = std::function<OperatorBase*(
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDescBind>>( using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDescBind>>(
const OpDescBind&, const std::unordered_set<std::string>& /*no_grad_set*/)>; const OpDescBind&, const std::unordered_set<std::string>& /*no_grad_set*/,
std::unordered_map<std::string, std::string>* /*grad_to_var*/)>;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册