提交 ee85ce06 编写于 作者: L lixinqi

more cases of REGISTER_FUNCTION_PASS

上级 c117d4ed
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "oneflow/core/job/improver.h" #include "oneflow/core/job/improver.h"
#include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/job_builder.h"
#include "oneflow/core/job_completer/user_job_completer.h" #include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/job_set.pb.h" #include "oneflow/core/job/job_set.pb.h"
#include "oneflow/core/job/machine_context.h" #include "oneflow/core/job/machine_context.h"
#include "oneflow/core/job/profiler.h" #include "oneflow/core/job/profiler.h"
...@@ -743,7 +743,7 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) { ...@@ -743,7 +743,7 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
AddJobName2JobId(jobs.at(job_id).job_conf().job_name(), job_id); AddJobName2JobId(jobs.at(job_id).job_conf().job_name(), job_id);
{ {
auto scope = std::make_unique<GlobalJobDescScope>(jobs.at(job_id).job_conf(), job_id); auto scope = std::make_unique<GlobalJobDescScope>(jobs.at(job_id).job_conf(), job_id);
UserJobCompleter().Complete(&jobs.at(job_id)); FunctionPass("CompleteOfrecordDecoder")(&jobs.at(job_id));
CompileCurJobOnMaster(&jobs.at(job_id), &sub_plans.at(job_id), true); CompileCurJobOnMaster(&jobs.at(job_id), &sub_plans.at(job_id), true);
} }
} }
......
#include <algorithm> #include <algorithm>
#include "oneflow/core/job_completer/auto_mixed_precision.h" #include "oneflow/core/job_completer/auto_mixed_precision_lists.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_desc.h"
#include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cuda_util.h"
...@@ -175,7 +176,36 @@ void InsertCastOpImpl(bool f2h, const OpGraph& op_graph, const HashSet<OpNode*>& ...@@ -175,7 +176,36 @@ void InsertCastOpImpl(bool f2h, const OpGraph& op_graph, const HashSet<OpNode*>&
job_builder->MutOpsOnlyOnce(dst_op_confs); job_builder->MutOpsOnlyOnce(dst_op_confs);
} }
} // namespace class AutoMixedPrecision final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(AutoMixedPrecision);
AutoMixedPrecision()
: white_list_(AutoMixedPrecisionLists::WhiteList()),
black_list_(AutoMixedPrecisionLists::BlackList()),
gray_list_(AutoMixedPrecisionLists::GrayList()),
clear_list_(AutoMixedPrecisionLists::ClearList()) {}
~AutoMixedPrecision() = default;
bool IsEnabled() const override { return GlobalJobDesc().enable_auto_mixed_precision(); }
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
private:
void FillBlackSet(const OpGraph& op_graph, HashSet<OpNode*>* black_set) const;
void FillWhiteSet(const OpGraph& op_graph, std::function<bool(OpNode*)> IsAllowedToRunWithHalf,
const HashSet<OpNode*>& black_set, HashSet<OpNode*>* white_set) const;
void PropagateWhiteThroughClearNodes(const OpGraph& op_graph,
std::function<bool(OpNode*)> IsAllowedToRunWithHalf,
const HashSet<OpNode*>& black_set,
HashSet<OpNode*>* white_set) const;
void InsertCastOp(const OpGraph& op_graph, const HashSet<OpNode*>& white_set,
JobBuilder* job_builder) const;
const AMPList& white_list_;
const AMPList& black_list_;
const AMPList& gray_list_;
const AMPList& clear_list_;
};
void AutoMixedPrecision::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { void AutoMixedPrecision::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
CHECK_GE(CUDA_VERSION, 10000); CHECK_GE(CUDA_VERSION, 10000);
...@@ -286,4 +316,8 @@ void AutoMixedPrecision::InsertCastOp(const OpGraph& op_graph, const HashSet<OpN ...@@ -286,4 +316,8 @@ void AutoMixedPrecision::InsertCastOp(const OpGraph& op_graph, const HashSet<OpN
InsertCastOpImpl(false, op_graph, white_set, job_builder); InsertCastOpImpl(false, op_graph, white_set, job_builder);
} }
REGISTER_FUNCTION_PASS("AutoMixedPrecision", AutoMixedPrecision);
} // namespace
} // namespace oneflow } // namespace oneflow
#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_MIXED_PRECISION_H_
#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_MIXED_PRECISION_H_
#include "oneflow/core/job_completer/auto_mixed_precision_lists.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace oneflow {
class OpGraph;
class OpNode;
class Job;
class AutoMixedPrecision final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(AutoMixedPrecision);
AutoMixedPrecision()
: white_list_(AutoMixedPrecisionLists::WhiteList()),
black_list_(AutoMixedPrecisionLists::BlackList()),
gray_list_(AutoMixedPrecisionLists::GrayList()),
clear_list_(AutoMixedPrecisionLists::ClearList()) {}
~AutoMixedPrecision() = default;
bool IsEnabled() const override { return GlobalJobDesc().enable_auto_mixed_precision(); }
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
private:
void FillBlackSet(const OpGraph& op_graph, HashSet<OpNode*>* black_set) const;
void FillWhiteSet(const OpGraph& op_graph, std::function<bool(OpNode*)> IsAllowedToRunWithHalf,
const HashSet<OpNode*>& black_set, HashSet<OpNode*>* white_set) const;
void PropagateWhiteThroughClearNodes(const OpGraph& op_graph,
std::function<bool(OpNode*)> IsAllowedToRunWithHalf,
const HashSet<OpNode*>& black_set,
HashSet<OpNode*>* white_set) const;
void InsertCastOp(const OpGraph& op_graph, const HashSet<OpNode*>& white_set,
JobBuilder* job_builder) const;
const AMPList& white_list_;
const AMPList& black_list_;
const AMPList& gray_list_;
const AMPList& clear_list_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_MIXED_PRECISION_H_
#include "oneflow/core/job_completer/user_job_completer.h" #include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/job_builder.h"
#include "oneflow/core/operator/operator.h" #include "oneflow/core/operator/operator.h"
#include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/parallel_desc.h"
...@@ -132,9 +132,15 @@ void AddRecordLoadOps(Job* job) { ...@@ -132,9 +132,15 @@ void AddRecordLoadOps(Job* job) {
} // namespace } // namespace
void UserJobCompleter::Complete(Job* job) const { class CompleteOfrecordDecoder final : public OpGraphPass {
SplitDecodeOps(job); public:
AddRecordLoadOps(job); bool IsEnabled() const override { return true; }
} void Apply(Job* job) const override {
SplitDecodeOps(job);
AddRecordLoadOps(job);
}
};
REGISTER_FUNCTION_PASS("CompleteOfrecordDecoder", CompleteOfrecordDecoder);
} // namespace oneflow } // namespace oneflow
...@@ -5,11 +5,8 @@ ...@@ -5,11 +5,8 @@
#include "oneflow/core/job_completer/optimizer.h" #include "oneflow/core/job_completer/optimizer.h"
#include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job_completer/all_reduce_add_pass.h" #include "oneflow/core/job_completer/all_reduce_add_pass.h"
#include "oneflow/core/job_completer/set_default_variable_conf.h"
#include "oneflow/core/job_completer/all_reduce_sequence_pass.h" #include "oneflow/core/job_completer/all_reduce_sequence_pass.h"
#include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h" #include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h"
#include "oneflow/core/job_completer/auto_mixed_precision.h"
#include "oneflow/core/job_completer/non_distributed_optimizer_pass.h"
#include "oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h" #include "oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h"
#include "oneflow/core/job_completer/auto_train_step.h" #include "oneflow/core/job_completer/auto_train_step.h"
#include "oneflow/core/job_completer/auto_learning_rate.h" #include "oneflow/core/job_completer/auto_learning_rate.h"
...@@ -190,11 +187,11 @@ void MakeNcclTupleBroadcastReduceSequence(const OpGraph& op_graph, JobBuilder* j ...@@ -190,11 +187,11 @@ void MakeNcclTupleBroadcastReduceSequence(const OpGraph& op_graph, JobBuilder* j
void JobCompleter::Complete(Job* job) const { void JobCompleter::Complete(Job* job) const {
// complete variable ops // complete variable ops
WithOpGraphAndMutJobBuilder(job, &SetDefaultVariableConf); FunctionPass("SetDefaultVariableConf")(job);
AutoMixedPrecision()(job); FunctionPass("AutoMixedPrecision")(job);
if (GlobalJobDesc().IsTrain()) { if (GlobalJobDesc().IsTrain()) {
FindFunctionPass("TieUpChainHeadersUnReachableFromAnyVariableOps")(job); FunctionPass("TieUpChainHeadersUnReachableFromAnyVariableOps")(job);
NonDistributedOptimizerPass()(job); FunctionPass("NonDistributedOptimizerPass")(job);
WithOpGraphAndMutJob(job, &AutoTrainStep); WithOpGraphAndMutJob(job, &AutoTrainStep);
WithOpGraphAndMutJob(job, &AutoLearningRate); WithOpGraphAndMutJob(job, &AutoLearningRate);
// complete ops for trainning // complete ops for trainning
......
#include "oneflow/core/job_completer/non_distributed_optimizer_pass.h" #include "oneflow/core/common/util.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/graph/op_graph.h" #include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_desc.h"
...@@ -29,7 +30,14 @@ ParallelConf NonDistributedParallelConf4ParallelId(const ParallelDesc& pd, ...@@ -29,7 +30,14 @@ ParallelConf NonDistributedParallelConf4ParallelId(const ParallelDesc& pd,
return parallel_conf; return parallel_conf;
} }
} // namespace class NonDistributedOptimizerPass final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(NonDistributedOptimizerPass);
NonDistributedOptimizerPass() = default;
~NonDistributedOptimizerPass() = default;
bool IsEnabled() const override { return GlobalJobDesc().enable_non_distributed_optimizer(); }
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
};
void NonDistributedOptimizerPass::Apply(const OpGraph& op_graph, JobBuilder* builder) const { void NonDistributedOptimizerPass::Apply(const OpGraph& op_graph, JobBuilder* builder) const {
HashMap<ParallelDesc, HashMap<const OpNode*, std::vector<const OpNode*>>> pd2last_node2node_seqs; HashMap<ParallelDesc, HashMap<const OpNode*, std::vector<const OpNode*>>> pd2last_node2node_seqs;
...@@ -176,4 +184,8 @@ void NonDistributedOptimizerPass::Apply(const OpGraph& op_graph, JobBuilder* bui ...@@ -176,4 +184,8 @@ void NonDistributedOptimizerPass::Apply(const OpGraph& op_graph, JobBuilder* bui
} }
} }
REGISTER_FUNCTION_PASS("NonDistributedOptimizerPass", NonDistributedOptimizerPass);
} // namespace
} // namespace oneflow } // namespace oneflow
#ifndef ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_
#define ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace oneflow {
class OpGraph;
class JobBuilder;
class NonDistributedOptimizerPass final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(NonDistributedOptimizerPass);
NonDistributedOptimizerPass() = default;
~NonDistributedOptimizerPass() = default;
bool IsEnabled() const override { return GlobalJobDesc().enable_non_distributed_optimizer(); }
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_
...@@ -15,7 +15,11 @@ void RegisterFunctionPass(const std::string& pass_name, const OpGraphPass* pass) ...@@ -15,7 +15,11 @@ void RegisterFunctionPass(const std::string& pass_name, const OpGraphPass* pass)
CHECK(PassName2FunctionPass()->emplace(pass_name, pass).second); CHECK(PassName2FunctionPass()->emplace(pass_name, pass).second);
} }
const OpGraphPass& FindFunctionPass(const std::string& pass_name) { bool HasFunctionPass(const std::string& pass_name) {
return PassName2FunctionPass()->find(pass_name) != PassName2FunctionPass()->end();
}
const OpGraphPass& FunctionPass(const std::string& pass_name) {
const auto& iter = PassName2FunctionPass()->find(pass_name); const auto& iter = PassName2FunctionPass()->find(pass_name);
CHECK(iter != PassName2FunctionPass()->end()); CHECK(iter != PassName2FunctionPass()->end());
return *iter->second; return *iter->second;
......
...@@ -11,10 +11,13 @@ class OpGraphPass { ...@@ -11,10 +11,13 @@ class OpGraphPass {
public: public:
void operator()(Job* job) const { void operator()(Job* job) const {
if (IsEnabled() == false) { return; } if (IsEnabled() == false) { return; }
Apply(job);
}
virtual bool IsEnabled() const { return true; }
virtual void Apply(Job* job) const {
const OpGraph op_graph(*job); const OpGraph op_graph(*job);
Apply(op_graph, job); Apply(op_graph, job);
} }
virtual bool IsEnabled() const { return true; }
virtual void Apply(const OpGraph& op_graph, Job* job) const { virtual void Apply(const OpGraph& op_graph, Job* job) const {
JobBuilder job_builder(job); JobBuilder job_builder(job);
Apply(op_graph, &job_builder); Apply(op_graph, &job_builder);
...@@ -26,7 +29,8 @@ class OpGraphPass { ...@@ -26,7 +29,8 @@ class OpGraphPass {
COMMAND(RegisterFunctionPass(pass_name, new pass_type)) COMMAND(RegisterFunctionPass(pass_name, new pass_type))
void RegisterFunctionPass(const std::string& pass_name, const OpGraphPass* pass); void RegisterFunctionPass(const std::string& pass_name, const OpGraphPass* pass);
const OpGraphPass& FindFunctionPass(const std::string& pass_name); bool HasFunctionPass(const std::string& pass_name);
const OpGraphPass& FunctionPass(const std::string& pass_name);
} // namespace oneflow } // namespace oneflow
......
#include "oneflow/core/job_completer/set_default_variable_conf.h" #include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/job_builder.h"
#include "oneflow/core/job/job_set_compile_ctx.h" #include "oneflow/core/job/job_set_compile_ctx.h"
namespace oneflow { namespace oneflow {
void SetDefaultVariableConf(const OpGraph& op_graph, JobBuilder* job_builder) { namespace {
auto BlobDesc4ModelLbi = op_graph.MakeGetterBlobDesc4ModelLbi();
op_graph.ForEachNode([&](OpNode* op_node) { class SetDefaultVariableConf final : public OpGraphPass {
if (op_node->op().op_conf().has_variable_conf()) { bool IsEnabled() const override { return true; }
OperatorConf variable_op_conf(op_node->op().op_conf());
VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf(); void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override {
if (!variable_conf->has_data_type()) { auto BlobDesc4ModelLbi = op_graph.MakeGetterBlobDesc4ModelLbi();
variable_conf->set_data_type(job_builder->job().job_conf().default_data_type()); op_graph.ForEachNode([&](OpNode* op_node) {
} if (op_node->op().op_conf().has_variable_conf()) {
if (!variable_conf->has_initializer() && !variable_conf->has_initialize_with_snapshot()) { OperatorConf variable_op_conf(op_node->op().op_conf());
if (job_builder->job().job_conf().has_default_initializer_conf()) { VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf();
*variable_conf->mutable_initializer() = if (!variable_conf->has_data_type()) {
job_builder->job().job_conf().default_initializer_conf(); variable_conf->set_data_type(job_builder->job().job_conf().default_data_type());
} else if (job_builder->job().job_conf().has_default_initialize_with_snapshot_path()) { }
variable_conf->mutable_initialize_with_snapshot()->set_path( if (!variable_conf->has_initializer() && !variable_conf->has_initialize_with_snapshot()) {
job_builder->job().job_conf().default_initialize_with_snapshot_path()); if (job_builder->job().job_conf().has_default_initializer_conf()) {
variable_conf->mutable_initialize_with_snapshot()->set_key( *variable_conf->mutable_initializer() =
GenLogicalBlobName(op_node->op().BnInOp2Lbi("out"))); job_builder->job().job_conf().default_initializer_conf();
} else if (job_builder->job().job_conf().has_default_initialize_with_snapshot_path()) {
variable_conf->mutable_initialize_with_snapshot()->set_path(
job_builder->job().job_conf().default_initialize_with_snapshot_path());
variable_conf->mutable_initialize_with_snapshot()->set_key(
GenLogicalBlobName(op_node->op().BnInOp2Lbi("out")));
} else {
UNIMPLEMENTED();
}
}
int64_t random_seed;
auto* var_op_name2random = Global<JobSetCompileCtx>::Get()->GetVarOpName2randomSeed();
const std::string& var_op_name = variable_op_conf.name();
if (variable_conf->has_random_seed()) {
random_seed = variable_conf->random_seed();
} else { } else {
UNIMPLEMENTED(); random_seed = NewRandomSeed();
} }
const auto& pair = var_op_name2random->insert({var_op_name, random_seed});
if (variable_conf->has_random_seed()) {
CHECK_EQ(variable_conf->random_seed(), pair.first->second);
} else {
variable_conf->set_random_seed(pair.first->second);
}
job_builder->AddOrMutOpsOnlyOnce(op_node->parallel_desc().parallel_conf(),
{variable_op_conf});
} }
int64_t random_seed; });
auto* var_op_name2random = Global<JobSetCompileCtx>::Get()->GetVarOpName2randomSeed(); }
const std::string& var_op_name = variable_op_conf.name(); };
if (variable_conf->has_random_seed()) {
random_seed = variable_conf->random_seed(); REGISTER_FUNCTION_PASS("SetDefaultVariableConf", SetDefaultVariableConf);
} else {
random_seed = NewRandomSeed(); } // namespace
}
const auto& pair = var_op_name2random->insert({var_op_name, random_seed});
if (variable_conf->has_random_seed()) {
CHECK_EQ(variable_conf->random_seed(), pair.first->second);
} else {
variable_conf->set_random_seed(pair.first->second);
}
job_builder->AddOrMutOpsOnlyOnce(op_node->parallel_desc().parallel_conf(),
{variable_op_conf});
}
});
}
} // namespace oneflow } // namespace oneflow
#ifndef ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_
#define ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/graph/op_graph.h"
namespace oneflow {
void SetDefaultVariableConf(const OpGraph& op_graph, JobBuilder* job_builder);
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_
...@@ -95,6 +95,7 @@ REGISTER_FUNCTION_CONFIG_DEF().Bool("enable_pseudo_chain_merge", false, ...@@ -95,6 +95,7 @@ REGISTER_FUNCTION_CONFIG_DEF().Bool("enable_pseudo_chain_merge", false,
"ties up chain headers unreachable from any variable ops"); "ties up chain headers unreachable from any variable ops");
class TieUpChainHeadersUnReachableFromAnyVariableOps final : public OpGraphPass { class TieUpChainHeadersUnReachableFromAnyVariableOps final : public OpGraphPass {
bool IsEnabled() const override { return GlobalJobDesc().Bool("enable_pseudo_chain_merge"); } bool IsEnabled() const override { return GlobalJobDesc().Bool("enable_pseudo_chain_merge"); }
void Apply(const OpGraph& op_graph, Job* job) const override { void Apply(const OpGraph& op_graph, Job* job) const override {
auto IsReachableFromAnyVariableOps = MakePredicatorIsReachableFromAnyVariableOps(op_graph); auto IsReachableFromAnyVariableOps = MakePredicatorIsReachableFromAnyVariableOps(op_graph);
auto GetSourceNodesAndEdges = [&](const HashSet<OpNode*>& chain_nodes, auto GetSourceNodesAndEdges = [&](const HashSet<OpNode*>& chain_nodes,
......
#ifndef ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_
#define ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/job_desc.h"
namespace oneflow {
class UserJobCompleter final {
public:
OF_DISALLOW_COPY_AND_MOVE(UserJobCompleter);
UserJobCompleter() = default;
~UserJobCompleter() = default;
void Complete(Job* job) const;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册