diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index d8653b5dd681603b7261e58de02c6787bcdcebfe..5b35de77e4030fb37a63bbe14e9d7d0e8b6b75fe 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -48,9 +48,11 @@ static std::shared_ptr EmptyOp() { return net_op; } -static std::shared_ptr BackwardImpl( - const OperatorBase& forwardOp, - std::unordered_set& no_grad_names, int& uniq_id) { +static void DeDuplicate(NetOp* net, std::unordered_se) + + static std::shared_ptr BackwardImpl( + const OperatorBase& forwardOp, + std::unordered_set& no_grad_names, unsigned& uniq_id) { if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), no_grad_names)) { return EmptyOp(); @@ -68,6 +70,38 @@ static std::shared_ptr BackwardImpl( auto* net = new NetOp(); if (forwardOp.IsNetOp()) { + std::unordered_map dup_output; + std::unordered_map> dup_output_ops; + const unsigned uniq_id_local = uniq_id; + unsigned op_id_offset = 0; + for (auto& fwd : forwardOp) { + auto bwd = Backward(fwd, no_grad_names); + net->AddOp(bwd); + for (size_t i = 0; i < bwd.outputs_; ++i) { + bwd->outputs_[i] += OperatorBase::EMPTY_VAR_NAME(); + if (dup_output.find(bwd->inputs_[i]) == dup_output.end()) { + dup_output[bwd->inputs_[i]] = 1; + dup_output_ops[bwd->inputs_[i]] = std::vector{op_id_offset++}; + } else { + dup_output[bwd->inputs_[i]]++; + dup_output_ops[bwd->inputs_[i]].emplace_back(op_id_offset++); + } + } + } + for (auto dup : dup_output) { + if (dup.second == 1) continue; + auto op_ids = dup_output_ops.at(dup.first); + for (auto& op_id : op_ids) { + auto& op_ptr = net->ops_[op_id]; + for (size_t i = 0; i < op_ptr->inputs_.size(); ++i) { + if (op_ptr->inputs_[i] == dup.first) { + // unique the duplicate name + op_ptr->inputs_[i] += std::to_string(uniq_id++); + } + } + } + } + //! TODO(dzh) } else { //! TODO(fjy)