提交 7088654a 编写于 作者: D dongzhihong

"add duplicate"

上级 a0669ead
......@@ -13,6 +13,7 @@
limitations under the License. */
#include "paddle/framework/backward.h"
#include <list>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
......@@ -71,7 +72,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
return EmptyOp();
}
// auto* net = new NetOp();
auto* net = new NetOp();
if (forwardOp.IsNetOp()) {
//! TODO(dzh)
......@@ -93,29 +94,32 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
}
// unique the duplicate name
auto uid = uniq_id++;
std::unordered_map<size_t, OperatorBase> insert_postion;
// TODO(dzh): more comment
typedef std::pair<size_t, std::shared_ptr<OperatorBase>> Pos;
std::list<Pos> insert_postion;
for (auto& dup_output_op : dup_output_ops) {
std::string& name = dup_output_op.first;
const std::string& name = dup_output_op.first;
auto& dup_op = dup_output_op.second;
if (dup_op.size() == 1) continue;
std::vector<std::string> dup_outputs;
for (size_t i = 0; i < dup_op.size(); ++i) {
auto op_offset = dup_op[i];
net->ops_[op_offset].Rename(
name,
name + "@RENAME@" + std::to_string(uid) + "@" + std::to_string(i));
}
insert_postion[op_offset] =
OpRegistry::CreateOp("Add", {}, {dup_op->inputs_}, {});
net->AddOp("Add");
net->AddOp();
// process shared variable
// while(dup_op.size()) {
//
// AddOp(OpRegistry::CreateOp("generic_add", {dup_outputs},
// {dup_op->inputs_}, {}));
//}
dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" +
std::to_string(i));
net->ops_[op_offset]->Rename(name, dup_outputs.back());
}
insert_postion.push_back(
{dup_op.back(),
OpRegistry::CreateOp(
"Add", {dup_outputs}, {name},
{{"input_format",
std::vector<int>{0, (int)dup_outputs.size()}}})});
}
insert_postion.sort(
[](const Pos& l, const Pos& r) { return l.first > r.first; });
for (auto& pos : insert_postion) {
net->InsertOp(pos.first, pos.second);
}
} else {
......
......@@ -215,7 +215,7 @@ TEST(Backward, net_input_of_network_not_need_grad) {
ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
all_output.end());
ASSERT_EQ(2, bwd_net->ops_.size());
ASSERT_EQ(2UL, bwd_net->ops_.size());
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3, first_fc_grad->ops_.size());
......
......@@ -74,27 +74,5 @@ std::string NetOp::DebugString() const {
bool NetOp::IsNetOp() const { return true; }
void NetOp::Rename(const std::unordered_map<
std::string, std::vector<OpIdentity>>& dup_output_ops,
size_t& uniq_id) {
for (auto& op : ops_) {
if (op->isNetOp()) {
op->Rename(dup_output_ops, uniq_id);
}
for (size_t i = 0; i < op->outputs_.size(); ++i) {
std::vector<std::string> dup_outputs;
if (op->outputs_[i] ==) {
op->outputs_[i] += std::to_string(uniq_id++);
dup_outputs.push_back(op->outputs_[i]);
}
// add duplicate output together. replace with AddOp
if (dup_outputs.size() >= 2) {
AddOp(OpRegistry::CreateOp("generic_add", {dup_outputs}, {op->inputs_},
{}));
}
}
}
}
} // namespace framework
} // namespace paddle
......@@ -49,11 +49,6 @@ class NetOp : public OperatorBase {
}
}
/**
* @brief rename duplicated output gradient name in Net
*/
bool Rename(size_t& uniq_id);
/**
* @brief Run the network.
*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册