提交 7088654a 编写于 作者: D dongzhihong

"add duplicate"

上级 a0669ead
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include <list>
#include "paddle/framework/net.h" #include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
...@@ -71,7 +72,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -71,7 +72,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
return EmptyOp(); return EmptyOp();
} }
// auto* net = new NetOp(); auto* net = new NetOp();
if (forwardOp.IsNetOp()) { if (forwardOp.IsNetOp()) {
//! TODO(dzh) //! TODO(dzh)
...@@ -93,29 +94,32 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -93,29 +94,32 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
} }
// unique the duplicate name // unique the duplicate name
auto uid = uniq_id++; auto uid = uniq_id++;
std::unordered_map<size_t, OperatorBase> insert_postion; // TODO(dzh): more comment
typedef std::pair<size_t, std::shared_ptr<OperatorBase>> Pos;
std::list<Pos> insert_postion;
for (auto& dup_output_op : dup_output_ops) { for (auto& dup_output_op : dup_output_ops) {
std::string& name = dup_output_op.first; const std::string& name = dup_output_op.first;
auto& dup_op = dup_output_op.second; auto& dup_op = dup_output_op.second;
if (dup_op.size() == 1) continue; if (dup_op.size() == 1) continue;
std::vector<std::string> dup_outputs; std::vector<std::string> dup_outputs;
for (size_t i = 0; i < dup_op.size(); ++i) { for (size_t i = 0; i < dup_op.size(); ++i) {
auto op_offset = dup_op[i]; auto op_offset = dup_op[i];
net->ops_[op_offset].Rename( dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" +
name, std::to_string(i));
name + "@RENAME@" + std::to_string(uid) + "@" + std::to_string(i)); net->ops_[op_offset]->Rename(name, dup_outputs.back());
} }
insert_postion[op_offset] = insert_postion.push_back(
OpRegistry::CreateOp("Add", {}, {dup_op->inputs_}, {}); {dup_op.back(),
net->AddOp("Add"); OpRegistry::CreateOp(
net->AddOp(); "Add", {dup_outputs}, {name},
// process shared variable {{"input_format",
// while(dup_op.size()) { std::vector<int>{0, (int)dup_outputs.size()}}})});
// }
// AddOp(OpRegistry::CreateOp("generic_add", {dup_outputs}, insert_postion.sort(
// {dup_op->inputs_}, {})); [](const Pos& l, const Pos& r) { return l.first > r.first; });
//} for (auto& pos : insert_postion) {
net->InsertOp(pos.first, pos.second);
} }
} else { } else {
......
...@@ -215,7 +215,7 @@ TEST(Backward, net_input_of_network_not_need_grad) { ...@@ -215,7 +215,7 @@ TEST(Backward, net_input_of_network_not_need_grad) {
ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
all_output.end()); all_output.end());
ASSERT_EQ(2, bwd_net->ops_.size()); ASSERT_EQ(2UL, bwd_net->ops_.size());
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get()); auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3, first_fc_grad->ops_.size()); ASSERT_EQ(3, first_fc_grad->ops_.size());
......
...@@ -74,27 +74,5 @@ std::string NetOp::DebugString() const { ...@@ -74,27 +74,5 @@ std::string NetOp::DebugString() const {
bool NetOp::IsNetOp() const { return true; } bool NetOp::IsNetOp() const { return true; }
void NetOp::Rename(const std::unordered_map<
std::string, std::vector<OpIdentity>>& dup_output_ops,
size_t& uniq_id) {
for (auto& op : ops_) {
if (op->isNetOp()) {
op->Rename(dup_output_ops, uniq_id);
}
for (size_t i = 0; i < op->outputs_.size(); ++i) {
std::vector<std::string> dup_outputs;
if (op->outputs_[i] ==) {
op->outputs_[i] += std::to_string(uniq_id++);
dup_outputs.push_back(op->outputs_[i]);
}
// add duplicate output together. replace with AddOp
if (dup_outputs.size() >= 2) {
AddOp(OpRegistry::CreateOp("generic_add", {dup_outputs}, {op->inputs_},
{}));
}
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -49,11 +49,6 @@ class NetOp : public OperatorBase { ...@@ -49,11 +49,6 @@ class NetOp : public OperatorBase {
} }
} }
/**
* @brief rename duplicated output gradient name in Net
*/
bool Rename(size_t& uniq_id);
/** /**
* @brief Run the network. * @brief Run the network.
* *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册