diff --git a/oneflow/core/graph/boxing_task_node.cpp b/oneflow/core/graph/boxing_task_node.cpp index a97988c22c58063820d0b2df52a4bf179e81a3af..e873efe883e235bd899dc978a9df3ed96fa8bf2c 100644 --- a/oneflow/core/graph/boxing_task_node.cpp +++ b/oneflow/core/graph/boxing_task_node.cpp @@ -143,9 +143,9 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair( CHECK_EQ(lbns.size(), 1); lbns.clear(); auto in_regst_0 = GetRelatedRegst(sorted_in_edges.at(0)); - for (const auto& pair : in_regst_0->lbn2shape()) { - lbns.push_back(pair.first); - } + in_regst_0->ForEachLbn([&](const std::string& lbn) { + lbns.push_back(lbn); + }); } // Enroll Lbn auto middle_regst = GetProducedRegstDesc("middle"); diff --git a/oneflow/core/graph/model_save_comp_task_node.cpp b/oneflow/core/graph/model_save_comp_task_node.cpp index 9bf0797454826e6f0479b94bfcad410805c94b7f..5f0c680a5605cc85a7fa1f4c36932fd0f07488da 100644 --- a/oneflow/core/graph/model_save_comp_task_node.cpp +++ b/oneflow/core/graph/model_save_comp_task_node.cpp @@ -16,9 +16,9 @@ void MdSaveCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) { OperatorConf op_conf; op_conf.set_name("model_save_op" + updt_task->node_id_str()); op_conf.mutable_model_save_conf(); - for (const auto& pair : GetRelatedRegst(SoleInEdge())->lbn2shape()) { - op_conf.mutable_model_save_conf()->add_lbns(pair.first); - } + GetRelatedRegst(SoleInEdge())->ForEachLbn([&](const std::string& lbn) { + op_conf.mutable_model_save_conf()->add_lbns(lbn); + }); ExecNode* exec_node = mut_exec_gph().NewNode(); exec_node->mut_op() = OpMgr::Singleton().ConstructOp(op_conf); diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index f940b4d60b3eaf44964dcffdbe854e40c213e561..f3df43834b4ad1efbc94ee1a12daa8b9da426d76 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -77,7 +77,7 @@ void TaskNode::TakeOverRegstDesc(TaskNode* rhs, void TaskNode::EraseProducedEmptyRegsts() { EraseIf> (&produced_regst_descs_, [] (HashMap>::iterator it) { - return it->second->lbn2shape().empty(); + return it->second->NumOfLbn() == 0; }); } diff --git a/oneflow/core/register/register_desc.cpp b/oneflow/core/register/register_desc.cpp index 56585eab85c6fc0a776ede68c74f39551464d30b..5813a487d5c5890e754935ef220df31378696fe6 100644 --- a/oneflow/core/register/register_desc.cpp +++ b/oneflow/core/register/register_desc.cpp @@ -68,13 +68,10 @@ Shape* RegstDesc::GetMutShapePtr(const std::string& lbn) { return lbn2shape_.at(lbn).get(); } -HashMap>& RegstDesc::mut_lbn2shape() { - return lbn2shape_; -} - -const HashMap>& -RegstDesc::lbn2shape() const { - return lbn2shape_; +void RegstDesc::ForEachLbn(std::function func) const { + for (const auto& p : lbn2shape_) { + func(p.first); + } } void RegstDesc::EraseZeroSizeBlob() { diff --git a/oneflow/core/register/register_desc.h b/oneflow/core/register/register_desc.h index aad76803718ba78faf31a35752974a0e2c15ca5a..5d4b7723a30e7e6e2575f2a8a6f36364b71ec6c6 100644 --- a/oneflow/core/register/register_desc.h +++ b/oneflow/core/register/register_desc.h @@ -33,8 +33,8 @@ class RegstDesc final { void EnrollLbn(const std::string& lbn); const Shape& GetShape(const std::string& lbn) const; Shape* GetMutShapePtr(const std::string& lbn); - HashMap>& mut_lbn2shape(); - const HashMap>& lbn2shape() const; + void ForEachLbn(std::function func) const; + size_t NumOfLbn() const { return lbn2shape_.size(); } // void EraseZeroSizeBlob();