提交 e47770bd 编写于 作者: F fengjiayi

Update

上级 9935fdd3
...@@ -234,18 +234,17 @@ static bool AllGradInSet(const std::vector<std::string>& names, ...@@ -234,18 +234,17 @@ static bool AllGradInSet(const std::vector<std::string>& names,
return true; return true;
} }
std::vector<OpDescBind> CreatBackwardOps( std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs(
const std::unique_ptr<OpDescBind>& op_desc_ptr, const std::unique_ptr<OpDescBind>& op_desc,
unordered_map<std::string>& no_grad_vars) { unordered_set<std::string>& no_grad_vars) {
const OpDescBind& op_desc = *op_desc_ptr; std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
std::vector<OpDescBind> grad_op_descs;
// All input gradients of forwarding operator do not need to calculat. // All input gradients of forwarding operator do not need to calculat.
if (AllGradInSet(op_desc_.InputArgumentNames(), kGradVarSuffix, if (AllGradInSet(op_desc->InputArgumentNames(), kGradVarSuffix,
no_grad_vars)) { no_grad_vars)) {
return grad_op_descs; // empty vector return grad_op_descs; // empty vector
} }
// All output gradients of forwarding operator do not need to calculate. // All output gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& outputs = op_desc_.OutputArugumentNames(); const std::vector<std::string>& outputs = op_desc->OutputArugumentNames();
if (AllGradInSet(outputs, kGradVarSuffix, no_grad_vars)) { if (AllGradInSet(outputs, kGradVarSuffix, no_grad_vars)) {
for (const std::string& name : outputs) { for (const std::string& name : outputs) {
no_grad_vars.insert(GradVarName(name)); no_grad_vars.insert(GradVarName(name));
...@@ -255,50 +254,54 @@ std::vector<OpDescBind> CreatBackwardOps( ...@@ -255,50 +254,54 @@ std::vector<OpDescBind> CreatBackwardOps(
grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc); grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc);
std::vector<OpDescBind> fill_zeros_ops; std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (OpDescBind& desc : grad_op_descs) { for (auto& desc : grad_op_descs) {
for (const std::string& in_name : desc.InputArgumentNames()) { for (const std::string& in_name : desc->InputArgumentNames()) {
if (no_grad_vars.count(in_name)) { if (no_grad_vars.count(in_name)) {
std::string prefix = in_name.substr( std::string prefix = in_name.substr(
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); 0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
std::string new_name = prefix + kZeroVarSuffix; std::string new_name = prefix + kZeroVarSuffix;
desc.Rename(in_name, new_name); desc->Rename(in_name, new_name);
OpDescBind op_desc_bind( OpDescBind* fill_zeros_op = new OpDescBind(
{"fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}}); "fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {});
fill_zeros_ops.push_back(op_desc_bind); pending_fill_zeros_ops.push_back({fill_zeros_op});
} }
} }
for (const std::string& out_name : desc.OutputName()) { for (const std::string& out_name : desc->OutputArgumentName()) {
if (no_grad_vars.count(out_name)) { if (no_grad_vars.count(out_name)) {
desc.Rename(out_name, kEmptyVarName); desc->Rename(out_name, kEmptyVarName);
} }
} }
} }
grad_op_descs.insert(grad_op_descs.begin(), fill_zeros_ops.begin(), grad_op_descs.insert(std::begin(grad_op_descs),
fill_zeros_ops.end()); std::begin(pending_fill_zeros_ops),
std::end(pending_fill_zeros_ops));
// TODO (fengjiayi): RNN op // TODO (fengjiayi): RNN op
return grad_op_descs; return grad_op_descs;
} }
void AppendBackwardOps(BlockDescBind& block_desc, void AppendBackwardOpDescs(
const std::unordered_set<std::string>& no_grad_vars) { BlockDescBind& block_desc,
const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops; std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
size_t grad_desc_idx = 0; size_t grad_desc_idx = 0;
std::deque<std::unique_ptr<OpDescBind>> op_descs = block_desc.ops_; std::deque<std::unique_ptr<OpDescBind>> block_op_descs = block_desc.ops_;
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs; std::vector<std::unique_ptr<OpDescBind>> backward_descs;
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) { for (auto it = block_op_descs.rbegin(); it != block_op_descs.rend(); ++it) {
std::vector<OpDescBind> op_grads = CreatBackwardOps(*it, no_grad_vars); std::vector<std::unique_ptr<OpDescBind>> op_grads =
for (const OpDescBind& desc : op_grads) { MakeGradOpDescs(*it, no_grad_vars);
for (const std::string& out_name : desc.OutputArugumentNames()) { for (const auto& desc : op_grads) {
for (const std::string& out_name : desc->OutputArugumentNames()) {
dup_out_ops[out_name].emplace_back(grad_desc_idx); dup_out_ops[out_name].emplace_back(grad_desc_idx);
} }
++grad_desc_idx; ++grad_desc_idx;
} }
grad_op_descs.insert(grad_op_descs.end(), op_grads.begin(), op_grads.end()); backward_descs.insert(backward_descs.end(), op_grads.begin(),
op_grads.end());
} }
// Check whether some variables are written more than once // Check whether some variables are written more than once
std::list<std::pair<size_t, OpDescBind>> pending_sum_ops; std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops;
for (const auto& dup : dup_out_ops) { for (const auto& dup : dup_out_ops) {
const std::string& out_name = dup.first; const std::string& out_name = dup.first;
const std::vector<size_t> dup_op = dup.second; const std::vector<size_t> dup_op = dup.second;
...@@ -306,25 +309,27 @@ void AppendBackwardOps(BlockDescBind& block_desc, ...@@ -306,25 +309,27 @@ void AppendBackwardOps(BlockDescBind& block_desc,
std::vector<std::string> sum_op_inputs; std::vector<std::string> sum_op_inputs;
for (size_t i = 0; i < dup_op.size(); ++i) { for (size_t i = 0; i < dup_op.size(); ++i) {
std::string new_name = out_name + "@RENAME@" + std::to_string(i); std::string new_name = out_name + "@RENAME@" + std::to_string(i);
grad_op_descs[dup_op[i]].Rename(out_name, new_name); backward_descs[dup_op[i]]->Rename(out_name, new_name);
sum_op_inputs.emplace_back(new_name); sum_op_inputs.emplace_back(new_name);
} }
pending_sum_ops.push_back( OpDescBind* sum_op = new OpDescBind("sum", {{"X", sum_op_inputs}},
{dup_op.back(), {{"Out", {out_name}}}, {});
OpDescBind( pending_sum_ops.push_back({dup_op.back(), {sum_op}});
{"sum", {{"X", {sum_op_inputs}}}, {{"Out", {out_name}}}, {}})});
} }
} }
pending_sum_ops.sort( pending_sum_ops.sort(
[](const std::pair<size_t, OpDescBind>& a, [](const std::pair<size_t, std::unique_ptr<OpDescBind>>& a,
const std::pair<size_t, OpDescBind>& b) { return a.first > b.first; }); const std::pair<size_t, std::unique_ptr<OpDescBind>>& b) {
return a.first > b.first;
});
for (auto& p : pending_sum_ops) { for (auto& p : pending_sum_ops) {
grad_op_descs.insert(grad_op_descs.begin() + p.first + 1, backward_descs.insert(backward_descs.begin() + p.first + 1,
std::move(p.second)); std::move(p.second));
}
// Append grad_op_descs to BlockDescBind::ops_
for () {
} }
// Append backward_descs to BlockDescBind::ops_
block_op_descs.insert(std::end(block_op_descs), std::begin(backward_descs),
std::end(backward_descs));
return;
} }
} // namespace framework } // namespace framework
......
...@@ -24,7 +24,7 @@ extern std::unique_ptr<OperatorBase> Backward( ...@@ -24,7 +24,7 @@ extern std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars); const std::unordered_set<std::string>& no_grad_vars);
extern void AppendBackwardOps( extern void AppendBackwardOpDescs(
BlockDescBind& block_desc, BlockDescBind& block_desc,
const std::unordered_set<std::string>& no_grad_vars); const std::unordered_set<std::string>& no_grad_vars);
......
...@@ -49,7 +49,7 @@ std::vector<std::string> OpDescBind::InputNames() const { ...@@ -49,7 +49,7 @@ std::vector<std::string> OpDescBind::InputNames() const {
return retv; return retv;
} }
std::vector<std::string> InputArgumentNames() const { std::vector<std::string> OpDescBind::InputArgumentNames() const {
std::vector<std::string> retv; std::vector<std::string> retv;
for (auto &ipt : this->inputs_) { for (auto &ipt : this->inputs_) {
retv.insert(retv.end(), ipt.second.begin(), ipt.second.end()); retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
...@@ -80,7 +80,7 @@ std::vector<std::string> OpDescBind::OutputNames() const { ...@@ -80,7 +80,7 @@ std::vector<std::string> OpDescBind::OutputNames() const {
return retv; return retv;
} }
std::vector<std::string> OutputArgumentNames() const { std::vector<std::string> OpDescBind::OutputArgumentNames() const {
std::vector<std::string> retv; std::vector<std::string> retv;
for (auto &ipt : this->outputs_) { for (auto &ipt : this->outputs_) {
retv.insert(retv.end(), ipt.second.begin(), ipt.second.end()); retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
...@@ -137,12 +137,13 @@ const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap() ...@@ -137,12 +137,13 @@ const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap()
return attrs_; return attrs_;
} }
void Rename(const std::string &old_name, const std::string &new_name) { void OpDescBind::Rename(const std::string &old_name,
for (std : string &input : inputs_) { const std::string &new_name) {
for (auto &input : inputs_) {
std::replace(input.second.begin(), input.second.end(), old_name, new_name); std::replace(input.second.begin(), input.second.end(), old_name, new_name);
} }
for (std::string &output : outputs_) { for (auto &output : outputs_) {
std::repalce(output.second.begin(), output.second.end(), old_name, std::replace(output.second.begin(), output.second.end(), old_name,
new_name); new_name);
} }
need_update_ = true; need_update_ = true;
......
...@@ -57,7 +57,8 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateGradOp(const OperatorBase& op) { ...@@ -57,7 +57,8 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateGradOp(const OperatorBase& op) {
return std::unique_ptr<OperatorBase>(BuildGradOp(&op)); return std::unique_ptr<OperatorBase>(BuildGradOp(&op));
} }
static std::vector<OpDescBind> CreateGradOpDescs(const OpDescBind& op_desc) { static std::vector<std::unique_ptr<OpDescBind>> OpRegistry::CreateGradOpDescs(
const OpDescBind& op_desc) {
auto& info = OpInfoMap::Instance().Get(op_desc.Type()); auto& info = OpInfoMap::Instance().Get(op_desc.Type());
return info.grad_op_maker_(op_desc); return info.grad_op_maker_(op_desc);
} }
......
...@@ -69,7 +69,8 @@ class OpRegistry { ...@@ -69,7 +69,8 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op); static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op);
static std::vector<OpDescBind> CreateGradOpDescs(const OpDescBind& op_desc); static std::vector<std::unique_ptr<OpDescBind>> CreateGradOpDescs(
const OpDescBind& op_desc);
}; };
class Registrar { class Registrar {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册