From 10bab9f1c85ee11004c7611e400d45b540901abc Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Tue, 22 Mar 2022 13:30:57 +0800 Subject: [PATCH] fix edge count error (#40761) --- paddle/fluid/eager/backward.cc | 4 +++- paddle/fluid/eager/grad_node_info.cc | 4 ++++ paddle/fluid/eager/grad_node_info.h | 14 ++++++++++++-- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index ebd3333c526..0e9dc19c2e3 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -612,7 +612,9 @@ std::vector RunBackward( for (size_t i = 0; i < edges.size(); i++) { for (size_t j = 0; j < edges[i].size(); j++) { const Edge& edge = edges[i][j]; - + if (!edge.IsInitialized()) { + continue; + } auto edge_rank = edge.GetEdgeRankInfo(); // Since we make edge has as same rank as bwd outputs, we indexing them // with diff --git a/paddle/fluid/eager/grad_node_info.cc b/paddle/fluid/eager/grad_node_info.cc index 891ad4d8983..1d44d842b08 100644 --- a/paddle/fluid/eager/grad_node_info.cc +++ b/paddle/fluid/eager/grad_node_info.cc @@ -63,6 +63,8 @@ void GradNodeBase::AddEdges(std::vector* metas, size_t slot_id) { adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), meta->OutRankInfo()); + } else { + adj_edges_[slot_id].emplace_back(); } } } @@ -85,6 +87,8 @@ void GradNodeBase::AddEdges(AutogradMeta* meta, size_t slot_id) { adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), meta->OutRankInfo()); + } else { + adj_edges_[slot_id].emplace_back(); } } diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index 4b21a193ee0..28c12717a24 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -257,12 +257,22 @@ class Edge { } // Currently we use grad_node_ to identify if a edge is initialized. - bool IsInitialized() const { return grad_node_.get(); } + bool IsInitialized() const { + if (!grad_node_) { + return false; + } else { + if (!(grad_node_.get())) { + return false; + } else { + return true; + } + } + } private: size_t in_slot_id_; size_t in_rank_; - std::shared_ptr grad_node_; + std::shared_ptr grad_node_{nullptr}; }; } // namespace egr -- GitLab