From b21c190b44fefd52bc2ab4ccf09dd5c3fd112eb9 Mon Sep 17 00:00:00 2001 From: Jinhui Yuan Date: Mon, 20 Aug 2018 14:55:59 +0800 Subject: [PATCH] Partition plan (#1138) * fix typo * rm useless IsThisMachineMaster * refine the var name of naive_plan, mem_shared_plan, improved_plan * refactor PushPlan and PullPlan * let master node broadcast subplans instead the whole plan * remove useless code * rm useless code * use total_mbn_name_key --- oneflow/core/job/improver.cpp | 4 +- oneflow/core/job/oneflow.cpp | 90 +++++++++++++++++++++++++++------ oneflow/core/job/sub_plan.proto | 16 ++++++ 3 files changed, 93 insertions(+), 17 deletions(-) create mode 100644 oneflow/core/job/sub_plan.proto diff --git a/oneflow/core/job/improver.cpp b/oneflow/core/job/improver.cpp index d9dbbf9f5d..6c3e533644 100644 --- a/oneflow/core/job/improver.cpp +++ b/oneflow/core/job/improver.cpp @@ -283,7 +283,7 @@ void CollectTailRegstConsumerTaskIds(const std::vector& s void CollectSinkTaskIds(const HashSet& task_ids, const std::function& IsReachable, std::list* sink_task_ids) { - auto IsReachableToAnyOherTask = [&](int64_t src_task_id) -> bool { + auto IsReachableToAnyOtherTask = [&](int64_t src_task_id) -> bool { for (int64_t dst_task_id : task_ids) { if (src_task_id == dst_task_id) { continue; } if (IsReachable(src_task_id, dst_task_id)) { return true; } @@ -292,7 +292,7 @@ void CollectSinkTaskIds(const HashSet& task_ids, }; sink_task_ids->clear(); for (int64_t src_task_id : task_ids) { - if (!IsReachableToAnyOherTask(src_task_id)) { sink_task_ids->push_back(src_task_id); } + if (!IsReachableToAnyOtherTask(src_task_id)) { sink_task_ids->push_back(src_task_id); } } } diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp index 4ec284bb98..b7567326c7 100644 --- a/oneflow/core/job/oneflow.cpp +++ b/oneflow/core/job/oneflow.cpp @@ -6,6 +6,7 @@ #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/machine_context.h" #include "oneflow/core/job/profiler.h" +#include "oneflow/core/job/sub_plan.pb.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/job/runtime.h" #include "oneflow/core/job/available_memory_desc.pb.h" @@ -76,6 +77,61 @@ void FixCpuDeviceNum() { Global::Get()->SetCpuDeviceNum(cpu_device_num); } +std::string sub_plan_key(const std::string& plan_name, int64_t machine_id, int64_t thrd_id) { + return plan_name + "_" + std::to_string(machine_id) + "_" + std::to_string(thrd_id); +} + +std::string total_mbn_num_key(const std::string& plan_name) { return plan_name + "_total_mbn_num"; } + +void PushPlan(const std::string& plan_name, const Plan& plan) { + HashMap> machine_id2thrd_id_set; + HashMap, std::vector> mchn_thrd_id2task_protos; + for (const auto& task : plan.task()) { + machine_id2thrd_id_set[task.machine_id()].insert(task.thrd_id()); + mchn_thrd_id2task_protos[std::make_pair(task.machine_id(), task.thrd_id())].emplace_back(task); + } + + HashMap machine_id2thrd_ids; + for (const auto& pair : machine_id2thrd_id_set) { + CHECK(machine_id2thrd_ids.emplace(pair.first, ThrdIds()).second); + std::vector thrd_id_vec(pair.second.begin(), pair.second.end()); + *(machine_id2thrd_ids.at(pair.first).mutable_thrd_id()) = StdVec2PbRf(thrd_id_vec); + } + + ClusterThrdIds cluster_thrd_ids; + *(cluster_thrd_ids.mutable_machine_id2thrd_ids()) = HashMap2PbMap(machine_id2thrd_ids); + Global::Get()->PushKV(plan_name + "_cluster_thrd_ids", cluster_thrd_ids); + + for (const auto& pair : mchn_thrd_id2task_protos) { + SubPlan sub_plan; + *(sub_plan.mutable_task()) = StdVec2PbRpf(pair.second); + Global::Get()->PushKV(sub_plan_key(plan_name, pair.first.first, pair.first.second), + sub_plan); + } + Global::Get()->PushKV(total_mbn_num_key(plan_name), + std::to_string(plan.total_mbn_num())); +} + +void PullPlan(const std::string& plan_name, Plan* plan) { + ClusterThrdIds cluster_thrd_ids; + Global::Get()->PullKV(plan_name + "_cluster_thrd_ids", &cluster_thrd_ids); + PrintProtoToTextFile(cluster_thrd_ids, JoinPath(LogDir(), plan_name + "_cluster_thrd_ids")); + HashMap machine_id2thrd_ids; + machine_id2thrd_ids = PbMap2HashMap(cluster_thrd_ids.machine_id2thrd_ids()); + for (const auto& pair : machine_id2thrd_ids) { + int64_t machine_id = pair.first; + std::vector thrd_id_vec = PbRf2StdVec(pair.second.thrd_id()); + for (auto thrd_id : thrd_id_vec) { + SubPlan sub_plan; + Global::Get()->PullKV(sub_plan_key(plan_name, machine_id, thrd_id), &sub_plan); + plan->mutable_task()->MergeFrom(sub_plan.task()); + } + } + std::string total_mbn_num; + Global::Get()->PullKV(total_mbn_num_key(plan_name), &total_mbn_num); + plan->set_total_mbn_num(oneflow_cast(total_mbn_num)); +} + } // namespace class Oneflow final { @@ -101,50 +157,54 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m Global::New(); // Compile Plan naive_plan; - Plan plan; + Plan mem_shared_plan; + Plan improved_plan; PushAvailableMemDescOfThisMachine(); AvailableMemDesc amd; if (machine_ctx->IsThisMachineMaster()) { naive_plan = Compiler().Compile(); amd = PullAvailableMemDesc(); - plan = Improver().ImproveMemSharedIdOnly(amd, naive_plan); - Global::Get()->PushKV("mem_shared_plan", plan); + mem_shared_plan = Improver().ImproveMemSharedIdOnly(amd, naive_plan); + PushPlan("naive_plan", naive_plan); + PushPlan("mem_shared_plan", mem_shared_plan); } else { - Global::Get()->PullKV("mem_shared_plan", &plan); + PullPlan("naive_plan", &naive_plan); + PullPlan("mem_shared_plan", &mem_shared_plan); } OF_BARRIER(); PrintProtoToTextFile(naive_plan, JoinPath(LogDir(), "naive_plan")); - PrintProtoToTextFile(plan, JoinPath(LogDir(), "mem_shared_plan")); + PrintProtoToTextFile(mem_shared_plan, JoinPath(LogDir(), "mem_shared_plan")); // Experiment Runtime - { Runtime experiment_run(plan, true); } + { Runtime experiment_run(mem_shared_plan, true); } // Improve if (machine_ctx->IsThisMachineMaster()) { PrintProtoToTextFile(amd, JoinPath(LogDir(), "available_mem_desc")); CHECK_GT(amd.machine_amd_size(), 0); - plan = Improver().Improve(amd, naive_plan, - JoinPath(LogDir(), ActEventLogger::experiment_prefix_ - + ActEventLogger::act_event_bin_filename_)); - Global::Get()->PushKV("improved_plan", plan); + improved_plan = + Improver().Improve(amd, naive_plan, + JoinPath(LogDir(), ActEventLogger::experiment_prefix_ + + ActEventLogger::act_event_bin_filename_)); + PushPlan("improved_plan", improved_plan); } else { - Global::Get()->PullKV("improved_plan", &plan); + PullPlan("improved_plan", &improved_plan); } OF_BARRIER(); - PrintProtoToTextFile(plan, JoinPath(LogDir(), "improved_plan")); + PrintProtoToTextFile(improved_plan, JoinPath(LogDir(), "improved_plan")); Global::Get()->Clear(); OF_BARRIER(); // Runtime - { Runtime run(plan, false); } + { Runtime run(improved_plan, false); } if (machine_ctx->IsThisMachineMaster()) { if (Global::Get()->collect_act_event()) { - Global::Get()->Profile(plan, + Global::Get()->Profile(improved_plan, JoinPath(LogDir(), ActEventLogger::act_event_bin_filename_)); } } // Delete All Global Global::Delete(); ctrl_server_.reset(); - if (machine_ctx->IsThisMachineMaster()) { Global::Delete(); } + Global::Delete(); Global::Delete(); Global::Delete(); Global::Delete(); diff --git a/oneflow/core/job/sub_plan.proto b/oneflow/core/job/sub_plan.proto new file mode 100644 index 0000000000..e9c9d2e546 --- /dev/null +++ b/oneflow/core/job/sub_plan.proto @@ -0,0 +1,16 @@ +syntax = "proto2"; +package oneflow; + +import "oneflow/core/job/task.proto"; + +message ThrdIds { + repeated int64 thrd_id = 1; +} + +message ClusterThrdIds { + map machine_id2thrd_ids = 1; +} + +message SubPlan { + repeated TaskProto task = 1; +} -- GitLab