提交 37b0bb15 编写于 作者: F fengjiayi

Fix compile errors

上级 d3559108
......@@ -14,9 +14,11 @@
#include "paddle/framework/backward.h"
#include <deque>
#include <list>
#include <memory>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
......@@ -254,23 +256,22 @@ static bool AllGradInSet(const std::vector<std::string>& names,
std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs(
const std::unique_ptr<OpDescBind>& op_desc,
unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string>& no_grad_vars) {
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
// All input gradients of forwarding operator do not need to calculat.
if (AllGradInSet(op_desc->InputArgumentNames(), kGradVarSuffix,
no_grad_vars)) {
if (AllGradInSet(op_desc->InputArgumentNames(), no_grad_vars)) {
return grad_op_descs; // empty vector
}
// All output gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& outputs = op_desc->OutputArugumentNames();
if (AllGradInSet(outputs, kGradVarSuffix, no_grad_vars)) {
const std::vector<std::string>& outputs = op_desc->OutputArgumentNames();
if (AllGradInSet(outputs, no_grad_vars)) {
for (const std::string& name : outputs) {
no_grad_vars.insert(GradVarName(name));
}
return grad_op_descs; // empty vector
}
grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc);
grad_op_descs = OpRegistry::CreateGradOpDescs(*op_desc);
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) {
......@@ -280,43 +281,43 @@ std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs(
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
std::string new_name = prefix + kZeroVarSuffix;
desc->Rename(in_name, new_name);
OpDescBind* fill_zeros_op = new OpDescBind(
"fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {});
pending_fill_zeros_ops.push_back({fill_zeros_op});
std::unique_ptr<OpDescBind> fill_zeros_op(new OpDescBind(
"fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}));
pending_fill_zeros_ops.push_back(std::move(fill_zeros_op));
}
}
for (const std::string& out_name : desc->OutputArgumentName()) {
for (const std::string& out_name : desc->OutputArgumentNames()) {
if (no_grad_vars.count(out_name)) {
desc->Rename(out_name, kEmptyVarName);
}
}
}
grad_op_descs.insert(std::begin(grad_op_descs),
std::begin(pending_fill_zeros_ops),
std::end(pending_fill_zeros_ops));
for (auto& p : pending_fill_zeros_ops) {
grad_op_descs.push_back(std::move(p));
}
// TODO (fengjiayi): RNN op
// TODO(fengjiayi): RNN op
return grad_op_descs;
}
void AppendBackwardOpDescs(
BlockDescBind& block_desc,
const std::unordered_set<std::string>& no_grad_vars) {
void AppendBackwardOpDescs(BlockDescBind& block_desc,
std::unordered_set<std::string>& no_grad_vars) {
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
size_t grad_desc_idx = 0;
std::deque<std::unique_ptr<OpDescBind>> block_op_descs = block_desc.ops_;
std::deque<std::unique_ptr<OpDescBind>>& block_op_descs = block_desc.ops_;
std::vector<std::unique_ptr<OpDescBind>> backward_descs;
for (auto it = block_op_descs.rbegin(); it != block_op_descs.rend(); ++it) {
std::vector<std::unique_ptr<OpDescBind>> op_grads =
MakeGradOpDescs(*it, no_grad_vars);
for (const auto& desc : op_grads) {
for (const std::string& out_name : desc->OutputArugumentNames()) {
for (const std::string& out_name : desc->OutputArgumentNames()) {
dup_out_ops[out_name].emplace_back(grad_desc_idx);
}
++grad_desc_idx;
}
backward_descs.insert(backward_descs.end(), op_grads.begin(),
op_grads.end());
std::transform(
op_grads.begin(), op_grads.end(), std::back_inserter(backward_descs),
[](std::unique_ptr<OpDescBind>& ptr) { return std::move(ptr); });
}
// Check whether some variables are written more than once
std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops;
......@@ -330,9 +331,9 @@ void AppendBackwardOpDescs(
backward_descs[dup_op[i]]->Rename(out_name, new_name);
sum_op_inputs.emplace_back(new_name);
}
OpDescBind* sum_op = new OpDescBind("sum", {{"X", sum_op_inputs}},
{{"Out", {out_name}}}, {});
pending_sum_ops.push_back({dup_op.back(), {sum_op}});
std::unique_ptr<OpDescBind> sum_op(new OpDescBind(
"sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {}));
pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
}
}
pending_sum_ops.sort(
......@@ -345,8 +346,9 @@ void AppendBackwardOpDescs(
std::move(p.second));
}
// Append backward_descs to BlockDescBind::ops_
block_op_descs.insert(std::end(block_op_descs), std::begin(backward_descs),
std::end(backward_descs));
for (std::unique_ptr<OpDescBind>& ptr : backward_descs) {
block_op_descs.push_back(std::move(ptr));
}
return;
}
......
......@@ -32,9 +32,8 @@ class ProgramDescBind;
class BlockDescBind {
public:
friend void AppendBackwardOps(
BlockDescBind &block_desc,
const std::unordered_set<std::string> &no_grad_vars);
friend void AppendBackwardOpDescs(
BlockDescBind &block_desc, std::unordered_set<std::string> &no_grad_vars);
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册