提交 b9f2bb37 编写于 作者: D dongzhihong

"wait add generic"

上级 f9fab14c
......@@ -52,6 +52,11 @@ static std::shared_ptr<OperatorBase> EmptyOp() {
static std::shared_ptr<OperatorBase> BackwardImpl(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
// struct OpIdentity {
// size_t local_op_id;
// size_t op_output_offset;
// };
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) {
return EmptyOp();
......@@ -66,44 +71,51 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
return EmptyOp();
}
auto* net = new NetOp();
// auto* net = new NetOp();
if (forwardOp.IsNetOp()) {
//! TODO(dzh)
std::unordered_map<std::string, int> dup_output;
std::unordered_map<std::string, std::vector<int>> dup_output_ops;
// const unsigned uniq_id_local = uniq_id;
int op_id_offset = 0;
std::unordered_map<std::string /*var name*/,
std::vector<size_t> /*op offs et*/>
dup_output_ops;
size_t local_op_id = 0;
// Because it is a net op, it can static_cast.
auto& forwardNet = static_cast<const NetOp&>(forwardOp);
// travesal subnet/op
for (auto& fwd : forwardNet.ops_) {
auto bwd = Backward(*fwd, no_grad_names);
net->AddOp(bwd);
for (size_t i = 0; i < bwd->outputs_.size(); ++i) {
bwd->outputs_[i] += OperatorBase::EMPTY_VAR_NAME();
if (dup_output.find(bwd->inputs_[i]) == dup_output.end()) {
dup_output[bwd->inputs_[i]] = 1;
dup_output_ops[bwd->inputs_[i]] = std::vector<int>{op_id_offset++};
} else {
dup_output[bwd->inputs_[i]]++;
dup_output_ops[bwd->inputs_[i]].emplace_back(op_id_offset++);
}
dup_output_ops[bwd->outputs_[i]].emplace_back(local_op_id);
}
local_op_id++;
}
for (auto dup : dup_output) {
if (dup.second == 1) continue;
auto op_ids = dup_output_ops.at(dup.first);
for (auto& op_id : op_ids) {
auto& op_ptr = net->ops_[op_id];
for (size_t i = 0; i < op_ptr->inputs_.size(); ++i) {
if (op_ptr->inputs_[i] == dup.first) {
// unique the duplicate name
op_ptr->inputs_[i] += std::to_string(uniq_id++);
// TODO(dzh): need a generic add op here
}
}
// unique the duplicate name
auto uid = uniq_id++;
std::unordered_map<size_t, OperatorBase> insert_postion;
for (auto& dup_output_op : dup_output_ops) {
std::string& name = dup_output_op.first;
auto& dup_op = dup_output_op.second;
if (dup_op.size() == 1) continue;
std::vector<std::string> dup_outputs;
for (size_t i = 0; i < dup_op.size(); ++i) {
auto op_offset = dup_op[i];
net->ops_[op_offset].Rename(
name,
name + "@RENAME@" + std::to_string(uid) + "@" + std::to_string(i));
}
insert_postion[op_offset] =
OpRegistry::CreateOp("Add", {}, {dup_op->inputs_}, {});
net->AddOp("Add");
net->AddOp();
// process shared variable
// while(dup_op.size()) {
//
// AddOp(OpRegistry::CreateOp("generic_add", {dup_outputs},
// {dup_op->inputs_}, {}));
//}
}
} else {
......
......@@ -74,5 +74,27 @@ std::string NetOp::DebugString() const {
bool NetOp::IsNetOp() const { return true; }
void NetOp::Rename(const std::unordered_map<
std::string, std::vector<OpIdentity>>& dup_output_ops,
size_t& uniq_id) {
for (auto& op : ops_) {
if (op->isNetOp()) {
op->Rename(dup_output_ops, uniq_id);
}
for (size_t i = 0; i < op->outputs_.size(); ++i) {
std::vector<std::string> dup_outputs;
if (op->outputs_[i] ==) {
op->outputs_[i] += std::to_string(uniq_id++);
dup_outputs.push_back(op->outputs_[i]);
}
// add duplicate output together. replace with AddOp
if (dup_outputs.size() >= 2) {
AddOp(OpRegistry::CreateOp("generic_add", {dup_outputs}, {op->inputs_},
{}));
}
}
}
}
} // namespace framework
} // namespace paddle
......@@ -49,6 +49,11 @@ class NetOp : public OperatorBase {
}
}
/**
* @brief rename duplicated output gradient name in Net
*/
bool Rename(size_t& uniq_id);
/**
* @brief Run the network.
*
......@@ -88,5 +93,9 @@ class NetOp : public OperatorBase {
}
};
/**
* @brief Identify operator in local Net. used in backward
*/
} // namespace framework
} // namespace paddle
......@@ -99,5 +99,11 @@ std::string OperatorBase::DebugString() const {
return ss.str();
}
void OperatorBase::Rename(const std::string& old_name,
const std::string& new_name) {
std::replace(inputs_.begin(), inputs_.end(), old_name, new_name);
std::replace(outputs_.begin(), outputs_.end(), old_name, new_name);
}
} // namespace framework
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
......@@ -95,6 +96,9 @@ class OperatorBase {
virtual bool IsNetOp() const { return false; }
/// rename inputs outputs name
void Rename(const std::string& old_name, const std::string& new_name);
//! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const;
//! Get a input which has multiple variables.
......@@ -108,7 +112,13 @@ class OperatorBase {
public:
std::string type_;
// NOTE: in case of OpGrad, inputs_ contains:
// I (Inputs)
// O (Outputs)
// OG (Output Gradients)
std::vector<std::string> inputs_;
// NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients)
std::vector<std::string> outputs_;
AttributeMap attrs_;
// store the arguments' offset described in op_desc.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册