提交 d8811520 编写于 作者: C chengtbf

log for logical blob with bw op consumers


Former-commit-id: f015a245efa5e0791bb1e5070b217e3fec616fc9
上级 20301e8a
......@@ -120,6 +120,31 @@ void GenConnectedCheckpointingSubgraphs(
}
Maybe<void> CheckpointingPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
op_graph.TopoForEachNode([&](const OpNode* op_node) {
HashMap<std::string, HashSet<std::string>> lbn2bw_consumer_op_names;
for (const OpEdge* out_edge : op_node->out_edges()) {
bool is_bw_consumer = false;
const OpNode* out_node = out_edge->dst_node();
if (!IsForwardPassScope(Scope4OpNode(out_node))) { is_bw_consumer = true; }
for (const auto& lbi : out_edge->lbis()) {
std::string lbn = GenLogicalBlobName(lbi);
auto& bw_consumer_op_names = lbn2bw_consumer_op_names[lbn];
if (is_bw_consumer) { bw_consumer_op_names.insert(out_node->op().op_name()); }
}
}
for (const auto& pair : lbn2bw_consumer_op_names) {
int op_num = pair.second.size();
LOG(INFO) << "Checkpointing log: lbn = " << pair.first
<< ", bw_consumer_op_num = " << op_num;
if (op_num > 0) {
std::string log_str = "They are: {";
for (const auto& bw_op_name : pair.second) { log_str += bw_op_name + ","; }
log_str += "}";
LOG(INFO) << log_str;
}
}
});
// step 1. collect all checkpointing ops in forwardpass.
HashMap<std::string, const OpNode*> checkpointing_op_name2op_node;
CollectAllCheckpointingOpsInForwardPass(op_graph, &checkpointing_op_name2op_node);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册