diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 7b470adb47f63d855ec784e9b2765e7fe0fc3ae6..dae457f858593f476cb0973b7bc119e546f12c79 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -49,11 +49,9 @@ static std::shared_ptr EmptyOp() { return net_op; } -static void DeDuplicate(NetOp* net, std::unordered_se) - - static std::shared_ptr BackwardImpl( - const OperatorBase& forwardOp, - std::unordered_set& no_grad_names, unsigned& uniq_id) { +static std::shared_ptr BackwardImpl( + const OperatorBase& forwardOp, + std::unordered_set& no_grad_names, size_t& uniq_id) { if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), no_grad_names)) { return EmptyOp(); @@ -73,13 +71,16 @@ static void DeDuplicate(NetOp* net, std::unordered_se) if (forwardOp.IsNetOp()) { //! TODO(dzh) std::unordered_map dup_output; - std::unordered_map> dup_output_ops; - const unsigned uniq_id_local = uniq_id; - unsigned op_id_offset = 0; - for (auto& fwd : forwardOp) { - auto bwd = Backward(fwd, no_grad_names); + std::unordered_map> dup_output_ops; + // const unsigned uniq_id_local = uniq_id; + int op_id_offset = 0; + // Because it is a net op, it can static_cast. + auto& forwardNet = static_cast(forwardOp); + + for (auto& fwd : forwardNet.ops_) { + auto bwd = Backward(*fwd, no_grad_names); net->AddOp(bwd); - for (size_t i = 0; i < bwd.outputs_; ++i) { + for (size_t i = 0; i < bwd->outputs_.size(); ++i) { bwd->outputs_[i] += OperatorBase::EMPTY_VAR_NAME(); if (dup_output.find(bwd->inputs_[i]) == dup_output.end()) { dup_output[bwd->inputs_[i]] = 1; @@ -138,7 +139,7 @@ extern std::shared_ptr Backward( for (auto& name : no_grad_vars) { no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); } - int uid = 0; + size_t uid = 0; return BackwardImpl(forwardOp, no_grad_names, uid); } } // namespace framework