提交 cf84a6e8 编写于 作者: L Li Xinqi 提交者: Niu Chong

no out_diff then no backward node (#1250)

上级 31693ec1
......@@ -7,6 +7,41 @@
namespace oneflow {
namespace {
std::function<bool(const LogicalNode*)> MakePredicatorHasActualOutDiff(const LogicalGraph* graph) {
std::list<LogicalNode*> loss_nodes;
graph->ForEachNode([&](LogicalNode* node) {
if (dynamic_cast<LossLogicalNode*>(node)) { loss_nodes.push_back(node); }
});
auto nodes_have_actual_out_diff_ptr = std::make_shared<HashSet<const LogicalNode*>>();
auto HasBwConnection = [](const LogicalNode* prev, const LogicalNode* next) {
HashSet<LogicalBlobId> idbn_lbis;
for (const auto& idbn : next->SoleOp()->input_diff_bns()) {
idbn_lbis.insert(next->SoleOp()->BnInOp2Lbi(idbn));
}
for (const auto& odbn : prev->SoleOp()->output_diff_bns()) {
LogicalBlobId lbi = prev->SoleOp()->BnInOp2Lbi(odbn);
if (idbn_lbis.find(lbi) != idbn_lbis.end()) { return true; }
}
return false;
};
auto ForEachNext = [&](LogicalNode* node, const std::function<void(LogicalNode*)>& Handler) {
node->ForEachNodeOnInEdge([&](LogicalNode* in_node) {
if (HasBwConnection(in_node, node)) { Handler(in_node); }
});
};
graph->BfsForEachNode(loss_nodes, ForEachNext,
[nodes_have_actual_out_diff_ptr](LogicalNode* node) {
nodes_have_actual_out_diff_ptr->insert(node);
});
return [nodes_have_actual_out_diff_ptr](const LogicalNode* node) {
return nodes_have_actual_out_diff_ptr->find(node) != nodes_have_actual_out_diff_ptr->end();
};
}
} // namespace
LogicalGraph::LogicalGraph(bool is_train) {
BuildFwStruct();
if (is_train) { GroupNodesForReduceStruct(); }
......@@ -166,6 +201,7 @@ void LogicalGraph::BuildBwStruct() {
}
void LogicalGraph::NaiveBuildBwStruct() {
auto HasActualOutDiff = MakePredicatorHasActualOutDiff(this);
HashSet<LogicalNode*> nodes_need_bw;
TopoForEachNode([&](LogicalNode* logical_node) {
auto fw_node = dynamic_cast<ForwardLogicalNode*>(logical_node);
......@@ -175,7 +211,8 @@ void LogicalGraph::NaiveBuildBwStruct() {
return;
}
for (LogicalEdge* edge : fw_node->in_edges()) {
if (nodes_need_bw.find(edge->src_node()) != nodes_need_bw.end()) {
if (nodes_need_bw.find(edge->src_node()) != nodes_need_bw.end()
&& HasActualOutDiff(fw_node)) {
CHECK(nodes_need_bw.insert(fw_node).second);
return;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册