From 17549fa4e8a7c7233b0e1a2e9f23577b7545d48a Mon Sep 17 00:00:00 2001 From: lixinqi Date: Wed, 8 Jan 2020 22:59:19 +0800 Subject: [PATCH] refine model_io_job.cpp --- oneflow/core/job/model_io_job.cpp | 6 ++++++ oneflow/core/job/oneflow.cpp | 16 +++++++++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/oneflow/core/job/model_io_job.cpp b/oneflow/core/job/model_io_job.cpp index 9cd39f24a9..f118d9146b 100644 --- a/oneflow/core/job/model_io_job.cpp +++ b/oneflow/core/job/model_io_job.cpp @@ -88,6 +88,8 @@ void MakeModelInitJob( const std::string& job_name, Job* job, const HashMap& var_op_name2op_conf, const HashMap& var_op_name2parallel_blob_conf) { + auto* flag_name2flag_value = job->mutable_job_conf()->mutable_flag_name2flag_value(); + (*flag_name2flag_value)["__is_user_function__"].set_at_bool(false); SetModelIoDefaultJobConf(job->mutable_job_conf(), job_name); Global::Get()->set_global_model_init_job_name(job_name); JobBuilder job_builder(job); @@ -128,6 +130,8 @@ void MakeModelLoadJob( const std::string& job_name, Job* job, const HashMap& var_op_name2op_conf, const HashMap& var_op_name2parallel_blob_conf) { + auto* flag_name2flag_value = job->mutable_job_conf()->mutable_flag_name2flag_value(); + (*flag_name2flag_value)["__is_user_function__"].set_at_bool(false); SetModelIoDefaultJobConf(job->mutable_job_conf(), job_name); Global::Get()->set_global_model_load_job_name(job_name); JobBuilder job_builder(job); @@ -166,6 +170,8 @@ void MakeModelSaveJob( const std::string& job_name, Job* job, const HashMap& var_op_name2op_conf, const HashMap& var_op_name2parallel_blob_conf) { + auto* flag_name2flag_value = job->mutable_job_conf()->mutable_flag_name2flag_value(); + (*flag_name2flag_value)["__is_user_function__"].set_at_bool(false); Global::Get()->set_global_model_save_job_name(job_name); SetModelIoDefaultJobConf(job->mutable_job_conf(), job_name); JobBuilder job_builder(job); diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp index 25725f5c99..88c1ebec60 100644 --- a/oneflow/core/job/oneflow.cpp +++ b/oneflow/core/job/oneflow.cpp @@ -679,6 +679,17 @@ REGISTER_FUNCTION_CONFIG_DEF().Bool("__is_user_function__", true, "is user defin void CompileAndMergePlanOnMaster(const PbRpf& conf_jobs, Plan* plan) { std::vector> jobs(conf_jobs.size()); FOR_RANGE(int, i, 0, jobs.size()) { jobs.at(i).reset(new Job(conf_jobs.Get(i))); } + if (Global::Get()->IsThisMachineMaster()) { + HashMap var_op_name2parallel_blob_conf; + FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, jobs, + &var_op_name2parallel_blob_conf); + auto AppendJob = [&](Job* job) { + JobDesc job_desc(job->job_conf(), jobs.size()); + CHECK(!job_desc.Bool("__is_user_function__")); + jobs.emplace_back(new Job(*job)); + }; + MakeModelIoJobs(jobs, var_op_name2parallel_blob_conf, AppendJob); + } std::vector> function_jobs; function_jobs.reserve(jobs.size()); FOR_RANGE(int, i, 0, jobs.size()) { @@ -692,9 +703,6 @@ void CompileAndMergePlanOnMaster(const PbRpf& conf_jobs, Plan* plan) { HashMap pull_op_name2parallel_blob_conf; FilterOpName2ParallelBlobConf({OperatorConf::kReturnConf}, function_jobs, &pull_op_name2parallel_blob_conf); - HashMap var_op_name2parallel_blob_conf; - FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, function_jobs, - &var_op_name2parallel_blob_conf); for (const auto& pair : push_op_name2parallel_blob_conf) { auto push_job = std::make_shared(); MakePushJob(std::string("System-Push-") + pair.first, pair.first, pair.second, @@ -707,8 +715,6 @@ void CompileAndMergePlanOnMaster(const PbRpf& conf_jobs, Plan* plan) { pull_job.get()); jobs.emplace_back(pull_job); } - MakeModelIoJobs(function_jobs, var_op_name2parallel_blob_conf, - [&](Job* job) { jobs.emplace_back(new Job(*job)); }); } std::vector sub_plans(jobs.size()); FOR_RANGE(int64_t, i, 0, jobs.size()) { -- GitLab