未验证 提交 3f728c00 编写于 作者: L levi 提交者: GitHub

Lml/mem optimize (#4725)

* add memory detect info

* small fix in opattrref optimize

* use bitset

* refactor using vector

* refine

* refine

* rename

* refine

* address review

* address review

* refine

* refine

* address review

* smaller BITSET_SIZE

* refine

* refine

* refine

* refine nameing

* refine

* refine

* refine

* update

* delete swp file

* small update

* format fix

* format modify

* format modify

* Update compiler.cpp

fix for comment

* Update reshape_user_op_util.cpp

bug about reshape is fixed
Co-authored-by: Njackalcooper <jackalcooper@gmail.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 4b876e46
......@@ -201,6 +201,7 @@ std::string TaskNode::VisualStr() const {
bool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_regsts_.empty(); }
void TaskNode::ToProto(TaskProto* task_proto) {
// Step1: process some scalar items.
CHECK_NE(chain_id_, -1);
task_proto->set_task_type(GetTaskType());
task_proto->set_machine_id(machine_id_);
......@@ -209,13 +210,19 @@ void TaskNode::ToProto(TaskProto* task_proto) {
task_proto->set_job_id(GlobalJobDesc().job_id());
task_proto->mutable_task_set_info()->set_chain_id(chain_id_);
task_proto->mutable_task_set_info()->set_order_in_graph(order_in_graph_);
// Step2: process exec_gph.
exec_gph_.ToExecSequence(parallel_ctx(), task_proto->mutable_exec_sequence());
// Step3: process produced_regst.
auto* produced_regst_proto = task_proto->mutable_produced_regst_desc();
for (auto& pair : produced_regsts_) {
RegstDescProto regst_desc_proto;
pair.second->ToProto(&regst_desc_proto);
CHECK(produced_regst_proto->insert({pair.first, regst_desc_proto}).second);
}
// Step4: process consumed_regst.
auto* consumed_regst_proto = task_proto->mutable_consumed_regst_desc_id();
for (const auto& pair : consumed_regsts_) {
RegstDescIdSet regst_desc_ids;
......
......@@ -85,15 +85,21 @@ void CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto) {
}
void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
const JobDesc& job_desc = GlobalJobDesc();
// Step1: ensure job is completed.
if (need_job_complete) { JobCompleter().Complete(job); }
// Step2: new Global<OpGraph> and set log configs.
Global<OpGraph>::New(*job);
const JobDesc& job_desc = GlobalJobDesc();
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()
|| Global<ResourceDesc, ForSession>::Get()->enable_dry_run()) {
TeePersistentLogStream::Create(StrCat("optimized_job", job_desc.job_id()))->Write(*job);
Global<OpGraph>::Get()->ToDotWithFilePath("optimized_dlnet_" + std::to_string(job_desc.job_id())
+ "_op_graph.dot");
}
// Step3: build task_gph.
// TODO(levi): we can rewrite this part of code in visitor pattern.
auto task_gph = std::make_unique<TaskGraph>();
using std::placeholders::_1;
task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1));
......@@ -102,33 +108,31 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
task_gph->TopoForEachNode(&TaskNode::Build);
task_gph->RemoveEmptyRegsts();
task_gph->MergeChainAndAddOrderingCtrlEdgeInSameChain();
if (job_desc.enable_inplace()) {
auto IsReachable = Global<OpGraph>::Get()->MakePredicatorIsOpNameDataOrCtrlReachable();
task_gph->EnableInplaceMemSharing(IsReachable);
}
auto IsReachable = Global<OpGraph>::Get()->MakePredicatorIsOpNameDataOrCtrlReachable();
if (job_desc.enable_inplace()) { task_gph->EnableInplaceMemSharing(IsReachable); }
task_gph->TopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful);
task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); });
// Step4: put infomation from task_gph into plan.
task_gph->ForEachNode([&](TaskNode* task_node) {
if (task_node->IsMeaningLess()) { return; }
const bool use_op_attribute_ref = task_node->GetTaskType() == kNormalForward;
TaskProto task_proto;
task_node->ToProto(&task_proto);
if (use_op_attribute_ref) { CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto); }
if (task_node->GetTaskType() == kNormalForward) {
CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto);
}
plan->mutable_task()->Add(std::move(task_proto));
});
{
auto* job_id2job_conf = plan->mutable_job_confs()->mutable_job_id2job_conf();
(*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf();
}
{
// NOTE(chengcheng): infer mem blob id & set inplace & add ctrl
auto IsReachable = Global<OpGraph>::Get()->MakePredicatorIsOpNameDataOrCtrlReachable();
IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan, IsReachable);
PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan);
PlanUtil::GenMemBlockAndChunk4Plan(plan);
}
// NOTE(levi): release task_gph here to decrise memory peak.
task_gph.reset();
// Step5: post-process for plan and delete Global<OpGraph>.
auto* job_id2job_conf = plan->mutable_job_confs()->mutable_job_id2job_conf();
(*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf();
// NOTE(chengcheng): infer mem blob id & set inplace & add ctrl
IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan, IsReachable);
PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan);
PlanUtil::GenMemBlockAndChunk4Plan(plan);
Global<OpGraph>::Delete();
}
......
......@@ -385,10 +385,11 @@ void MergePlanWithoutGenNetTopo(Plan* plan, Plan&& other) {
CHECK(
plan->mutable_collective_boxing_plan()->mutable_job_id2request_set()->insert(pair).second);
}
for (auto& pair : other.job_id2op_attribute_ref_table()) {
const bool result =
plan->mutable_job_id2op_attribute_ref_table()->insert(std::move(pair)).second;
CHECK(result) << "fail to merge op attribute info for job: " << pair.first;
for (auto& pair : *(other.mutable_job_id2op_attribute_ref_table())) {
CHECK(plan->job_id2op_attribute_ref_table().find(pair.first)
== plan->job_id2op_attribute_ref_table().end())
<< "fail to merge op attribute info for job: " << pair.first;
(*plan->mutable_job_id2op_attribute_ref_table())[pair.first] = std::move(pair.second);
}
}
......@@ -1141,6 +1142,7 @@ Maybe<void> CompileJobsAndMergePlans(const PbRpf<Job>& job_confs, Plan& plan) {
MakePullJob(std::string("System-Pull-") + pair.first, pair.first, pair.second, pull_job.get());
jobs.emplace_back(pull_job);
}
std::vector<Plan> sub_plans(jobs.size());
FOR_RANGE(int64_t, i, 0, jobs.size()) {
AddJobName2JobId(jobs.at(i)->job_conf().job_name(), i);
......
......@@ -899,6 +899,7 @@ void Operator::GenKernelConf(
if (blob_desc == nullptr) { continue; }
(*dtype_signature->mutable_name2dtype())[ibn] = blob_desc->data_type();
}
CHECK_JUST(ToOpAttribute(kernel_conf->mutable_op_attribute()));
if (HasBlobDescWithField(GetBlobDesc4BnInOp, output_bns(),
[](const BlobDesc* blob_desc) { return blob_desc->is_dynamic(); })) {
......@@ -912,6 +913,7 @@ void Operator::GenKernelConf(
}
kernel_conf->set_data_type(data_type);
}
if (parallel_ctx != nullptr) { *(kernel_conf->mutable_parallel_ctx()) = *parallel_ctx; }
VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册