未验证 提交 b21c190b 编写于 作者: J Jinhui Yuan 提交者: GitHub

Partition plan (#1138)

* fix typo

* rm useless IsThisMachineMaster

* refine the var name of naive_plan, mem_shared_plan, improved_plan

* refactor PushPlan and PullPlan

* let master node broadcast subplans instead the whole plan

* remove useless code

* rm useless code

* use total_mbn_name_key
上级 be213666
......@@ -283,7 +283,7 @@ void CollectTailRegstConsumerTaskIds(const std::vector<const RegstDescProto*>& s
void CollectSinkTaskIds(const HashSet<int64_t>& task_ids,
const std::function<bool(int64_t, int64_t)>& IsReachable,
std::list<int64_t>* sink_task_ids) {
auto IsReachableToAnyOherTask = [&](int64_t src_task_id) -> bool {
auto IsReachableToAnyOtherTask = [&](int64_t src_task_id) -> bool {
for (int64_t dst_task_id : task_ids) {
if (src_task_id == dst_task_id) { continue; }
if (IsReachable(src_task_id, dst_task_id)) { return true; }
......@@ -292,7 +292,7 @@ void CollectSinkTaskIds(const HashSet<int64_t>& task_ids,
};
sink_task_ids->clear();
for (int64_t src_task_id : task_ids) {
if (!IsReachableToAnyOherTask(src_task_id)) { sink_task_ids->push_back(src_task_id); }
if (!IsReachableToAnyOtherTask(src_task_id)) { sink_task_ids->push_back(src_task_id); }
}
}
......
......@@ -6,6 +6,7 @@
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/machine_context.h"
#include "oneflow/core/job/profiler.h"
#include "oneflow/core/job/sub_plan.pb.h"
#include "oneflow/core/job/plan.pb.h"
#include "oneflow/core/job/runtime.h"
#include "oneflow/core/job/available_memory_desc.pb.h"
......@@ -76,6 +77,61 @@ void FixCpuDeviceNum() {
Global<JobDesc>::Get()->SetCpuDeviceNum(cpu_device_num);
}
std::string sub_plan_key(const std::string& plan_name, int64_t machine_id, int64_t thrd_id) {
return plan_name + "_" + std::to_string(machine_id) + "_" + std::to_string(thrd_id);
}
std::string total_mbn_num_key(const std::string& plan_name) { return plan_name + "_total_mbn_num"; }
void PushPlan(const std::string& plan_name, const Plan& plan) {
HashMap<int64_t, std::set<int64_t>> machine_id2thrd_id_set;
HashMap<std::pair<int64_t, int64_t>, std::vector<TaskProto>> mchn_thrd_id2task_protos;
for (const auto& task : plan.task()) {
machine_id2thrd_id_set[task.machine_id()].insert(task.thrd_id());
mchn_thrd_id2task_protos[std::make_pair(task.machine_id(), task.thrd_id())].emplace_back(task);
}
HashMap<int64_t, ThrdIds> machine_id2thrd_ids;
for (const auto& pair : machine_id2thrd_id_set) {
CHECK(machine_id2thrd_ids.emplace(pair.first, ThrdIds()).second);
std::vector<int64_t> thrd_id_vec(pair.second.begin(), pair.second.end());
*(machine_id2thrd_ids.at(pair.first).mutable_thrd_id()) = StdVec2PbRf(thrd_id_vec);
}
ClusterThrdIds cluster_thrd_ids;
*(cluster_thrd_ids.mutable_machine_id2thrd_ids()) = HashMap2PbMap(machine_id2thrd_ids);
Global<CtrlClient>::Get()->PushKV(plan_name + "_cluster_thrd_ids", cluster_thrd_ids);
for (const auto& pair : mchn_thrd_id2task_protos) {
SubPlan sub_plan;
*(sub_plan.mutable_task()) = StdVec2PbRpf(pair.second);
Global<CtrlClient>::Get()->PushKV(sub_plan_key(plan_name, pair.first.first, pair.first.second),
sub_plan);
}
Global<CtrlClient>::Get()->PushKV(total_mbn_num_key(plan_name),
std::to_string(plan.total_mbn_num()));
}
void PullPlan(const std::string& plan_name, Plan* plan) {
ClusterThrdIds cluster_thrd_ids;
Global<CtrlClient>::Get()->PullKV(plan_name + "_cluster_thrd_ids", &cluster_thrd_ids);
PrintProtoToTextFile(cluster_thrd_ids, JoinPath(LogDir(), plan_name + "_cluster_thrd_ids"));
HashMap<int64_t, ThrdIds> machine_id2thrd_ids;
machine_id2thrd_ids = PbMap2HashMap(cluster_thrd_ids.machine_id2thrd_ids());
for (const auto& pair : machine_id2thrd_ids) {
int64_t machine_id = pair.first;
std::vector<int64_t> thrd_id_vec = PbRf2StdVec(pair.second.thrd_id());
for (auto thrd_id : thrd_id_vec) {
SubPlan sub_plan;
Global<CtrlClient>::Get()->PullKV(sub_plan_key(plan_name, machine_id, thrd_id), &sub_plan);
plan->mutable_task()->MergeFrom(sub_plan.task());
}
}
std::string total_mbn_num;
Global<CtrlClient>::Get()->PullKV(total_mbn_num_key(plan_name), &total_mbn_num);
plan->set_total_mbn_num(oneflow_cast<int64_t>(total_mbn_num));
}
} // namespace
class Oneflow final {
......@@ -101,50 +157,54 @@ Oneflow::Oneflow(const std::string& job_conf_filepath, const std::string& this_m
Global<IDMgr>::New();
// Compile
Plan naive_plan;
Plan plan;
Plan mem_shared_plan;
Plan improved_plan;
PushAvailableMemDescOfThisMachine();
AvailableMemDesc amd;
if (machine_ctx->IsThisMachineMaster()) {
naive_plan = Compiler().Compile();
amd = PullAvailableMemDesc();
plan = Improver().ImproveMemSharedIdOnly(amd, naive_plan);
Global<CtrlClient>::Get()->PushKV("mem_shared_plan", plan);
mem_shared_plan = Improver().ImproveMemSharedIdOnly(amd, naive_plan);
PushPlan("naive_plan", naive_plan);
PushPlan("mem_shared_plan", mem_shared_plan);
} else {
Global<CtrlClient>::Get()->PullKV("mem_shared_plan", &plan);
PullPlan("naive_plan", &naive_plan);
PullPlan("mem_shared_plan", &mem_shared_plan);
}
OF_BARRIER();
PrintProtoToTextFile(naive_plan, JoinPath(LogDir(), "naive_plan"));
PrintProtoToTextFile(plan, JoinPath(LogDir(), "mem_shared_plan"));
PrintProtoToTextFile(mem_shared_plan, JoinPath(LogDir(), "mem_shared_plan"));
// Experiment Runtime
{ Runtime experiment_run(plan, true); }
{ Runtime experiment_run(mem_shared_plan, true); }
// Improve
if (machine_ctx->IsThisMachineMaster()) {
PrintProtoToTextFile(amd, JoinPath(LogDir(), "available_mem_desc"));
CHECK_GT(amd.machine_amd_size(), 0);
plan = Improver().Improve(amd, naive_plan,
JoinPath(LogDir(), ActEventLogger::experiment_prefix_
+ ActEventLogger::act_event_bin_filename_));
Global<CtrlClient>::Get()->PushKV("improved_plan", plan);
improved_plan =
Improver().Improve(amd, naive_plan,
JoinPath(LogDir(), ActEventLogger::experiment_prefix_
+ ActEventLogger::act_event_bin_filename_));
PushPlan("improved_plan", improved_plan);
} else {
Global<CtrlClient>::Get()->PullKV("improved_plan", &plan);
PullPlan("improved_plan", &improved_plan);
}
OF_BARRIER();
PrintProtoToTextFile(plan, JoinPath(LogDir(), "improved_plan"));
PrintProtoToTextFile(improved_plan, JoinPath(LogDir(), "improved_plan"));
Global<CtrlClient>::Get()->Clear();
OF_BARRIER();
// Runtime
{ Runtime run(plan, false); }
{ Runtime run(improved_plan, false); }
if (machine_ctx->IsThisMachineMaster()) {
if (Global<JobDesc>::Get()->collect_act_event()) {
Global<Profiler>::Get()->Profile(plan,
Global<Profiler>::Get()->Profile(improved_plan,
JoinPath(LogDir(), ActEventLogger::act_event_bin_filename_));
}
}
// Delete All Global
Global<CtrlClient>::Delete();
ctrl_server_.reset();
if (machine_ctx->IsThisMachineMaster()) { Global<Profiler>::Delete(); }
Global<Profiler>::Delete();
Global<MachineCtx>::Delete();
Global<IDMgr>::Delete();
Global<JobDesc>::Delete();
......
syntax = "proto2";
package oneflow;
import "oneflow/core/job/task.proto";
message ThrdIds {
repeated int64 thrd_id = 1;
}
message ClusterThrdIds {
map<int64, ThrdIds> machine_id2thrd_ids = 1;
}
message SubPlan {
repeated TaskProto task = 1;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册