diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index 1135ff555dd0375663ef304e57830425a93f0d01..60645122b75dfdd3d8cae06998ee6f5c81678dc5 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -201,6 +201,7 @@ std::string TaskNode::VisualStr() const { bool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_regsts_.empty(); } void TaskNode::ToProto(TaskProto* task_proto) { + // Step1: process some scalar items. CHECK_NE(chain_id_, -1); task_proto->set_task_type(GetTaskType()); task_proto->set_machine_id(machine_id_); @@ -209,13 +210,19 @@ void TaskNode::ToProto(TaskProto* task_proto) { task_proto->set_job_id(GlobalJobDesc().job_id()); task_proto->mutable_task_set_info()->set_chain_id(chain_id_); task_proto->mutable_task_set_info()->set_order_in_graph(order_in_graph_); + + // Step2: process exec_gph. exec_gph_.ToExecSequence(parallel_ctx(), task_proto->mutable_exec_sequence()); + + // Step3: process produced_regst. auto* produced_regst_proto = task_proto->mutable_produced_regst_desc(); for (auto& pair : produced_regsts_) { RegstDescProto regst_desc_proto; pair.second->ToProto(®st_desc_proto); CHECK(produced_regst_proto->insert({pair.first, regst_desc_proto}).second); } + + // Step4: process consumed_regst. auto* consumed_regst_proto = task_proto->mutable_consumed_regst_desc_id(); for (const auto& pair : consumed_regsts_) { RegstDescIdSet regst_desc_ids; diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index efbe476e8ec3ef2f15acd87ceae73dfc8caa8707..e5639be88b7c6ea596a8a717ce86bbc1f1a9cd47 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -85,15 +85,21 @@ void CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto) { } void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const { - const JobDesc& job_desc = GlobalJobDesc(); + // Step1: ensure job is completed. if (need_job_complete) { JobCompleter().Complete(job); } + + // Step2: new Global and set log configs. Global::New(*job); + const JobDesc& job_desc = GlobalJobDesc(); if (Global::Get()->enable_debug_mode() || Global::Get()->enable_dry_run()) { TeePersistentLogStream::Create(StrCat("optimized_job", job_desc.job_id()))->Write(*job); Global::Get()->ToDotWithFilePath("optimized_dlnet_" + std::to_string(job_desc.job_id()) + "_op_graph.dot"); } + + // Step3: build task_gph. + // TODO(levi): we can rewrite this part of code in visitor pattern. auto task_gph = std::make_unique(); using std::placeholders::_1; task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1)); @@ -102,33 +108,31 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const { task_gph->TopoForEachNode(&TaskNode::Build); task_gph->RemoveEmptyRegsts(); task_gph->MergeChainAndAddOrderingCtrlEdgeInSameChain(); - if (job_desc.enable_inplace()) { - auto IsReachable = Global::Get()->MakePredicatorIsOpNameDataOrCtrlReachable(); - task_gph->EnableInplaceMemSharing(IsReachable); - } + auto IsReachable = Global::Get()->MakePredicatorIsOpNameDataOrCtrlReachable(); + if (job_desc.enable_inplace()) { task_gph->EnableInplaceMemSharing(IsReachable); } task_gph->TopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful); - task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); }); + // Step4: put infomation from task_gph into plan. task_gph->ForEachNode([&](TaskNode* task_node) { if (task_node->IsMeaningLess()) { return; } - const bool use_op_attribute_ref = task_node->GetTaskType() == kNormalForward; TaskProto task_proto; task_node->ToProto(&task_proto); - if (use_op_attribute_ref) { CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto); } + if (task_node->GetTaskType() == kNormalForward) { + CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto); + } plan->mutable_task()->Add(std::move(task_proto)); }); - { - auto* job_id2job_conf = plan->mutable_job_confs()->mutable_job_id2job_conf(); - (*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf(); - } - { - // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl - auto IsReachable = Global::Get()->MakePredicatorIsOpNameDataOrCtrlReachable(); - IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan, IsReachable); - PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan); - PlanUtil::GenMemBlockAndChunk4Plan(plan); - } + // NOTE(levi): release task_gph here to decrise memory peak. + task_gph.reset(); + + // Step5: post-process for plan and delete Global. + auto* job_id2job_conf = plan->mutable_job_confs()->mutable_job_id2job_conf(); + (*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf(); + // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl + IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan, IsReachable); + PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan); + PlanUtil::GenMemBlockAndChunk4Plan(plan); Global::Delete(); } diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp index c244ff1ded4cebf9944044ed3c268496218ab0ae..2d47003ec6f05c4e83d39340e2dbbf56c6c3efe7 100644 --- a/oneflow/core/job/oneflow.cpp +++ b/oneflow/core/job/oneflow.cpp @@ -385,10 +385,11 @@ void MergePlanWithoutGenNetTopo(Plan* plan, Plan&& other) { CHECK( plan->mutable_collective_boxing_plan()->mutable_job_id2request_set()->insert(pair).second); } - for (auto& pair : other.job_id2op_attribute_ref_table()) { - const bool result = - plan->mutable_job_id2op_attribute_ref_table()->insert(std::move(pair)).second; - CHECK(result) << "fail to merge op attribute info for job: " << pair.first; + for (auto& pair : *(other.mutable_job_id2op_attribute_ref_table())) { + CHECK(plan->job_id2op_attribute_ref_table().find(pair.first) + == plan->job_id2op_attribute_ref_table().end()) + << "fail to merge op attribute info for job: " << pair.first; + (*plan->mutable_job_id2op_attribute_ref_table())[pair.first] = std::move(pair.second); } } @@ -1141,6 +1142,7 @@ Maybe CompileJobsAndMergePlans(const PbRpf& job_confs, Plan& plan) { MakePullJob(std::string("System-Pull-") + pair.first, pair.first, pair.second, pull_job.get()); jobs.emplace_back(pull_job); } + std::vector sub_plans(jobs.size()); FOR_RANGE(int64_t, i, 0, jobs.size()) { AddJobName2JobId(jobs.at(i)->job_conf().job_name(), i); diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index d91760cfa32e29bdc17fadcf4c2e6f40943ba132..a05e1874ac2001cbc803d0dc7cb0e69148b57182 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -899,6 +899,7 @@ void Operator::GenKernelConf( if (blob_desc == nullptr) { continue; } (*dtype_signature->mutable_name2dtype())[ibn] = blob_desc->data_type(); } + CHECK_JUST(ToOpAttribute(kernel_conf->mutable_op_attribute())); if (HasBlobDescWithField(GetBlobDesc4BnInOp, output_bns(), [](const BlobDesc* blob_desc) { return blob_desc->is_dynamic(); })) { @@ -912,6 +913,7 @@ void Operator::GenKernelConf( } kernel_conf->set_data_type(data_type); } + if (parallel_ctx != nullptr) { *(kernel_conf->mutable_parallel_ctx()) = *parallel_ctx; } VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf);