提交 b9f2bb37 编写于 作者: D dongzhihong

"wait add generic"

上级 f9fab14c
...@@ -52,6 +52,11 @@ static std::shared_ptr<OperatorBase> EmptyOp() { ...@@ -52,6 +52,11 @@ static std::shared_ptr<OperatorBase> EmptyOp() {
static std::shared_ptr<OperatorBase> BackwardImpl( static std::shared_ptr<OperatorBase> BackwardImpl(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
// struct OpIdentity {
// size_t local_op_id;
// size_t op_output_offset;
// };
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) { no_grad_names)) {
return EmptyOp(); return EmptyOp();
...@@ -66,44 +71,51 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -66,44 +71,51 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
return EmptyOp(); return EmptyOp();
} }
auto* net = new NetOp(); // auto* net = new NetOp();
if (forwardOp.IsNetOp()) { if (forwardOp.IsNetOp()) {
//! TODO(dzh) //! TODO(dzh)
std::unordered_map<std::string, int> dup_output; std::unordered_map<std::string /*var name*/,
std::unordered_map<std::string, std::vector<int>> dup_output_ops; std::vector<size_t> /*op offs et*/>
// const unsigned uniq_id_local = uniq_id; dup_output_ops;
int op_id_offset = 0; size_t local_op_id = 0;
// Because it is a net op, it can static_cast. // Because it is a net op, it can static_cast.
auto& forwardNet = static_cast<const NetOp&>(forwardOp); auto& forwardNet = static_cast<const NetOp&>(forwardOp);
// travesal subnet/op
for (auto& fwd : forwardNet.ops_) { for (auto& fwd : forwardNet.ops_) {
auto bwd = Backward(*fwd, no_grad_names); auto bwd = Backward(*fwd, no_grad_names);
net->AddOp(bwd); net->AddOp(bwd);
for (size_t i = 0; i < bwd->outputs_.size(); ++i) { for (size_t i = 0; i < bwd->outputs_.size(); ++i) {
bwd->outputs_[i] += OperatorBase::EMPTY_VAR_NAME(); dup_output_ops[bwd->outputs_[i]].emplace_back(local_op_id);
if (dup_output.find(bwd->inputs_[i]) == dup_output.end()) {
dup_output[bwd->inputs_[i]] = 1;
dup_output_ops[bwd->inputs_[i]] = std::vector<int>{op_id_offset++};
} else {
dup_output[bwd->inputs_[i]]++;
dup_output_ops[bwd->inputs_[i]].emplace_back(op_id_offset++);
}
} }
local_op_id++;
} }
for (auto dup : dup_output) { // unique the duplicate name
if (dup.second == 1) continue; auto uid = uniq_id++;
auto op_ids = dup_output_ops.at(dup.first); std::unordered_map<size_t, OperatorBase> insert_postion;
for (auto& op_id : op_ids) { for (auto& dup_output_op : dup_output_ops) {
auto& op_ptr = net->ops_[op_id]; std::string& name = dup_output_op.first;
for (size_t i = 0; i < op_ptr->inputs_.size(); ++i) { auto& dup_op = dup_output_op.second;
if (op_ptr->inputs_[i] == dup.first) { if (dup_op.size() == 1) continue;
// unique the duplicate name std::vector<std::string> dup_outputs;
op_ptr->inputs_[i] += std::to_string(uniq_id++);
// TODO(dzh): need a generic add op here for (size_t i = 0; i < dup_op.size(); ++i) {
} auto op_offset = dup_op[i];
} net->ops_[op_offset].Rename(
name,
name + "@RENAME@" + std::to_string(uid) + "@" + std::to_string(i));
} }
insert_postion[op_offset] =
OpRegistry::CreateOp("Add", {}, {dup_op->inputs_}, {});
net->AddOp("Add");
net->AddOp();
// process shared variable
// while(dup_op.size()) {
//
// AddOp(OpRegistry::CreateOp("generic_add", {dup_outputs},
// {dup_op->inputs_}, {}));
//}
} }
} else { } else {
......
...@@ -74,5 +74,27 @@ std::string NetOp::DebugString() const { ...@@ -74,5 +74,27 @@ std::string NetOp::DebugString() const {
bool NetOp::IsNetOp() const { return true; } bool NetOp::IsNetOp() const { return true; }
void NetOp::Rename(const std::unordered_map<
std::string, std::vector<OpIdentity>>& dup_output_ops,
size_t& uniq_id) {
for (auto& op : ops_) {
if (op->isNetOp()) {
op->Rename(dup_output_ops, uniq_id);
}
for (size_t i = 0; i < op->outputs_.size(); ++i) {
std::vector<std::string> dup_outputs;
if (op->outputs_[i] ==) {
op->outputs_[i] += std::to_string(uniq_id++);
dup_outputs.push_back(op->outputs_[i]);
}
// add duplicate output together. replace with AddOp
if (dup_outputs.size() >= 2) {
AddOp(OpRegistry::CreateOp("generic_add", {dup_outputs}, {op->inputs_},
{}));
}
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -49,6 +49,11 @@ class NetOp : public OperatorBase { ...@@ -49,6 +49,11 @@ class NetOp : public OperatorBase {
} }
} }
/**
* @brief rename duplicated output gradient name in Net
*/
bool Rename(size_t& uniq_id);
/** /**
* @brief Run the network. * @brief Run the network.
* *
...@@ -88,5 +93,9 @@ class NetOp : public OperatorBase { ...@@ -88,5 +93,9 @@ class NetOp : public OperatorBase {
} }
}; };
/**
* @brief Identify operator in local Net. used in backward
*/
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -99,5 +99,11 @@ std::string OperatorBase::DebugString() const { ...@@ -99,5 +99,11 @@ std::string OperatorBase::DebugString() const {
return ss.str(); return ss.str();
} }
void OperatorBase::Rename(const std::string& old_name,
const std::string& new_name) {
std::replace(inputs_.begin(), inputs_.end(), old_name, new_name);
std::replace(outputs_.begin(), outputs_.end(), old_name, new_name);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include <boost/variant.hpp> #include <boost/variant.hpp>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -95,6 +96,9 @@ class OperatorBase { ...@@ -95,6 +96,9 @@ class OperatorBase {
virtual bool IsNetOp() const { return false; } virtual bool IsNetOp() const { return false; }
/// rename inputs outputs name
void Rename(const std::string& old_name, const std::string& new_name);
//! Get a input with argument's name described in `op_proto` //! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const; const std::string& Input(const std::string& name) const;
//! Get a input which has multiple variables. //! Get a input which has multiple variables.
...@@ -108,7 +112,13 @@ class OperatorBase { ...@@ -108,7 +112,13 @@ class OperatorBase {
public: public:
std::string type_; std::string type_;
// NOTE: in case of OpGrad, inputs_ contains:
// I (Inputs)
// O (Outputs)
// OG (Output Gradients)
std::vector<std::string> inputs_; std::vector<std::string> inputs_;
// NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients)
std::vector<std::string> outputs_; std::vector<std::string> outputs_;
AttributeMap attrs_; AttributeMap attrs_;
// store the arguments' offset described in op_desc. // store the arguments' offset described in op_desc.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册