diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 1531cb53f9dc7ace9911a231e928e869cd7eca28..a4660d7156e506aba3021749214c263771bf676b 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -49,9 +49,11 @@ static std::shared_ptr EmptyOp() { return net_op; } -static std::shared_ptr BackwardImpl( - const OperatorBase& forwardOp, - std::unordered_set& no_grad_names, int& uniq_id) { +static void DeDuplicate(NetOp* net, std::unordered_se) + + static std::shared_ptr BackwardImpl( + const OperatorBase& forwardOp, + std::unordered_set& no_grad_names, unsigned& uniq_id) { if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), no_grad_names)) { return EmptyOp(); @@ -70,6 +72,39 @@ static std::shared_ptr BackwardImpl( if (forwardOp.IsNetOp()) { //! TODO(dzh) + std::unordered_map dup_output; + std::unordered_map> dup_output_ops; + const unsigned uniq_id_local = uniq_id; + unsigned op_id_offset = 0; + for (auto& fwd : forwardOp) { + auto bwd = Backward(fwd, no_grad_names); + net->AddOp(bwd); + for (size_t i = 0; i < bwd.outputs_; ++i) { + bwd->outputs_[i] += OperatorBase::EMPTY_VAR_NAME(); + if (dup_output.find(bwd->inputs_[i]) == dup_output.end()) { + dup_output[bwd->inputs_[i]] = 1; + dup_output_ops[bwd->inputs_[i]] = std::vector{op_id_offset++}; + } else { + dup_output[bwd->inputs_[i]]++; + dup_output_ops[bwd->inputs_[i]].emplace_back(op_id_offset++); + } + } + } + for (auto dup : dup_output) { + if (dup.second == 1) continue; + auto op_ids = dup_output_ops.at(dup.first); + for (auto& op_id : op_ids) { + auto& op_ptr = net->ops_[op_id]; + for (size_t i = 0; i < op_ptr->inputs_.size(); ++i) { + if (op_ptr->inputs_[i] == dup.first) { + // unique the duplicate name + op_ptr->inputs_[i] += std::to_string(uniq_id++); + // TODO(dzh): need a generic add op here + } + } + } + } + } else { //! TODO(fjy) std::shared_ptr grad_op = OpRegistry::CreateGradOp(forwardOp);