diff --git a/oneflow/core/job_completer/all_reduce_add_pass.cpp b/oneflow/core/job_completer/add_all_reduce_group_pass.cpp similarity index 95% rename from oneflow/core/job_completer/all_reduce_add_pass.cpp rename to oneflow/core/job_completer/add_all_reduce_group_pass.cpp index 2cf8012a3848567aba0b3315d1891353fcc475b7..77ab6bbedfc2f498b22df3ae3589042d5f233957 100644 --- a/oneflow/core/job_completer/all_reduce_add_pass.cpp +++ b/oneflow/core/job_completer/add_all_reduce_group_pass.cpp @@ -1,4 +1,4 @@ -#include "oneflow/core/job_completer/all_reduce_add_pass.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/register/runtime_blob_desc.h" namespace oneflow { @@ -240,9 +240,18 @@ void BuildAllReduceStruct( all_reduced_lbi, GetLastTouchedOpName); } -} // namespace +class AddAllReduceGroupPass final : public OpGraphPass { + public: + AddAllReduceGroupPass() = default; + ~AddAllReduceGroupPass() = default; + bool IsEnabled() const override { + return GlobalJobDesc().IsTrain() && !GlobalJobDesc().enable_non_distributed_optimizer() + && GlobalJobDesc().enable_all_reduce_group(); + } + void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; +}; -void AllReduceAddPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { +void AddAllReduceGroupPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { auto ProducerOpNode4Lbi = MakeGetterProducerOpNode4Lbi(op_graph); std::vector lbis; FindAllReducedLbis(job_builder->job(), op_graph, ProducerOpNode4Lbi, &lbis); @@ -286,4 +295,8 @@ void AllReduceAddPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) c }); } +} // namespace + +REGISTER_FUNCTION_PASS("AddAllReduceGroupPass", AddAllReduceGroupPass); + } // namespace oneflow diff --git a/oneflow/core/job_completer/add_lbi_diff_watcher.cpp b/oneflow/core/job_completer/add_lbi_diff_watcher.cpp index 62d8a3afebad0fe86328920401d248e0d23baeff..71a8bd9e925ef920ac5e1efa8e1e7351918bc6fc 100644 --- a/oneflow/core/job_completer/add_lbi_diff_watcher.cpp +++ b/oneflow/core/job_completer/add_lbi_diff_watcher.cpp @@ -1,10 +1,18 @@ -#include "oneflow/core/job_completer/add_lbi_diff_watcher.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job/lbi_diff_watcher_info.pb.h" #include "oneflow/core/operator/operator.h" namespace oneflow { -void AddLbiDiffWatcherOpConfs(Job* job) { +namespace { + +class AddLbiDiffWatcherOpConfs final : public OpGraphPass { + public: + bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); } + void Apply(Job* job) const override; +}; + +void AddLbiDiffWatcherOpConfs::Apply(Job* job) const { JobBuilder job_builder(job); const auto& map = Global::Get()->job_name2lbi_and_watcher_uuids(); if (map.find(GlobalJobDesc().job_name()) == map.end()) { return; } @@ -27,4 +35,8 @@ void AddLbiDiffWatcherOpConfs(Job* job) { } } +REGISTER_FUNCTION_PASS("AddLbiDiffWatcherOpConfs", AddLbiDiffWatcherOpConfs); + +} // namespace + } // namespace oneflow diff --git a/oneflow/core/job_completer/add_lbi_diff_watcher.h b/oneflow/core/job_completer/add_lbi_diff_watcher.h deleted file mode 100644 index 8c58191596a1107c85d431d6af9081f320cdc936..0000000000000000000000000000000000000000 --- a/oneflow/core/job_completer/add_lbi_diff_watcher.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_ADD_LBI_DIFF_WATCHER_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_ADD_LBI_DIFF_WATCHER_H_ - -#include "oneflow/core/job/job_builder.h" - -namespace oneflow { - -void AddLbiDiffWatcherOpConfs(Job* job); -} - -#endif // ONEFLOW_CORE_JOB_COMPLETER_ADD_LBI_DIFF_WATCHER_H_ diff --git a/oneflow/core/job_completer/all_reduce_add_pass.h b/oneflow/core/job_completer/all_reduce_add_pass.h deleted file mode 100644 index f908813298286fa8ccf133972431c86f6deb3ba8..0000000000000000000000000000000000000000 --- a/oneflow/core/job_completer/all_reduce_add_pass.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_ADD_PASS_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_ADD_PASS_H_ - -#include "oneflow/core/job_completer/op_graph_pass.h" - -namespace oneflow { - -class OpGraph; - -class AllReduceAddPass final : public OpGraphPass { - public: - AllReduceAddPass() = default; - ~AllReduceAddPass() = default; - bool IsEnabled() const override { - return !GlobalJobDesc().enable_non_distributed_optimizer() - && GlobalJobDesc().enable_all_reduce_group(); - } - void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_ADD_PASS_H_ diff --git a/oneflow/core/job_completer/all_reduce_sequence_pass.h b/oneflow/core/job_completer/all_reduce_sequence_pass.h deleted file mode 100644 index 3e81fe75cf00a3c0258be1153850ca38fcb9c9cf..0000000000000000000000000000000000000000 --- a/oneflow/core/job_completer/all_reduce_sequence_pass.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_SEQUENCE_PASS_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_SEQUENCE_PASS_H_ - -#include "oneflow/core/job/job.pb.h" -#include "oneflow/core/job_completer/op_graph_pass.h" - -namespace oneflow { - -class OpGraph; - -class AllReduceSequencePass final : public OpGraphPass { - public: - AllReduceSequencePass() = default; - ~AllReduceSequencePass() = default; - bool IsEnabled() const override { return !GlobalJobDesc().disable_all_reduce_sequence(); } - void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_SEQUENCE_PASS_H_ diff --git a/oneflow/core/job_completer/auto_learning_rate.cpp b/oneflow/core/job_completer/auto_learning_rate.cpp index df6b824d17c89badb769e237e8942e9ed3e3705f..6ec6694b19e29d93356b289911f0427c5310aaac 100644 --- a/oneflow/core/job_completer/auto_learning_rate.cpp +++ b/oneflow/core/job_completer/auto_learning_rate.cpp @@ -1,9 +1,23 @@ -#include "oneflow/core/graph/op_graph.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job/job.pb.h" namespace oneflow { -void AutoLearningRate(const OpGraph& op_graph, Job* job) { +namespace { + +class AutoLearningRate final : public OpGraphPass { + public: + OF_DISALLOW_COPY_AND_MOVE(AutoLearningRate); + AutoLearningRate() = default; + ~AutoLearningRate() override = default; + + bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); } + + void Apply(const OpGraph& op_graph, Job* job) const override; +}; + +void AutoLearningRate::Apply(const OpGraph& op_graph, Job* job) const { JobBuilder job_builder(job); const TrainConf& train_conf = job->job_conf().train_conf(); auto AddScheduleOp = [&](const std::string& op_name, const float learning_rate) -> std::string { @@ -58,4 +72,8 @@ void AutoLearningRate(const OpGraph& op_graph, Job* job) { } } +REGISTER_FUNCTION_PASS("AutoLearningRate", AutoLearningRate); + +} // namespace + } // namespace oneflow diff --git a/oneflow/core/job_completer/auto_learning_rate.h b/oneflow/core/job_completer/auto_learning_rate.h deleted file mode 100644 index 5ffdc5f7296beaf1229f8223477d8bae60fa6479..0000000000000000000000000000000000000000 --- a/oneflow/core/job_completer/auto_learning_rate.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_ - -namespace oneflow { - -class OpGraph; -class Job; - -void AutoLearningRate(const OpGraph& op_graph, Job* job); - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_ diff --git a/oneflow/core/job_completer/auto_train_step.cpp b/oneflow/core/job_completer/auto_train_step.cpp index a3c01850eef40e0441d55914aea72364d79c2857..5af8d589f698e34ccba65aed8ceb38ccdcc98df3 100644 --- a/oneflow/core/job_completer/auto_train_step.cpp +++ b/oneflow/core/job_completer/auto_train_step.cpp @@ -1,9 +1,22 @@ -#include "oneflow/core/graph/op_graph.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job/job.pb.h" namespace oneflow { -void AutoTrainStep(const OpGraph& op_graph, Job* job) { +namespace { + +class AutoTrainStep final : public OpGraphPass { + public: + OF_DISALLOW_COPY_AND_MOVE(AutoTrainStep); + AutoTrainStep() = default; + ~AutoTrainStep() override = default; + + bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); } + + void Apply(const OpGraph& op_graph, Job* job) const override; +}; + +void AutoTrainStep::Apply(const OpGraph& op_graph, Job* job) const { if (job->job_conf().train_conf().has_train_step_lbn()) { return; } OperatorConf variable_op_conf{}; const std::string train_step_name = "System-Train-TrainStep-" + job->job_conf().job_name(); @@ -42,4 +55,8 @@ void AutoTrainStep(const OpGraph& op_graph, Job* job) { job->mutable_job_conf()->mutable_train_conf()->set_train_step_lbn(train_step_lbn); } +REGISTER_FUNCTION_PASS("AutoTrainStep", AutoTrainStep); + +} // namespace + } // namespace oneflow diff --git a/oneflow/core/job_completer/auto_train_step.h b/oneflow/core/job_completer/auto_train_step.h deleted file mode 100644 index 7d16dc0d8c901e3fa5b2b350060c15f58aecc64e..0000000000000000000000000000000000000000 --- a/oneflow/core/job_completer/auto_train_step.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_TRAIN_STEP_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_TRAIN_STEP_H_ - -namespace oneflow { - -class OpGraph; -class Job; - -void AutoTrainStep(const OpGraph& op_graph, Job* job); - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_TRAIN_STEP_H_ diff --git a/oneflow/core/job_completer/generate_backward_and_optimizer_op_confs.cpp b/oneflow/core/job_completer/generate_backward_and_optimizer_op_confs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fafc479d4531fa6a6e10659d89e65d9a1b8ebaf6 --- /dev/null +++ b/oneflow/core/job_completer/generate_backward_and_optimizer_op_confs.cpp @@ -0,0 +1,98 @@ +#include "oneflow/core/job_completer/op_graph_pass.h" +#include "oneflow/core/job_completer/autograd.h" +#include "oneflow/core/job_completer/optimizer.h" + +namespace oneflow { + +namespace { + +void UpdateJobHelperConfProducedLbi2ConsumedDiffLbi( + const HashMap& lbi2diff_lbi, JobBuilder* job_builder) { + auto& mut_pairs = + (*job_builder->mutable_helper()->mutable_tag2lbi_relations())[kProducedLbi2ConsumedDiffLbi]; + for (const auto& pair : lbi2diff_lbi) { + auto* mut_pair = mut_pairs.add_pair(); + *mut_pair->mutable_first() = pair.first; + *mut_pair->mutable_second() = pair.second; + } +} + +void BindIdenticalSbpObaPairsBetweenIbns(const OpNode& op_node, JobBuilder* job_builder) { + HashMap> in_lbi2obas; + for (const std::string& ibn : op_node.op().input_bns()) { + in_lbi2obas[op_node.op().BnInOp2Lbi(ibn)].push_back(GenOpBlobArg(op_node.op().op_name(), ibn)); + } + for (const auto& pair : in_lbi2obas) { + if (pair.second.size() > 1) { + FOR_RANGE(int32_t, i, 1, pair.second.size()) { + job_builder->BindIdenticalSbpOpBlobArgPair(pair.second.at(0), pair.second.at(i)); + } + } + } +} + +void SetSbpSignatureHintByIdenticalSbpObaPairs(const OpGraph& op_graph, JobBuilder* job_builder) { + HashMap oba2sbp_parallel; + op_graph.ForEachNode([&](OpNode* op_node) { + auto ForEachBn = [&](const std::function& Handler) { + for (const auto& ibn : op_node->op().input_bns()) { Handler(ibn); } + for (const auto& obn : op_node->op().output_bns()) { Handler(obn); } + }; + ForEachBn([&](const std::string& bn_in_op) { + const auto& oba = GenOpBlobArg(op_node->op().op_name(), bn_in_op); + oba2sbp_parallel[oba] = &op_node->SbpParallel4Lbi(op_node->op().BnInOp2Lbi(bn_in_op)); + }); + }); + auto HasSbpParallel = [&](const OpBlobArg& oba) { + return oba2sbp_parallel.find(oba) != oba2sbp_parallel.end(); + }; + for (const auto& pair : job_builder->job().helper().identical_sbp_oba_pairs().pair()) { + const SbpParallel* sbp_parallel = nullptr; + if (HasSbpParallel(pair.first()) && HasSbpParallel(pair.second())) { + CHECK(oba2sbp_parallel.at(pair.first()) == oba2sbp_parallel.at(pair.second())); + sbp_parallel = oba2sbp_parallel.at(pair.first()); + } else if (HasSbpParallel(pair.first())) { + sbp_parallel = oba2sbp_parallel.at(pair.first()); + } else if (HasSbpParallel(pair.second())) { + sbp_parallel = oba2sbp_parallel.at(pair.second()); + } else { + UNIMPLEMENTED(); + } + *job_builder->MutSbpParallel4Oba(pair.first()) = *sbp_parallel; + *job_builder->MutSbpParallel4Oba(pair.second()) = *sbp_parallel; + } +} + +void UpdateOpSbpSignatureHint(const OpGraph& op_graph, JobBuilder* job_builder) { + op_graph.ForEachNode( + [&](OpNode* op_node) { BindIdenticalSbpObaPairsBetweenIbns(*op_node, job_builder); }); + SetSbpSignatureHintByIdenticalSbpObaPairs(op_graph, job_builder); +} + +class GenerateBackwardAndOptimizerOpConfs final : public OpGraphPass { + public: + bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); } + OF_DISALLOW_COPY_AND_MOVE(GenerateBackwardAndOptimizerOpConfs); + GenerateBackwardAndOptimizerOpConfs() = default; + ~GenerateBackwardAndOptimizerOpConfs() override = default; + + void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; +}; + +void GenerateBackwardAndOptimizerOpConfs::Apply(const OpGraph& op_graph, + JobBuilder* job_builder) const { + LogicalBlobId total_loss_instance_num; + HashMap lbi2diff_lbi; + AutoGrad(op_graph, job_builder, &lbi2diff_lbi); + std::function LossInstanceNum4ParallelDesc; + AddTotalLossInstanceNumOpConf(op_graph, job_builder, lbi2diff_lbi, &LossInstanceNum4ParallelDesc); + AddOptimizerOpConf(op_graph, job_builder, lbi2diff_lbi, LossInstanceNum4ParallelDesc); + UpdateJobHelperConfProducedLbi2ConsumedDiffLbi(lbi2diff_lbi, job_builder); + UpdateOpSbpSignatureHint(op_graph, job_builder); +} + +REGISTER_FUNCTION_PASS("GenerateBackwardAndOptimizerOpConfs", GenerateBackwardAndOptimizerOpConfs); + +} // namespace + +} // namespace oneflow diff --git a/oneflow/core/job_completer/job_completer.cpp b/oneflow/core/job_completer/job_completer.cpp index d70c4f749571e062cd1d221fd9246a82d12614fd..b469295b89a31053b9c9f49841f3c7ca125a3439 100644 --- a/oneflow/core/job_completer/job_completer.cpp +++ b/oneflow/core/job_completer/job_completer.cpp @@ -1,18 +1,11 @@ #include "oneflow/core/job_completer/job_completer.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job_completer/autograd.h" #include "oneflow/core/job_completer/autotick.h" #include "oneflow/core/job_completer/add_keep_header_only_op_conf.h" -#include "oneflow/core/job_completer/optimizer.h" #include "oneflow/core/job/job_desc.h" -#include "oneflow/core/job_completer/all_reduce_add_pass.h" -#include "oneflow/core/job_completer/all_reduce_sequence_pass.h" #include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h" -#include "oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h" -#include "oneflow/core/job_completer/auto_train_step.h" -#include "oneflow/core/job_completer/auto_learning_rate.h" -#include "oneflow/core/job_completer/add_lbi_diff_watcher.h" #include "oneflow/core/framework/config_def.h" - #include "oneflow/core/job_completer/xrt_compilation.h" namespace oneflow { @@ -42,95 +35,6 @@ void WithOpGraphAndMutJobBuilder(Job* job, Handler(op_graph, &job_builder); } -void UpdateJobHelperConfProducedLbi2ConsumedDiffLbi( - const HashMap& lbi2diff_lbi, JobBuilder* job_builder) { - auto& mut_pairs = - (*job_builder->mutable_helper()->mutable_tag2lbi_relations())[kProducedLbi2ConsumedDiffLbi]; - for (const auto& pair : lbi2diff_lbi) { - auto* mut_pair = mut_pairs.add_pair(); - *mut_pair->mutable_first() = pair.first; - *mut_pair->mutable_second() = pair.second; - } -} - -void BindIdenticalSbpObaPairsBetweenIbns(const OpNode& op_node, JobBuilder* job_builder) { - HashMap> in_lbi2obas; - for (const std::string& ibn : op_node.op().input_bns()) { - in_lbi2obas[op_node.op().BnInOp2Lbi(ibn)].push_back(GenOpBlobArg(op_node.op().op_name(), ibn)); - } - for (const auto& pair : in_lbi2obas) { - if (pair.second.size() > 1) { - FOR_RANGE(int32_t, i, 1, pair.second.size()) { - job_builder->BindIdenticalSbpOpBlobArgPair(pair.second.at(0), pair.second.at(i)); - } - } - } -} - -void SetSbpSignatureHintByIdenticalSbpObaPairs(const OpGraph& op_graph, JobBuilder* job_builder) { - HashMap oba2sbp_parallel; - op_graph.ForEachNode([&](OpNode* op_node) { - auto ForEachBn = [&](const std::function& Handler) { - for (const auto& ibn : op_node->op().input_bns()) { Handler(ibn); } - for (const auto& obn : op_node->op().output_bns()) { Handler(obn); } - }; - ForEachBn([&](const std::string& bn_in_op) { - const auto& oba = GenOpBlobArg(op_node->op().op_name(), bn_in_op); - oba2sbp_parallel[oba] = &op_node->SbpParallel4Lbi(op_node->op().BnInOp2Lbi(bn_in_op)); - }); - }); - auto HasSbpParallel = [&](const OpBlobArg& oba) { - return oba2sbp_parallel.find(oba) != oba2sbp_parallel.end(); - }; - for (const auto& pair : job_builder->job().helper().identical_sbp_oba_pairs().pair()) { - const SbpParallel* sbp_parallel = nullptr; - if (HasSbpParallel(pair.first()) && HasSbpParallel(pair.second())) { - CHECK(oba2sbp_parallel.at(pair.first()) == oba2sbp_parallel.at(pair.second())); - sbp_parallel = oba2sbp_parallel.at(pair.first()); - } else if (HasSbpParallel(pair.first())) { - sbp_parallel = oba2sbp_parallel.at(pair.first()); - } else if (HasSbpParallel(pair.second())) { - sbp_parallel = oba2sbp_parallel.at(pair.second()); - } else { - UNIMPLEMENTED(); - } - *job_builder->MutSbpParallel4Oba(pair.first()) = *sbp_parallel; - *job_builder->MutSbpParallel4Oba(pair.second()) = *sbp_parallel; - } -} - -void UpdateOpSbpSignatureHint(const OpGraph& op_graph, JobBuilder* job_builder) { - op_graph.ForEachNode( - [&](OpNode* op_node) { BindIdenticalSbpObaPairsBetweenIbns(*op_node, job_builder); }); - SetSbpSignatureHintByIdenticalSbpObaPairs(op_graph, job_builder); -} - -void GenerateOpConf4Trainning(const OpGraph& op_graph, JobBuilder* job_builder) { - LogicalBlobId total_loss_instance_num; - HashMap lbi2diff_lbi; - AutoGrad(op_graph, job_builder, &lbi2diff_lbi); - std::function LossInstanceNum4ParallelDesc; - AddTotalLossInstanceNumOpConf(op_graph, job_builder, lbi2diff_lbi, &LossInstanceNum4ParallelDesc); - AddOptimizerOpConf(op_graph, job_builder, lbi2diff_lbi, LossInstanceNum4ParallelDesc); - UpdateJobHelperConfProducedLbi2ConsumedDiffLbi(lbi2diff_lbi, job_builder); - UpdateOpSbpSignatureHint(op_graph, job_builder); -} - -std::function MakeGetterMutParallelConf4OpName( - Placement* placement) { - auto op_name2parallel_conf = std::make_shared>(); - FOR_RANGE(int, idx, 0, placement->placement_group_size()) { - auto* placement_group = placement->mutable_placement_group(idx); - for (const std::string& op_name : placement_group->op_set().op_name()) { - ParallelConf* parallel_conf = placement_group->mutable_parallel_conf(); - CHECK(op_name2parallel_conf->emplace(op_name, parallel_conf).second); - } - } - return [op_name2parallel_conf](const std::string& op_name) { - return op_name2parallel_conf->at(op_name); - }; -} - void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder) { auto IsMutableConsumedLbi = [](const Operator& op, const LogicalBlobId& lbi) -> bool { for (const std::string& bn : op.input_bns()) { @@ -179,28 +83,20 @@ void DumpLogicalBlobDescAndSbpSignature(const OpGraph& op_graph, JobBuilder* job op_graph.DumpSbpSignature(job_builder); } -void MakeNcclTupleBroadcastReduceSequence(const OpGraph& op_graph, JobBuilder* job_builder) { - NcclTupleBroadcastReduceSequencePass().Apply(op_graph, job_builder); -} - } // namespace void JobCompleter::Complete(Job* job) const { - // complete variable ops FunctionPass("SetDefaultVariableConf")(job); FunctionPass("AutoMixedPrecision")(job); - if (GlobalJobDesc().IsTrain()) { - FunctionPass("TieUpChainHeadersUnReachableFromAnyVariableOps")(job); - FunctionPass("NonDistributedOptimizerPass")(job); - WithOpGraphAndMutJob(job, &AutoTrainStep); - WithOpGraphAndMutJob(job, &AutoLearningRate); - // complete ops for trainning - WithOpGraphAndMutJobBuilder(job, &GenerateOpConf4Trainning); - WithOpGraphAndMutJobBuilder(job, &MakeNcclTupleBroadcastReduceSequence); - AllReduceAddPass()(job); - AddLbiDiffWatcherOpConfs(job); - AllReduceSequencePass()(job); - } + FunctionPass("TieUpChainHeadersUnReachableFromAnyVariableOps")(job); + FunctionPass("NonDistributedOptimizerPass")(job); + FunctionPass("AutoTrainStep")(job); + FunctionPass("AutoLearningRate")(job); + FunctionPass("GenerateBackwardAndOptimizerOpConfs")(job); + FunctionPass("SequentializeNcclTupleBroadcastReducePass")(job); + FunctionPass("AddAllReduceGroupPass")(job); + FunctionPass("AddLbiDiffWatcherOpConfs")(job); + FunctionPass("SequentializeAllReduceGroupPass")(job); WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature); WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel); WithOpGraphAndMutJobBuilder(job, &AddKeepHeaderOnlyOp); diff --git a/oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h b/oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h deleted file mode 100644 index 928bd70a664dff249011b6a0ffb94ee696479597..0000000000000000000000000000000000000000 --- a/oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_NCCL_TUPLE_BROADCAST_REDUCE_SEQUENCE_PASS_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_NCCL_TUPLE_BROADCAST_REDUCE_SEQUENCE_PASS_H_ - -#include "oneflow/core/common/util.h" - -namespace oneflow { - -class OpGraph; -class JobBuilder; - -class NcclTupleBroadcastReduceSequencePass final { - public: - OF_DISALLOW_COPY_AND_MOVE(NcclTupleBroadcastReduceSequencePass); - NcclTupleBroadcastReduceSequencePass() = default; - ~NcclTupleBroadcastReduceSequencePass() = default; - - void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; -}; - -} // namespace oneflow - -#endif // #define ONEFLOW_CORE_JOB_COMPLETER_NCCL_TUPLE_BROADCAST_REDUCE_SEQUENCE_PASS_H_ diff --git a/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp b/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp index 9d52f517843bc44843ebb46be12e1cb61925e35e..b2ea585a6d54cb74fc57067979ed1a3b65c935ed 100644 --- a/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp +++ b/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp @@ -35,7 +35,9 @@ class NonDistributedOptimizerPass final : public OpGraphPass { OF_DISALLOW_COPY_AND_MOVE(NonDistributedOptimizerPass); NonDistributedOptimizerPass() = default; ~NonDistributedOptimizerPass() = default; - bool IsEnabled() const override { return GlobalJobDesc().enable_non_distributed_optimizer(); } + bool IsEnabled() const override { + return GlobalJobDesc().IsTrain() && GlobalJobDesc().enable_non_distributed_optimizer(); + } void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; }; diff --git a/oneflow/core/job_completer/op_graph_pass.h b/oneflow/core/job_completer/op_graph_pass.h index 28948b7b65ffb38169c39809accef77fe360663c..4334a35867437cc4681d3101ff9e8e25d38f1770 100644 --- a/oneflow/core/job_completer/op_graph_pass.h +++ b/oneflow/core/job_completer/op_graph_pass.h @@ -9,6 +9,9 @@ namespace oneflow { class OpGraphPass { public: + OpGraphPass() = default; + virtual ~OpGraphPass() = default; + void operator()(Job* job) const { if (IsEnabled() == false) { return; } Apply(job); diff --git a/oneflow/core/job_completer/all_reduce_sequence_pass.cpp b/oneflow/core/job_completer/sequentialize_all_reduce_group_pass.cpp similarity index 79% rename from oneflow/core/job_completer/all_reduce_sequence_pass.cpp rename to oneflow/core/job_completer/sequentialize_all_reduce_group_pass.cpp index 6b565bef4e36e8ff64db7c06c5f2ae38fc35cc51..6a697b0a35482db39e7e66f27cab6a5a1a78e75a 100644 --- a/oneflow/core/job_completer/all_reduce_sequence_pass.cpp +++ b/oneflow/core/job_completer/sequentialize_all_reduce_group_pass.cpp @@ -1,4 +1,4 @@ -#include "oneflow/core/job_completer/all_reduce_sequence_pass.h" +#include "oneflow/core/job_completer/op_graph_pass.h" namespace oneflow { @@ -52,9 +52,18 @@ void ReOrderAllReduceGroups(std::vector* all_reduce_groups) { all_reduce_groups->end() - lazy_count); } -} // namespace +class SequentializeAllReduceGroupPass final : public OpGraphPass { + public: + SequentializeAllReduceGroupPass() = default; + ~SequentializeAllReduceGroupPass() = default; + bool IsEnabled() const override { + return GlobalJobDesc().IsTrain() && !GlobalJobDesc().disable_all_reduce_sequence(); + } + void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; +}; -void AllReduceSequencePass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { +void SequentializeAllReduceGroupPass::Apply(const OpGraph& op_graph, + JobBuilder* job_builder) const { std::vector all_reduce_groups; FindAllReduceGroups(op_graph, &all_reduce_groups); ReOrderAllReduceGroups(&all_reduce_groups); @@ -68,4 +77,8 @@ void AllReduceSequencePass::Apply(const OpGraph& op_graph, JobBuilder* job_build } } +REGISTER_FUNCTION_PASS("SequentializeAllReduceGroupPass", SequentializeAllReduceGroupPass); + +} // namespace + } // namespace oneflow diff --git a/oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.cpp b/oneflow/core/job_completer/sequentialize_nccl_tuple_broadcast_reduce_pass.cpp similarity index 67% rename from oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.cpp rename to oneflow/core/job_completer/sequentialize_nccl_tuple_broadcast_reduce_pass.cpp index 0d45f26b1fc2b6d6a6cbe0c85f85606f636420fd..b7ca1d1c2a8e27938e32a52d3c183fb131931a0d 100644 --- a/oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.cpp +++ b/oneflow/core/job_completer/sequentialize_nccl_tuple_broadcast_reduce_pass.cpp @@ -1,10 +1,21 @@ -#include "oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/graph/op_graph.h" namespace oneflow { -void NcclTupleBroadcastReduceSequencePass::Apply(const OpGraph& op_graph, - JobBuilder* builder) const { +class SequentializeNcclTupleBroadcastReducePass final : public OpGraphPass { + public: + OF_DISALLOW_COPY_AND_MOVE(SequentializeNcclTupleBroadcastReducePass); + SequentializeNcclTupleBroadcastReducePass() = default; + ~SequentializeNcclTupleBroadcastReducePass() = default; + + bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); } + + void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; +}; + +void SequentializeNcclTupleBroadcastReducePass::Apply(const OpGraph& op_graph, + JobBuilder* builder) const { std::vector broadcast_ops; std::vector reduce_ops; op_graph.ForEachNode([&](const OpNode* node) { @@ -41,4 +52,7 @@ void NcclTupleBroadcastReduceSequencePass::Apply(const OpGraph& op_graph, builder->MutOpsOnlyOnce(reduce_ops); } +REGISTER_FUNCTION_PASS("SequentializeNcclTupleBroadcastReducePass", + SequentializeNcclTupleBroadcastReducePass); + } // namespace oneflow diff --git a/oneflow/core/job_completer/tie_up_chain_headers.cpp b/oneflow/core/job_completer/tie_up_chain_headers.cpp index db6611487b9e7db71b695c0a75dbd600fb4a0bfc..d86b8088b08ef67c39a47f24467612feb0046f19 100644 --- a/oneflow/core/job_completer/tie_up_chain_headers.cpp +++ b/oneflow/core/job_completer/tie_up_chain_headers.cpp @@ -94,7 +94,9 @@ std::function MakePredicatorIsReachableFromAnyVariableOps(const O REGISTER_FUNCTION_CONFIG_DEF().Bool("enable_pseudo_chain_merge", false, "ties up chain headers unreachable from any variable ops"); class TieUpChainHeadersUnReachableFromAnyVariableOps final : public OpGraphPass { - bool IsEnabled() const override { return GlobalJobDesc().Bool("enable_pseudo_chain_merge"); } + bool IsEnabled() const override { + return GlobalJobDesc().IsTrain() && GlobalJobDesc().Bool("enable_pseudo_chain_merge"); + } void Apply(const OpGraph& op_graph, Job* job) const override { auto IsReachableFromAnyVariableOps = MakePredicatorIsReachableFromAnyVariableOps(op_graph);