提交 37b0bb15 编写于 作者: F fengjiayi

Fix compile errors

上级 d3559108
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include <deque>
#include <list> #include <list>
#include <memory> #include <memory>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h" #include "paddle/operators/recurrent_op.h"
...@@ -254,23 +256,22 @@ static bool AllGradInSet(const std::vector<std::string>& names, ...@@ -254,23 +256,22 @@ static bool AllGradInSet(const std::vector<std::string>& names,
std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs( std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs(
const std::unique_ptr<OpDescBind>& op_desc, const std::unique_ptr<OpDescBind>& op_desc,
unordered_set<std::string>& no_grad_vars) { std::unordered_set<std::string>& no_grad_vars) {
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs; std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
// All input gradients of forwarding operator do not need to calculat. // All input gradients of forwarding operator do not need to calculat.
if (AllGradInSet(op_desc->InputArgumentNames(), kGradVarSuffix, if (AllGradInSet(op_desc->InputArgumentNames(), no_grad_vars)) {
no_grad_vars)) {
return grad_op_descs; // empty vector return grad_op_descs; // empty vector
} }
// All output gradients of forwarding operator do not need to calculate. // All output gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& outputs = op_desc->OutputArugumentNames(); const std::vector<std::string>& outputs = op_desc->OutputArgumentNames();
if (AllGradInSet(outputs, kGradVarSuffix, no_grad_vars)) { if (AllGradInSet(outputs, no_grad_vars)) {
for (const std::string& name : outputs) { for (const std::string& name : outputs) {
no_grad_vars.insert(GradVarName(name)); no_grad_vars.insert(GradVarName(name));
} }
return grad_op_descs; // empty vector return grad_op_descs; // empty vector
} }
grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc); grad_op_descs = OpRegistry::CreateGradOpDescs(*op_desc);
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops; std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) { for (auto& desc : grad_op_descs) {
...@@ -280,43 +281,43 @@ std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs( ...@@ -280,43 +281,43 @@ std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs(
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); 0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
std::string new_name = prefix + kZeroVarSuffix; std::string new_name = prefix + kZeroVarSuffix;
desc->Rename(in_name, new_name); desc->Rename(in_name, new_name);
OpDescBind* fill_zeros_op = new OpDescBind( std::unique_ptr<OpDescBind> fill_zeros_op(new OpDescBind(
"fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}); "fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}));
pending_fill_zeros_ops.push_back({fill_zeros_op}); pending_fill_zeros_ops.push_back(std::move(fill_zeros_op));
} }
} }
for (const std::string& out_name : desc->OutputArgumentName()) { for (const std::string& out_name : desc->OutputArgumentNames()) {
if (no_grad_vars.count(out_name)) { if (no_grad_vars.count(out_name)) {
desc->Rename(out_name, kEmptyVarName); desc->Rename(out_name, kEmptyVarName);
} }
} }
} }
grad_op_descs.insert(std::begin(grad_op_descs), for (auto& p : pending_fill_zeros_ops) {
std::begin(pending_fill_zeros_ops), grad_op_descs.push_back(std::move(p));
std::end(pending_fill_zeros_ops)); }
// TODO (fengjiayi): RNN op // TODO(fengjiayi): RNN op
return grad_op_descs; return grad_op_descs;
} }
void AppendBackwardOpDescs( void AppendBackwardOpDescs(BlockDescBind& block_desc,
BlockDescBind& block_desc, std::unordered_set<std::string>& no_grad_vars) {
const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops; std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
size_t grad_desc_idx = 0; size_t grad_desc_idx = 0;
std::deque<std::unique_ptr<OpDescBind>> block_op_descs = block_desc.ops_; std::deque<std::unique_ptr<OpDescBind>>& block_op_descs = block_desc.ops_;
std::vector<std::unique_ptr<OpDescBind>> backward_descs; std::vector<std::unique_ptr<OpDescBind>> backward_descs;
for (auto it = block_op_descs.rbegin(); it != block_op_descs.rend(); ++it) { for (auto it = block_op_descs.rbegin(); it != block_op_descs.rend(); ++it) {
std::vector<std::unique_ptr<OpDescBind>> op_grads = std::vector<std::unique_ptr<OpDescBind>> op_grads =
MakeGradOpDescs(*it, no_grad_vars); MakeGradOpDescs(*it, no_grad_vars);
for (const auto& desc : op_grads) { for (const auto& desc : op_grads) {
for (const std::string& out_name : desc->OutputArugumentNames()) { for (const std::string& out_name : desc->OutputArgumentNames()) {
dup_out_ops[out_name].emplace_back(grad_desc_idx); dup_out_ops[out_name].emplace_back(grad_desc_idx);
} }
++grad_desc_idx; ++grad_desc_idx;
} }
backward_descs.insert(backward_descs.end(), op_grads.begin(), std::transform(
op_grads.end()); op_grads.begin(), op_grads.end(), std::back_inserter(backward_descs),
[](std::unique_ptr<OpDescBind>& ptr) { return std::move(ptr); });
} }
// Check whether some variables are written more than once // Check whether some variables are written more than once
std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops; std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops;
...@@ -330,9 +331,9 @@ void AppendBackwardOpDescs( ...@@ -330,9 +331,9 @@ void AppendBackwardOpDescs(
backward_descs[dup_op[i]]->Rename(out_name, new_name); backward_descs[dup_op[i]]->Rename(out_name, new_name);
sum_op_inputs.emplace_back(new_name); sum_op_inputs.emplace_back(new_name);
} }
OpDescBind* sum_op = new OpDescBind("sum", {{"X", sum_op_inputs}}, std::unique_ptr<OpDescBind> sum_op(new OpDescBind(
{{"Out", {out_name}}}, {}); "sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {}));
pending_sum_ops.push_back({dup_op.back(), {sum_op}}); pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
} }
} }
pending_sum_ops.sort( pending_sum_ops.sort(
...@@ -345,8 +346,9 @@ void AppendBackwardOpDescs( ...@@ -345,8 +346,9 @@ void AppendBackwardOpDescs(
std::move(p.second)); std::move(p.second));
} }
// Append backward_descs to BlockDescBind::ops_ // Append backward_descs to BlockDescBind::ops_
block_op_descs.insert(std::end(block_op_descs), std::begin(backward_descs), for (std::unique_ptr<OpDescBind>& ptr : backward_descs) {
std::end(backward_descs)); block_op_descs.push_back(std::move(ptr));
}
return; return;
} }
......
...@@ -32,9 +32,8 @@ class ProgramDescBind; ...@@ -32,9 +32,8 @@ class ProgramDescBind;
class BlockDescBind { class BlockDescBind {
public: public:
friend void AppendBackwardOps( friend void AppendBackwardOpDescs(
BlockDescBind &block_desc, BlockDescBind &block_desc, std::unordered_set<std::string> &no_grad_vars);
const std::unordered_set<std::string> &no_grad_vars);
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {} : prog_(prog), desc_(desc), need_update_(false) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册