From b9f2bb3747512f8bd0f5f0a7e024ff329477aabc Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Thu, 27 Jul 2017 16:44:06 +0800 Subject: [PATCH] "wait add generic" --- paddle/framework/backward.cc | 62 +++++++++++++++++++++--------------- paddle/framework/net.cc | 22 +++++++++++++ paddle/framework/net.h | 9 ++++++ paddle/framework/operator.cc | 6 ++++ paddle/framework/operator.h | 10 ++++++ 5 files changed, 84 insertions(+), 25 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index dae457f8585..8538ad9f0a9 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -52,6 +52,11 @@ static std::shared_ptr EmptyOp() { static std::shared_ptr BackwardImpl( const OperatorBase& forwardOp, std::unordered_set& no_grad_names, size_t& uniq_id) { + // struct OpIdentity { + // size_t local_op_id; + // size_t op_output_offset; + // }; + if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), no_grad_names)) { return EmptyOp(); @@ -66,44 +71,51 @@ static std::shared_ptr BackwardImpl( return EmptyOp(); } - auto* net = new NetOp(); + // auto* net = new NetOp(); if (forwardOp.IsNetOp()) { //! TODO(dzh) - std::unordered_map dup_output; - std::unordered_map> dup_output_ops; - // const unsigned uniq_id_local = uniq_id; - int op_id_offset = 0; + std::unordered_map /*op offs et*/> + dup_output_ops; + size_t local_op_id = 0; // Because it is a net op, it can static_cast. auto& forwardNet = static_cast(forwardOp); + // travesal subnet/op for (auto& fwd : forwardNet.ops_) { auto bwd = Backward(*fwd, no_grad_names); net->AddOp(bwd); for (size_t i = 0; i < bwd->outputs_.size(); ++i) { - bwd->outputs_[i] += OperatorBase::EMPTY_VAR_NAME(); - if (dup_output.find(bwd->inputs_[i]) == dup_output.end()) { - dup_output[bwd->inputs_[i]] = 1; - dup_output_ops[bwd->inputs_[i]] = std::vector{op_id_offset++}; - } else { - dup_output[bwd->inputs_[i]]++; - dup_output_ops[bwd->inputs_[i]].emplace_back(op_id_offset++); - } + dup_output_ops[bwd->outputs_[i]].emplace_back(local_op_id); } + local_op_id++; } - for (auto dup : dup_output) { - if (dup.second == 1) continue; - auto op_ids = dup_output_ops.at(dup.first); - for (auto& op_id : op_ids) { - auto& op_ptr = net->ops_[op_id]; - for (size_t i = 0; i < op_ptr->inputs_.size(); ++i) { - if (op_ptr->inputs_[i] == dup.first) { - // unique the duplicate name - op_ptr->inputs_[i] += std::to_string(uniq_id++); - // TODO(dzh): need a generic add op here - } - } + // unique the duplicate name + auto uid = uniq_id++; + std::unordered_map insert_postion; + for (auto& dup_output_op : dup_output_ops) { + std::string& name = dup_output_op.first; + auto& dup_op = dup_output_op.second; + if (dup_op.size() == 1) continue; + std::vector dup_outputs; + + for (size_t i = 0; i < dup_op.size(); ++i) { + auto op_offset = dup_op[i]; + net->ops_[op_offset].Rename( + name, + name + "@RENAME@" + std::to_string(uid) + "@" + std::to_string(i)); } + insert_postion[op_offset] = + OpRegistry::CreateOp("Add", {}, {dup_op->inputs_}, {}); + net->AddOp("Add"); + net->AddOp(); + // process shared variable + // while(dup_op.size()) { + // + // AddOp(OpRegistry::CreateOp("generic_add", {dup_outputs}, + // {dup_op->inputs_}, {})); + //} } } else { diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 2cd378c6b21..403d96a22de 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -74,5 +74,27 @@ std::string NetOp::DebugString() const { bool NetOp::IsNetOp() const { return true; } +void NetOp::Rename(const std::unordered_map< + std::string, std::vector>& dup_output_ops, + size_t& uniq_id) { + for (auto& op : ops_) { + if (op->isNetOp()) { + op->Rename(dup_output_ops, uniq_id); + } + for (size_t i = 0; i < op->outputs_.size(); ++i) { + std::vector dup_outputs; + if (op->outputs_[i] ==) { + op->outputs_[i] += std::to_string(uniq_id++); + dup_outputs.push_back(op->outputs_[i]); + } + // add duplicate output together. replace with AddOp + if (dup_outputs.size() >= 2) { + AddOp(OpRegistry::CreateOp("generic_add", {dup_outputs}, {op->inputs_}, + {})); + } + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 089c1355951..fa8aaf654ce 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -49,6 +49,11 @@ class NetOp : public OperatorBase { } } + /** + * @brief rename duplicated output gradient name in Net + */ + bool Rename(size_t& uniq_id); + /** * @brief Run the network. * @@ -88,5 +93,9 @@ class NetOp : public OperatorBase { } }; +/** + * @brief Identify operator in local Net. used in backward + */ + } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 1e57e9a20f3..c49b2288d61 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -99,5 +99,11 @@ std::string OperatorBase::DebugString() const { return ss.str(); } +void OperatorBase::Rename(const std::string& old_name, + const std::string& new_name) { + std::replace(inputs_.begin(), inputs_.end(), old_name, new_name); + std::replace(outputs_.begin(), outputs_.end(), old_name, new_name); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index c2cd21a0806..f98359de124 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -95,6 +96,9 @@ class OperatorBase { virtual bool IsNetOp() const { return false; } + /// rename inputs outputs name + void Rename(const std::string& old_name, const std::string& new_name); + //! Get a input with argument's name described in `op_proto` const std::string& Input(const std::string& name) const; //! Get a input which has multiple variables. @@ -108,7 +112,13 @@ class OperatorBase { public: std::string type_; + // NOTE: in case of OpGrad, inputs_ contains: + // I (Inputs) + // O (Outputs) + // OG (Output Gradients) std::vector inputs_; + // NOTE: in case of OpGrad, outputs_ contains + // IG (Inputs Gradients) std::vector outputs_; AttributeMap attrs_; // store the arguments' offset described in op_desc. -- GitLab