diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp index 48382c2cdf28540b66c670a995791c2ad87cf909..c66982991fa73388a92246cf6ca820a7a33d47f1 100644 --- a/oneflow/core/job/oneflow.cpp +++ b/oneflow/core/job/oneflow.cpp @@ -6,7 +6,7 @@ #include "oneflow/core/job/improver.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_builder.h" -#include "oneflow/core/job_completer/user_job_completer.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job/job_set.pb.h" #include "oneflow/core/job/machine_context.h" #include "oneflow/core/job/profiler.h" @@ -743,7 +743,7 @@ void CompileAndMergePlanOnMaster(const PbRpf& conf_jobs, Plan* plan) { AddJobName2JobId(jobs.at(job_id).job_conf().job_name(), job_id); { auto scope = std::make_unique(jobs.at(job_id).job_conf(), job_id); - UserJobCompleter().Complete(&jobs.at(job_id)); + FunctionPass("CompleteOfrecordDecoder")(&jobs.at(job_id)); CompileCurJobOnMaster(&jobs.at(job_id), &sub_plans.at(job_id), true); } } diff --git a/oneflow/core/job_completer/auto_mixed_precision.cpp b/oneflow/core/job_completer/auto_mixed_precision.cpp index 7f5e89fdedd2ea8eaeb1af863a783cf13c3acd64..e5acd1b42ef3781decbe13e69c970a920e5239fb 100644 --- a/oneflow/core/job_completer/auto_mixed_precision.cpp +++ b/oneflow/core/job_completer/auto_mixed_precision.cpp @@ -1,5 +1,6 @@ #include -#include "oneflow/core/job_completer/auto_mixed_precision.h" +#include "oneflow/core/job_completer/auto_mixed_precision_lists.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/device/cuda_util.h" @@ -175,7 +176,36 @@ void InsertCastOpImpl(bool f2h, const OpGraph& op_graph, const HashSet& job_builder->MutOpsOnlyOnce(dst_op_confs); } -} // namespace +class AutoMixedPrecision final : public OpGraphPass { + public: + OF_DISALLOW_COPY_AND_MOVE(AutoMixedPrecision); + AutoMixedPrecision() + : white_list_(AutoMixedPrecisionLists::WhiteList()), + black_list_(AutoMixedPrecisionLists::BlackList()), + gray_list_(AutoMixedPrecisionLists::GrayList()), + clear_list_(AutoMixedPrecisionLists::ClearList()) {} + ~AutoMixedPrecision() = default; + + bool IsEnabled() const override { return GlobalJobDesc().enable_auto_mixed_precision(); } + + void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; + + private: + void FillBlackSet(const OpGraph& op_graph, HashSet* black_set) const; + void FillWhiteSet(const OpGraph& op_graph, std::function IsAllowedToRunWithHalf, + const HashSet& black_set, HashSet* white_set) const; + void PropagateWhiteThroughClearNodes(const OpGraph& op_graph, + std::function IsAllowedToRunWithHalf, + const HashSet& black_set, + HashSet* white_set) const; + void InsertCastOp(const OpGraph& op_graph, const HashSet& white_set, + JobBuilder* job_builder) const; + + const AMPList& white_list_; + const AMPList& black_list_; + const AMPList& gray_list_; + const AMPList& clear_list_; +}; void AutoMixedPrecision::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { CHECK_GE(CUDA_VERSION, 10000); @@ -286,4 +316,8 @@ void AutoMixedPrecision::InsertCastOp(const OpGraph& op_graph, const HashSet* black_set) const; - void FillWhiteSet(const OpGraph& op_graph, std::function IsAllowedToRunWithHalf, - const HashSet& black_set, HashSet* white_set) const; - void PropagateWhiteThroughClearNodes(const OpGraph& op_graph, - std::function IsAllowedToRunWithHalf, - const HashSet& black_set, - HashSet* white_set) const; - void InsertCastOp(const OpGraph& op_graph, const HashSet& white_set, - JobBuilder* job_builder) const; - - const AMPList& white_list_; - const AMPList& black_list_; - const AMPList& gray_list_; - const AMPList& clear_list_; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_MIXED_PRECISION_H_ diff --git a/oneflow/core/job_completer/user_job_completer.cpp b/oneflow/core/job_completer/complete_ofrecord_decoder.cpp similarity index 95% rename from oneflow/core/job_completer/user_job_completer.cpp rename to oneflow/core/job_completer/complete_ofrecord_decoder.cpp index ec58a10e115a3e4d51677bf9422b5c07b03cb476..8b70ee7247b474f46ee9d58b2852b96995333e82 100644 --- a/oneflow/core/job_completer/user_job_completer.cpp +++ b/oneflow/core/job_completer/complete_ofrecord_decoder.cpp @@ -1,4 +1,4 @@ -#include "oneflow/core/job_completer/user_job_completer.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/job/parallel_desc.h" @@ -132,9 +132,15 @@ void AddRecordLoadOps(Job* job) { } // namespace -void UserJobCompleter::Complete(Job* job) const { - SplitDecodeOps(job); - AddRecordLoadOps(job); -} +class CompleteOfrecordDecoder final : public OpGraphPass { + public: + bool IsEnabled() const override { return true; } + void Apply(Job* job) const override { + SplitDecodeOps(job); + AddRecordLoadOps(job); + } +}; + +REGISTER_FUNCTION_PASS("CompleteOfrecordDecoder", CompleteOfrecordDecoder); } // namespace oneflow diff --git a/oneflow/core/job_completer/job_completer.cpp b/oneflow/core/job_completer/job_completer.cpp index b6ca098d4f9ac6055fc99b607ed4b837d39da10e..d70c4f749571e062cd1d221fd9246a82d12614fd 100644 --- a/oneflow/core/job_completer/job_completer.cpp +++ b/oneflow/core/job_completer/job_completer.cpp @@ -5,11 +5,8 @@ #include "oneflow/core/job_completer/optimizer.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job_completer/all_reduce_add_pass.h" -#include "oneflow/core/job_completer/set_default_variable_conf.h" #include "oneflow/core/job_completer/all_reduce_sequence_pass.h" #include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h" -#include "oneflow/core/job_completer/auto_mixed_precision.h" -#include "oneflow/core/job_completer/non_distributed_optimizer_pass.h" #include "oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h" #include "oneflow/core/job_completer/auto_train_step.h" #include "oneflow/core/job_completer/auto_learning_rate.h" @@ -190,11 +187,11 @@ void MakeNcclTupleBroadcastReduceSequence(const OpGraph& op_graph, JobBuilder* j void JobCompleter::Complete(Job* job) const { // complete variable ops - WithOpGraphAndMutJobBuilder(job, &SetDefaultVariableConf); - AutoMixedPrecision()(job); + FunctionPass("SetDefaultVariableConf")(job); + FunctionPass("AutoMixedPrecision")(job); if (GlobalJobDesc().IsTrain()) { - FindFunctionPass("TieUpChainHeadersUnReachableFromAnyVariableOps")(job); - NonDistributedOptimizerPass()(job); + FunctionPass("TieUpChainHeadersUnReachableFromAnyVariableOps")(job); + FunctionPass("NonDistributedOptimizerPass")(job); WithOpGraphAndMutJob(job, &AutoTrainStep); WithOpGraphAndMutJob(job, &AutoLearningRate); // complete ops for trainning diff --git a/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp b/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp index 90e252d37a892490941e9dc09dc1d90e3ea6392f..9d52f517843bc44843ebb46be12e1cb61925e35e 100644 --- a/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp +++ b/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp @@ -1,4 +1,5 @@ -#include "oneflow/core/job_completer/non_distributed_optimizer_pass.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/job_desc.h" @@ -29,7 +30,14 @@ ParallelConf NonDistributedParallelConf4ParallelId(const ParallelDesc& pd, return parallel_conf; } -} // namespace +class NonDistributedOptimizerPass final : public OpGraphPass { + public: + OF_DISALLOW_COPY_AND_MOVE(NonDistributedOptimizerPass); + NonDistributedOptimizerPass() = default; + ~NonDistributedOptimizerPass() = default; + bool IsEnabled() const override { return GlobalJobDesc().enable_non_distributed_optimizer(); } + void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; +}; void NonDistributedOptimizerPass::Apply(const OpGraph& op_graph, JobBuilder* builder) const { HashMap>> pd2last_node2node_seqs; @@ -176,4 +184,8 @@ void NonDistributedOptimizerPass::Apply(const OpGraph& op_graph, JobBuilder* bui } } +REGISTER_FUNCTION_PASS("NonDistributedOptimizerPass", NonDistributedOptimizerPass); + +} // namespace + } // namespace oneflow diff --git a/oneflow/core/job_completer/non_distributed_optimizer_pass.h b/oneflow/core/job_completer/non_distributed_optimizer_pass.h deleted file mode 100644 index 62b1e68bcf05cdde9d7928bdc7d8e5530d3ebaa8..0000000000000000000000000000000000000000 --- a/oneflow/core/job_completer/non_distributed_optimizer_pass.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_ - -#include "oneflow/core/common/util.h" -#include "oneflow/core/job_completer/op_graph_pass.h" - -namespace oneflow { - -class OpGraph; -class JobBuilder; - -class NonDistributedOptimizerPass final : public OpGraphPass { - public: - OF_DISALLOW_COPY_AND_MOVE(NonDistributedOptimizerPass); - NonDistributedOptimizerPass() = default; - ~NonDistributedOptimizerPass() = default; - bool IsEnabled() const override { return GlobalJobDesc().enable_non_distributed_optimizer(); } - void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_ diff --git a/oneflow/core/job_completer/op_graph_pass.cpp b/oneflow/core/job_completer/op_graph_pass.cpp index 6ad526ceb608790edf51c2a675c5e21bdeb8c881..f1c7ebea67ade031e86d30554acd47412b8a349f 100644 --- a/oneflow/core/job_completer/op_graph_pass.cpp +++ b/oneflow/core/job_completer/op_graph_pass.cpp @@ -15,7 +15,11 @@ void RegisterFunctionPass(const std::string& pass_name, const OpGraphPass* pass) CHECK(PassName2FunctionPass()->emplace(pass_name, pass).second); } -const OpGraphPass& FindFunctionPass(const std::string& pass_name) { +bool HasFunctionPass(const std::string& pass_name) { + return PassName2FunctionPass()->find(pass_name) != PassName2FunctionPass()->end(); +} + +const OpGraphPass& FunctionPass(const std::string& pass_name) { const auto& iter = PassName2FunctionPass()->find(pass_name); CHECK(iter != PassName2FunctionPass()->end()); return *iter->second; diff --git a/oneflow/core/job_completer/op_graph_pass.h b/oneflow/core/job_completer/op_graph_pass.h index d368be17526e69d8f1dc9180722f6fd256a73372..28948b7b65ffb38169c39809accef77fe360663c 100644 --- a/oneflow/core/job_completer/op_graph_pass.h +++ b/oneflow/core/job_completer/op_graph_pass.h @@ -11,10 +11,13 @@ class OpGraphPass { public: void operator()(Job* job) const { if (IsEnabled() == false) { return; } + Apply(job); + } + virtual bool IsEnabled() const { return true; } + virtual void Apply(Job* job) const { const OpGraph op_graph(*job); Apply(op_graph, job); } - virtual bool IsEnabled() const { return true; } virtual void Apply(const OpGraph& op_graph, Job* job) const { JobBuilder job_builder(job); Apply(op_graph, &job_builder); @@ -26,7 +29,8 @@ class OpGraphPass { COMMAND(RegisterFunctionPass(pass_name, new pass_type)) void RegisterFunctionPass(const std::string& pass_name, const OpGraphPass* pass); -const OpGraphPass& FindFunctionPass(const std::string& pass_name); +bool HasFunctionPass(const std::string& pass_name); +const OpGraphPass& FunctionPass(const std::string& pass_name); } // namespace oneflow diff --git a/oneflow/core/job_completer/set_default_variable_conf.cpp b/oneflow/core/job_completer/set_default_variable_conf.cpp index 4d265fd213b07c6b4faf57b0be66496ec923f093..55dbe3e25ac20cbe9ddb734d2c8fde370ff33163 100644 --- a/oneflow/core/job_completer/set_default_variable_conf.cpp +++ b/oneflow/core/job_completer/set_default_variable_conf.cpp @@ -1,49 +1,59 @@ -#include "oneflow/core/job_completer/set_default_variable_conf.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/job_set_compile_ctx.h" namespace oneflow { -void SetDefaultVariableConf(const OpGraph& op_graph, JobBuilder* job_builder) { - auto BlobDesc4ModelLbi = op_graph.MakeGetterBlobDesc4ModelLbi(); - op_graph.ForEachNode([&](OpNode* op_node) { - if (op_node->op().op_conf().has_variable_conf()) { - OperatorConf variable_op_conf(op_node->op().op_conf()); - VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf(); - if (!variable_conf->has_data_type()) { - variable_conf->set_data_type(job_builder->job().job_conf().default_data_type()); - } - if (!variable_conf->has_initializer() && !variable_conf->has_initialize_with_snapshot()) { - if (job_builder->job().job_conf().has_default_initializer_conf()) { - *variable_conf->mutable_initializer() = - job_builder->job().job_conf().default_initializer_conf(); - } else if (job_builder->job().job_conf().has_default_initialize_with_snapshot_path()) { - variable_conf->mutable_initialize_with_snapshot()->set_path( - job_builder->job().job_conf().default_initialize_with_snapshot_path()); - variable_conf->mutable_initialize_with_snapshot()->set_key( - GenLogicalBlobName(op_node->op().BnInOp2Lbi("out"))); +namespace { + +class SetDefaultVariableConf final : public OpGraphPass { + bool IsEnabled() const override { return true; } + + void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override { + auto BlobDesc4ModelLbi = op_graph.MakeGetterBlobDesc4ModelLbi(); + op_graph.ForEachNode([&](OpNode* op_node) { + if (op_node->op().op_conf().has_variable_conf()) { + OperatorConf variable_op_conf(op_node->op().op_conf()); + VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf(); + if (!variable_conf->has_data_type()) { + variable_conf->set_data_type(job_builder->job().job_conf().default_data_type()); + } + if (!variable_conf->has_initializer() && !variable_conf->has_initialize_with_snapshot()) { + if (job_builder->job().job_conf().has_default_initializer_conf()) { + *variable_conf->mutable_initializer() = + job_builder->job().job_conf().default_initializer_conf(); + } else if (job_builder->job().job_conf().has_default_initialize_with_snapshot_path()) { + variable_conf->mutable_initialize_with_snapshot()->set_path( + job_builder->job().job_conf().default_initialize_with_snapshot_path()); + variable_conf->mutable_initialize_with_snapshot()->set_key( + GenLogicalBlobName(op_node->op().BnInOp2Lbi("out"))); + } else { + UNIMPLEMENTED(); + } + } + int64_t random_seed; + auto* var_op_name2random = Global::Get()->GetVarOpName2randomSeed(); + const std::string& var_op_name = variable_op_conf.name(); + if (variable_conf->has_random_seed()) { + random_seed = variable_conf->random_seed(); } else { - UNIMPLEMENTED(); + random_seed = NewRandomSeed(); } + const auto& pair = var_op_name2random->insert({var_op_name, random_seed}); + if (variable_conf->has_random_seed()) { + CHECK_EQ(variable_conf->random_seed(), pair.first->second); + } else { + variable_conf->set_random_seed(pair.first->second); + } + job_builder->AddOrMutOpsOnlyOnce(op_node->parallel_desc().parallel_conf(), + {variable_op_conf}); } - int64_t random_seed; - auto* var_op_name2random = Global::Get()->GetVarOpName2randomSeed(); - const std::string& var_op_name = variable_op_conf.name(); - if (variable_conf->has_random_seed()) { - random_seed = variable_conf->random_seed(); - } else { - random_seed = NewRandomSeed(); - } - const auto& pair = var_op_name2random->insert({var_op_name, random_seed}); - if (variable_conf->has_random_seed()) { - CHECK_EQ(variable_conf->random_seed(), pair.first->second); - } else { - variable_conf->set_random_seed(pair.first->second); - } - job_builder->AddOrMutOpsOnlyOnce(op_node->parallel_desc().parallel_conf(), - {variable_op_conf}); - } - }); -} + }); + } +}; + +REGISTER_FUNCTION_PASS("SetDefaultVariableConf", SetDefaultVariableConf); + +} // namespace } // namespace oneflow diff --git a/oneflow/core/job_completer/set_default_variable_conf.h b/oneflow/core/job_completer/set_default_variable_conf.h deleted file mode 100644 index 287d1d44ba21c705184275e2cd6031577856e6dd..0000000000000000000000000000000000000000 --- a/oneflow/core/job_completer/set_default_variable_conf.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_ - -#include "oneflow/core/job/job_desc.h" -#include "oneflow/core/operator/operator.h" -#include "oneflow/core/graph/op_graph.h" - -namespace oneflow { - -void SetDefaultVariableConf(const OpGraph& op_graph, JobBuilder* job_builder); - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_ diff --git a/oneflow/core/job_completer/tie_up_chain_headers.cpp b/oneflow/core/job_completer/tie_up_chain_headers.cpp index 5cdb514d160fccc07b2707c14821870253620dc2..db6611487b9e7db71b695c0a75dbd600fb4a0bfc 100644 --- a/oneflow/core/job_completer/tie_up_chain_headers.cpp +++ b/oneflow/core/job_completer/tie_up_chain_headers.cpp @@ -95,6 +95,7 @@ REGISTER_FUNCTION_CONFIG_DEF().Bool("enable_pseudo_chain_merge", false, "ties up chain headers unreachable from any variable ops"); class TieUpChainHeadersUnReachableFromAnyVariableOps final : public OpGraphPass { bool IsEnabled() const override { return GlobalJobDesc().Bool("enable_pseudo_chain_merge"); } + void Apply(const OpGraph& op_graph, Job* job) const override { auto IsReachableFromAnyVariableOps = MakePredicatorIsReachableFromAnyVariableOps(op_graph); auto GetSourceNodesAndEdges = [&](const HashSet& chain_nodes, diff --git a/oneflow/core/job_completer/user_job_completer.h b/oneflow/core/job_completer/user_job_completer.h deleted file mode 100644 index 0de8e59c8d821a0a2be2f7af604faff2e6b1acc1..0000000000000000000000000000000000000000 --- a/oneflow/core/job_completer/user_job_completer.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_ - -#include "oneflow/core/common/util.h" -#include "oneflow/core/job/job_desc.h" - -namespace oneflow { - -class UserJobCompleter final { - public: - OF_DISALLOW_COPY_AND_MOVE(UserJobCompleter); - UserJobCompleter() = default; - ~UserJobCompleter() = default; - - void Complete(Job* job) const; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_