提交 fa6a46a8 编写于 作者: Y Yu Yang

Merge branch 'feature/backward' of github.com:reyoung/Paddle into feature/backward

...@@ -49,9 +49,11 @@ static std::shared_ptr<OperatorBase> EmptyOp() { ...@@ -49,9 +49,11 @@ static std::shared_ptr<OperatorBase> EmptyOp() {
return net_op; return net_op;
} }
static std::shared_ptr<OperatorBase> BackwardImpl( static void DeDuplicate(NetOp* net, std::unordered_se)
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, int& uniq_id) { static std::shared_ptr<OperatorBase> BackwardImpl(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, unsigned& uniq_id) {
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) { no_grad_names)) {
return EmptyOp(); return EmptyOp();
...@@ -70,6 +72,39 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -70,6 +72,39 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
if (forwardOp.IsNetOp()) { if (forwardOp.IsNetOp()) {
//! TODO(dzh) //! TODO(dzh)
std::unordered_map<std::string, int> dup_output;
std::unordered_map<std::string std::vector<int>> dup_output_ops;
const unsigned uniq_id_local = uniq_id;
unsigned op_id_offset = 0;
for (auto& fwd : forwardOp) {
auto bwd = Backward(fwd, no_grad_names);
net->AddOp(bwd);
for (size_t i = 0; i < bwd.outputs_; ++i) {
bwd->outputs_[i] += OperatorBase::EMPTY_VAR_NAME();
if (dup_output.find(bwd->inputs_[i]) == dup_output.end()) {
dup_output[bwd->inputs_[i]] = 1;
dup_output_ops[bwd->inputs_[i]] = std::vector<int>{op_id_offset++};
} else {
dup_output[bwd->inputs_[i]]++;
dup_output_ops[bwd->inputs_[i]].emplace_back(op_id_offset++);
}
}
}
for (auto dup : dup_output) {
if (dup.second == 1) continue;
auto op_ids = dup_output_ops.at(dup.first);
for (auto& op_id : op_ids) {
auto& op_ptr = net->ops_[op_id];
for (size_t i = 0; i < op_ptr->inputs_.size(); ++i) {
if (op_ptr->inputs_[i] == dup.first) {
// unique the duplicate name
op_ptr->inputs_[i] += std::to_string(uniq_id++);
// TODO(dzh): need a generic add op here
}
}
}
}
} else { } else {
//! TODO(fjy) //! TODO(fjy)
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp); std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册