From ba8cfc3a96ab2b1919cd5038c00bafd0be084f9e Mon Sep 17 00:00:00 2001 From: lixinqi Date: Wed, 8 Jan 2020 21:54:31 +0800 Subject: [PATCH] refine signature of InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs --- .../core/job/inter_job_mem_sharing_util.cpp | 5 ++--- oneflow/core/job/inter_job_mem_sharing_util.h | 4 ++-- oneflow/core/job/job_build_and_infer_ctx.cpp | 4 +--- oneflow/core/job/oneflow.cpp | 19 +++++++++++++------ 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/oneflow/core/job/inter_job_mem_sharing_util.cpp b/oneflow/core/job/inter_job_mem_sharing_util.cpp index fb5d5d0b1c..803ca317dc 100644 --- a/oneflow/core/job/inter_job_mem_sharing_util.cpp +++ b/oneflow/core/job/inter_job_mem_sharing_util.cpp @@ -337,9 +337,8 @@ void InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs( } void InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs( - const std::vector>& jobs, Plan* plan, int64_t user_job_size) { - if (user_job_size == 1) { return; } - std::vector> user_jobs(jobs.begin(), jobs.begin() + user_job_size); + const std::vector>& user_jobs, Plan* plan) { + if (user_jobs.size() == 1) { return; } std::vector> reuse_mem_job_groups = GetMutualExclusionJobGroups(user_jobs); HashMap chunk_id2chunk; diff --git a/oneflow/core/job/inter_job_mem_sharing_util.h b/oneflow/core/job/inter_job_mem_sharing_util.h index 6ce9d08081..0c55eb05e3 100644 --- a/oneflow/core/job/inter_job_mem_sharing_util.h +++ b/oneflow/core/job/inter_job_mem_sharing_util.h @@ -10,8 +10,8 @@ struct InterJobMemSharingUtil { static void MergeMemSharedInterfaceMemBlockBetweenJobs( const std::vector>& jobs, Plan* plan); - static void MergeMemReusedChunkBetweenUserJobs(const std::vector>& jobs, - Plan* plan, int64_t user_job_size); + static void MergeMemReusedChunkBetweenUserJobs(const std::vector>& user_jobs, + Plan* plan); }; } // namespace oneflow diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index f177e60c2a..04e18f0b10 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -41,14 +41,12 @@ Maybe JobBuildAndInferCtx::SetJobConf(const JobConfigProto& job_conf) { return Maybe::Ok(); } -REGISTER_FUNCTION_CONFIG_DEF().Bool("is_user_function", true, "is user defined function"); - Maybe JobBuildAndInferCtx::Complete() { CHECK_NOTNULL(Global::Get()); Global::Delete(); auto scope = std::make_unique(job_->job_conf(), job_id_); auto DoPass = [&](const std::string& pass_name) { FunctionPass(pass_name)(job_); }; - if (GlobalJobDesc().Bool("is_user_function")) { + if (GlobalJobDesc().Bool("__is_user_function__")) { DoPass("CompleteOfrecordDecoder"); DoPass("SetDefaultVariableConf"); DoPass("AutoMixedPrecision"); diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp index 9c92715446..25725f5c99 100644 --- a/oneflow/core/job/oneflow.cpp +++ b/oneflow/core/job/oneflow.cpp @@ -674,19 +674,26 @@ void MakePushJob(const std::string& job_name, const std::string& op_name, job_conf->set_default_data_type(data_type); } +REGISTER_FUNCTION_CONFIG_DEF().Bool("__is_user_function__", true, "is user defined function"); + void CompileAndMergePlanOnMaster(const PbRpf& conf_jobs, Plan* plan) { - size_t user_job_size = conf_jobs.size(); std::vector> jobs(conf_jobs.size()); FOR_RANGE(int, i, 0, jobs.size()) { jobs.at(i).reset(new Job(conf_jobs.Get(i))); } + std::vector> function_jobs; + function_jobs.reserve(jobs.size()); + FOR_RANGE(int, i, 0, jobs.size()) { + JobDesc job_desc(jobs.at(i)->job_conf(), i); + if (job_desc.Bool("__is_user_function__")) { function_jobs.push_back(jobs.at(i)); } + } if (Global::Get()->IsThisMachineMaster()) { HashMap push_op_name2parallel_blob_conf; - FilterOpName2ParallelBlobConf({OperatorConf::kInputConf}, jobs, + FilterOpName2ParallelBlobConf({OperatorConf::kInputConf}, function_jobs, &push_op_name2parallel_blob_conf); HashMap pull_op_name2parallel_blob_conf; - FilterOpName2ParallelBlobConf({OperatorConf::kReturnConf}, jobs, + FilterOpName2ParallelBlobConf({OperatorConf::kReturnConf}, function_jobs, &pull_op_name2parallel_blob_conf); HashMap var_op_name2parallel_blob_conf; - FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, jobs, + FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, function_jobs, &var_op_name2parallel_blob_conf); for (const auto& pair : push_op_name2parallel_blob_conf) { auto push_job = std::make_shared(); @@ -700,7 +707,7 @@ void CompileAndMergePlanOnMaster(const PbRpf& conf_jobs, Plan* plan) { pull_job.get()); jobs.emplace_back(pull_job); } - MakeModelIoJobs(jobs, var_op_name2parallel_blob_conf, + MakeModelIoJobs(function_jobs, var_op_name2parallel_blob_conf, [&](Job* job) { jobs.emplace_back(new Job(*job)); }); } std::vector sub_plans(jobs.size()); @@ -711,7 +718,7 @@ void CompileAndMergePlanOnMaster(const PbRpf& conf_jobs, Plan* plan) { } if (Global::Get()->IsThisMachineMaster()) { MergeSubPlanWithoutGenNetTopo(plan, sub_plans); - InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs(jobs, plan, user_job_size); + InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs(function_jobs, plan); InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs(jobs, plan); FinishGlobalCriticalSectionDesc(*plan, jobs.size()); Plan main_plan; -- GitLab