From ee85ce060078b3baaf27ef0da931855d49b79cc8 Mon Sep 17 00:00:00 2001 From: lixinqi Date: Tue, 7 Jan 2020 18:48:31 +0800 Subject: [PATCH] more cases of REGISTER_FUNCTION_PASS --- oneflow/core/job/oneflow.cpp | 4 +- .../job_completer/auto_mixed_precision.cpp | 38 +++++++- .../core/job_completer/auto_mixed_precision.h | 45 ---------- ...eter.cpp => complete_ofrecord_decoder.cpp} | 16 ++-- oneflow/core/job_completer/job_completer.cpp | 11 +-- .../non_distributed_optimizer_pass.cpp | 16 +++- .../non_distributed_optimizer_pass.h | 23 ----- oneflow/core/job_completer/op_graph_pass.cpp | 6 +- oneflow/core/job_completer/op_graph_pass.h | 8 +- .../set_default_variable_conf.cpp | 88 +++++++++++-------- .../job_completer/set_default_variable_conf.h | 14 --- .../job_completer/tie_up_chain_headers.cpp | 1 + .../core/job_completer/user_job_completer.h | 20 ----- 13 files changed, 128 insertions(+), 162 deletions(-) delete mode 100644 oneflow/core/job_completer/auto_mixed_precision.h rename oneflow/core/job_completer/{user_job_completer.cpp => complete_ofrecord_decoder.cpp} (95%) delete mode 100644 oneflow/core/job_completer/non_distributed_optimizer_pass.h delete mode 100644 oneflow/core/job_completer/set_default_variable_conf.h delete mode 100644 oneflow/core/job_completer/user_job_completer.h diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp index 48382c2cdf..c66982991f 100644 --- a/oneflow/core/job/oneflow.cpp +++ b/oneflow/core/job/oneflow.cpp @@ -6,7 +6,7 @@ #include "oneflow/core/job/improver.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_builder.h" -#include "oneflow/core/job_completer/user_job_completer.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job/job_set.pb.h" #include "oneflow/core/job/machine_context.h" #include "oneflow/core/job/profiler.h" @@ -743,7 +743,7 @@ void CompileAndMergePlanOnMaster(const PbRpf& conf_jobs, Plan* plan) { AddJobName2JobId(jobs.at(job_id).job_conf().job_name(), job_id); { auto scope = std::make_unique(jobs.at(job_id).job_conf(), job_id); - UserJobCompleter().Complete(&jobs.at(job_id)); + FunctionPass("CompleteOfrecordDecoder")(&jobs.at(job_id)); CompileCurJobOnMaster(&jobs.at(job_id), &sub_plans.at(job_id), true); } } diff --git a/oneflow/core/job_completer/auto_mixed_precision.cpp b/oneflow/core/job_completer/auto_mixed_precision.cpp index 7f5e89fded..e5acd1b42e 100644 --- a/oneflow/core/job_completer/auto_mixed_precision.cpp +++ b/oneflow/core/job_completer/auto_mixed_precision.cpp @@ -1,5 +1,6 @@ #include -#include "oneflow/core/job_completer/auto_mixed_precision.h" +#include "oneflow/core/job_completer/auto_mixed_precision_lists.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/device/cuda_util.h" @@ -175,7 +176,36 @@ void InsertCastOpImpl(bool f2h, const OpGraph& op_graph, const HashSet& job_builder->MutOpsOnlyOnce(dst_op_confs); } -} // namespace +class AutoMixedPrecision final : public OpGraphPass { + public: + OF_DISALLOW_COPY_AND_MOVE(AutoMixedPrecision); + AutoMixedPrecision() + : white_list_(AutoMixedPrecisionLists::WhiteList()), + black_list_(AutoMixedPrecisionLists::BlackList()), + gray_list_(AutoMixedPrecisionLists::GrayList()), + clear_list_(AutoMixedPrecisionLists::ClearList()) {} + ~AutoMixedPrecision() = default; + + bool IsEnabled() const override { return GlobalJobDesc().enable_auto_mixed_precision(); } + + void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; + + private: + void FillBlackSet(const OpGraph& op_graph, HashSet* black_set) const; + void FillWhiteSet(const OpGraph& op_graph, std::function IsAllowedToRunWithHalf, + const HashSet& black_set, HashSet* white_set) const; + void PropagateWhiteThroughClearNodes(const OpGraph& op_graph, + std::function IsAllowedToRunWithHalf, + const HashSet& black_set, + HashSet* white_set) const; + void InsertCastOp(const OpGraph& op_graph, const HashSet& white_set, + JobBuilder* job_builder) const; + + const AMPList& white_list_; + const AMPList& black_list_; + const AMPList& gray_list_; + const AMPList& clear_list_; +}; void AutoMixedPrecision::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { CHECK_GE(CUDA_VERSION, 10000); @@ -286,4 +316,8 @@ void AutoMixedPrecision::InsertCastOp(const OpGraph& op_graph, const HashSet* black_set) const; - void FillWhiteSet(const OpGraph& op_graph, std::function IsAllowedToRunWithHalf, - const HashSet& black_set, HashSet* white_set) const; - void PropagateWhiteThroughClearNodes(const OpGraph& op_graph, - std::function IsAllowedToRunWithHalf, - const HashSet& black_set, - HashSet* white_set) const; - void InsertCastOp(const OpGraph& op_graph, const HashSet& white_set, - JobBuilder* job_builder) const; - - const AMPList& white_list_; - const AMPList& black_list_; - const AMPList& gray_list_; - const AMPList& clear_list_; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_MIXED_PRECISION_H_ diff --git a/oneflow/core/job_completer/user_job_completer.cpp b/oneflow/core/job_completer/complete_ofrecord_decoder.cpp similarity index 95% rename from oneflow/core/job_completer/user_job_completer.cpp rename to oneflow/core/job_completer/complete_ofrecord_decoder.cpp index ec58a10e11..8b70ee7247 100644 --- a/oneflow/core/job_completer/user_job_completer.cpp +++ b/oneflow/core/job_completer/complete_ofrecord_decoder.cpp @@ -1,4 +1,4 @@ -#include "oneflow/core/job_completer/user_job_completer.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/job/parallel_desc.h" @@ -132,9 +132,15 @@ void AddRecordLoadOps(Job* job) { } // namespace -void UserJobCompleter::Complete(Job* job) const { - SplitDecodeOps(job); - AddRecordLoadOps(job); -} +class CompleteOfrecordDecoder final : public OpGraphPass { + public: + bool IsEnabled() const override { return true; } + void Apply(Job* job) const override { + SplitDecodeOps(job); + AddRecordLoadOps(job); + } +}; + +REGISTER_FUNCTION_PASS("CompleteOfrecordDecoder", CompleteOfrecordDecoder); } // namespace oneflow diff --git a/oneflow/core/job_completer/job_completer.cpp b/oneflow/core/job_completer/job_completer.cpp index b6ca098d4f..d70c4f7495 100644 --- a/oneflow/core/job_completer/job_completer.cpp +++ b/oneflow/core/job_completer/job_completer.cpp @@ -5,11 +5,8 @@ #include "oneflow/core/job_completer/optimizer.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job_completer/all_reduce_add_pass.h" -#include "oneflow/core/job_completer/set_default_variable_conf.h" #include "oneflow/core/job_completer/all_reduce_sequence_pass.h" #include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h" -#include "oneflow/core/job_completer/auto_mixed_precision.h" -#include "oneflow/core/job_completer/non_distributed_optimizer_pass.h" #include "oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h" #include "oneflow/core/job_completer/auto_train_step.h" #include "oneflow/core/job_completer/auto_learning_rate.h" @@ -190,11 +187,11 @@ void MakeNcclTupleBroadcastReduceSequence(const OpGraph& op_graph, JobBuilder* j void JobCompleter::Complete(Job* job) const { // complete variable ops - WithOpGraphAndMutJobBuilder(job, &SetDefaultVariableConf); - AutoMixedPrecision()(job); + FunctionPass("SetDefaultVariableConf")(job); + FunctionPass("AutoMixedPrecision")(job); if (GlobalJobDesc().IsTrain()) { - FindFunctionPass("TieUpChainHeadersUnReachableFromAnyVariableOps")(job); - NonDistributedOptimizerPass()(job); + FunctionPass("TieUpChainHeadersUnReachableFromAnyVariableOps")(job); + FunctionPass("NonDistributedOptimizerPass")(job); WithOpGraphAndMutJob(job, &AutoTrainStep); WithOpGraphAndMutJob(job, &AutoLearningRate); // complete ops for trainning diff --git a/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp b/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp index 90e252d37a..9d52f51784 100644 --- a/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp +++ b/oneflow/core/job_completer/non_distributed_optimizer_pass.cpp @@ -1,4 +1,5 @@ -#include "oneflow/core/job_completer/non_distributed_optimizer_pass.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/job_desc.h" @@ -29,7 +30,14 @@ ParallelConf NonDistributedParallelConf4ParallelId(const ParallelDesc& pd, return parallel_conf; } -} // namespace +class NonDistributedOptimizerPass final : public OpGraphPass { + public: + OF_DISALLOW_COPY_AND_MOVE(NonDistributedOptimizerPass); + NonDistributedOptimizerPass() = default; + ~NonDistributedOptimizerPass() = default; + bool IsEnabled() const override { return GlobalJobDesc().enable_non_distributed_optimizer(); } + void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; +}; void NonDistributedOptimizerPass::Apply(const OpGraph& op_graph, JobBuilder* builder) const { HashMap>> pd2last_node2node_seqs; @@ -176,4 +184,8 @@ void NonDistributedOptimizerPass::Apply(const OpGraph& op_graph, JobBuilder* bui } } +REGISTER_FUNCTION_PASS("NonDistributedOptimizerPass", NonDistributedOptimizerPass); + +} // namespace + } // namespace oneflow diff --git a/oneflow/core/job_completer/non_distributed_optimizer_pass.h b/oneflow/core/job_completer/non_distributed_optimizer_pass.h deleted file mode 100644 index 62b1e68bcf..0000000000 --- a/oneflow/core/job_completer/non_distributed_optimizer_pass.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_ - -#include "oneflow/core/common/util.h" -#include "oneflow/core/job_completer/op_graph_pass.h" - -namespace oneflow { - -class OpGraph; -class JobBuilder; - -class NonDistributedOptimizerPass final : public OpGraphPass { - public: - OF_DISALLOW_COPY_AND_MOVE(NonDistributedOptimizerPass); - NonDistributedOptimizerPass() = default; - ~NonDistributedOptimizerPass() = default; - bool IsEnabled() const override { return GlobalJobDesc().enable_non_distributed_optimizer(); } - void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_ diff --git a/oneflow/core/job_completer/op_graph_pass.cpp b/oneflow/core/job_completer/op_graph_pass.cpp index 6ad526ceb6..f1c7ebea67 100644 --- a/oneflow/core/job_completer/op_graph_pass.cpp +++ b/oneflow/core/job_completer/op_graph_pass.cpp @@ -15,7 +15,11 @@ void RegisterFunctionPass(const std::string& pass_name, const OpGraphPass* pass) CHECK(PassName2FunctionPass()->emplace(pass_name, pass).second); } -const OpGraphPass& FindFunctionPass(const std::string& pass_name) { +bool HasFunctionPass(const std::string& pass_name) { + return PassName2FunctionPass()->find(pass_name) != PassName2FunctionPass()->end(); +} + +const OpGraphPass& FunctionPass(const std::string& pass_name) { const auto& iter = PassName2FunctionPass()->find(pass_name); CHECK(iter != PassName2FunctionPass()->end()); return *iter->second; diff --git a/oneflow/core/job_completer/op_graph_pass.h b/oneflow/core/job_completer/op_graph_pass.h index d368be1752..28948b7b65 100644 --- a/oneflow/core/job_completer/op_graph_pass.h +++ b/oneflow/core/job_completer/op_graph_pass.h @@ -11,10 +11,13 @@ class OpGraphPass { public: void operator()(Job* job) const { if (IsEnabled() == false) { return; } + Apply(job); + } + virtual bool IsEnabled() const { return true; } + virtual void Apply(Job* job) const { const OpGraph op_graph(*job); Apply(op_graph, job); } - virtual bool IsEnabled() const { return true; } virtual void Apply(const OpGraph& op_graph, Job* job) const { JobBuilder job_builder(job); Apply(op_graph, &job_builder); @@ -26,7 +29,8 @@ class OpGraphPass { COMMAND(RegisterFunctionPass(pass_name, new pass_type)) void RegisterFunctionPass(const std::string& pass_name, const OpGraphPass* pass); -const OpGraphPass& FindFunctionPass(const std::string& pass_name); +bool HasFunctionPass(const std::string& pass_name); +const OpGraphPass& FunctionPass(const std::string& pass_name); } // namespace oneflow diff --git a/oneflow/core/job_completer/set_default_variable_conf.cpp b/oneflow/core/job_completer/set_default_variable_conf.cpp index 4d265fd213..55dbe3e25a 100644 --- a/oneflow/core/job_completer/set_default_variable_conf.cpp +++ b/oneflow/core/job_completer/set_default_variable_conf.cpp @@ -1,49 +1,59 @@ -#include "oneflow/core/job_completer/set_default_variable_conf.h" +#include "oneflow/core/job_completer/op_graph_pass.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/job_set_compile_ctx.h" namespace oneflow { -void SetDefaultVariableConf(const OpGraph& op_graph, JobBuilder* job_builder) { - auto BlobDesc4ModelLbi = op_graph.MakeGetterBlobDesc4ModelLbi(); - op_graph.ForEachNode([&](OpNode* op_node) { - if (op_node->op().op_conf().has_variable_conf()) { - OperatorConf variable_op_conf(op_node->op().op_conf()); - VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf(); - if (!variable_conf->has_data_type()) { - variable_conf->set_data_type(job_builder->job().job_conf().default_data_type()); - } - if (!variable_conf->has_initializer() && !variable_conf->has_initialize_with_snapshot()) { - if (job_builder->job().job_conf().has_default_initializer_conf()) { - *variable_conf->mutable_initializer() = - job_builder->job().job_conf().default_initializer_conf(); - } else if (job_builder->job().job_conf().has_default_initialize_with_snapshot_path()) { - variable_conf->mutable_initialize_with_snapshot()->set_path( - job_builder->job().job_conf().default_initialize_with_snapshot_path()); - variable_conf->mutable_initialize_with_snapshot()->set_key( - GenLogicalBlobName(op_node->op().BnInOp2Lbi("out"))); +namespace { + +class SetDefaultVariableConf final : public OpGraphPass { + bool IsEnabled() const override { return true; } + + void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override { + auto BlobDesc4ModelLbi = op_graph.MakeGetterBlobDesc4ModelLbi(); + op_graph.ForEachNode([&](OpNode* op_node) { + if (op_node->op().op_conf().has_variable_conf()) { + OperatorConf variable_op_conf(op_node->op().op_conf()); + VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf(); + if (!variable_conf->has_data_type()) { + variable_conf->set_data_type(job_builder->job().job_conf().default_data_type()); + } + if (!variable_conf->has_initializer() && !variable_conf->has_initialize_with_snapshot()) { + if (job_builder->job().job_conf().has_default_initializer_conf()) { + *variable_conf->mutable_initializer() = + job_builder->job().job_conf().default_initializer_conf(); + } else if (job_builder->job().job_conf().has_default_initialize_with_snapshot_path()) { + variable_conf->mutable_initialize_with_snapshot()->set_path( + job_builder->job().job_conf().default_initialize_with_snapshot_path()); + variable_conf->mutable_initialize_with_snapshot()->set_key( + GenLogicalBlobName(op_node->op().BnInOp2Lbi("out"))); + } else { + UNIMPLEMENTED(); + } + } + int64_t random_seed; + auto* var_op_name2random = Global::Get()->GetVarOpName2randomSeed(); + const std::string& var_op_name = variable_op_conf.name(); + if (variable_conf->has_random_seed()) { + random_seed = variable_conf->random_seed(); } else { - UNIMPLEMENTED(); + random_seed = NewRandomSeed(); } + const auto& pair = var_op_name2random->insert({var_op_name, random_seed}); + if (variable_conf->has_random_seed()) { + CHECK_EQ(variable_conf->random_seed(), pair.first->second); + } else { + variable_conf->set_random_seed(pair.first->second); + } + job_builder->AddOrMutOpsOnlyOnce(op_node->parallel_desc().parallel_conf(), + {variable_op_conf}); } - int64_t random_seed; - auto* var_op_name2random = Global::Get()->GetVarOpName2randomSeed(); - const std::string& var_op_name = variable_op_conf.name(); - if (variable_conf->has_random_seed()) { - random_seed = variable_conf->random_seed(); - } else { - random_seed = NewRandomSeed(); - } - const auto& pair = var_op_name2random->insert({var_op_name, random_seed}); - if (variable_conf->has_random_seed()) { - CHECK_EQ(variable_conf->random_seed(), pair.first->second); - } else { - variable_conf->set_random_seed(pair.first->second); - } - job_builder->AddOrMutOpsOnlyOnce(op_node->parallel_desc().parallel_conf(), - {variable_op_conf}); - } - }); -} + }); + } +}; + +REGISTER_FUNCTION_PASS("SetDefaultVariableConf", SetDefaultVariableConf); + +} // namespace } // namespace oneflow diff --git a/oneflow/core/job_completer/set_default_variable_conf.h b/oneflow/core/job_completer/set_default_variable_conf.h deleted file mode 100644 index 287d1d44ba..0000000000 --- a/oneflow/core/job_completer/set_default_variable_conf.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_ - -#include "oneflow/core/job/job_desc.h" -#include "oneflow/core/operator/operator.h" -#include "oneflow/core/graph/op_graph.h" - -namespace oneflow { - -void SetDefaultVariableConf(const OpGraph& op_graph, JobBuilder* job_builder); - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_ diff --git a/oneflow/core/job_completer/tie_up_chain_headers.cpp b/oneflow/core/job_completer/tie_up_chain_headers.cpp index 5cdb514d16..db6611487b 100644 --- a/oneflow/core/job_completer/tie_up_chain_headers.cpp +++ b/oneflow/core/job_completer/tie_up_chain_headers.cpp @@ -95,6 +95,7 @@ REGISTER_FUNCTION_CONFIG_DEF().Bool("enable_pseudo_chain_merge", false, "ties up chain headers unreachable from any variable ops"); class TieUpChainHeadersUnReachableFromAnyVariableOps final : public OpGraphPass { bool IsEnabled() const override { return GlobalJobDesc().Bool("enable_pseudo_chain_merge"); } + void Apply(const OpGraph& op_graph, Job* job) const override { auto IsReachableFromAnyVariableOps = MakePredicatorIsReachableFromAnyVariableOps(op_graph); auto GetSourceNodesAndEdges = [&](const HashSet& chain_nodes, diff --git a/oneflow/core/job_completer/user_job_completer.h b/oneflow/core/job_completer/user_job_completer.h deleted file mode 100644 index 0de8e59c8d..0000000000 --- a/oneflow/core/job_completer/user_job_completer.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_ -#define ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_ - -#include "oneflow/core/common/util.h" -#include "oneflow/core/job/job_desc.h" - -namespace oneflow { - -class UserJobCompleter final { - public: - OF_DISALLOW_COPY_AND_MOVE(UserJobCompleter); - UserJobCompleter() = default; - ~UserJobCompleter() = default; - - void Complete(Job* job) const; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_ -- GitLab