From 3d33bde22ce9840fccf15cadbc0888723788f656 Mon Sep 17 00:00:00 2001 From: Li Xinqi Date: Wed, 17 Feb 2021 18:30:34 +0800 Subject: [PATCH] Multi reentrant lock (#4225) * source subset tick * remove useless header files * insert DstSubsetTickOp * remove incorrect CHECK * add tick op for each machine * TryBindBnWithOneofRegst * add sink tick op in main_job * refactor LinkMainJob * fix typo in task_graph * refactor AddGlobalCriticalSection * rename and refactor DstSubsetTick::InferBlobDescs and SrcSubsetTick::InferBlobDescs * add src_subset_tick for input-output critical section * refactor AutoSourceTick and AutoSinkTick * vectorizedly link main job * resize vectorh identity_tick_op_names then access elements * SrcSubsetTickCompTaskNode: bind bns and in_regst if bns is valid in current device * refactor optional input to repeated inputs for SrcSubsetTickOpConf * fix a bug in CaseCompTaskNode; fix a bug when create identity tick in main_job * 1) Insert tick between sourc tick and src_subset_tick; 2) Insert tick between dst_subset_tick and sink tick * stash code * refactor MakeMainJob by using Range::ForEachSubRange * refactor MakeMainJob by using Range::ForEachSubRange * rename ReentrantLockLinkPoint to ReentrantLockBackEdge * set piece id for regst sent by wait_and_send_ids actor * callback_notifier_sink_tick Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- .../actor/wait_and_send_ids_compute_actor.cpp | 6 +- .../actor/wait_and_send_ids_compute_actor.h | 1 + oneflow/core/common/range.cpp | 13 +- oneflow/core/common/range.h | 4 + oneflow/core/graph_impl/logical_node.cpp | 6 + oneflow/core/job/oneflow.cpp | 269 +++++++++++------- oneflow/core/job_rewriter/autotick.cpp | 16 +- 7 files changed, 207 insertions(+), 108 deletions(-) diff --git a/oneflow/core/actor/wait_and_send_ids_compute_actor.cpp b/oneflow/core/actor/wait_and_send_ids_compute_actor.cpp index 059b0b7938..933a7b19be 100644 --- a/oneflow/core/actor/wait_and_send_ids_compute_actor.cpp +++ b/oneflow/core/actor/wait_and_send_ids_compute_actor.cpp @@ -24,6 +24,7 @@ void WaitAndSendIdsCompActor::VirtualCompActorInit(const TaskProto& task_proto) wait_and_send_ids_status_.in_id_ = 0; wait_and_send_ids_status_.out_idx_ = 0; wait_and_send_ids_status_.out_num_ = 0; + cur_piece_id_ = -1; OF_SET_MSG_HANDLER(&WaitAndSendIdsCompActor::HandlerWaitToStart); } @@ -36,7 +37,10 @@ void WaitAndSendIdsCompActor::Act() { void WaitAndSendIdsCompActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { if (wait_and_send_ids_status_.buffer_status_ == kBufferStatusSuccess) { - HandleProducedNaiveDataRegstToConsumer(); + HandleProducedNaiveDataRegstToConsumer([&](Regst* regst) { + regst->set_piece_id(++cur_piece_id_); + return true; + }); } } diff --git a/oneflow/core/actor/wait_and_send_ids_compute_actor.h b/oneflow/core/actor/wait_and_send_ids_compute_actor.h index 7ff141d7a6..20eef92c14 100644 --- a/oneflow/core/actor/wait_and_send_ids_compute_actor.h +++ b/oneflow/core/actor/wait_and_send_ids_compute_actor.h @@ -41,6 +41,7 @@ class WaitAndSendIdsCompActor final : public CompActor { int HandlerWaitToStart(const ActorMsg&); WaitAndSendIdsStatus wait_and_send_ids_status_; + int64_t cur_piece_id_; }; } // namespace oneflow diff --git a/oneflow/core/common/range.cpp b/oneflow/core/common/range.cpp index c694b74b59..a6ec5959ae 100644 --- a/oneflow/core/common/range.cpp +++ b/oneflow/core/common/range.cpp @@ -27,6 +27,17 @@ void Range::ToProto(RangeProto* ret) const { ret->set_end(end_); } +Maybe Range::ForEachSubRange( + int64_t sub_range_size, const std::function(const Range&)>& DoEachRange) const { + CHECK_EQ_OR_RETURN(size() % sub_range_size, 0); + int64_t start = begin(); + for (; start < end(); start += sub_range_size) { + JUST(DoEachRange(Range(start, start + sub_range_size))); + } + CHECK_EQ_OR_RETURN(start, end()); + return Maybe::Ok(); +} + Range FindIntersectant(const Range& lhs, const Range& rhs) { if (lhs.end() > rhs.begin() && rhs.end() > lhs.begin()) { int64_t left = lhs.begin() > rhs.begin() ? lhs.begin() : rhs.begin(); @@ -37,4 +48,4 @@ Range FindIntersectant(const Range& lhs, const Range& rhs) { } } -} // namespace oneflow \ No newline at end of file +} // namespace oneflow diff --git a/oneflow/core/common/range.h b/oneflow/core/common/range.h index 3781b1856f..6d8b8e3c64 100644 --- a/oneflow/core/common/range.h +++ b/oneflow/core/common/range.h @@ -17,6 +17,7 @@ limitations under the License. #define ONEFLOW_CORE_COMMON_RANGE_H_ #include "oneflow/core/common/util.h" +#include "oneflow/core/common/maybe.h" #include "oneflow/core/common/range.pb.h" namespace oneflow { @@ -41,6 +42,9 @@ class Range final { int64_t size() const { return end_ - begin_; } + Maybe ForEachSubRange(int64_t sub_range_size, + const std::function(const Range&)>& DoEachRange) const; + void ToProto(RangeProto* ret) const; private: diff --git a/oneflow/core/graph_impl/logical_node.cpp b/oneflow/core/graph_impl/logical_node.cpp index 826eb5c5f4..1666fa3241 100644 --- a/oneflow/core/graph_impl/logical_node.cpp +++ b/oneflow/core/graph_impl/logical_node.cpp @@ -189,6 +189,12 @@ BldSubTskGphMthd GetMthdForBldSubTskGph(const LogicalNode* src_node, const Logic std::shared_ptr src_pd = src_node->parallel_desc(); std::shared_ptr dst_pd = dst_node->parallel_desc(); if (src_node->op_vec().size() == 1 && dst_node->op_vec().size() == 1) { + if (src_node->SoleOp()->op_conf().has_wait_and_send_ids_conf() + && dst_node->SoleOp()->op_conf().has_reentrant_lock_conf()) { + CHECK_EQ(src_pd->parallel_num(), 1); + CHECK_EQ(dst_pd->parallel_num(), 1); + return &TaskGraph::BldSubTskGphByBoxing; + } if (src_node->SoleOp()->op_conf().has_record_load_conf() && dst_node->SoleOp()->op_conf().has_tick_conf()) { CHECK(src_pd->parallel_num() == dst_pd->parallel_num()); diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp index d6be2771db..c2b1c236fa 100644 --- a/oneflow/core/job/oneflow.cpp +++ b/oneflow/core/job/oneflow.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/range.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/control/ctrl_client.h" @@ -64,6 +65,20 @@ bool operator==(const ParallelBlobConf& lhs, const ParallelBlobConf& rhs) { namespace { +// There are circles in MainJob. +// A MainJob is a Job like: +// +// wait_and_send_ids_op -> reentrant_lock_op -> case_op -> identity_op -> esac_op -> +// \________________________________________________/ +// +// back edges esac_op -> reentrant_lock_op are linked by rewriting the plan instead of +// compiling OpGraph to TaskGraph. +// ReentrantLockBackEdge holds the key information of a back edge +struct ReentrantLockBackEdge { + std::string reentrant_lock_op_name; // back edge destination. + LogicalBlobId critical_section_sink_lbi; // back edge source. +}; + std::string cluster_thrd_ids_key(const std::string& plan_name) { return plan_name + "_cluster_thrd_ids"; } @@ -84,6 +99,15 @@ std::string block7chunk_key(const std::string& plan_name, int64_t machine_id) { return plan_name + "_" + std::to_string(machine_id) + "_block7chunk"; } +std::shared_ptr CreateSinkTickOpConf(const std::string& in_op_name) { + auto tick_op = std::make_shared(); + tick_op->set_name("System-Main-CallbackNotifier_TmpSinkTick_" + NewUniqueId()); + auto* tick_conf = tick_op->mutable_sink_tick_conf(); + tick_conf->add_tick(in_op_name + "/out"); + tick_conf->set_out("out"); + return tick_op; +} + void PushPlan(const std::string& plan_name, const Plan& plan) { HashMap> machine_id2thrd_id_set; HashMap, std::vector> mchn_thrd_id2task_protos; @@ -392,7 +416,7 @@ void UpdateSoleObnRegstDescId(TaskProto* task) { // return: // op_A --> op_identity_tick --> op_C --> op_D --> op_E --> op_sink_tick --> op_B // / -// op_src_tick ---/ +// op_src_tick -->/ // // note: after this function called, op_src_tick is illegal and need to be deleted from plan void LinkTickTaskProto(TaskProto* identity_tick, TaskProto* src_tick, TaskProto* sink_tick) { @@ -430,7 +454,7 @@ void FixRegstHostMemCase(TaskProto* task_proto, } void LinkMainPlan(Plan* plan, const Plan& main_plan, - const std::vector>& identity_tick_op_names) { + const std::vector>& identity_tick_op_names) { std::function IsInterfaceTickTockTask; { auto task_ids = std::make_shared>(); @@ -582,40 +606,27 @@ void CheckNonDistributeOptimizerAvailable(const std::vector } } -void MakeMainJob(Job* main_job, std::vector>* identity_tick_op_names, - LogicalBlobId* critical_section_sink_lbi) { - JobBuilder job_builder(main_job); - CHECK(Global::Get()->IsThisMachineMaster()); - std::vector op_confs; - OperatorConf wait_and_send_ids_op_conf; - { - wait_and_send_ids_op_conf.set_name(std::string("System-Main-WaitAndSendIds_") + NewUniqueId()); - auto* wait_and_send_ids_conf = wait_and_send_ids_op_conf.mutable_wait_and_send_ids_conf(); - wait_and_send_ids_conf->set_out("out"); - wait_and_send_ids_conf->set_wait_buffer_name(kBufferNameGlobalWaitJobId); - wait_and_send_ids_conf->set_data_type(DataType::kInt32); - auto* id_list = wait_and_send_ids_conf->mutable_id_list(); - FOR_RANGE(int32_t, i, 0, Global::Get()->size()) { id_list->Add(); } - HashSet unique_check; - for (const auto& pair : *Global::Get()) { - int64_t job_id = pair.second; - CHECK(unique_check.insert(job_id).second); - const auto& cs_idx = Global::Get()->CriticalSectionIds4JobId(job_id); - *id_list->Mutable(job_id)->mutable_value() = {cs_idx.begin(), cs_idx.end()}; - } - } - op_confs.push_back(wait_and_send_ids_op_conf); +Maybe MakeMainJobComponent( + const std::string& wait_and_send_ids_lbn, const Range& machine_id_range, + JobBuilder* job_builder, std::vector>* identity_tick_op_names, + std::vector>* cb_sink_tick_op_names) { + ParallelConf parallel_conf; + parallel_conf.set_device_tag("cpu"); + parallel_conf.add_device_name(std::to_string(machine_id_range.begin()) + ":0"); + auto lock_back_edge = std::make_shared(); OperatorConf reentrant_lock_op_conf; { - reentrant_lock_op_conf.set_name(std::string("System-Main-ReentrantLock_") + NewUniqueId()); + lock_back_edge->reentrant_lock_op_name = + std::string("System-Main-ReentrantLock_") + NewUniqueId(); + reentrant_lock_op_conf.set_name(lock_back_edge->reentrant_lock_op_name); auto* reentrant_lock_conf = reentrant_lock_op_conf.mutable_reentrant_lock_conf(); - reentrant_lock_conf->set_start(wait_and_send_ids_op_conf.name() + "/out"); + reentrant_lock_conf->set_start(wait_and_send_ids_lbn); // ibn "end" is set after plan generated because we don't like cycle in job reentrant_lock_conf->set_out("out"); Global::Get()->DumpCriticalSectionId2IntersectinIds( reentrant_lock_conf->mutable_lock_id2intersecting_lock_ids()); + JUST(job_builder->AddOp(parallel_conf, reentrant_lock_op_conf)); } - op_confs.push_back(reentrant_lock_op_conf); // critical section case op conf OperatorConf cs_case_op_conf; { @@ -625,47 +636,60 @@ void MakeMainJob(Job* main_job, std::vector>* iden FOR_RANGE(int64_t, i, 0, Global::Get()->CriticalSectionNum()) { cs_case_conf->add_out(GenRepeatedBn("out", i)); } + JUST(job_builder->AddOp(parallel_conf, cs_case_op_conf)); } - op_confs.push_back(cs_case_op_conf); + const int64_t num_critial_sections = Global::Get()->CriticalSectionNum(); std::vector snk_tick_op_names; - int64_t num_critial_sections = Global::Get()->CriticalSectionNum(); - identity_tick_op_names->resize(num_critial_sections); - int64_t num_machines = Global::Get()->TotalMachineNum(); FOR_RANGE(int64_t, i, 0, num_critial_sections) { // source tick OperatorConf src_tick_op_conf; { std::string name_prefix = "System-Main-SourceTick_CriticalSection_"; - src_tick_op_conf.set_name(name_prefix + std::to_string(i)); + src_tick_op_conf.set_name(name_prefix + std::to_string(i) + "_" + NewUniqueId()); auto* src_tick_conf = src_tick_op_conf.mutable_tick_conf(); src_tick_conf->add_tick(cs_case_op_conf.name() + "/" + GenRepeatedBn("out", i)); src_tick_conf->set_out("out"); - op_confs.push_back(src_tick_op_conf); + JUST(job_builder->AddOp(parallel_conf, src_tick_op_conf)); } // identity tick - auto* cur_id_tick_op_names = &identity_tick_op_names->at(i); - for (int64_t machine_id = 0; machine_id < num_machines; ++machine_id) { + auto* cur_cb_sink_tick_op_names = &cb_sink_tick_op_names->at(i); + for (int64_t machine_id = machine_id_range.begin(); machine_id < machine_id_range.end(); + ++machine_id) { OperatorConf identity_tick_op_conf; - std::string name_prefix = "System-Main-Tick_CriticalSection_"; - identity_tick_op_conf.set_name(name_prefix + std::to_string(i) + "_" - + std::to_string(machine_id)); - auto* identity_tick_conf = identity_tick_op_conf.mutable_tick_conf(); - identity_tick_conf->add_tick(src_tick_op_conf.name() + "/out"); - identity_tick_conf->set_out("out"); - op_confs.push_back(identity_tick_op_conf); - CHECK(cur_id_tick_op_names->emplace(machine_id, identity_tick_op_conf.name()).second); + { + std::string name_prefix = "System-Main-Tick_CriticalSection_"; + identity_tick_op_conf.set_name(name_prefix + std::to_string(i) + "_" + NewUniqueId()); + auto* identity_tick_conf = identity_tick_op_conf.mutable_tick_conf(); + identity_tick_conf->add_tick(src_tick_op_conf.name() + "/out"); + identity_tick_conf->set_out("out"); + JUST(job_builder->AddOp(parallel_conf, identity_tick_op_conf)); + auto* cur_id_tick_op_names = &identity_tick_op_names->at(i); + CHECK_OR_RETURN( + cur_id_tick_op_names->emplace(machine_id, identity_tick_op_conf.name()).second); + } + { + OperatorConf cb_sink_tick_op_conf; + std::string name_prefix = "System-Main-CallbackSinkTick_"; + cb_sink_tick_op_conf.set_name(name_prefix + std::to_string(i) + NewUniqueId()); + auto* cb_sink_tick_conf = cb_sink_tick_op_conf.mutable_sink_tick_conf(); + cb_sink_tick_conf->add_tick(identity_tick_op_conf.name() + "/out"); + cb_sink_tick_conf->set_out("out"); + JUST(job_builder->AddOp(parallel_conf, cb_sink_tick_op_conf)); + CHECK_OR_RETURN( + cur_cb_sink_tick_op_names->emplace(machine_id, cb_sink_tick_op_conf.name()).second); + } } // sink tick - OperatorConf snk_tick_op_conf; { + OperatorConf snk_tick_op_conf; std::string name_prefix = "System-Main-SinkTick_CriticalSection_"; - snk_tick_op_conf.set_name(name_prefix + std::to_string(i)); + snk_tick_op_conf.set_name(name_prefix + std::to_string(i) + NewUniqueId()); auto* snk_tick_conf = snk_tick_op_conf.mutable_sink_tick_conf(); - for (const auto& pair : *cur_id_tick_op_names) { + for (const auto& pair : *cur_cb_sink_tick_op_names) { snk_tick_conf->add_tick(pair.second + "/out"); } snk_tick_conf->set_out("out"); - op_confs.push_back(snk_tick_op_conf); + JUST(job_builder->AddOp(parallel_conf, snk_tick_op_conf)); snk_tick_op_names.push_back(snk_tick_op_conf.name()); } } @@ -679,32 +703,89 @@ void MakeMainJob(Job* main_job, std::vector>* iden } cs_esac_conf->set_out("out"); cs_esac_conf->set_data_type(DataType::kInt32); + JUST(job_builder->AddOp(parallel_conf, cs_esac_op_conf)); + } + lock_back_edge->critical_section_sink_lbi.set_op_name(cs_esac_op_conf.name()); + lock_back_edge->critical_section_sink_lbi.set_blob_name("out"); + return lock_back_edge; +} + +Maybe MakeCallbackNotifierSinkTick( + const Range& machine_id_range, + const std::vector>& cb_sink_tick_op_names, + JobBuilder* job_builder, const std::function& DoEachSinkTickLbn) { + ParallelConf parallel_conf; + parallel_conf.set_device_tag("cpu"); + parallel_conf.add_device_name("0:0"); + for (int64_t total_job_cs_id : + Global::Get()->job_id2total_job_critical_section_id()) { + OperatorConf snk_tick_op_conf; + { + std::string name_prefix = "System-Main-CallbackNotifier_CriticalSection_"; + snk_tick_op_conf.set_name(name_prefix + std::to_string(total_job_cs_id)); + auto* snk_tick_conf = snk_tick_op_conf.mutable_sink_tick_conf(); + for (int64_t machine_id = machine_id_range.begin(); machine_id < machine_id_range.end(); + ++machine_id) { + const auto& cb_sink_tick_op_name = cb_sink_tick_op_names.at(total_job_cs_id).at(machine_id); + snk_tick_conf->add_tick(cb_sink_tick_op_name + "/out"); + } + snk_tick_conf->set_out("out"); + JUST(job_builder->AddOp(parallel_conf, snk_tick_op_conf)); + } + DoEachSinkTickLbn(snk_tick_op_conf.name() + "/out"); } - op_confs.push_back(cs_esac_op_conf); + return Maybe::Ok(); +} + +Maybe MakeMainJob(Job* main_job, + std::vector>* identity_tick_op_names, + std::vector* lock_back_edges) { + JobBuilder job_builder(main_job); + CHECK_OR_RETURN(Global::Get()->IsThisMachineMaster()); + ParallelConf parallel_conf; + parallel_conf.set_device_tag("cpu"); + parallel_conf.add_device_name("0:0"); + OperatorConf wait_and_send_ids_op_conf; + { + wait_and_send_ids_op_conf.set_name(std::string("System-Main-WaitAndSendIds_") + NewUniqueId()); + auto* wait_and_send_ids_conf = wait_and_send_ids_op_conf.mutable_wait_and_send_ids_conf(); + wait_and_send_ids_conf->set_out("out"); + wait_and_send_ids_conf->set_wait_buffer_name(kBufferNameGlobalWaitJobId); + wait_and_send_ids_conf->set_data_type(DataType::kInt32); + auto* id_list = wait_and_send_ids_conf->mutable_id_list(); + FOR_RANGE(int32_t, i, 0, Global::Get()->size()) { id_list->Add(); } + HashSet unique_check; + for (const auto& pair : *Global::Get()) { + int64_t job_id = pair.second; + CHECK_OR_RETURN(unique_check.insert(job_id).second); + const auto& cs_idx = Global::Get()->CriticalSectionIds4JobId(job_id); + *id_list->Mutable(job_id)->mutable_value() = {cs_idx.begin(), cs_idx.end()}; + } + JUST(job_builder.AddOp(parallel_conf, wait_and_send_ids_op_conf)); + } + const int64_t num_critial_sections = Global::Get()->CriticalSectionNum(); + std::vector> cb_sink_tick_op_names; + identity_tick_op_names->resize(num_critial_sections); + cb_sink_tick_op_names.resize(num_critial_sections); + const int64_t num_machines = Global::Get()->TotalMachineNum(); + const Range machine_id_range(0, num_machines); + JUST(machine_id_range.ForEachSubRange(1, [&](const Range& sub_range) -> Maybe { + const auto& in_lbn = wait_and_send_ids_op_conf.name() + "/out"; + lock_back_edges->push_back(*JUST(MakeMainJobComponent( + in_lbn, sub_range, &job_builder, identity_tick_op_names, &cb_sink_tick_op_names))); + return Maybe::Ok(); + })); OperatorConf callback_notify_esac_op_conf; { callback_notify_esac_op_conf.set_name(std::string("System-Main-Esac_") + NewUniqueId()); auto* callback_notify_esac_conf = callback_notify_esac_op_conf.mutable_esac_conf(); - for (int64_t total_job_cs_id : - Global::Get()->job_id2total_job_critical_section_id()) { - OperatorConf snk_tick_op_conf; - { - std::string name_prefix = "System-Main-CallbackNotifier_CriticalSection_"; - snk_tick_op_conf.set_name(name_prefix + std::to_string(total_job_cs_id)); - auto* snk_tick_conf = snk_tick_op_conf.mutable_sink_tick_conf(); - for (int64_t machine_id = 0; machine_id < num_machines; ++machine_id) { - const auto& id_tick_op_name = identity_tick_op_names->at(total_job_cs_id).at(machine_id); - snk_tick_conf->add_tick(id_tick_op_name + "/out"); - } - snk_tick_conf->set_out("out"); - op_confs.push_back(snk_tick_op_conf); - } - callback_notify_esac_conf->add_in(snk_tick_op_conf.name() + "/out"); - } + JUST(MakeCallbackNotifierSinkTick( + machine_id_range, cb_sink_tick_op_names, &job_builder, + [&](const std::string& lbn) { callback_notify_esac_conf->add_in(lbn); })); callback_notify_esac_conf->set_out("out"); callback_notify_esac_conf->set_data_type(DataType::kInt32); + JUST(job_builder.AddOp(parallel_conf, callback_notify_esac_op_conf)); } - op_confs.push_back(callback_notify_esac_op_conf); OperatorConf callback_notify_op_conf; { callback_notify_op_conf.set_name(std::string("System-Main-CallbackNotify_") + NewUniqueId()); @@ -717,42 +798,37 @@ void MakeMainJob(Job* main_job, std::vector>* iden const auto& buffer_name = GetCallbackNotifierBufferName(pair.first); *buffer_names->Mutable(job_id) = buffer_name; } + JUST(job_builder.AddOp(parallel_conf, callback_notify_op_conf)); } - op_confs.push_back(callback_notify_op_conf); - - critical_section_sink_lbi->set_op_name(cs_esac_op_conf.name()); - critical_section_sink_lbi->set_blob_name("out"); - ParallelConf parallel_conf; - parallel_conf.set_device_tag("cpu"); - parallel_conf.add_device_name("0:0"); - job_builder.AddOps(parallel_conf, op_confs); auto* job_conf = main_job->mutable_job_conf(); job_conf->set_job_name("MainJob-unamed"); job_conf->mutable_predict_conf(); job_conf->set_default_data_type(DataType::kInt32); + return Maybe::Ok(); } -void ConnectCriticalSectionEndToReentrantLockEnd(Plan* main_plan, - const LogicalBlobId& critical_section_sink_lbi) { +Maybe ConnectCriticalSectionEndToReentrantLockEnd( + Plan* main_plan, const ReentrantLockBackEdge& lock_back_edge) { TaskProto* reentrant_lock_task = nullptr; TaskProto* cs_sink_task = nullptr; FOR_RANGE(int64_t, i, 0, main_plan->task_size()) { auto* task = main_plan->mutable_task(i); - CHECK_EQ(task->exec_sequence().exec_node_size(), 1); - if (task->task_type() == TaskType::kReentrantLock) { - CHECK_ISNULL(reentrant_lock_task); + CHECK_EQ_OR_RETURN(task->exec_sequence().exec_node_size(), 1); + const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf(); + const auto& op_name = kernel_conf.op_attribute().op_conf().name(); + if (op_name == lock_back_edge.reentrant_lock_op_name) { + CHECK_ISNULL_OR_RETURN(reentrant_lock_task); reentrant_lock_task = task; + } else if (op_name == lock_back_edge.critical_section_sink_lbi.op_name()) { + CHECK_ISNULL_OR_RETURN(cs_sink_task); + cs_sink_task = task; } else { - const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf(); - if (critical_section_sink_lbi.op_name() == kernel_conf.op_attribute().op_conf().name()) { - CHECK_ISNULL(cs_sink_task); - cs_sink_task = task; - } + // do nothing } } - CHECK_NOTNULL(reentrant_lock_task); - CHECK_NOTNULL(cs_sink_task); + CHECK_NOTNULL_OR_RETURN(reentrant_lock_task); + CHECK_NOTNULL_OR_RETURN(cs_sink_task); RegstDescProto* cs_end_regst = PlanUtil::GetSoleProducedDataRegst(cs_sink_task); cs_end_regst->add_consumer_task_id(reentrant_lock_task->task_id()); reentrant_lock_task->mutable_consumed_regst_desc_id()->at("in").add_regst_desc_id( @@ -764,20 +840,23 @@ void ConnectCriticalSectionEndToReentrantLockEnd(Plan* main_plan, auto* op_attribute = reentrant_exec_node->mutable_kernel_conf()->mutable_op_attribute(); op_attribute->add_input_bns("end"); (*op_attribute->mutable_arg_signature()->mutable_bn_in_op2lbi())["end"] = - critical_section_sink_lbi; + lock_back_edge.critical_section_sink_lbi; auto* reentrant_lock_conf = op_attribute->mutable_op_conf()->mutable_reentrant_lock_conf(); - reentrant_lock_conf->set_end(GenLogicalBlobName(critical_section_sink_lbi)); + reentrant_lock_conf->set_end(GenLogicalBlobName(lock_back_edge.critical_section_sink_lbi)); + return Maybe::Ok(); } -Maybe CompileMainJob(Job* main_job, const LogicalBlobId& critical_section_sink_lbi, +Maybe CompileMainJob(Job* main_job, const std::vector& lock_back_edges, int64_t job_id, Plan* main_plan) { - CHECK(Global::Get()->IsThisMachineMaster()); + CHECK_OR_RETURN(Global::Get()->IsThisMachineMaster()); { auto scope = std::make_unique(main_job->job_conf(), job_id); JUST(CompileCurJobOnMaster(main_job, main_plan, false)); } - ConnectCriticalSectionEndToReentrantLockEnd(main_plan, critical_section_sink_lbi); + for (const auto& lock_back_edge : lock_back_edges) { + JUST(ConnectCriticalSectionEndToReentrantLockEnd(main_plan, lock_back_edge)); + } return Maybe::Ok(); } @@ -1011,13 +1090,13 @@ Maybe CompileAndMergePlanOnMaster(const PbRpf& conf_jobs, Plan* plan) PlanUtil::SetForceInplaceMemBlock(plan); FinishGlobalCriticalSectionDesc(*plan, jobs.size()); Plan main_plan; - std::vector> identity_tick_op_names; + std::vector> identity_tick_op_names; { Job main_job; - LogicalBlobId critical_section_sink_lbi; - MakeMainJob(&main_job, &identity_tick_op_names, &critical_section_sink_lbi); + std::vector lock_back_edges; + JUST(MakeMainJob(&main_job, &identity_tick_op_names, &lock_back_edges)); AddJobName2JobId(main_job.job_conf().job_name(), jobs.size()); - JUST(CompileMainJob(&main_job, critical_section_sink_lbi, sub_plans.size(), &main_plan)); + JUST(CompileMainJob(&main_job, lock_back_edges, sub_plans.size(), &main_plan)); } LinkMainPlan(plan, main_plan, identity_tick_op_names); PlanUtil::CleanUselessMemBlockAndCheckValid(plan); diff --git a/oneflow/core/job_rewriter/autotick.cpp b/oneflow/core/job_rewriter/autotick.cpp index d6083c3e24..f87fc16658 100644 --- a/oneflow/core/job_rewriter/autotick.cpp +++ b/oneflow/core/job_rewriter/autotick.cpp @@ -94,9 +94,6 @@ Maybe CreateDstSubsetTickAndSinkTicks(CriticalSection* critical_section, JUST(BuildDstSubsetTickOpAndParallelConf(tick_lbis, &dst_subset_tick, job_builder)); int64_t num_machines = Global::Get()->TotalMachineNum(); auto* map = critical_section->mutable_machine_id2sink_tick_op_name(); - ParallelConf cpu0_parallel_conf; - cpu0_parallel_conf.set_device_tag("cpu"); - cpu0_parallel_conf.add_device_name("0:0"); for (int64_t machine_id = 0; machine_id < num_machines; ++machine_id) { ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); @@ -116,7 +113,7 @@ Maybe CreateDstSubsetTickAndSinkTicks(CriticalSection* critical_section, auto* sink_tick_conf = sink_tick_op.mutable_sink_tick_conf(); sink_tick_conf->add_tick(tick_op.name() + "/out"); sink_tick_conf->set_out("out"); - JUST(job_builder->AddOp(cpu0_parallel_conf, sink_tick_op)); + JUST(job_builder->AddOp(parallel_conf, sink_tick_op)); } (*map)[machine_id] = sink_tick_op.name(); } @@ -142,15 +139,15 @@ Maybe CreateSourceTicksAndSrcSubsetTick(CriticalSection* critical_section, JobBuilder* job_builder) { int64_t num_machines = Global::Get()->TotalMachineNum(); auto* map = critical_section->mutable_machine_id2source_tick_op_name(); - ParallelConf cpu0_parallel_conf; - cpu0_parallel_conf.set_device_tag("cpu"); - cpu0_parallel_conf.add_device_name("0:0"); for (int64_t machine_id = 0; machine_id < num_machines; ++machine_id) { + ParallelConf parallel_conf; + parallel_conf.set_device_tag("cpu"); + parallel_conf.add_device_name(std::to_string(machine_id) + ":0"); OperatorConf src_tick_op; { src_tick_op.set_name("System-AutoTick-SourceTick_" + NewUniqueId()); src_tick_op.mutable_source_tick_conf()->set_out("out"); - JUST(job_builder->AddOp(cpu0_parallel_conf, src_tick_op)); + JUST(job_builder->AddOp(parallel_conf, src_tick_op)); } (*map)[machine_id] = src_tick_op.name(); OperatorConf tick_op; @@ -158,9 +155,6 @@ Maybe CreateSourceTicksAndSrcSubsetTick(CriticalSection* critical_section, tick_op.set_name("System-AutoTick-Tick_" + NewUniqueId()); tick_op.mutable_tick_conf()->add_tick(src_tick_op.name() + "/out"); tick_op.mutable_tick_conf()->set_out("out"); - ParallelConf parallel_conf; - parallel_conf.set_device_tag("cpu"); - parallel_conf.add_device_name(std::to_string(machine_id) + ":0"); JUST(job_builder->AddOp(parallel_conf, tick_op)); } src_subset_tick_op->mutable_src_subset_tick_conf()->add_in(tick_op.name() + "/out"); -- GitLab