提交 ba8cfc3a 编写于 作者: L lixinqi

refine signature of InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs

上级 eb320970
...@@ -337,9 +337,8 @@ void InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs( ...@@ -337,9 +337,8 @@ void InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs(
} }
void InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs( void InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs(
const std::vector<std::shared_ptr<Job>>& jobs, Plan* plan, int64_t user_job_size) { const std::vector<std::shared_ptr<Job>>& user_jobs, Plan* plan) {
if (user_job_size == 1) { return; } if (user_jobs.size() == 1) { return; }
std::vector<std::shared_ptr<Job>> user_jobs(jobs.begin(), jobs.begin() + user_job_size);
std::vector<HashSet<int64_t>> reuse_mem_job_groups = GetMutualExclusionJobGroups(user_jobs); std::vector<HashSet<int64_t>> reuse_mem_job_groups = GetMutualExclusionJobGroups(user_jobs);
HashMap<int64_t, ChunkProto> chunk_id2chunk; HashMap<int64_t, ChunkProto> chunk_id2chunk;
......
...@@ -10,8 +10,8 @@ struct InterJobMemSharingUtil { ...@@ -10,8 +10,8 @@ struct InterJobMemSharingUtil {
static void MergeMemSharedInterfaceMemBlockBetweenJobs( static void MergeMemSharedInterfaceMemBlockBetweenJobs(
const std::vector<std::shared_ptr<Job>>& jobs, Plan* plan); const std::vector<std::shared_ptr<Job>>& jobs, Plan* plan);
static void MergeMemReusedChunkBetweenUserJobs(const std::vector<std::shared_ptr<Job>>& jobs, static void MergeMemReusedChunkBetweenUserJobs(const std::vector<std::shared_ptr<Job>>& user_jobs,
Plan* plan, int64_t user_job_size); Plan* plan);
}; };
} // namespace oneflow } // namespace oneflow
......
...@@ -41,14 +41,12 @@ Maybe<void> JobBuildAndInferCtx::SetJobConf(const JobConfigProto& job_conf) { ...@@ -41,14 +41,12 @@ Maybe<void> JobBuildAndInferCtx::SetJobConf(const JobConfigProto& job_conf) {
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
REGISTER_FUNCTION_CONFIG_DEF().Bool("is_user_function", true, "is user defined function");
Maybe<void> JobBuildAndInferCtx::Complete() { Maybe<void> JobBuildAndInferCtx::Complete() {
CHECK_NOTNULL(Global<JobDesc>::Get()); CHECK_NOTNULL(Global<JobDesc>::Get());
Global<JobDesc>::Delete(); Global<JobDesc>::Delete();
auto scope = std::make_unique<GlobalJobDescScope>(job_->job_conf(), job_id_); auto scope = std::make_unique<GlobalJobDescScope>(job_->job_conf(), job_id_);
auto DoPass = [&](const std::string& pass_name) { FunctionPass(pass_name)(job_); }; auto DoPass = [&](const std::string& pass_name) { FunctionPass(pass_name)(job_); };
if (GlobalJobDesc().Bool("is_user_function")) { if (GlobalJobDesc().Bool("__is_user_function__")) {
DoPass("CompleteOfrecordDecoder"); DoPass("CompleteOfrecordDecoder");
DoPass("SetDefaultVariableConf"); DoPass("SetDefaultVariableConf");
DoPass("AutoMixedPrecision"); DoPass("AutoMixedPrecision");
......
...@@ -674,19 +674,26 @@ void MakePushJob(const std::string& job_name, const std::string& op_name, ...@@ -674,19 +674,26 @@ void MakePushJob(const std::string& job_name, const std::string& op_name,
job_conf->set_default_data_type(data_type); job_conf->set_default_data_type(data_type);
} }
REGISTER_FUNCTION_CONFIG_DEF().Bool("__is_user_function__", true, "is user defined function");
void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) { void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
size_t user_job_size = conf_jobs.size();
std::vector<std::shared_ptr<Job>> jobs(conf_jobs.size()); std::vector<std::shared_ptr<Job>> jobs(conf_jobs.size());
FOR_RANGE(int, i, 0, jobs.size()) { jobs.at(i).reset(new Job(conf_jobs.Get(i))); } FOR_RANGE(int, i, 0, jobs.size()) { jobs.at(i).reset(new Job(conf_jobs.Get(i))); }
std::vector<std::shared_ptr<Job>> function_jobs;
function_jobs.reserve(jobs.size());
FOR_RANGE(int, i, 0, jobs.size()) {
JobDesc job_desc(jobs.at(i)->job_conf(), i);
if (job_desc.Bool("__is_user_function__")) { function_jobs.push_back(jobs.at(i)); }
}
if (Global<MachineCtx>::Get()->IsThisMachineMaster()) { if (Global<MachineCtx>::Get()->IsThisMachineMaster()) {
HashMap<std::string, ParallelBlobConf> push_op_name2parallel_blob_conf; HashMap<std::string, ParallelBlobConf> push_op_name2parallel_blob_conf;
FilterOpName2ParallelBlobConf({OperatorConf::kInputConf}, jobs, FilterOpName2ParallelBlobConf({OperatorConf::kInputConf}, function_jobs,
&push_op_name2parallel_blob_conf); &push_op_name2parallel_blob_conf);
HashMap<std::string, ParallelBlobConf> pull_op_name2parallel_blob_conf; HashMap<std::string, ParallelBlobConf> pull_op_name2parallel_blob_conf;
FilterOpName2ParallelBlobConf({OperatorConf::kReturnConf}, jobs, FilterOpName2ParallelBlobConf({OperatorConf::kReturnConf}, function_jobs,
&pull_op_name2parallel_blob_conf); &pull_op_name2parallel_blob_conf);
HashMap<std::string, ParallelBlobConf> var_op_name2parallel_blob_conf; HashMap<std::string, ParallelBlobConf> var_op_name2parallel_blob_conf;
FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, jobs, FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, function_jobs,
&var_op_name2parallel_blob_conf); &var_op_name2parallel_blob_conf);
for (const auto& pair : push_op_name2parallel_blob_conf) { for (const auto& pair : push_op_name2parallel_blob_conf) {
auto push_job = std::make_shared<Job>(); auto push_job = std::make_shared<Job>();
...@@ -700,7 +707,7 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) { ...@@ -700,7 +707,7 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
pull_job.get()); pull_job.get());
jobs.emplace_back(pull_job); jobs.emplace_back(pull_job);
} }
MakeModelIoJobs(jobs, var_op_name2parallel_blob_conf, MakeModelIoJobs(function_jobs, var_op_name2parallel_blob_conf,
[&](Job* job) { jobs.emplace_back(new Job(*job)); }); [&](Job* job) { jobs.emplace_back(new Job(*job)); });
} }
std::vector<Plan> sub_plans(jobs.size()); std::vector<Plan> sub_plans(jobs.size());
...@@ -711,7 +718,7 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) { ...@@ -711,7 +718,7 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
} }
if (Global<MachineCtx>::Get()->IsThisMachineMaster()) { if (Global<MachineCtx>::Get()->IsThisMachineMaster()) {
MergeSubPlanWithoutGenNetTopo(plan, sub_plans); MergeSubPlanWithoutGenNetTopo(plan, sub_plans);
InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs(jobs, plan, user_job_size); InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs(function_jobs, plan);
InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs(jobs, plan); InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs(jobs, plan);
FinishGlobalCriticalSectionDesc(*plan, jobs.size()); FinishGlobalCriticalSectionDesc(*plan, jobs.size());
Plan main_plan; Plan main_plan;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册