From fb138d8af9c22514da242326d06a08a1b41a5df5 Mon Sep 17 00:00:00 2001 From: lixinqi Date: Wed, 8 Jan 2020 18:24:25 +0800 Subject: [PATCH] add pass DumpTimeShapeAndBlobParallelConfPass --- oneflow/core/graph/op_graph.cpp | 21 ++++++------- oneflow/core/graph/op_graph.h | 8 ++--- oneflow/core/job/job_build_and_infer_ctx.cpp | 30 +++++++++++-------- oneflow/core/job/oneflow.cpp | 16 +++++----- ...time_shape_and_blob_parallel_conf_pass.cpp | 30 +++++++++++++++++++ oneflow/core/job_completer/job_completer.cpp | 16 ++-------- 6 files changed, 74 insertions(+), 47 deletions(-) create mode 100644 oneflow/core/job_completer/dump_time_shape_and_blob_parallel_conf_pass.cpp diff --git a/oneflow/core/graph/op_graph.cpp b/oneflow/core/graph/op_graph.cpp index c6274f8b98..94be8387f2 100644 --- a/oneflow/core/graph/op_graph.cpp +++ b/oneflow/core/graph/op_graph.cpp @@ -746,8 +746,8 @@ std::list OpGraph::DataOrCtrlSourceNodes() const { return ret; } -void OpGraph::DumpLogicalBlobDesc(JobBuilder* job_builder) const { - auto* helper = job_builder->mutable_helper(); +void OpGraph::DumpLogicalBlobDesc(Job* job) const { + auto* helper = job->mutable_helper(); ForEachNode([&](const OpNode* node) { for (const auto& obn : node->op().output_bns()) { const auto& lbi = node->op().BnInOp2Lbi(obn); @@ -757,17 +757,17 @@ void OpGraph::DumpLogicalBlobDesc(JobBuilder* job_builder) const { }); } -void OpGraph::DumpSbpSignature(JobBuilder* job_builder) const { +void OpGraph::DumpSbpSignature(Job* job) const { ForEachNode([&](const OpNode* node) { - (*job_builder->mutable_sbp_conf()->mutable_op_name2sbp_signature_conf())[node->op().op_name()] = + (*job->mutable_sbp_conf()->mutable_op_name2sbp_signature_conf())[node->op().op_name()] = node->sbp_signature(); }); } -void OpGraph::DumpOpTimeShape(JobBuilder* job_builder) const { +void OpGraph::DumpOpTimeShape(Job* job) const { ForEachNode([&](OpNode* op_node) { auto* op_time_shape = - &(*job_builder->mutable_helper()->mutable_op_name2op_time_shape())[op_node->op().op_name()]; + &(*job->mutable_helper()->mutable_op_name2op_time_shape())[op_node->op().op_name()]; if (op_node->out_blob_time_shape() != nullptr) { op_node->out_blob_time_shape()->ToProto(op_time_shape->mutable_out_blob_time_shape()); } @@ -778,14 +778,15 @@ void OpGraph::DumpOpTimeShape(JobBuilder* job_builder) const { }); } -void OpGraph::DumpBatchAxisLbi(JobBuilder* job_builder) const { - auto* lbn2batch_axis = job_builder->mutable_helper()->mutable_lbn2batch_axis(); +void OpGraph::DumpBatchAxisLbi(Job* job) const { + auto* lbn2batch_axis = job->mutable_helper()->mutable_lbn2batch_axis(); ForEachNode([&](OpNode* op_node) { for (const auto& obn : op_node->op().output_bns()) { const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(obn); const auto& lbn = GenLogicalBlobName(lbi); - const auto& pair = PbMapPair(lbn, op_node->BatchAxis4Lbi(lbi)); - CHECK(lbn2batch_axis->insert(pair).second); + const auto& batch_axis = op_node->BatchAxis4Lbi(lbi); + const auto& pair = PbMapPair(lbn, batch_axis); + CHECK(lbn2batch_axis->insert(pair).first->second == batch_axis); } }); } diff --git a/oneflow/core/graph/op_graph.h b/oneflow/core/graph/op_graph.h index 9eaff2f91f..08409ba7a8 100644 --- a/oneflow/core/graph/op_graph.h +++ b/oneflow/core/graph/op_graph.h @@ -149,10 +149,10 @@ class OpGraph final : public Graph { void ForEachDataAndCtrlInNode(OpNode* node, const std::function& Handler) const; void ForEachDataAndCtrlOutNode(OpNode* node, const std::function& Handler) const; - void DumpLogicalBlobDesc(JobBuilder* job_builder) const; - void DumpSbpSignature(JobBuilder* job_builder) const; - void DumpOpTimeShape(JobBuilder* job_builder) const; - void DumpBatchAxisLbi(JobBuilder* job_builder) const; + void DumpLogicalBlobDesc(Job* job) const; + void DumpSbpSignature(Job* job) const; + void DumpOpTimeShape(Job* job) const; + void DumpBatchAxisLbi(Job* job) const; private: void Init(const Job& job); diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index 8d05e56d01..f177e60c2a 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -1,6 +1,7 @@ #include "oneflow/core/job/job_build_and_infer_ctx.h" #include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/framework/user_op_conf.h" +#include "oneflow/core/framework/config_def.h" #include "oneflow/core/common/protobuf.h" namespace oneflow { @@ -40,23 +41,28 @@ Maybe JobBuildAndInferCtx::SetJobConf(const JobConfigProto& job_conf) { return Maybe::Ok(); } +REGISTER_FUNCTION_CONFIG_DEF().Bool("is_user_function", true, "is user defined function"); + Maybe JobBuildAndInferCtx::Complete() { CHECK_NOTNULL(Global::Get()); Global::Delete(); auto scope = std::make_unique(job_->job_conf(), job_id_); auto DoPass = [&](const std::string& pass_name) { FunctionPass(pass_name)(job_); }; - DoPass("CompleteOfrecordDecoder"); - DoPass("SetDefaultVariableConf"); - DoPass("AutoMixedPrecision"); - DoPass("TieUpChainHeadersUnReachableFromAnyVariableOps"); - DoPass("NonDistributedOptimizerPass"); - DoPass("AutoTrainStep"); - DoPass("AutoLearningRate"); - DoPass("GenerateBackwardAndOptimizerOpConfs"); - DoPass("SequentializeNcclTupleBroadcastReducePass"); - DoPass("AddAllReduceGroupPass"); - DoPass("AddLbiDiffWatcherOpConfs"); - DoPass("SequentializeAllReduceGroupPass"); + if (GlobalJobDesc().Bool("is_user_function")) { + DoPass("CompleteOfrecordDecoder"); + DoPass("SetDefaultVariableConf"); + DoPass("AutoMixedPrecision"); + DoPass("TieUpChainHeadersUnReachableFromAnyVariableOps"); + DoPass("NonDistributedOptimizerPass"); + DoPass("AutoTrainStep"); + DoPass("AutoLearningRate"); + DoPass("GenerateBackwardAndOptimizerOpConfs"); + DoPass("SequentializeNcclTupleBroadcastReducePass"); + DoPass("AddAllReduceGroupPass"); + DoPass("AddLbiDiffWatcherOpConfs"); + DoPass("SequentializeAllReduceGroupPass"); + } + DoPass("DumpTimeShapeAndBlobParallelConfPass"); return Maybe::Ok(); } diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp index c0e0ec9edb..fd601cf705 100644 --- a/oneflow/core/job/oneflow.cpp +++ b/oneflow/core/job/oneflow.cpp @@ -678,16 +678,17 @@ void MakePushJob(const std::string& job_name, const std::string& op_name, void CompileAndMergePlanOnMaster(const PbRpf& conf_jobs, Plan* plan) { std::vector jobs(conf_jobs.size()); std::vector sub_plans(conf_jobs.size()); - FOR_RANGE(int64_t, job_id, 0, sub_plans.size()) { - jobs.at(job_id) = conf_jobs.Get(job_id); - AddJobName2JobId(jobs.at(job_id).job_conf().job_name(), job_id); + size_t user_job_size = jobs.size(); + int64_t job_id = -1; + FOR_RANGE(int64_t, i, 0, sub_plans.size()) { + jobs.at(i) = conf_jobs.Get(i); + AddJobName2JobId(jobs.at(i).job_conf().job_name(), i); { - auto scope = std::make_unique(jobs.at(job_id).job_conf(), job_id); - CompileCurJobOnMaster(&jobs.at(job_id), &sub_plans.at(job_id), true); + auto scope = std::make_unique(jobs.at(i).job_conf(), i); + CompileCurJobOnMaster(&jobs.at(i), &sub_plans.at(i), true); } } if (Global::Get()->IsThisMachineMaster()) { - size_t user_job_size = jobs.size(); HashMap push_op_name2parallel_blob_conf; FilterOpName2ParallelBlobConf({OperatorConf::kInputConf}, &jobs, &push_op_name2parallel_blob_conf); @@ -697,7 +698,6 @@ void CompileAndMergePlanOnMaster(const PbRpf& conf_jobs, Plan* plan) { HashMap var_op_name2parallel_blob_conf; FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, &jobs, &var_op_name2parallel_blob_conf); - int64_t job_id = -1; { size_t helper_job_size = push_op_name2parallel_blob_conf.size() + pull_op_name2parallel_blob_conf.size(); @@ -727,6 +727,8 @@ void CompileAndMergePlanOnMaster(const PbRpf& conf_jobs, Plan* plan) { CompileHelperJob(&pull_job); } MakeModelIoJobs(jobs, var_op_name2parallel_blob_conf, [&](Job* job) { CompileHelperJob(job); }); + } + if (Global::Get()->IsThisMachineMaster()) { MergeSubPlanWithoutGenNetTopo(plan, sub_plans); InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs(jobs, plan, user_job_size); InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs(jobs, plan); diff --git a/oneflow/core/job_completer/dump_time_shape_and_blob_parallel_conf_pass.cpp b/oneflow/core/job_completer/dump_time_shape_and_blob_parallel_conf_pass.cpp new file mode 100644 index 0000000000..a8c905b376 --- /dev/null +++ b/oneflow/core/job_completer/dump_time_shape_and_blob_parallel_conf_pass.cpp @@ -0,0 +1,30 @@ +#include "oneflow/core/common/util.h" +#include "oneflow/core/job_completer/op_graph_pass.h" +#include "oneflow/core/job/job.pb.h" + +namespace oneflow { + +namespace { + +class DumpTimeShapeAndBlobParallelConfPass final : public OpGraphPass { + public: + OF_DISALLOW_COPY_AND_MOVE(DumpTimeShapeAndBlobParallelConfPass); + DumpTimeShapeAndBlobParallelConfPass() = default; + ~DumpTimeShapeAndBlobParallelConfPass() override = default; + + bool IsEnabled() const override { return true; } + + void Apply(const OpGraph& op_graph, Job* job) const override { + op_graph.DumpOpTimeShape(job); + op_graph.DumpBatchAxisLbi(job); + op_graph.DumpLogicalBlobDesc(job); + op_graph.DumpSbpSignature(job); + } +}; + +REGISTER_FUNCTION_PASS("DumpTimeShapeAndBlobParallelConfPass", + DumpTimeShapeAndBlobParallelConfPass); + +} // namespace + +} // namespace oneflow diff --git a/oneflow/core/job_completer/job_completer.cpp b/oneflow/core/job_completer/job_completer.cpp index a217644a3f..d87cd96217 100644 --- a/oneflow/core/job_completer/job_completer.cpp +++ b/oneflow/core/job_completer/job_completer.cpp @@ -73,20 +73,10 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder } } -void SetOpTimeShape7BatchAxisLbis(const OpGraph& op_graph, JobBuilder* job_builder) { - op_graph.DumpOpTimeShape(job_builder); - op_graph.DumpBatchAxisLbi(job_builder); -} - -void DumpLogicalBlobDescAndSbpSignature(const OpGraph& op_graph, JobBuilder* job_builder) { - op_graph.DumpLogicalBlobDesc(job_builder); - op_graph.DumpSbpSignature(job_builder); -} - } // namespace void JobCompleter::Complete(Job* job) const { - WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature); + FunctionPass("DumpTimeShapeAndBlobParallelConfPass")(job); WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel); WithOpGraphAndMutJobBuilder(job, &AddKeepHeaderOnlyOp); WithOpGraphAndMutJobBuilder(job, &SetCtrlInOpName4VariableOp); @@ -97,9 +87,7 @@ void JobCompleter::Complete(Job* job) const { AddGlobalTotalJobCriticalSection(*job); WithOpGraphAndMutJobBuilder(job, &AddGlobalInputCriticalSections); WithOpGraphAndMutJobBuilder(job, &AddGlobalOutputCriticalSections); - WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature); - WithOpGraphAndMutJobBuilder(job, &SetOpTimeShape7BatchAxisLbis); - + FunctionPass("DumpTimeShapeAndBlobParallelConfPass")(job); if (XrtCompilationEnabled(GlobalJobDesc())) { #ifdef OF_WITH_XRT WithOpGraphAndMutJob(job, &RebuildXrtCompiledJob); -- GitLab