未验证 提交 10bab9f1 编写于 作者: J Jiabin Yang 提交者: GitHub

fix edge count error (#40761)

上级 d2aaa751
...@@ -612,7 +612,9 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -612,7 +612,9 @@ std::vector<paddle::experimental::Tensor> RunBackward(
for (size_t i = 0; i < edges.size(); i++) { for (size_t i = 0; i < edges.size(); i++) {
for (size_t j = 0; j < edges[i].size(); j++) { for (size_t j = 0; j < edges[i].size(); j++) {
const Edge& edge = edges[i][j]; const Edge& edge = edges[i][j];
if (!edge.IsInitialized()) {
continue;
}
auto edge_rank = edge.GetEdgeRankInfo(); auto edge_rank = edge.GetEdgeRankInfo();
// Since we make edge has as same rank as bwd outputs, we indexing them // Since we make edge has as same rank as bwd outputs, we indexing them
// with // with
......
...@@ -63,6 +63,8 @@ void GradNodeBase::AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id) { ...@@ -63,6 +63,8 @@ void GradNodeBase::AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id) {
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo()); meta->OutRankInfo());
} else {
adj_edges_[slot_id].emplace_back();
} }
} }
} }
...@@ -85,6 +87,8 @@ void GradNodeBase::AddEdges(AutogradMeta* meta, size_t slot_id) { ...@@ -85,6 +87,8 @@ void GradNodeBase::AddEdges(AutogradMeta* meta, size_t slot_id) {
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo()); meta->OutRankInfo());
} else {
adj_edges_[slot_id].emplace_back();
} }
} }
......
...@@ -257,12 +257,22 @@ class Edge { ...@@ -257,12 +257,22 @@ class Edge {
} }
// Currently we use grad_node_ to identify if a edge is initialized. // Currently we use grad_node_ to identify if a edge is initialized.
bool IsInitialized() const { return grad_node_.get(); } bool IsInitialized() const {
if (!grad_node_) {
return false;
} else {
if (!(grad_node_.get())) {
return false;
} else {
return true;
}
}
}
private: private:
size_t in_slot_id_; size_t in_slot_id_;
size_t in_rank_; size_t in_rank_;
std::shared_ptr<GradNodeBase> grad_node_; std::shared_ptr<GradNodeBase> grad_node_{nullptr};
}; };
} // namespace egr } // namespace egr
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册