未验证 提交 10bab9f1 编写于 作者: J Jiabin Yang 提交者: GitHub

fix edge count error (#40761)

上级 d2aaa751
......@@ -612,7 +612,9 @@ std::vector<paddle::experimental::Tensor> RunBackward(
for (size_t i = 0; i < edges.size(); i++) {
for (size_t j = 0; j < edges[i].size(); j++) {
const Edge& edge = edges[i][j];
if (!edge.IsInitialized()) {
continue;
}
auto edge_rank = edge.GetEdgeRankInfo();
// Since we make edge has as same rank as bwd outputs, we indexing them
// with
......
......@@ -63,6 +63,8 @@ void GradNodeBase::AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id) {
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo());
} else {
adj_edges_[slot_id].emplace_back();
}
}
}
......@@ -85,6 +87,8 @@ void GradNodeBase::AddEdges(AutogradMeta* meta, size_t slot_id) {
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo());
} else {
adj_edges_[slot_id].emplace_back();
}
}
......
......@@ -257,12 +257,22 @@ class Edge {
}
// Currently we use grad_node_ to identify if a edge is initialized.
bool IsInitialized() const { return grad_node_.get(); }
bool IsInitialized() const {
if (!grad_node_) {
return false;
} else {
if (!(grad_node_.get())) {
return false;
} else {
return true;
}
}
}
private:
size_t in_slot_id_;
size_t in_rank_;
std::shared_ptr<GradNodeBase> grad_node_;
std::shared_ptr<GradNodeBase> grad_node_{nullptr};
};
} // namespace egr
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册