From bc5b5e0fef2293ad6ce77a94b5384b583295f7b9 Mon Sep 17 00:00:00 2001 From: lixinqi Date: Tue, 7 Jan 2020 21:54:26 +0800 Subject: [PATCH] refactor JobCompleter::Complete --- ...pass.cpp => add_all_reduce_group_pass.cpp} | 19 ++- .../job_completer/add_lbi_diff_watcher.cpp | 16 ++- .../core/job_completer/add_lbi_diff_watcher.h | 11 -- .../core/job_completer/all_reduce_add_pass.h | 23 ---- .../job_completer/all_reduce_sequence_pass.h | 21 --- .../core/job_completer/auto_learning_rate.cpp | 22 +++- .../core/job_completer/auto_learning_rate.h | 13 -- .../core/job_completer/auto_train_step.cpp | 21 ++- oneflow/core/job_completer/auto_train_step.h | 13 -- ...nerate_backward_and_optimizer_op_confs.cpp | 98 ++++++++++++++ oneflow/core/job_completer/job_completer.cpp | 124 ++---------------- ...ccl_tuple_broadcast_reduce_sequence_pass.h | 22 ---- .../non_distributed_optimizer_pass.cpp | 4 +- oneflow/core/job_completer/op_graph_pass.h | 3 + ...> sequentialize_all_reduce_group_pass.cpp} | 19 ++- ...lize_nccl_tuple_broadcast_reduce_pass.cpp} | 20 ++- .../job_completer/tie_up_chain_headers.cpp | 4 +- 17 files changed, 219 insertions(+), 234 deletions(-) rename oneflow/core/job_completer/{all_reduce_add_pass.cpp => add_all_reduce_group_pass.cpp} (95%) delete mode 100644 oneflow/core/job_completer/add_lbi_diff_watcher.h delete mode 100644 oneflow/core/job_completer/all_reduce_add_pass.h delete mode 100644 oneflow/core/job_completer/all_reduce_sequence_pass.h delete mode 100644 oneflow/core/job_completer/auto_learning_rate.h delete mode 100644 oneflow/core/job_completer/auto_train_step.h create mode 100644 oneflow/core/job_completer/generate_backward_and_optimizer_op_confs.cpp delete mode 100644 oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h rename oneflow/core/job_completer/{all_reduce_sequence_pass.cpp => sequentialize_all_reduce_group_pass.cpp} (79%) rename oneflow/core/job_completer/{nccl_tuple_broadcast_reduce_sequence_pass.cpp => sequentialize_nccl_tuple_broadcast_reduce_pass.cpp} (67%) diff --git a/oneflow/core/job_completer/all_reduce_add_pass.cpp b/oneflow/core/job_completer/add_all_reduce_group_pass.cpp similarity index 95% rename from oneflow/core/job_completer/all_reduce_add_pass.cpp rename to oneflow/core/job_completer/add_all_reduce_group_pass.cpp index 2cf8012a38..77ab6bbedf 100644 --- a/oneflow/core/job_completer/all_reduce_add_pass.cpp +++ b/oneflow/core/job_completer/add_all_reduce_group_pass.cpp @@ -1,4 +1,4 @@ -#include "oneflow/core/job_completer/all_reduce_add_pass.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/register/runtime_blob_desc.h" namespace oneflow { @@ -240,9 +240,18 @@ void BuildAllReduceStruct( all_reduced_lbi, GetLastTouchedOpName); } -} // namespace +class AddAllReduceGroupPass final : public OpGraphPass { + public: + AddAllReduceGroupPass() = default; + ~AddAllReduceGroupPass() = default; + bool IsEnabled() const override { + return GlobalJobDesc().IsTrain() && !GlobalJobDesc().enable_non_distributed_optimizer() + && GlobalJobDesc().enable_all_reduce_group(); + } + void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; +}; -void AllReduceAddPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { +void AddAllReduceGroupPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { auto ProducerOpNode4Lbi = MakeGetterProducerOpNode4Lbi(op_graph); std::vector lbis; FindAllReducedLbis(job_builder->job(), op_graph, ProducerOpNode4Lbi, &lbis); @@ -286,4 +295,8 @@ void AllReduceAddPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) c }); } +} // namespace + +REGISTER_FUNCTION_PASS("AddAllReduceGroupPass", AddAllReduceGroupPass); + } // namespace oneflow diff --git a/oneflow/core/job_completer/add_lbi_diff_watcher.cpp b/oneflow/core/job_completer/add_lbi_diff_watcher.cpp index 62d8a3afeb..71a8bd9e92 100644 --- a/oneflow/core/job_completer/add_lbi_diff_watcher.cpp +++ b/oneflow/core/job_completer/add_lbi_diff_watcher.cpp @@ -1,10 +1,18 @@ -#include "oneflow/core/job_completer/add_lbi_diff_watcher.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job/lbi_diff_watcher_info.pb.h" #include "oneflow/core/operator/operator.h" namespace oneflow { -void AddLbiDiffWatcherOpConfs(Job* job) { +namespace { + +class AddLbiDiffWatcherOpConfs final : public OpGraphPass { + public: + bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); } + void Apply(Job* job) const override; +}; + +void AddLbiDiffWatcherOpConfs::Apply(Job* job) const { JobBuilder job_builder(job); const auto& map = Global::Get()->job_name2lbi_and_watcher_uuids(); if (map.find(GlobalJobDesc().job_name()) == map.end()) { return; } @@ -27,4 +35,8 @@ void AddLbiDiffWatcherOpConfs(Job* job) { } } +REGISTER_FUNCTION_PASS("AddLbiDiffWatcherOpConfs", AddLbiDiffWatcherOpConfs); + +} // namespace + } // namespace oneflow diff --git a/oneflow/core/job_completer/add_lbi_diff_watcher.h b/oneflow/core/job_completer/add_lbi_diff_watcher.h deleted file mode 100644 index 8c58191596..0000000000 --- a/oneflow/core/job_completer/add_lbi_diff_watcher.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_ADD_LBI_DIFF_WATCHER_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_ADD_LBI_DIFF_WATCHER_H_ - -#include "oneflow/core/job/job_builder.h" - -namespace oneflow { - -void AddLbiDiffWatcherOpConfs(Job* job); -} - -#endif // ONEFLOW_CORE_JOB_COMPLETER_ADD_LBI_DIFF_WATCHER_H_ diff --git a/oneflow/core/job_completer/all_reduce_add_pass.h b/oneflow/core/job_completer/all_reduce_add_pass.h deleted file mode 100644 index f908813298..0000000000 --- a/oneflow/core/job_completer/all_reduce_add_pass.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_ADD_PASS_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_ADD_PASS_H_ - -#include "oneflow/core/job_completer/op_graph_pass.h" - -namespace oneflow { - -class OpGraph; - -class AllReduceAddPass final : public OpGraphPass { - public: - AllReduceAddPass() = default; - ~AllReduceAddPass() = default; - bool IsEnabled() const override { - return !GlobalJobDesc().enable_non_distributed_optimizer() - && GlobalJobDesc().enable_all_reduce_group(); - } - void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_ADD_PASS_H_ diff --git a/oneflow/core/job_completer/all_reduce_sequence_pass.h b/oneflow/core/job_completer/all_reduce_sequence_pass.h deleted file mode 100644 index 3e81fe75cf..0000000000 --- a/oneflow/core/job_completer/all_reduce_sequence_pass.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_SEQUENCE_PASS_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_SEQUENCE_PASS_H_ - -#include "oneflow/core/job/job.pb.h" -#include "oneflow/core/job_completer/op_graph_pass.h" - -namespace oneflow { - -class OpGraph; - -class AllReduceSequencePass final : public OpGraphPass { - public: - AllReduceSequencePass() = default; - ~AllReduceSequencePass() = default; - bool IsEnabled() const override { return !GlobalJobDesc().disable_all_reduce_sequence(); } - void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_SEQUENCE_PASS_H_ diff --git a/oneflow/core/job_completer/auto_learning_rate.cpp b/oneflow/core/job_completer/auto_learning_rate.cpp index df6b824d17..6ec6694b19 100644 --- a/oneflow/core/job_completer/auto_learning_rate.cpp +++ b/oneflow/core/job_completer/auto_learning_rate.cpp @@ -1,9 +1,23 @@ -#include "oneflow/core/graph/op_graph.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job/job.pb.h" namespace oneflow { -void AutoLearningRate(const OpGraph& op_graph, Job* job) { +namespace { + +class AutoLearningRate final : public OpGraphPass { + public: + OF_DISALLOW_COPY_AND_MOVE(AutoLearningRate); + AutoLearningRate() = default; + ~AutoLearningRate() override = default; + + bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); } + + void Apply(const OpGraph& op_graph, Job* job) const override; +}; + +void AutoLearningRate::Apply(const OpGraph& op_graph, Job* job) const { JobBuilder job_builder(job); const TrainConf& train_conf = job->job_conf().train_conf(); auto AddScheduleOp = [&](const std::string& op_name, const float learning_rate) -> std::string { @@ -58,4 +72,8 @@ void AutoLearningRate(const OpGraph& op_graph, Job* job) { } } +REGISTER_FUNCTION_PASS("AutoLearningRate", AutoLearningRate); + +} // namespace + } // namespace oneflow diff --git a/oneflow/core/job_completer/auto_learning_rate.h b/oneflow/core/job_completer/auto_learning_rate.h deleted file mode 100644 index 5ffdc5f729..0000000000 --- a/oneflow/core/job_completer/auto_learning_rate.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_ - -namespace oneflow { - -class OpGraph; -class Job; - -void AutoLearningRate(const OpGraph& op_graph, Job* job); - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_ diff --git a/oneflow/core/job_completer/auto_train_step.cpp b/oneflow/core/job_completer/auto_train_step.cpp index a3c01850ee..5af8d589f6 100644 --- a/oneflow/core/job_completer/auto_train_step.cpp +++ b/oneflow/core/job_completer/auto_train_step.cpp @@ -1,9 +1,22 @@ -#include "oneflow/core/graph/op_graph.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job/job.pb.h" namespace oneflow { -void AutoTrainStep(const OpGraph& op_graph, Job* job) { +namespace { + +class AutoTrainStep final : public OpGraphPass { + public: + OF_DISALLOW_COPY_AND_MOVE(AutoTrainStep); + AutoTrainStep() = default; + ~AutoTrainStep() override = default; + + bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); } + + void Apply(const OpGraph& op_graph, Job* job) const override; +}; + +void AutoTrainStep::Apply(const OpGraph& op_graph, Job* job) const { if (job->job_conf().train_conf().has_train_step_lbn()) { return; } OperatorConf variable_op_conf{}; const std::string train_step_name = "System-Train-TrainStep-" + job->job_conf().job_name(); @@ -42,4 +55,8 @@ void AutoTrainStep(const OpGraph& op_graph, Job* job) { job->mutable_job_conf()->mutable_train_conf()->set_train_step_lbn(train_step_lbn); } +REGISTER_FUNCTION_PASS("AutoTrainStep", AutoTrainStep); + +} // namespace + } // namespace oneflow diff --git a/oneflow/core/job_completer/auto_train_step.h b/oneflow/core/job_completer/auto_train_step.h deleted file mode 100644 index 7d16dc0d8c..0000000000 --- a/oneflow/core/job_completer/auto_train_step.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_TRAIN_STEP_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_TRAIN_STEP_H_ - -namespace oneflow { - -class OpGraph; -class Job; - -void AutoTrainStep(const OpGraph& op_graph, Job* job); - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_TRAIN_STEP_H_ diff --git a/oneflow/core/job_completer/generate_backward_and_optimizer_op_confs.cpp b/oneflow/core/job_completer/generate_backward_and_optimizer_op_confs.cpp new file mode 100644 index 0000000000..fafc479d45 --- /dev/null +++ b/oneflow/core/job_completer/generate_backward_and_optimizer_op_confs.cpp @@ -0,0 +1,98 @@ +#include "oneflow/core/job_completer/op_graph_pass.h" +#include "oneflow/core/job_completer/autograd.h" +#include "oneflow/core/job_completer/optimizer.h" + +namespace oneflow { + +namespace { + +void UpdateJobHelperConfProducedLbi2ConsumedDiffLbi( + const HashMap& lbi2diff_lbi, JobBuilder* job_builder) { + auto& mut_pairs = + (*job_builder->mutable_helper()->mutable_tag2lbi_relations())[kProducedLbi2ConsumedDiffLbi]; + for (const auto& pair : lbi2diff_lbi) { + auto* mut_pair = mut_pairs.add_pair(); + *mut_pair->mutable_first() = pair.first; + *mut_pair->mutable_second() = pair.second; + } +} + +void BindIdenticalSbpObaPairsBetweenIbns(const OpNode& op_node, JobBuilder* job_builder) { + HashMap> in_lbi2obas; + for (const std::string& ibn : op_node.op().input_bns()) { + in_lbi2obas[op_node.op().BnInOp2Lbi(ibn)].push_back(GenOpBlobArg(op_node.op().op_name(), ibn)); + } + for (const auto& pair : in_lbi2obas) { + if (pair.second.size() > 1) { + FOR_RANGE(int32_t, i, 1, pair.second.size()) { + job_builder->BindIdenticalSbpOpBlobArgPair(pair.second.at(0), pair.second.at(i)); + } + } + } +} + +void SetSbpSignatureHintByIdenticalSbpObaPairs(const OpGraph& op_graph, JobBuilder* job_builder) { + HashMap oba2sbp_parallel; + op_graph.ForEachNode([&](OpNode* op_node) { + auto ForEachBn = [&](const std::function& Handler) { + for (const auto& ibn : op_node->op().input_bns()) { Handler(ibn); } + for (const auto& obn : op_node->op().output_bns()) { Handler(obn); } + }; + ForEachBn([&](const std::string& bn_in_op) { + const auto& oba = GenOpBlobArg(op_node->op().op_name(), bn_in_op); + oba2sbp_parallel[oba] = &op_node->SbpParallel4Lbi(op_node->op().BnInOp2Lbi(bn_in_op)); + }); + }); + auto HasSbpParallel = [&](const OpBlobArg& oba) { + return oba2sbp_parallel.find(oba) != oba2sbp_parallel.end(); + }; + for (const auto& pair : job_builder->job().helper().identical_sbp_oba_pairs().pair()) { + const SbpParallel* sbp_parallel = nullptr; + if (HasSbpParallel(pair.first()) && HasSbpParallel(pair.second())) { + CHECK(oba2sbp_parallel.at(pair.first()) == oba2sbp_parallel.at(pair.second())); + sbp_parallel = oba2sbp_parallel.at(pair.first()); + } else if (HasSbpParallel(pair.first())) { + sbp_parallel = oba2sbp_parallel.at(pair.first()); + } else if (HasSbpParallel(pair.second())) { + sbp_parallel = oba2sbp_parallel.at(pair.second()); + } else { + UNIMPLEMENTED(); + } + *job_builder->MutSbpParallel4Oba(pair.first()) = *sbp_parallel; + *job_builder->MutSbpParallel4Oba(pair.second()) = *sbp_parallel; + } +} + +void UpdateOpSbpSignatureHint(const OpGraph& op_graph, JobBuilder* job_builder) { + op_graph.ForEachNode( + [&](OpNode* op_node) { BindIdenticalSbpObaPairsBetweenIbns(*op_node, job_builder); }); + SetSbpSignatureHintByIdenticalSbpObaPairs(op_graph, job_builder); +} + +class GenerateBackwardAndOptimizerOpConfs final : public OpGraphPass { + public: + bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); } + OF_DISALLOW_COPY_AND_MOVE(GenerateBackwardAndOptimizerOpConfs); + GenerateBackwardAndOptimizerOpConfs() = default; + ~GenerateBackwardAndOptimizerOpConfs() override = default; + + void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; +}; + +void GenerateBackwardAndOptimizerOpConfs::Apply(const OpGraph& op_graph, + JobBuilder* job_builder) const { + LogicalBlobId total_loss_instance_num; + HashMap lbi2diff_lbi; + AutoGrad(op_graph, job_builder, &lbi2diff_lbi); + std::function LossInstanceNum4ParallelDesc; + AddTotalLossInstanceNumOpConf(op_graph, job_builder, lbi2diff_lbi, &LossInstanceNum4ParallelDesc); + AddOptimizerOpConf(op_graph, job_builder, lbi2diff_lbi, LossInstanceNum4ParallelDesc); + UpdateJobHelperConfProducedLbi2ConsumedDiffLbi(lbi2diff_lbi, job_builder); + UpdateOpSbpSignatureHint(op_graph, job_builder); +} + +REGISTER_FUNCTION_PASS("GenerateBackwardAndOptimizerOpConfs", GenerateBackwardAndOptimizerOpConfs); + +} // namespace + +} // namespace oneflow diff --git a/oneflow/core/job_completer/job_completer.cpp b/oneflow/core/job_completer/job_completer.cpp index d70c4f7495..b469295b89 100644 --- a/oneflow/core/job_completer/job_completer.cpp +++ b/oneflow/core/job_completer/job_completer.cpp @@ -1,18 +1,11 @@ #include "oneflow/core/job_completer/job_completer.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job_completer/autograd.h" #include "oneflow/core/job_completer/autotick.h" #include "oneflow/core/job_completer/add_keep_header_only_op_conf.h" -#include "oneflow/core/job_completer/optimizer.h" #include "oneflow/core/job/job_desc.h" -#include "oneflow/core/job_completer/all_reduce_add_pass.h" -#include "oneflow/core/job_completer/all_reduce_sequence_pass.h" #include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h" -#include "oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h" -#include "oneflow/core/job_completer/auto_train_step.h" -#include "oneflow/core/job_completer/auto_learning_rate.h" -#include "oneflow/core/job_completer/add_lbi_diff_watcher.h" #include "oneflow/core/framework/config_def.h" - #include "oneflow/core/job_completer/xrt_compilation.h" namespace oneflow { @@ -42,95 +35,6 @@ void WithOpGraphAndMutJobBuilder(Job* job, Handler(op_graph, &job_builder); } -void UpdateJobHelperConfProducedLbi2ConsumedDiffLbi( - const HashMap& lbi2diff_lbi, JobBuilder* job_builder) { - auto& mut_pairs = - (*job_builder->mutable_helper()->mutable_tag2lbi_relations())[kProducedLbi2ConsumedDiffLbi]; - for (const auto& pair : lbi2diff_lbi) { - auto* mut_pair = mut_pairs.add_pair(); - *mut_pair->mutable_first() = pair.first; - *mut_pair->mutable_second() = pair.second; - } -} - -void BindIdenticalSbpObaPairsBetweenIbns(const OpNode& op_node, JobBuilder* job_builder) { - HashMap> in_lbi2obas; - for (const std::string& ibn : op_node.op().input_bns()) { - in_lbi2obas[op_node.op().BnInOp2Lbi(ibn)].push_back(GenOpBlobArg(op_node.op().op_name(), ibn)); - } - for (const auto& pair : in_lbi2obas) { - if (pair.second.size() > 1) { - FOR_RANGE(int32_t, i, 1, pair.second.size()) { - job_builder->BindIdenticalSbpOpBlobArgPair(pair.second.at(0), pair.second.at(i)); - } - } - } -} - -void SetSbpSignatureHintByIdenticalSbpObaPairs(const OpGraph& op_graph, JobBuilder* job_builder) { - HashMap oba2sbp_parallel; - op_graph.ForEachNode([&](OpNode* op_node) { - auto ForEachBn = [&](const std::function& Handler) { - for (const auto& ibn : op_node->op().input_bns()) { Handler(ibn); } - for (const auto& obn : op_node->op().output_bns()) { Handler(obn); } - }; - ForEachBn([&](const std::string& bn_in_op) { - const auto& oba = GenOpBlobArg(op_node->op().op_name(), bn_in_op); - oba2sbp_parallel[oba] = &op_node->SbpParallel4Lbi(op_node->op().BnInOp2Lbi(bn_in_op)); - }); - }); - auto HasSbpParallel = [&](const OpBlobArg& oba) { - return oba2sbp_parallel.find(oba) != oba2sbp_parallel.end(); - }; - for (const auto& pair : job_builder->job().helper().identical_sbp_oba_pairs().pair()) { - const SbpParallel* sbp_parallel = nullptr; - if (HasSbpParallel(pair.first()) && HasSbpParallel(pair.second())) { - CHECK(oba2sbp_parallel.at(pair.first()) == oba2sbp_parallel.at(pair.second())); - sbp_parallel = oba2sbp_parallel.at(pair.first()); - } else if (HasSbpParallel(pair.first())) { - sbp_parallel = oba2sbp_parallel.at(pair.first()); - } else if (HasSbpParallel(pair.second())) { - sbp_parallel = oba2sbp_parallel.at(pair.second()); - } else { - UNIMPLEMENTED(); - } - *job_builder->MutSbpParallel4Oba(pair.first()) = *sbp_parallel; - *job_builder->MutSbpParallel4Oba(pair.second()) = *sbp_parallel; - } -} - -void UpdateOpSbpSignatureHint(const OpGraph& op_graph, JobBuilder* job_builder) { - op_graph.ForEachNode( - [&](OpNode* op_node) { BindIdenticalSbpObaPairsBetweenIbns(*op_node, job_builder); }); - SetSbpSignatureHintByIdenticalSbpObaPairs(op_graph, job_builder); -} - -void GenerateOpConf4Trainning(const OpGraph& op_graph, JobBuilder* job_builder) { - LogicalBlobId total_loss_instance_num; - HashMap lbi2diff_lbi; - AutoGrad(op_graph, job_builder, &lbi2diff_lbi); - std::function LossInstanceNum4ParallelDesc; - AddTotalLossInstanceNumOpConf(op_graph, job_builder, lbi2diff_lbi, &LossInstanceNum4ParallelDesc); - AddOptimizerOpConf(op_graph, job_builder, lbi2diff_lbi, LossInstanceNum4ParallelDesc); - UpdateJobHelperConfProducedLbi2ConsumedDiffLbi(lbi2diff_lbi, job_builder); - UpdateOpSbpSignatureHint(op_graph, job_builder); -} - -std::function MakeGetterMutParallelConf4OpName( - Placement* placement) { - auto op_name2parallel_conf = std::make_shared>(); - FOR_RANGE(int, idx, 0, placement->placement_group_size()) { - auto* placement_group = placement->mutable_placement_group(idx); - for (const std::string& op_name : placement_group->op_set().op_name()) { - ParallelConf* parallel_conf = placement_group->mutable_parallel_conf(); - CHECK(op_name2parallel_conf->emplace(op_name, parallel_conf).second); - } - } - return [op_name2parallel_conf](const std::string& op_name) { - return op_name2parallel_conf->at(op_name); - }; -} - void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder) { auto IsMutableConsumedLbi = [](const Operator& op, const LogicalBlobId& lbi) -> bool { for (const std::string& bn : op.input_bns()) { @@ -179,28 +83,20 @@ void DumpLogicalBlobDescAndSbpSignature(const OpGraph& op_graph, JobBuilder* job op_graph.DumpSbpSignature(job_builder); } -void MakeNcclTupleBroadcastReduceSequence(const OpGraph& op_graph, JobBuilder* job_builder) { - NcclTupleBroadcastReduceSequencePass().Apply(op_graph, job_builder); -} - } // namespace void JobCompleter::Complete(Job* job) const { - // complete variable ops FunctionPass("SetDefaultVariableConf")(job); FunctionPass("AutoMixedPrecision")(job); - if (GlobalJobDesc().IsTrain()) { - FunctionPass("TieUpChainHeadersUnReachableFromAnyVariableOps")(job); - FunctionPass("NonDistributedOptimizerPass")(job); - WithOpGraphAndMutJob(job, &AutoTrainStep); - WithOpGraphAndMutJob(job, &AutoLearningRate); - // complete ops for trainning - WithOpGraphAndMutJobBuilder(job, &GenerateOpConf4Trainning); - WithOpGraphAndMutJobBuilder(job, &MakeNcclTupleBroadcastReduceSequence); - AllReduceAddPass()(job); - AddLbiDiffWatcherOpConfs(job); - AllReduceSequencePass()(job); - } + FunctionPass("TieUpChainHeadersUnReachableFromAnyVariableOps")(job); + FunctionPass("NonDistributedOptimizerPass")(job); + FunctionPass("AutoTrainStep")(job); + FunctionPass("AutoLearningRate")(job); + FunctionPass("GenerateBackwardAndOptimizerOpConfs")(job); + FunctionPass("SequentializeNcclTupleBroadcastReducePass")(job); + FunctionPass("AddAllReduceGroupPass")(job); + FunctionPass("AddLbiDiffWatcherOpConfs")(job); + FunctionPass("SequentializeAllReduceGroupPass")(job); WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature); WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel); WithOpGraphAndMutJobBuilder(job, &AddKeepHeaderOnlyOp); diff --git a/oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h b/oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h deleted file mode 100644 index 928bd70a66..0000000000 --- a/oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_NCCL_TUPLE_BROADCAST_REDUCE_SEQUENCE_PASS_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_NCCL_TUPLE_BROADCAST_REDUCE_SEQUENCE_PASS_H_ - -#include "oneflow/core/common/util.h" - -namespace oneflow { - -class OpGraph; -class JobBuilder; - -class NcclTupleBroadcastReduceSequencePass final { - public: - OF_DISALLOW_COPY_AND_MOVE(NcclTupleBroadcastReduceSequencePass); - NcclTupleBroadcastReduceSequencePass() = default; - ~NcclTupleBroadcastReduceSequencePass() = default; - - void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; -}; - -} // namespace oneflow - -#endif // #define ONEFLOW_CORE_JOB_COMPLETER_NCCL_TUPLE_BROADCAST_REDUCE_SEQUENCE_PASS_H_ diff --git a/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp b/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp index 9d52f51784..b2ea585a6d 100644 --- a/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp +++ b/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp @@ -35,7 +35,9 @@ class NonDistributedOptimizerPass final : public OpGraphPass { OF_DISALLOW_COPY_AND_MOVE(NonDistributedOptimizerPass); NonDistributedOptimizerPass() = default; ~NonDistributedOptimizerPass() = default; - bool IsEnabled() const override { return GlobalJobDesc().enable_non_distributed_optimizer(); } + bool IsEnabled() const override { + return GlobalJobDesc().IsTrain() && GlobalJobDesc().enable_non_distributed_optimizer(); + } void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; }; diff --git a/oneflow/core/job_completer/op_graph_pass.h b/oneflow/core/job_completer/op_graph_pass.h index 28948b7b65..4334a35867 100644 --- a/oneflow/core/job_completer/op_graph_pass.h +++ b/oneflow/core/job_completer/op_graph_pass.h @@ -9,6 +9,9 @@ namespace oneflow { class OpGraphPass { public: + OpGraphPass() = default; + virtual ~OpGraphPass() = default; + void operator()(Job* job) const { if (IsEnabled() == false) { return; } Apply(job); diff --git a/oneflow/core/job_completer/all_reduce_sequence_pass.cpp b/oneflow/core/job_completer/sequentialize_all_reduce_group_pass.cpp similarity index 79% rename from oneflow/core/job_completer/all_reduce_sequence_pass.cpp rename to oneflow/core/job_completer/sequentialize_all_reduce_group_pass.cpp index 6b565bef4e..6a697b0a35 100644 --- a/oneflow/core/job_completer/all_reduce_sequence_pass.cpp +++ b/oneflow/core/job_completer/sequentialize_all_reduce_group_pass.cpp @@ -1,4 +1,4 @@ -#include "oneflow/core/job_completer/all_reduce_sequence_pass.h" +#include "oneflow/core/job_completer/op_graph_pass.h" namespace oneflow { @@ -52,9 +52,18 @@ void ReOrderAllReduceGroups(std::vector* all_reduce_groups) { all_reduce_groups->end() - lazy_count); } -} // namespace +class SequentializeAllReduceGroupPass final : public OpGraphPass { + public: + SequentializeAllReduceGroupPass() = default; + ~SequentializeAllReduceGroupPass() = default; + bool IsEnabled() const override { + return GlobalJobDesc().IsTrain() && !GlobalJobDesc().disable_all_reduce_sequence(); + } + void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; +}; -void AllReduceSequencePass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { +void SequentializeAllReduceGroupPass::Apply(const OpGraph& op_graph, + JobBuilder* job_builder) const { std::vector all_reduce_groups; FindAllReduceGroups(op_graph, &all_reduce_groups); ReOrderAllReduceGroups(&all_reduce_groups); @@ -68,4 +77,8 @@ void AllReduceSequencePass::Apply(const OpGraph& op_graph, JobBuilder* job_build } } +REGISTER_FUNCTION_PASS("SequentializeAllReduceGroupPass", SequentializeAllReduceGroupPass); + +} // namespace + } // namespace oneflow diff --git a/oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.cpp b/oneflow/core/job_completer/sequentialize_nccl_tuple_broadcast_reduce_pass.cpp similarity index 67% rename from oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.cpp rename to oneflow/core/job_completer/sequentialize_nccl_tuple_broadcast_reduce_pass.cpp index 0d45f26b1f..b7ca1d1c2a 100644 --- a/oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.cpp +++ b/oneflow/core/job_completer/sequentialize_nccl_tuple_broadcast_reduce_pass.cpp @@ -1,10 +1,21 @@ -#include "oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/graph/op_graph.h" namespace oneflow { -void NcclTupleBroadcastReduceSequencePass::Apply(const OpGraph& op_graph, - JobBuilder* builder) const { +class SequentializeNcclTupleBroadcastReducePass final : public OpGraphPass { + public: + OF_DISALLOW_COPY_AND_MOVE(SequentializeNcclTupleBroadcastReducePass); + SequentializeNcclTupleBroadcastReducePass() = default; + ~SequentializeNcclTupleBroadcastReducePass() = default; + + bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); } + + void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; +}; + +void SequentializeNcclTupleBroadcastReducePass::Apply(const OpGraph& op_graph, + JobBuilder* builder) const { std::vector broadcast_ops; std::vector reduce_ops; op_graph.ForEachNode([&](const OpNode* node) { @@ -41,4 +52,7 @@ void NcclTupleBroadcastReduceSequencePass::Apply(const OpGraph& op_graph, builder->MutOpsOnlyOnce(reduce_ops); } +REGISTER_FUNCTION_PASS("SequentializeNcclTupleBroadcastReducePass", + SequentializeNcclTupleBroadcastReducePass); + } // namespace oneflow diff --git a/oneflow/core/job_completer/tie_up_chain_headers.cpp b/oneflow/core/job_completer/tie_up_chain_headers.cpp index db6611487b..d86b8088b0 100644 --- a/oneflow/core/job_completer/tie_up_chain_headers.cpp +++ b/oneflow/core/job_completer/tie_up_chain_headers.cpp @@ -94,7 +94,9 @@ std::function MakePredicatorIsReachableFromAnyVariableOps(const O REGISTER_FUNCTION_CONFIG_DEF().Bool("enable_pseudo_chain_merge", false, "ties up chain headers unreachable from any variable ops"); class TieUpChainHeadersUnReachableFromAnyVariableOps final : public OpGraphPass { - bool IsEnabled() const override { return GlobalJobDesc().Bool("enable_pseudo_chain_merge"); } + bool IsEnabled() const override { + return GlobalJobDesc().IsTrain() && GlobalJobDesc().Bool("enable_pseudo_chain_merge"); + } void Apply(const OpGraph& op_graph, Job* job) const override { auto IsReachableFromAnyVariableOps = MakePredicatorIsReachableFromAnyVariableOps(op_graph); -- GitLab