提交 9935fdd3 编写于 作者: F fengjiayi

Update

上级 d2701959
...@@ -235,14 +235,17 @@ static bool AllGradInSet(const std::vector<std::string>& names, ...@@ -235,14 +235,17 @@ static bool AllGradInSet(const std::vector<std::string>& names,
} }
std::vector<OpDescBind> CreatBackwardOps( std::vector<OpDescBind> CreatBackwardOps(
const OpDescBind& op_desc, unordered_map<std::string>& no_grad_vars) { const std::unique_ptr<OpDescBind>& op_desc_ptr,
unordered_map<std::string>& no_grad_vars) {
const OpDescBind& op_desc = *op_desc_ptr;
std::vector<OpDescBind> grad_op_descs; std::vector<OpDescBind> grad_op_descs;
// All input gradients of forwarding operator do not need to calculat. // All input gradients of forwarding operator do not need to calculat.
if (AllGradInSet(op_desc_.InputNames(), kGradVarSuffix, no_grad_vars)) { if (AllGradInSet(op_desc_.InputArgumentNames(), kGradVarSuffix,
no_grad_vars)) {
return grad_op_descs; // empty vector return grad_op_descs; // empty vector
} }
// All output gradients of forwarding operator do not need to calculate. // All output gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& outputs = op_desc_.OutputNames(); const std::vector<std::string>& outputs = op_desc_.OutputArugumentNames();
if (AllGradInSet(outputs, kGradVarSuffix, no_grad_vars)) { if (AllGradInSet(outputs, kGradVarSuffix, no_grad_vars)) {
for (const std::string& name : outputs) { for (const std::string& name : outputs) {
no_grad_vars.insert(GradVarName(name)); no_grad_vars.insert(GradVarName(name));
...@@ -254,7 +257,7 @@ std::vector<OpDescBind> CreatBackwardOps( ...@@ -254,7 +257,7 @@ std::vector<OpDescBind> CreatBackwardOps(
std::vector<OpDescBind> fill_zeros_ops; std::vector<OpDescBind> fill_zeros_ops;
for (OpDescBind& desc : grad_op_descs) { for (OpDescBind& desc : grad_op_descs) {
for (const std::string& in_name : desc.InputNames()) { for (const std::string& in_name : desc.InputArgumentNames()) {
if (no_grad_vars.count(in_name)) { if (no_grad_vars.count(in_name)) {
std::string prefix = in_name.substr( std::string prefix = in_name.substr(
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); 0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
...@@ -278,5 +281,51 @@ std::vector<OpDescBind> CreatBackwardOps( ...@@ -278,5 +281,51 @@ std::vector<OpDescBind> CreatBackwardOps(
return grad_op_descs; return grad_op_descs;
} }
void AppendBackwardOps(BlockDescBind& block_desc,
const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
size_t grad_desc_idx = 0;
std::deque<std::unique_ptr<OpDescBind>> op_descs = block_desc.ops_;
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
std::vector<OpDescBind> op_grads = CreatBackwardOps(*it, no_grad_vars);
for (const OpDescBind& desc : op_grads) {
for (const std::string& out_name : desc.OutputArugumentNames()) {
dup_out_ops[out_name].emplace_back(grad_desc_idx);
}
++grad_desc_idx;
}
grad_op_descs.insert(grad_op_descs.end(), op_grads.begin(), op_grads.end());
}
// Check whether some variables are written more than once
std::list<std::pair<size_t, OpDescBind>> pending_sum_ops;
for (const auto& dup : dup_out_ops) {
const std::string& out_name = dup.first;
const std::vector<size_t> dup_op = dup.second;
if (out_name != kEmptyVarName && dup_op.size() > 1) {
std::vector<std::string> sum_op_inputs;
for (size_t i = 0; i < dup_op.size(); ++i) {
std::string new_name = out_name + "@RENAME@" + std::to_string(i);
grad_op_descs[dup_op[i]].Rename(out_name, new_name);
sum_op_inputs.emplace_back(new_name);
}
pending_sum_ops.push_back(
{dup_op.back(),
OpDescBind(
{"sum", {{"X", {sum_op_inputs}}}, {{"Out", {out_name}}}, {}})});
}
}
pending_sum_ops.sort(
[](const std::pair<size_t, OpDescBind>& a,
const std::pair<size_t, OpDescBind>& b) { return a.first > b.first; });
for (auto& p : pending_sum_ops) {
grad_op_descs.insert(grad_op_descs.begin() + p.first + 1,
std::move(p.second));
}
// Append grad_op_descs to BlockDescBind::ops_
for () {
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -32,6 +32,10 @@ class ProgramDescBind; ...@@ -32,6 +32,10 @@ class ProgramDescBind;
class BlockDescBind { class BlockDescBind {
public: public:
friend void AppendBackwardOps(
BlockDescBind &block_desc,
const std::unordered_set<std::string> &no_grad_vars);
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {} : prog_(prog), desc_(desc), need_update_(false) {}
......
...@@ -49,6 +49,14 @@ std::vector<std::string> OpDescBind::InputNames() const { ...@@ -49,6 +49,14 @@ std::vector<std::string> OpDescBind::InputNames() const {
return retv; return retv;
} }
std::vector<std::string> InputArgumentNames() const {
std::vector<std::string> retv;
for (auto &ipt : this->inputs_) {
retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
}
return retv;
}
void OpDescBind::SetInput(const std::string &param_name, void OpDescBind::SetInput(const std::string &param_name,
const std::vector<std::string> &args) { const std::vector<std::string> &args) {
need_update_ = true; need_update_ = true;
...@@ -72,6 +80,14 @@ std::vector<std::string> OpDescBind::OutputNames() const { ...@@ -72,6 +80,14 @@ std::vector<std::string> OpDescBind::OutputNames() const {
return retv; return retv;
} }
std::vector<std::string> OutputArgumentNames() const {
std::vector<std::string> retv;
for (auto &ipt : this->outputs_) {
retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
}
return retv;
}
void OpDescBind::SetOutput(const std::string &param_name, void OpDescBind::SetOutput(const std::string &param_name,
const std::vector<std::string> &args) { const std::vector<std::string> &args) {
need_update_ = true; need_update_ = true;
......
...@@ -42,6 +42,8 @@ class OpDescBind { ...@@ -42,6 +42,8 @@ class OpDescBind {
std::vector<std::string> InputNames() const; std::vector<std::string> InputNames() const;
std::vector<std::string> InputArgumentNames() const;
void SetInput(const std::string &param_name, void SetInput(const std::string &param_name,
const std::vector<std::string> &args); const std::vector<std::string> &args);
...@@ -49,6 +51,8 @@ class OpDescBind { ...@@ -49,6 +51,8 @@ class OpDescBind {
std::vector<std::string> OutputNames() const; std::vector<std::string> OutputNames() const;
std::vector<std::string> OutputArgumentNames() const;
void SetOutput(const std::string &param_name, void SetOutput(const std::string &param_name,
const std::vector<std::string> &args); const std::vector<std::string> &args);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册