提交 17549fa4 编写于 作者: L lixinqi

refine model_io_job.cpp

上级 ba8cfc3a
......@@ -88,6 +88,8 @@ void MakeModelInitJob(
const std::string& job_name, Job* job,
const HashMap<std::string, OperatorConf>& var_op_name2op_conf,
const HashMap<std::string, ParallelBlobConf>& var_op_name2parallel_blob_conf) {
auto* flag_name2flag_value = job->mutable_job_conf()->mutable_flag_name2flag_value();
(*flag_name2flag_value)["__is_user_function__"].set_at_bool(false);
SetModelIoDefaultJobConf(job->mutable_job_conf(), job_name);
Global<InterUserJobInfo>::Get()->set_global_model_init_job_name(job_name);
JobBuilder job_builder(job);
......@@ -128,6 +130,8 @@ void MakeModelLoadJob(
const std::string& job_name, Job* job,
const HashMap<std::string, OperatorConf>& var_op_name2op_conf,
const HashMap<std::string, ParallelBlobConf>& var_op_name2parallel_blob_conf) {
auto* flag_name2flag_value = job->mutable_job_conf()->mutable_flag_name2flag_value();
(*flag_name2flag_value)["__is_user_function__"].set_at_bool(false);
SetModelIoDefaultJobConf(job->mutable_job_conf(), job_name);
Global<InterUserJobInfo>::Get()->set_global_model_load_job_name(job_name);
JobBuilder job_builder(job);
......@@ -166,6 +170,8 @@ void MakeModelSaveJob(
const std::string& job_name, Job* job,
const HashMap<std::string, OperatorConf>& var_op_name2op_conf,
const HashMap<std::string, ParallelBlobConf>& var_op_name2parallel_blob_conf) {
auto* flag_name2flag_value = job->mutable_job_conf()->mutable_flag_name2flag_value();
(*flag_name2flag_value)["__is_user_function__"].set_at_bool(false);
Global<InterUserJobInfo>::Get()->set_global_model_save_job_name(job_name);
SetModelIoDefaultJobConf(job->mutable_job_conf(), job_name);
JobBuilder job_builder(job);
......
......@@ -679,6 +679,17 @@ REGISTER_FUNCTION_CONFIG_DEF().Bool("__is_user_function__", true, "is user defin
void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
std::vector<std::shared_ptr<Job>> jobs(conf_jobs.size());
FOR_RANGE(int, i, 0, jobs.size()) { jobs.at(i).reset(new Job(conf_jobs.Get(i))); }
if (Global<MachineCtx>::Get()->IsThisMachineMaster()) {
HashMap<std::string, ParallelBlobConf> var_op_name2parallel_blob_conf;
FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, jobs,
&var_op_name2parallel_blob_conf);
auto AppendJob = [&](Job* job) {
JobDesc job_desc(job->job_conf(), jobs.size());
CHECK(!job_desc.Bool("__is_user_function__"));
jobs.emplace_back(new Job(*job));
};
MakeModelIoJobs(jobs, var_op_name2parallel_blob_conf, AppendJob);
}
std::vector<std::shared_ptr<Job>> function_jobs;
function_jobs.reserve(jobs.size());
FOR_RANGE(int, i, 0, jobs.size()) {
......@@ -692,9 +703,6 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
HashMap<std::string, ParallelBlobConf> pull_op_name2parallel_blob_conf;
FilterOpName2ParallelBlobConf({OperatorConf::kReturnConf}, function_jobs,
&pull_op_name2parallel_blob_conf);
HashMap<std::string, ParallelBlobConf> var_op_name2parallel_blob_conf;
FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, function_jobs,
&var_op_name2parallel_blob_conf);
for (const auto& pair : push_op_name2parallel_blob_conf) {
auto push_job = std::make_shared<Job>();
MakePushJob(std::string("System-Push-") + pair.first, pair.first, pair.second,
......@@ -707,8 +715,6 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
pull_job.get());
jobs.emplace_back(pull_job);
}
MakeModelIoJobs(function_jobs, var_op_name2parallel_blob_conf,
[&](Job* job) { jobs.emplace_back(new Job(*job)); });
}
std::vector<Plan> sub_plans(jobs.size());
FOR_RANGE(int64_t, i, 0, jobs.size()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册