diff --git a/oneflow/core/graph/logical_graph.cpp b/oneflow/core/graph/logical_graph.cpp index b4404c066f0d4db28a49b8d40c0ee4c7dbc316de..77412f91be06f95c97bec0787e7950df76a04c45 100644 --- a/oneflow/core/graph/logical_graph.cpp +++ b/oneflow/core/graph/logical_graph.cpp @@ -117,38 +117,33 @@ void LogicalGraph::NaiveBuildFwStruct( void LogicalGraph::FixSharedModelNodes( const HashMap>& op_name2nodes) { + HashSet all_shared_model_op_names; const DLNetConf& dlnet_conf = Global::Get()->dlnet_conf(); for (const OpNameSet& op_name_set : dlnet_conf.shared_model_group()) { + std::vector shared_model_op_names(op_name_set.op_name().begin(), + op_name_set.op_name().end()); + SortAndRemoveDuplication(&shared_model_op_names); + CHECK_GE(shared_model_op_names.size(), 2); + auto shared_model_nodes = std::make_shared>(); - for (const std::string& op_name : op_name_set.op_name()) { + shared_model_nodes->reserve(shared_model_op_names.size()); + for (const std::string& op_name : shared_model_op_names) { + CHECK(all_shared_model_op_names.insert(op_name).second); CHECK_EQ(op_name2nodes.at(op_name).size(), 1); shared_model_nodes->push_back(op_name2nodes.at(op_name).front()); } - SortAndRemoveDuplication(shared_model_nodes.get()); + for (LogicalNode* cur_node : *shared_model_nodes) { cur_node->mut_shared_model_nodes() = shared_model_nodes; } + const std::string& shared_op_name = shared_model_nodes->front()->SoleOp()->op_name(); + const ParallelDesc* shared_parallel_desc = shared_model_nodes->front()->parallel_desc().get(); FOR_RANGE(size_t, i, 1, shared_model_nodes->size()) { shared_model_nodes->at(i)->SoleOp()->FixLbiWhenShareModel(shared_op_name); + CHECK(shared_model_nodes->at(i)->parallel_desc()->Equal(shared_parallel_desc)); } } - ForEachNode([&](LogicalNode* cur_node) { - if (cur_node->shared_model_nodes()) { - for (LogicalNode* shared_node : *(cur_node->shared_model_nodes())) { - if (shared_node->parallel_desc() == nullptr) { continue; } - if (cur_node->parallel_desc()) { - CHECK(cur_node->parallel_desc()->Equal(shared_node->parallel_desc().get())); - } else { - cur_node->mut_parallel_desc() = shared_node->parallel_desc(); - } - } - } else { - // do nothing - } - CHECK(cur_node->parallel_desc()) - << "Please set the placement of " << cur_node->SoleOp()->op_name(); - }); } void LogicalGraph::SetMainModelParallel() {