提交 ee85ce06 编写于 作者: L lixinqi

more cases of REGISTER_FUNCTION_PASS

上级 c117d4ed
......@@ -6,7 +6,7 @@
#include "oneflow/core/job/improver.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/job_builder.h"
#include "oneflow/core/job_completer/user_job_completer.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/job_set.pb.h"
#include "oneflow/core/job/machine_context.h"
#include "oneflow/core/job/profiler.h"
......@@ -743,7 +743,7 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
AddJobName2JobId(jobs.at(job_id).job_conf().job_name(), job_id);
{
auto scope = std::make_unique<GlobalJobDescScope>(jobs.at(job_id).job_conf(), job_id);
UserJobCompleter().Complete(&jobs.at(job_id));
FunctionPass("CompleteOfrecordDecoder")(&jobs.at(job_id));
CompileCurJobOnMaster(&jobs.at(job_id), &sub_plans.at(job_id), true);
}
}
......
#include <algorithm>
#include "oneflow/core/job_completer/auto_mixed_precision.h"
#include "oneflow/core/job_completer/auto_mixed_precision_lists.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/device/cuda_util.h"
......@@ -175,7 +176,36 @@ void InsertCastOpImpl(bool f2h, const OpGraph& op_graph, const HashSet<OpNode*>&
job_builder->MutOpsOnlyOnce(dst_op_confs);
}
} // namespace
class AutoMixedPrecision final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(AutoMixedPrecision);
AutoMixedPrecision()
: white_list_(AutoMixedPrecisionLists::WhiteList()),
black_list_(AutoMixedPrecisionLists::BlackList()),
gray_list_(AutoMixedPrecisionLists::GrayList()),
clear_list_(AutoMixedPrecisionLists::ClearList()) {}
~AutoMixedPrecision() = default;
bool IsEnabled() const override { return GlobalJobDesc().enable_auto_mixed_precision(); }
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
private:
void FillBlackSet(const OpGraph& op_graph, HashSet<OpNode*>* black_set) const;
void FillWhiteSet(const OpGraph& op_graph, std::function<bool(OpNode*)> IsAllowedToRunWithHalf,
const HashSet<OpNode*>& black_set, HashSet<OpNode*>* white_set) const;
void PropagateWhiteThroughClearNodes(const OpGraph& op_graph,
std::function<bool(OpNode*)> IsAllowedToRunWithHalf,
const HashSet<OpNode*>& black_set,
HashSet<OpNode*>* white_set) const;
void InsertCastOp(const OpGraph& op_graph, const HashSet<OpNode*>& white_set,
JobBuilder* job_builder) const;
const AMPList& white_list_;
const AMPList& black_list_;
const AMPList& gray_list_;
const AMPList& clear_list_;
};
void AutoMixedPrecision::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
CHECK_GE(CUDA_VERSION, 10000);
......@@ -286,4 +316,8 @@ void AutoMixedPrecision::InsertCastOp(const OpGraph& op_graph, const HashSet<OpN
InsertCastOpImpl(false, op_graph, white_set, job_builder);
}
REGISTER_FUNCTION_PASS("AutoMixedPrecision", AutoMixedPrecision);
} // namespace
} // namespace oneflow
#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_MIXED_PRECISION_H_
#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_MIXED_PRECISION_H_
#include "oneflow/core/job_completer/auto_mixed_precision_lists.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace oneflow {
class OpGraph;
class OpNode;
class Job;
class AutoMixedPrecision final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(AutoMixedPrecision);
AutoMixedPrecision()
: white_list_(AutoMixedPrecisionLists::WhiteList()),
black_list_(AutoMixedPrecisionLists::BlackList()),
gray_list_(AutoMixedPrecisionLists::GrayList()),
clear_list_(AutoMixedPrecisionLists::ClearList()) {}
~AutoMixedPrecision() = default;
bool IsEnabled() const override { return GlobalJobDesc().enable_auto_mixed_precision(); }
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
private:
void FillBlackSet(const OpGraph& op_graph, HashSet<OpNode*>* black_set) const;
void FillWhiteSet(const OpGraph& op_graph, std::function<bool(OpNode*)> IsAllowedToRunWithHalf,
const HashSet<OpNode*>& black_set, HashSet<OpNode*>* white_set) const;
void PropagateWhiteThroughClearNodes(const OpGraph& op_graph,
std::function<bool(OpNode*)> IsAllowedToRunWithHalf,
const HashSet<OpNode*>& black_set,
HashSet<OpNode*>* white_set) const;
void InsertCastOp(const OpGraph& op_graph, const HashSet<OpNode*>& white_set,
JobBuilder* job_builder) const;
const AMPList& white_list_;
const AMPList& black_list_;
const AMPList& gray_list_;
const AMPList& clear_list_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_MIXED_PRECISION_H_
#include "oneflow/core/job_completer/user_job_completer.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/job_builder.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/job/parallel_desc.h"
......@@ -132,9 +132,15 @@ void AddRecordLoadOps(Job* job) {
} // namespace
void UserJobCompleter::Complete(Job* job) const {
SplitDecodeOps(job);
AddRecordLoadOps(job);
}
class CompleteOfrecordDecoder final : public OpGraphPass {
public:
bool IsEnabled() const override { return true; }
void Apply(Job* job) const override {
SplitDecodeOps(job);
AddRecordLoadOps(job);
}
};
REGISTER_FUNCTION_PASS("CompleteOfrecordDecoder", CompleteOfrecordDecoder);
} // namespace oneflow
......@@ -5,11 +5,8 @@
#include "oneflow/core/job_completer/optimizer.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job_completer/all_reduce_add_pass.h"
#include "oneflow/core/job_completer/set_default_variable_conf.h"
#include "oneflow/core/job_completer/all_reduce_sequence_pass.h"
#include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h"
#include "oneflow/core/job_completer/auto_mixed_precision.h"
#include "oneflow/core/job_completer/non_distributed_optimizer_pass.h"
#include "oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h"
#include "oneflow/core/job_completer/auto_train_step.h"
#include "oneflow/core/job_completer/auto_learning_rate.h"
......@@ -190,11 +187,11 @@ void MakeNcclTupleBroadcastReduceSequence(const OpGraph& op_graph, JobBuilder* j
void JobCompleter::Complete(Job* job) const {
// complete variable ops
WithOpGraphAndMutJobBuilder(job, &SetDefaultVariableConf);
AutoMixedPrecision()(job);
FunctionPass("SetDefaultVariableConf")(job);
FunctionPass("AutoMixedPrecision")(job);
if (GlobalJobDesc().IsTrain()) {
FindFunctionPass("TieUpChainHeadersUnReachableFromAnyVariableOps")(job);
NonDistributedOptimizerPass()(job);
FunctionPass("TieUpChainHeadersUnReachableFromAnyVariableOps")(job);
FunctionPass("NonDistributedOptimizerPass")(job);
WithOpGraphAndMutJob(job, &AutoTrainStep);
WithOpGraphAndMutJob(job, &AutoLearningRate);
// complete ops for trainning
......
#include "oneflow/core/job_completer/non_distributed_optimizer_pass.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/job_desc.h"
......@@ -29,7 +30,14 @@ ParallelConf NonDistributedParallelConf4ParallelId(const ParallelDesc& pd,
return parallel_conf;
}
} // namespace
class NonDistributedOptimizerPass final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(NonDistributedOptimizerPass);
NonDistributedOptimizerPass() = default;
~NonDistributedOptimizerPass() = default;
bool IsEnabled() const override { return GlobalJobDesc().enable_non_distributed_optimizer(); }
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
};
void NonDistributedOptimizerPass::Apply(const OpGraph& op_graph, JobBuilder* builder) const {
HashMap<ParallelDesc, HashMap<const OpNode*, std::vector<const OpNode*>>> pd2last_node2node_seqs;
......@@ -176,4 +184,8 @@ void NonDistributedOptimizerPass::Apply(const OpGraph& op_graph, JobBuilder* bui
}
}
REGISTER_FUNCTION_PASS("NonDistributedOptimizerPass", NonDistributedOptimizerPass);
} // namespace
} // namespace oneflow
#ifndef ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_
#define ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace oneflow {
class OpGraph;
class JobBuilder;
class NonDistributedOptimizerPass final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(NonDistributedOptimizerPass);
NonDistributedOptimizerPass() = default;
~NonDistributedOptimizerPass() = default;
bool IsEnabled() const override { return GlobalJobDesc().enable_non_distributed_optimizer(); }
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_
......@@ -15,7 +15,11 @@ void RegisterFunctionPass(const std::string& pass_name, const OpGraphPass* pass)
CHECK(PassName2FunctionPass()->emplace(pass_name, pass).second);
}
const OpGraphPass& FindFunctionPass(const std::string& pass_name) {
bool HasFunctionPass(const std::string& pass_name) {
return PassName2FunctionPass()->find(pass_name) != PassName2FunctionPass()->end();
}
const OpGraphPass& FunctionPass(const std::string& pass_name) {
const auto& iter = PassName2FunctionPass()->find(pass_name);
CHECK(iter != PassName2FunctionPass()->end());
return *iter->second;
......
......@@ -11,10 +11,13 @@ class OpGraphPass {
public:
void operator()(Job* job) const {
if (IsEnabled() == false) { return; }
Apply(job);
}
virtual bool IsEnabled() const { return true; }
virtual void Apply(Job* job) const {
const OpGraph op_graph(*job);
Apply(op_graph, job);
}
virtual bool IsEnabled() const { return true; }
virtual void Apply(const OpGraph& op_graph, Job* job) const {
JobBuilder job_builder(job);
Apply(op_graph, &job_builder);
......@@ -26,7 +29,8 @@ class OpGraphPass {
COMMAND(RegisterFunctionPass(pass_name, new pass_type))
void RegisterFunctionPass(const std::string& pass_name, const OpGraphPass* pass);
const OpGraphPass& FindFunctionPass(const std::string& pass_name);
bool HasFunctionPass(const std::string& pass_name);
const OpGraphPass& FunctionPass(const std::string& pass_name);
} // namespace oneflow
......
#include "oneflow/core/job_completer/set_default_variable_conf.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/job_builder.h"
#include "oneflow/core/job/job_set_compile_ctx.h"
namespace oneflow {
void SetDefaultVariableConf(const OpGraph& op_graph, JobBuilder* job_builder) {
auto BlobDesc4ModelLbi = op_graph.MakeGetterBlobDesc4ModelLbi();
op_graph.ForEachNode([&](OpNode* op_node) {
if (op_node->op().op_conf().has_variable_conf()) {
OperatorConf variable_op_conf(op_node->op().op_conf());
VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf();
if (!variable_conf->has_data_type()) {
variable_conf->set_data_type(job_builder->job().job_conf().default_data_type());
}
if (!variable_conf->has_initializer() && !variable_conf->has_initialize_with_snapshot()) {
if (job_builder->job().job_conf().has_default_initializer_conf()) {
*variable_conf->mutable_initializer() =
job_builder->job().job_conf().default_initializer_conf();
} else if (job_builder->job().job_conf().has_default_initialize_with_snapshot_path()) {
variable_conf->mutable_initialize_with_snapshot()->set_path(
job_builder->job().job_conf().default_initialize_with_snapshot_path());
variable_conf->mutable_initialize_with_snapshot()->set_key(
GenLogicalBlobName(op_node->op().BnInOp2Lbi("out")));
namespace {
class SetDefaultVariableConf final : public OpGraphPass {
bool IsEnabled() const override { return true; }
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override {
auto BlobDesc4ModelLbi = op_graph.MakeGetterBlobDesc4ModelLbi();
op_graph.ForEachNode([&](OpNode* op_node) {
if (op_node->op().op_conf().has_variable_conf()) {
OperatorConf variable_op_conf(op_node->op().op_conf());
VariableOpConf* variable_conf = variable_op_conf.mutable_variable_conf();
if (!variable_conf->has_data_type()) {
variable_conf->set_data_type(job_builder->job().job_conf().default_data_type());
}
if (!variable_conf->has_initializer() && !variable_conf->has_initialize_with_snapshot()) {
if (job_builder->job().job_conf().has_default_initializer_conf()) {
*variable_conf->mutable_initializer() =
job_builder->job().job_conf().default_initializer_conf();
} else if (job_builder->job().job_conf().has_default_initialize_with_snapshot_path()) {
variable_conf->mutable_initialize_with_snapshot()->set_path(
job_builder->job().job_conf().default_initialize_with_snapshot_path());
variable_conf->mutable_initialize_with_snapshot()->set_key(
GenLogicalBlobName(op_node->op().BnInOp2Lbi("out")));
} else {
UNIMPLEMENTED();
}
}
int64_t random_seed;
auto* var_op_name2random = Global<JobSetCompileCtx>::Get()->GetVarOpName2randomSeed();
const std::string& var_op_name = variable_op_conf.name();
if (variable_conf->has_random_seed()) {
random_seed = variable_conf->random_seed();
} else {
UNIMPLEMENTED();
random_seed = NewRandomSeed();
}
const auto& pair = var_op_name2random->insert({var_op_name, random_seed});
if (variable_conf->has_random_seed()) {
CHECK_EQ(variable_conf->random_seed(), pair.first->second);
} else {
variable_conf->set_random_seed(pair.first->second);
}
job_builder->AddOrMutOpsOnlyOnce(op_node->parallel_desc().parallel_conf(),
{variable_op_conf});
}
int64_t random_seed;
auto* var_op_name2random = Global<JobSetCompileCtx>::Get()->GetVarOpName2randomSeed();
const std::string& var_op_name = variable_op_conf.name();
if (variable_conf->has_random_seed()) {
random_seed = variable_conf->random_seed();
} else {
random_seed = NewRandomSeed();
}
const auto& pair = var_op_name2random->insert({var_op_name, random_seed});
if (variable_conf->has_random_seed()) {
CHECK_EQ(variable_conf->random_seed(), pair.first->second);
} else {
variable_conf->set_random_seed(pair.first->second);
}
job_builder->AddOrMutOpsOnlyOnce(op_node->parallel_desc().parallel_conf(),
{variable_op_conf});
}
});
}
});
}
};
REGISTER_FUNCTION_PASS("SetDefaultVariableConf", SetDefaultVariableConf);
} // namespace
} // namespace oneflow
#ifndef ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_
#define ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/graph/op_graph.h"
namespace oneflow {
void SetDefaultVariableConf(const OpGraph& op_graph, JobBuilder* job_builder);
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_
......@@ -95,6 +95,7 @@ REGISTER_FUNCTION_CONFIG_DEF().Bool("enable_pseudo_chain_merge", false,
"ties up chain headers unreachable from any variable ops");
class TieUpChainHeadersUnReachableFromAnyVariableOps final : public OpGraphPass {
bool IsEnabled() const override { return GlobalJobDesc().Bool("enable_pseudo_chain_merge"); }
void Apply(const OpGraph& op_graph, Job* job) const override {
auto IsReachableFromAnyVariableOps = MakePredicatorIsReachableFromAnyVariableOps(op_graph);
auto GetSourceNodesAndEdges = [&](const HashSet<OpNode*>& chain_nodes,
......
#ifndef ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_
#define ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/job_desc.h"
namespace oneflow {
class UserJobCompleter final {
public:
OF_DISALLOW_COPY_AND_MOVE(UserJobCompleter);
UserJobCompleter() = default;
~UserJobCompleter() = default;
void Complete(Job* job) const;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册