提交 ba8cfc3a 编写于 作者: L lixinqi

refine signature of InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs

上级 eb320970
......@@ -337,9 +337,8 @@ void InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs(
}
void InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs(
const std::vector<std::shared_ptr<Job>>& jobs, Plan* plan, int64_t user_job_size) {
if (user_job_size == 1) { return; }
std::vector<std::shared_ptr<Job>> user_jobs(jobs.begin(), jobs.begin() + user_job_size);
const std::vector<std::shared_ptr<Job>>& user_jobs, Plan* plan) {
if (user_jobs.size() == 1) { return; }
std::vector<HashSet<int64_t>> reuse_mem_job_groups = GetMutualExclusionJobGroups(user_jobs);
HashMap<int64_t, ChunkProto> chunk_id2chunk;
......
......@@ -10,8 +10,8 @@ struct InterJobMemSharingUtil {
static void MergeMemSharedInterfaceMemBlockBetweenJobs(
const std::vector<std::shared_ptr<Job>>& jobs, Plan* plan);
static void MergeMemReusedChunkBetweenUserJobs(const std::vector<std::shared_ptr<Job>>& jobs,
Plan* plan, int64_t user_job_size);
static void MergeMemReusedChunkBetweenUserJobs(const std::vector<std::shared_ptr<Job>>& user_jobs,
Plan* plan);
};
} // namespace oneflow
......
......@@ -41,14 +41,12 @@ Maybe<void> JobBuildAndInferCtx::SetJobConf(const JobConfigProto& job_conf) {
return Maybe<void>::Ok();
}
REGISTER_FUNCTION_CONFIG_DEF().Bool("is_user_function", true, "is user defined function");
Maybe<void> JobBuildAndInferCtx::Complete() {
CHECK_NOTNULL(Global<JobDesc>::Get());
Global<JobDesc>::Delete();
auto scope = std::make_unique<GlobalJobDescScope>(job_->job_conf(), job_id_);
auto DoPass = [&](const std::string& pass_name) { FunctionPass(pass_name)(job_); };
if (GlobalJobDesc().Bool("is_user_function")) {
if (GlobalJobDesc().Bool("__is_user_function__")) {
DoPass("CompleteOfrecordDecoder");
DoPass("SetDefaultVariableConf");
DoPass("AutoMixedPrecision");
......
......@@ -674,19 +674,26 @@ void MakePushJob(const std::string& job_name, const std::string& op_name,
job_conf->set_default_data_type(data_type);
}
REGISTER_FUNCTION_CONFIG_DEF().Bool("__is_user_function__", true, "is user defined function");
void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
size_t user_job_size = conf_jobs.size();
std::vector<std::shared_ptr<Job>> jobs(conf_jobs.size());
FOR_RANGE(int, i, 0, jobs.size()) { jobs.at(i).reset(new Job(conf_jobs.Get(i))); }
std::vector<std::shared_ptr<Job>> function_jobs;
function_jobs.reserve(jobs.size());
FOR_RANGE(int, i, 0, jobs.size()) {
JobDesc job_desc(jobs.at(i)->job_conf(), i);
if (job_desc.Bool("__is_user_function__")) { function_jobs.push_back(jobs.at(i)); }
}
if (Global<MachineCtx>::Get()->IsThisMachineMaster()) {
HashMap<std::string, ParallelBlobConf> push_op_name2parallel_blob_conf;
FilterOpName2ParallelBlobConf({OperatorConf::kInputConf}, jobs,
FilterOpName2ParallelBlobConf({OperatorConf::kInputConf}, function_jobs,
&push_op_name2parallel_blob_conf);
HashMap<std::string, ParallelBlobConf> pull_op_name2parallel_blob_conf;
FilterOpName2ParallelBlobConf({OperatorConf::kReturnConf}, jobs,
FilterOpName2ParallelBlobConf({OperatorConf::kReturnConf}, function_jobs,
&pull_op_name2parallel_blob_conf);
HashMap<std::string, ParallelBlobConf> var_op_name2parallel_blob_conf;
FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, jobs,
FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, function_jobs,
&var_op_name2parallel_blob_conf);
for (const auto& pair : push_op_name2parallel_blob_conf) {
auto push_job = std::make_shared<Job>();
......@@ -700,7 +707,7 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
pull_job.get());
jobs.emplace_back(pull_job);
}
MakeModelIoJobs(jobs, var_op_name2parallel_blob_conf,
MakeModelIoJobs(function_jobs, var_op_name2parallel_blob_conf,
[&](Job* job) { jobs.emplace_back(new Job(*job)); });
}
std::vector<Plan> sub_plans(jobs.size());
......@@ -711,7 +718,7 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
}
if (Global<MachineCtx>::Get()->IsThisMachineMaster()) {
MergeSubPlanWithoutGenNetTopo(plan, sub_plans);
InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs(jobs, plan, user_job_size);
InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs(function_jobs, plan);
InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs(jobs, plan);
FinishGlobalCriticalSectionDesc(*plan, jobs.size());
Plan main_plan;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册