提交 0fc062e7 编写于 作者: W willzhang4a58

scheduler


Former-commit-id: c989a9cc
上级 9cebdb9a
# main cpp
list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/oneflow/core/job/compiler.cpp)
list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/oneflow/core/job/runtime.cpp)
list(APPEND of_main_cc ${PROJECT_SOURCE_DIR}/oneflow/core/job/scheduler.cpp)
# source_group
if(WIN32)
......
......@@ -4,10 +4,6 @@ rm -rf ./predict_log ./core.*
mkdir predict_log
GLOG_logtostderr=0 GLOG_log_dir=./predict_log ./compiler \
-job_conf_filepath="./predict_job.prototxt" \
-plan_filepath="./predict_plan" \
GLOG_logtostderr=0 GLOG_log_dir=./predict_log GLOG_logbuflevel=-1 GLOG_v=0 ./runtime \
-plan_filepath="./predict_plan" \
GLOG_logtostderr=0 GLOG_log_dir=./predict_log GLOG_logbuflevel=-1 GLOG_v=0 ./scheduler \
-job_conf="./predict_job.prototxt" \
-this_machine_name="centos-0"
machine {
addr: ""
port: ""
data_port: 9000
ctrl_port: 9001
name: "centos-0"
}
......
../../build/scheduler
\ No newline at end of file
......@@ -4,10 +4,6 @@ rm -rf ./train_log ./core.* ./snapshots
mkdir train_log
GLOG_logtostderr=0 GLOG_log_dir=./train_log ./compiler \
-job_conf_filepath="./train_job.prototxt" \
-plan_filepath="./train_plan" \
GLOG_logtostderr=0 GLOG_log_dir=./train_log GLOG_logbuflevel=-1 GLOG_v=0 ./runtime \
-plan_filepath="./train_plan" \
GLOG_logtostderr=0 GLOG_log_dir=./train_log GLOG_logbuflevel=-1 GLOG_v=0 ./scheduler \
-job_conf="./train_job.prototxt" \
-this_machine_name="centos-0"
......@@ -37,10 +37,21 @@ namespace oneflow {
#define TODO() LOG(FATAL) << "TODO";
#define OF_SINGLETON(ClassName) \
static ClassName* Singleton() { \
static ClassName* ptr = new ClassName; \
return ptr; \
#define OF_SINGLETON(ClassName) \
static ClassName* Singleton() { return *SingletonPPtr(); } \
static ClassName** SingletonPPtr() { \
static ClassName* ptr = new ClassName; \
return &ptr; \
} \
static void RefreshSingleton() { \
DeleteSingleton(); \
*SingletonPPtr() = new ClassName; \
} \
static void DeleteSingleton() { \
if (Singleton()) { \
delete Singleton(); \
*SingletonPPtr() = nullptr; \
} \
}
#define COMMAND(...) \
......
#include "gflags/gflags.h"
#include "oneflow/core/job/compiler.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/str_util.h"
#include "oneflow/core/graph/data_comp_task_node.h"
#include "oneflow/core/graph/data_task_graph.h"
#include "oneflow/core/graph/loss_accumulate_task_graph.h"
#include "oneflow/core/graph/loss_record_task_graph.h"
#include "oneflow/core/graph/model_diff_accumulate_task_graph.h"
#include "oneflow/core/graph/model_save_comp_task_node.h"
#include "oneflow/core/graph/model_save_task_graph.h"
#include "oneflow/core/graph/model_update_task_graph.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/job/job_conf.pb.h"
#include "oneflow/core/job/plan.pb.h"
#include "oneflow/core/register/register_desc.h"
namespace oneflow {
class Compiler final {
public:
OF_DISALLOW_COPY_AND_MOVE(Compiler);
~Compiler() = default;
OF_SINGLETON(Compiler);
void Compile(const JobConf& job_conf, const std::string& plan_filepath);
private:
Compiler() = default;
void ConstForEachChainNode(std::function<void(const ChainNode*)> func);
void ConstForEachStageNode(std::function<void(const StageNode*)> func);
void ForEachTaskNode(std::function<void(TaskNode*)> func);
void BuildGraphs();
void BuildModelGraphs(
const std::pair<const ChainNode*, std::vector<CompTaskNode*>>&);
void BuildLossGraph(
const std::pair<const ChainNode*, std::vector<CompTaskNode*>>& pair);
void InferBlobDesc4Regsts();
void EraseMeaningLessRegsts();
void GenPlanFile(const std::string& plan_filepath);
void Plan2DotFile(const Plan& plan);
std::vector<std::unique_ptr<TaskGraph>> ordered_task_gphs_;
};
void Compiler::ConstForEachChainNode(
std::function<void(const ChainNode*)> func) {
for (const auto& task_gph : ordered_task_gphs_) {
......@@ -67,14 +34,11 @@ void Compiler::ForEachTaskNode(std::function<void(TaskNode*)> func) {
}
// TODO: inference "register_num for each register_desc"
void Compiler::Compile(const JobConf& job_conf,
const std::string& plan_filepath) {
JobDesc::Singleton()->InitFromJobConf(job_conf);
IDMgr::Singleton()->InitFromResource(JobDesc::Singleton()->resource());
Plan Compiler::Compile() {
BuildGraphs();
InferBlobDesc4Regsts();
EraseMeaningLessRegsts();
GenPlanFile(plan_filepath);
return GenPlanFile();
}
void Compiler::BuildGraphs() {
......@@ -179,7 +143,7 @@ void Compiler::EraseMeaningLessRegsts() {
});
}
void Compiler::GenPlanFile(const std::string& plan_filepath) {
Plan Compiler::GenPlanFile() {
HashMap<const ChainNode*, int64_t> chain2meaningless_task_cnt;
ForEachTaskNode([&](const TaskNode* node) {
auto comp_task_node = dynamic_cast<const DataCompTaskNode*>(node);
......@@ -224,8 +188,9 @@ void Compiler::GenPlanFile(const std::string& plan_filepath) {
// TODO: unique
});
});
PrintProtoToTextFile(plan, plan_filepath);
Plan2DotFile(plan);
OpMgr::DeleteSingleton();
return plan;
}
void Compiler::Plan2DotFile(const Plan& plan) {
......@@ -264,17 +229,3 @@ void Compiler::Plan2DotFile(const Plan& plan) {
}
} // namespace oneflow
DEFINE_string(job_conf_filepath, "", "");
DEFINE_string(plan_filepath, "", "");
int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]);
google::ParseCommandLineFlags(&argc, &argv, true);
LOG(INFO) << "Compiler Starting Up";
oneflow::JobConf job_conf;
oneflow::ParseProtoFromTextFile(FLAGS_job_conf_filepath, &job_conf);
oneflow::Compiler::Singleton()->Compile(job_conf, FLAGS_plan_filepath);
LOG(INFO) << "Compiler Shutting Down";
return 0;
}
#ifndef ONEFLOW_CORE_JOB_COMPILER_H_
#define ONEFLOW_CORE_JOB_COMPILER_H_
#include "oneflow/core/graph/data_comp_task_node.h"
#include "oneflow/core/graph/data_task_graph.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/job/job_conf.pb.h"
#include "oneflow/core/job/plan.pb.h"
namespace oneflow {
class Compiler final {
public:
OF_DISALLOW_COPY_AND_MOVE(Compiler);
~Compiler() = default;
OF_SINGLETON(Compiler);
Plan Compile();
private:
Compiler() = default;
void InitRelatedSingleton(const JobConf& job_conf);
void ConstForEachChainNode(std::function<void(const ChainNode*)> func);
void ConstForEachStageNode(std::function<void(const StageNode*)> func);
void ForEachTaskNode(std::function<void(TaskNode*)> func);
void BuildGraphs();
void BuildModelGraphs(
const std::pair<const ChainNode*, std::vector<CompTaskNode*>>&);
void BuildLossGraph(
const std::pair<const ChainNode*, std::vector<CompTaskNode*>>& pair);
void InferBlobDesc4Regsts();
void EraseMeaningLessRegsts();
Plan GenPlanFile();
void Plan2DotFile(const Plan& plan);
std::vector<std::unique_ptr<TaskGraph>> ordered_task_gphs_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPILER_H_
......@@ -2,6 +2,7 @@
#define ONEFLOW_CORE_JOB_ID_MANAGER_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/resource.pb.h"
namespace oneflow {
......@@ -13,9 +14,10 @@ class IDMgr final {
OF_SINGLETON(IDMgr);
void InitFromResource(const Resource& resource) {
void Init() {
LOG(INFO) << "Init IDManager";
Clear();
const Resource& resource = JobDesc::Singleton()->resource();
machine_num_ = resource.machine_size();
CHECK_LT(machine_num_, static_cast<int64_t>(1) << machine_id_bit_num_);
device_num_per_machine_ = resource.device_num_per_machine();
......
......@@ -3,8 +3,9 @@ package oneflow;
message Machine {
string addr = 1; // domain name or ip
string port = 2;
string name = 3;
uint32 ctrl_port = 2;
uint32 data_port = 3;
string name = 4;
}
enum DeviceType {
......
#include "gflags/gflags.h"
#include "oneflow/core/actor/actor_message_bus.h"
#include "oneflow/core/comm_network/rdma_comm_network.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/plan.pb.h"
#include "oneflow/core/job/runtime_context.h"
#include "oneflow/core/kernel/kernel_manager.h"
#include "oneflow/core/thread/thread_manager.h"
namespace oneflow {
class Runtime final {
public:
OF_DISALLOW_COPY_AND_MOVE(Runtime);
~Runtime() = default;
OF_SINGLETON(Runtime);
void Run(const Plan& plan, const std::string& this_machine_name) {
InitSingleton(plan, this_machine_name);
// find tasks on this machine
std::vector<const TaskProto*> mdupdt_tasks;
std::vector<const TaskProto*> source_tasks;
std::vector<const TaskProto*> other_tasks;
for (const TaskProto& task : plan.task()) {
if (task.machine_id() != RuntimeCtx::Singleton()->this_machine_id()) {
continue;
}
if (task.type() == kMdUpdtCompTask) {
mdupdt_tasks.push_back(&task);
} else if (task.consumed_regst_desc_id().empty()) {
source_tasks.push_back(&task);
} else {
other_tasks.push_back(&task);
}
}
size_t this_machine_task_num =
mdupdt_tasks.size() + source_tasks.size() + other_tasks.size();
LOG(INFO) << "number of mdupdt tasks is " << mdupdt_tasks.size();
LOG(INFO) << "number of source tasks is " << source_tasks.size();
LOG(INFO) << "number of other tasks is " << other_tasks.size();
RuntimeCtx::Singleton()->mut_inactive_actor_cnt().Init(
"inactive_actor_cnt", this_machine_task_num);
RuntimeCtx::Singleton()->mut_model_init_cnt().Init("model_init_cnt",
mdupdt_tasks.size());
HandoutTasks(mdupdt_tasks);
SendCmdMsg(mdupdt_tasks, ActorCmd::kInitializeModel);
RuntimeCtx::Singleton()->mut_model_init_cnt().WaitUntilCntEqualZero();
LOG(INFO) << "InitModel on this machine done";
OF_BARRIER();
LOG(INFO) << "InitModel on all machine done";
HandoutTasks(source_tasks);
HandoutTasks(other_tasks);
RuntimeCtx::Singleton()->mut_inactive_actor_cnt().WaitUntilCntEqualZero();
LOG(INFO) << "All actor on this machine are activated";
OF_BARRIER();
LOG(INFO) << "All actor on all machine are activated";
CommNet::Singleton()->RegisterMemoryDone();
RuntimeCtx::Singleton()->mut_active_actor_cnt().Init("active_actor_cnt",
this_machine_task_num);
SendCmdMsg(mdupdt_tasks, ActorCmd::kSendInitialModel);
SendCmdMsg(source_tasks, ActorCmd::kStart);
RuntimeCtx::Singleton()->mut_active_actor_cnt().WaitUntilCntEqualZero();
DeleteSingleton();
}
private:
Runtime() = default;
void InitSingleton(const Plan& plan, const std::string& this_machine_name) {
JobDesc::Singleton()->InitFromProto(plan.job_desc());
IDMgr::Singleton()->InitFromResource(JobDesc::Singleton()->resource());
RuntimeCtx::Singleton()->set_this_machine_name(this_machine_name);
KernelMgr::Singleton()->InitFromPlan(plan);
RdmaCommNet::Init();
SnapshotMgr::Singleton()->Init(plan);
ActorMsgBus::Singleton()->Init();
ThreadMgr::Singleton();
}
void DeleteSingleton() {
delete ThreadMgr::Singleton();
delete ActorMsgBus::Singleton();
delete SnapshotMgr::Singleton();
}
void HandoutTasks(const std::vector<const TaskProto*>& tasks) {
for (const TaskProto* task : tasks) {
ThreadMgr::Singleton()->GetThrd(task->thrd_local_id())->AddTask(*task);
}
SendCmdMsg(tasks, ActorCmd::kActivateActor);
}
void SendCmdMsg(const std::vector<const TaskProto*>& tasks, ActorCmd cmd) {
for (const TaskProto* task : tasks) {
ActorMsg msg = ActorMsg::BuildCommandMsg(
IDMgr::Singleton()->ActorId4TaskId(task->id()), cmd);
;
ActorMsgBus::Singleton()->SendMsg(msg);
}
}
};
} // namespace oneflow
DEFINE_string(plan_filepath, "", "");
DEFINE_string(this_machine_name, "", "");
int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]);
google::ParseCommandLineFlags(&argc, &argv, true);
LOG(INFO) << "Runtime Starting Up";
oneflow::Plan plan;
LOG(INFO) << "Parse Plan File";
oneflow::ParseProtoFromTextFile(FLAGS_plan_filepath, &plan);
oneflow::Runtime::Singleton()->Run(plan, FLAGS_this_machine_name);
LOG(INFO) << "Runtime Shutting Down";
return 0;
}
......@@ -15,8 +15,8 @@ class RuntimeCtx final {
OF_SINGLETON(RuntimeCtx);
int64_t this_machine_id() const { return this_machine_id_; }
void set_this_machine_name(const std::string& name);
bool IsThisMachineMaster() const { return this_machine_id_ == 0; }
ThreadSafeCounter& mut_model_init_cnt() { return model_init_cnt_; }
......
#include "oneflow/core/job/scheduler.h"
#include "oneflow/core/comm_network/rdma_comm_network.h"
#include "oneflow/core/job/compiler.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/runtime_context.h"
#include "oneflow/core/kernel/kernel_manager.h"
#include "oneflow/core/thread/thread_manager.h"
namespace oneflow {
void Scheduler::Process(const JobConf& job_conf,
const std::string& this_machine_name) {
Plan plan = GetPlanFromJobConf(job_conf, this_machine_name);
// find tasks on this machine
std::vector<const TaskProto*> mdupdt_tasks;
std::vector<const TaskProto*> source_tasks;
std::vector<const TaskProto*> other_tasks;
for (const TaskProto& task : plan.task()) {
if (task.machine_id() != RuntimeCtx::Singleton()->this_machine_id()) {
continue;
}
if (task.type() == kMdUpdtCompTask) {
mdupdt_tasks.push_back(&task);
} else if (task.consumed_regst_desc_id().empty()) {
source_tasks.push_back(&task);
} else {
other_tasks.push_back(&task);
}
}
size_t this_machine_task_num =
mdupdt_tasks.size() + source_tasks.size() + other_tasks.size();
LOG(INFO) << "number of mdupdt tasks is " << mdupdt_tasks.size();
LOG(INFO) << "number of source tasks is " << source_tasks.size();
LOG(INFO) << "number of other tasks is " << other_tasks.size();
RuntimeCtx::Singleton()->mut_inactive_actor_cnt().Init("inactive_actor_cnt",
this_machine_task_num);
RuntimeCtx::Singleton()->mut_model_init_cnt().Init("model_init_cnt",
mdupdt_tasks.size());
HandoutTasks(mdupdt_tasks);
SendCmdMsg(mdupdt_tasks, ActorCmd::kInitializeModel);
RuntimeCtx::Singleton()->mut_model_init_cnt().WaitUntilCntEqualZero();
LOG(INFO) << "InitModel on this machine done";
OF_BARRIER();
LOG(INFO) << "InitModel on all machine done";
HandoutTasks(source_tasks);
HandoutTasks(other_tasks);
RuntimeCtx::Singleton()->mut_inactive_actor_cnt().WaitUntilCntEqualZero();
LOG(INFO) << "All actor on this machine are activated";
OF_BARRIER();
LOG(INFO) << "All actor on all machine are activated";
CommNet::Singleton()->RegisterMemoryDone();
RuntimeCtx::Singleton()->mut_active_actor_cnt().Init("active_actor_cnt",
this_machine_task_num);
SendCmdMsg(mdupdt_tasks, ActorCmd::kSendInitialModel);
SendCmdMsg(source_tasks, ActorCmd::kStart);
RuntimeCtx::Singleton()->mut_active_actor_cnt().WaitUntilCntEqualZero();
DeleteAllSingleton();
}
Plan Scheduler::GetPlanFromJobConf(const JobConf& job_conf,
const std::string& this_machine_name) {
JobDesc::Singleton()->InitFromJobConf(job_conf);
IDMgr::Singleton()->Init();
RuntimeCtx::Singleton()->set_this_machine_name(this_machine_name);
// TODO: build rpc connect
Plan plan;
if (RuntimeCtx::Singleton()->IsThisMachineMaster()) {
plan = Compiler::Singleton()->Compile();
OpMgr::RefreshSingleton();
// TODO: send plan
} else {
// TODO: receive plan
}
KernelMgr::Singleton()->InitFromPlan(plan);
RdmaCommNet::Init();
SnapshotMgr::Singleton()->Init(plan);
ActorMsgBus::Singleton()->Init();
ThreadMgr::Singleton();
return plan;
}
void Scheduler::DeleteAllSingleton() {
delete ThreadMgr::Singleton();
delete ActorMsgBus::Singleton();
delete SnapshotMgr::Singleton();
}
void Scheduler::HandoutTasks(const std::vector<const TaskProto*>& tasks) {
for (const TaskProto* task : tasks) {
ThreadMgr::Singleton()->GetThrd(task->thrd_local_id())->AddTask(*task);
}
SendCmdMsg(tasks, ActorCmd::kActivateActor);
}
void Scheduler::SendCmdMsg(const std::vector<const TaskProto*>& tasks,
ActorCmd cmd) {
for (const TaskProto* task : tasks) {
ActorMsg msg = ActorMsg::BuildCommandMsg(
IDMgr::Singleton()->ActorId4TaskId(task->id()), cmd);
;
ActorMsgBus::Singleton()->SendMsg(msg);
}
}
} // namespace oneflow
DEFINE_string(job_conf, "", "");
DEFINE_string(this_machine_name, "", "");
int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]);
google::ParseCommandLineFlags(&argc, &argv, true);
oneflow::JobConf job_conf;
oneflow::ParseProtoFromTextFile(FLAGS_job_conf, &job_conf);
oneflow::Scheduler::Singleton()->Process(job_conf, FLAGS_this_machine_name);
return 0;
}
#ifndef ONEFLOW_CORE_JOB_SCHEDULER_H_
#define ONEFLOW_CORE_JOB_SCHEDULER_H_
#include "oneflow/core/actor/actor_message_bus.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/job/plan.pb.h"
namespace oneflow {
class Scheduler final {
public:
OF_DISALLOW_COPY_AND_MOVE(Scheduler);
~Scheduler() = default;
OF_SINGLETON(Scheduler);
void Process(const JobConf& job_conf, const std::string& this_machine_name);
private:
Scheduler() = default;
Plan GetPlanFromJobConf(const JobConf& job_conf,
const std::string& this_machine_name);
void DeleteAllSingleton();
void HandoutTasks(const std::vector<const TaskProto*>& tasks);
void SendCmdMsg(const std::vector<const TaskProto*>& tasks, ActorCmd cmd);
};
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_SCHEDULER_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册