提交 bc5b5e0f 编写于 作者: L lixinqi

refactor JobCompleter::Complete

上级 4c031f17
#include "oneflow/core/job_completer/all_reduce_add_pass.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/register/runtime_blob_desc.h"
namespace oneflow {
......@@ -240,9 +240,18 @@ void BuildAllReduceStruct(
all_reduced_lbi, GetLastTouchedOpName);
}
} // namespace
class AddAllReduceGroupPass final : public OpGraphPass {
public:
AddAllReduceGroupPass() = default;
~AddAllReduceGroupPass() = default;
bool IsEnabled() const override {
return GlobalJobDesc().IsTrain() && !GlobalJobDesc().enable_non_distributed_optimizer()
&& GlobalJobDesc().enable_all_reduce_group();
}
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
};
void AllReduceAddPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
void AddAllReduceGroupPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
auto ProducerOpNode4Lbi = MakeGetterProducerOpNode4Lbi(op_graph);
std::vector<LogicalBlobId> lbis;
FindAllReducedLbis(job_builder->job(), op_graph, ProducerOpNode4Lbi, &lbis);
......@@ -286,4 +295,8 @@ void AllReduceAddPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) c
});
}
} // namespace
REGISTER_FUNCTION_PASS("AddAllReduceGroupPass", AddAllReduceGroupPass);
} // namespace oneflow
#include "oneflow/core/job_completer/add_lbi_diff_watcher.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/lbi_diff_watcher_info.pb.h"
#include "oneflow/core/operator/operator.h"
namespace oneflow {
void AddLbiDiffWatcherOpConfs(Job* job) {
namespace {
class AddLbiDiffWatcherOpConfs final : public OpGraphPass {
public:
bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); }
void Apply(Job* job) const override;
};
void AddLbiDiffWatcherOpConfs::Apply(Job* job) const {
JobBuilder job_builder(job);
const auto& map = Global<LbiDiffWatcherInfo>::Get()->job_name2lbi_and_watcher_uuids();
if (map.find(GlobalJobDesc().job_name()) == map.end()) { return; }
......@@ -27,4 +35,8 @@ void AddLbiDiffWatcherOpConfs(Job* job) {
}
}
REGISTER_FUNCTION_PASS("AddLbiDiffWatcherOpConfs", AddLbiDiffWatcherOpConfs);
} // namespace
} // namespace oneflow
#ifndef ONEFLOW_CORE_JOB_COMPLETER_ADD_LBI_DIFF_WATCHER_H_
#define ONEFLOW_CORE_JOB_COMPLETER_ADD_LBI_DIFF_WATCHER_H_
#include "oneflow/core/job/job_builder.h"
namespace oneflow {
void AddLbiDiffWatcherOpConfs(Job* job);
}
#endif // ONEFLOW_CORE_JOB_COMPLETER_ADD_LBI_DIFF_WATCHER_H_
#ifndef ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_ADD_PASS_H_
#define ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_ADD_PASS_H_
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace oneflow {
class OpGraph;
class AllReduceAddPass final : public OpGraphPass {
public:
AllReduceAddPass() = default;
~AllReduceAddPass() = default;
bool IsEnabled() const override {
return !GlobalJobDesc().enable_non_distributed_optimizer()
&& GlobalJobDesc().enable_all_reduce_group();
}
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_ADD_PASS_H_
#ifndef ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_SEQUENCE_PASS_H_
#define ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_SEQUENCE_PASS_H_
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace oneflow {
class OpGraph;
class AllReduceSequencePass final : public OpGraphPass {
public:
AllReduceSequencePass() = default;
~AllReduceSequencePass() = default;
bool IsEnabled() const override { return !GlobalJobDesc().disable_all_reduce_sequence(); }
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_SEQUENCE_PASS_H_
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/job.pb.h"
namespace oneflow {
void AutoLearningRate(const OpGraph& op_graph, Job* job) {
namespace {
class AutoLearningRate final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(AutoLearningRate);
AutoLearningRate() = default;
~AutoLearningRate() override = default;
bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); }
void Apply(const OpGraph& op_graph, Job* job) const override;
};
void AutoLearningRate::Apply(const OpGraph& op_graph, Job* job) const {
JobBuilder job_builder(job);
const TrainConf& train_conf = job->job_conf().train_conf();
auto AddScheduleOp = [&](const std::string& op_name, const float learning_rate) -> std::string {
......@@ -58,4 +72,8 @@ void AutoLearningRate(const OpGraph& op_graph, Job* job) {
}
}
REGISTER_FUNCTION_PASS("AutoLearningRate", AutoLearningRate);
} // namespace
} // namespace oneflow
#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_
#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_
namespace oneflow {
class OpGraph;
class Job;
void AutoLearningRate(const OpGraph& op_graph, Job* job);
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/job.pb.h"
namespace oneflow {
void AutoTrainStep(const OpGraph& op_graph, Job* job) {
namespace {
class AutoTrainStep final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(AutoTrainStep);
AutoTrainStep() = default;
~AutoTrainStep() override = default;
bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); }
void Apply(const OpGraph& op_graph, Job* job) const override;
};
void AutoTrainStep::Apply(const OpGraph& op_graph, Job* job) const {
if (job->job_conf().train_conf().has_train_step_lbn()) { return; }
OperatorConf variable_op_conf{};
const std::string train_step_name = "System-Train-TrainStep-" + job->job_conf().job_name();
......@@ -42,4 +55,8 @@ void AutoTrainStep(const OpGraph& op_graph, Job* job) {
job->mutable_job_conf()->mutable_train_conf()->set_train_step_lbn(train_step_lbn);
}
REGISTER_FUNCTION_PASS("AutoTrainStep", AutoTrainStep);
} // namespace
} // namespace oneflow
#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_TRAIN_STEP_H_
#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_TRAIN_STEP_H_
namespace oneflow {
class OpGraph;
class Job;
void AutoTrainStep(const OpGraph& op_graph, Job* job);
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_TRAIN_STEP_H_
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job_completer/autograd.h"
#include "oneflow/core/job_completer/optimizer.h"
namespace oneflow {
namespace {
void UpdateJobHelperConfProducedLbi2ConsumedDiffLbi(
const HashMap<LogicalBlobId, LogicalBlobId>& lbi2diff_lbi, JobBuilder* job_builder) {
auto& mut_pairs =
(*job_builder->mutable_helper()->mutable_tag2lbi_relations())[kProducedLbi2ConsumedDiffLbi];
for (const auto& pair : lbi2diff_lbi) {
auto* mut_pair = mut_pairs.add_pair();
*mut_pair->mutable_first() = pair.first;
*mut_pair->mutable_second() = pair.second;
}
}
void BindIdenticalSbpObaPairsBetweenIbns(const OpNode& op_node, JobBuilder* job_builder) {
HashMap<LogicalBlobId, std::vector<OpBlobArg>> in_lbi2obas;
for (const std::string& ibn : op_node.op().input_bns()) {
in_lbi2obas[op_node.op().BnInOp2Lbi(ibn)].push_back(GenOpBlobArg(op_node.op().op_name(), ibn));
}
for (const auto& pair : in_lbi2obas) {
if (pair.second.size() > 1) {
FOR_RANGE(int32_t, i, 1, pair.second.size()) {
job_builder->BindIdenticalSbpOpBlobArgPair(pair.second.at(0), pair.second.at(i));
}
}
}
}
void SetSbpSignatureHintByIdenticalSbpObaPairs(const OpGraph& op_graph, JobBuilder* job_builder) {
HashMap<OpBlobArg, const SbpParallel*> oba2sbp_parallel;
op_graph.ForEachNode([&](OpNode* op_node) {
auto ForEachBn = [&](const std::function<void(const std::string&)>& Handler) {
for (const auto& ibn : op_node->op().input_bns()) { Handler(ibn); }
for (const auto& obn : op_node->op().output_bns()) { Handler(obn); }
};
ForEachBn([&](const std::string& bn_in_op) {
const auto& oba = GenOpBlobArg(op_node->op().op_name(), bn_in_op);
oba2sbp_parallel[oba] = &op_node->SbpParallel4Lbi(op_node->op().BnInOp2Lbi(bn_in_op));
});
});
auto HasSbpParallel = [&](const OpBlobArg& oba) {
return oba2sbp_parallel.find(oba) != oba2sbp_parallel.end();
};
for (const auto& pair : job_builder->job().helper().identical_sbp_oba_pairs().pair()) {
const SbpParallel* sbp_parallel = nullptr;
if (HasSbpParallel(pair.first()) && HasSbpParallel(pair.second())) {
CHECK(oba2sbp_parallel.at(pair.first()) == oba2sbp_parallel.at(pair.second()));
sbp_parallel = oba2sbp_parallel.at(pair.first());
} else if (HasSbpParallel(pair.first())) {
sbp_parallel = oba2sbp_parallel.at(pair.first());
} else if (HasSbpParallel(pair.second())) {
sbp_parallel = oba2sbp_parallel.at(pair.second());
} else {
UNIMPLEMENTED();
}
*job_builder->MutSbpParallel4Oba(pair.first()) = *sbp_parallel;
*job_builder->MutSbpParallel4Oba(pair.second()) = *sbp_parallel;
}
}
void UpdateOpSbpSignatureHint(const OpGraph& op_graph, JobBuilder* job_builder) {
op_graph.ForEachNode(
[&](OpNode* op_node) { BindIdenticalSbpObaPairsBetweenIbns(*op_node, job_builder); });
SetSbpSignatureHintByIdenticalSbpObaPairs(op_graph, job_builder);
}
class GenerateBackwardAndOptimizerOpConfs final : public OpGraphPass {
public:
bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); }
OF_DISALLOW_COPY_AND_MOVE(GenerateBackwardAndOptimizerOpConfs);
GenerateBackwardAndOptimizerOpConfs() = default;
~GenerateBackwardAndOptimizerOpConfs() override = default;
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
};
void GenerateBackwardAndOptimizerOpConfs::Apply(const OpGraph& op_graph,
JobBuilder* job_builder) const {
LogicalBlobId total_loss_instance_num;
HashMap<LogicalBlobId, LogicalBlobId> lbi2diff_lbi;
AutoGrad(op_graph, job_builder, &lbi2diff_lbi);
std::function<const LogicalBlobId&(const ParallelDesc&)> LossInstanceNum4ParallelDesc;
AddTotalLossInstanceNumOpConf(op_graph, job_builder, lbi2diff_lbi, &LossInstanceNum4ParallelDesc);
AddOptimizerOpConf(op_graph, job_builder, lbi2diff_lbi, LossInstanceNum4ParallelDesc);
UpdateJobHelperConfProducedLbi2ConsumedDiffLbi(lbi2diff_lbi, job_builder);
UpdateOpSbpSignatureHint(op_graph, job_builder);
}
REGISTER_FUNCTION_PASS("GenerateBackwardAndOptimizerOpConfs", GenerateBackwardAndOptimizerOpConfs);
} // namespace
} // namespace oneflow
#include "oneflow/core/job_completer/job_completer.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job_completer/autograd.h"
#include "oneflow/core/job_completer/autotick.h"
#include "oneflow/core/job_completer/add_keep_header_only_op_conf.h"
#include "oneflow/core/job_completer/optimizer.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job_completer/all_reduce_add_pass.h"
#include "oneflow/core/job_completer/all_reduce_sequence_pass.h"
#include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h"
#include "oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h"
#include "oneflow/core/job_completer/auto_train_step.h"
#include "oneflow/core/job_completer/auto_learning_rate.h"
#include "oneflow/core/job_completer/add_lbi_diff_watcher.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/job_completer/xrt_compilation.h"
namespace oneflow {
......@@ -42,95 +35,6 @@ void WithOpGraphAndMutJobBuilder(Job* job,
Handler(op_graph, &job_builder);
}
void UpdateJobHelperConfProducedLbi2ConsumedDiffLbi(
const HashMap<LogicalBlobId, LogicalBlobId>& lbi2diff_lbi, JobBuilder* job_builder) {
auto& mut_pairs =
(*job_builder->mutable_helper()->mutable_tag2lbi_relations())[kProducedLbi2ConsumedDiffLbi];
for (const auto& pair : lbi2diff_lbi) {
auto* mut_pair = mut_pairs.add_pair();
*mut_pair->mutable_first() = pair.first;
*mut_pair->mutable_second() = pair.second;
}
}
void BindIdenticalSbpObaPairsBetweenIbns(const OpNode& op_node, JobBuilder* job_builder) {
HashMap<LogicalBlobId, std::vector<OpBlobArg>> in_lbi2obas;
for (const std::string& ibn : op_node.op().input_bns()) {
in_lbi2obas[op_node.op().BnInOp2Lbi(ibn)].push_back(GenOpBlobArg(op_node.op().op_name(), ibn));
}
for (const auto& pair : in_lbi2obas) {
if (pair.second.size() > 1) {
FOR_RANGE(int32_t, i, 1, pair.second.size()) {
job_builder->BindIdenticalSbpOpBlobArgPair(pair.second.at(0), pair.second.at(i));
}
}
}
}
void SetSbpSignatureHintByIdenticalSbpObaPairs(const OpGraph& op_graph, JobBuilder* job_builder) {
HashMap<OpBlobArg, const SbpParallel*> oba2sbp_parallel;
op_graph.ForEachNode([&](OpNode* op_node) {
auto ForEachBn = [&](const std::function<void(const std::string&)>& Handler) {
for (const auto& ibn : op_node->op().input_bns()) { Handler(ibn); }
for (const auto& obn : op_node->op().output_bns()) { Handler(obn); }
};
ForEachBn([&](const std::string& bn_in_op) {
const auto& oba = GenOpBlobArg(op_node->op().op_name(), bn_in_op);
oba2sbp_parallel[oba] = &op_node->SbpParallel4Lbi(op_node->op().BnInOp2Lbi(bn_in_op));
});
});
auto HasSbpParallel = [&](const OpBlobArg& oba) {
return oba2sbp_parallel.find(oba) != oba2sbp_parallel.end();
};
for (const auto& pair : job_builder->job().helper().identical_sbp_oba_pairs().pair()) {
const SbpParallel* sbp_parallel = nullptr;
if (HasSbpParallel(pair.first()) && HasSbpParallel(pair.second())) {
CHECK(oba2sbp_parallel.at(pair.first()) == oba2sbp_parallel.at(pair.second()));
sbp_parallel = oba2sbp_parallel.at(pair.first());
} else if (HasSbpParallel(pair.first())) {
sbp_parallel = oba2sbp_parallel.at(pair.first());
} else if (HasSbpParallel(pair.second())) {
sbp_parallel = oba2sbp_parallel.at(pair.second());
} else {
UNIMPLEMENTED();
}
*job_builder->MutSbpParallel4Oba(pair.first()) = *sbp_parallel;
*job_builder->MutSbpParallel4Oba(pair.second()) = *sbp_parallel;
}
}
void UpdateOpSbpSignatureHint(const OpGraph& op_graph, JobBuilder* job_builder) {
op_graph.ForEachNode(
[&](OpNode* op_node) { BindIdenticalSbpObaPairsBetweenIbns(*op_node, job_builder); });
SetSbpSignatureHintByIdenticalSbpObaPairs(op_graph, job_builder);
}
void GenerateOpConf4Trainning(const OpGraph& op_graph, JobBuilder* job_builder) {
LogicalBlobId total_loss_instance_num;
HashMap<LogicalBlobId, LogicalBlobId> lbi2diff_lbi;
AutoGrad(op_graph, job_builder, &lbi2diff_lbi);
std::function<const LogicalBlobId&(const ParallelDesc&)> LossInstanceNum4ParallelDesc;
AddTotalLossInstanceNumOpConf(op_graph, job_builder, lbi2diff_lbi, &LossInstanceNum4ParallelDesc);
AddOptimizerOpConf(op_graph, job_builder, lbi2diff_lbi, LossInstanceNum4ParallelDesc);
UpdateJobHelperConfProducedLbi2ConsumedDiffLbi(lbi2diff_lbi, job_builder);
UpdateOpSbpSignatureHint(op_graph, job_builder);
}
std::function<ParallelConf*(const std::string&)> MakeGetterMutParallelConf4OpName(
Placement* placement) {
auto op_name2parallel_conf = std::make_shared<HashMap<std::string, ParallelConf*>>();
FOR_RANGE(int, idx, 0, placement->placement_group_size()) {
auto* placement_group = placement->mutable_placement_group(idx);
for (const std::string& op_name : placement_group->op_set().op_name()) {
ParallelConf* parallel_conf = placement_group->mutable_parallel_conf();
CHECK(op_name2parallel_conf->emplace(op_name, parallel_conf).second);
}
}
return [op_name2parallel_conf](const std::string& op_name) {
return op_name2parallel_conf->at(op_name);
};
}
void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder) {
auto IsMutableConsumedLbi = [](const Operator& op, const LogicalBlobId& lbi) -> bool {
for (const std::string& bn : op.input_bns()) {
......@@ -179,28 +83,20 @@ void DumpLogicalBlobDescAndSbpSignature(const OpGraph& op_graph, JobBuilder* job
op_graph.DumpSbpSignature(job_builder);
}
void MakeNcclTupleBroadcastReduceSequence(const OpGraph& op_graph, JobBuilder* job_builder) {
NcclTupleBroadcastReduceSequencePass().Apply(op_graph, job_builder);
}
} // namespace
void JobCompleter::Complete(Job* job) const {
// complete variable ops
FunctionPass("SetDefaultVariableConf")(job);
FunctionPass("AutoMixedPrecision")(job);
if (GlobalJobDesc().IsTrain()) {
FunctionPass("TieUpChainHeadersUnReachableFromAnyVariableOps")(job);
FunctionPass("NonDistributedOptimizerPass")(job);
WithOpGraphAndMutJob(job, &AutoTrainStep);
WithOpGraphAndMutJob(job, &AutoLearningRate);
// complete ops for trainning
WithOpGraphAndMutJobBuilder(job, &GenerateOpConf4Trainning);
WithOpGraphAndMutJobBuilder(job, &MakeNcclTupleBroadcastReduceSequence);
AllReduceAddPass()(job);
AddLbiDiffWatcherOpConfs(job);
AllReduceSequencePass()(job);
}
FunctionPass("TieUpChainHeadersUnReachableFromAnyVariableOps")(job);
FunctionPass("NonDistributedOptimizerPass")(job);
FunctionPass("AutoTrainStep")(job);
FunctionPass("AutoLearningRate")(job);
FunctionPass("GenerateBackwardAndOptimizerOpConfs")(job);
FunctionPass("SequentializeNcclTupleBroadcastReducePass")(job);
FunctionPass("AddAllReduceGroupPass")(job);
FunctionPass("AddLbiDiffWatcherOpConfs")(job);
FunctionPass("SequentializeAllReduceGroupPass")(job);
WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature);
WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel);
WithOpGraphAndMutJobBuilder(job, &AddKeepHeaderOnlyOp);
......
#ifndef ONEFLOW_CORE_JOB_COMPLETER_NCCL_TUPLE_BROADCAST_REDUCE_SEQUENCE_PASS_H_
#define ONEFLOW_CORE_JOB_COMPLETER_NCCL_TUPLE_BROADCAST_REDUCE_SEQUENCE_PASS_H_
#include "oneflow/core/common/util.h"
namespace oneflow {
class OpGraph;
class JobBuilder;
class NcclTupleBroadcastReduceSequencePass final {
public:
OF_DISALLOW_COPY_AND_MOVE(NcclTupleBroadcastReduceSequencePass);
NcclTupleBroadcastReduceSequencePass() = default;
~NcclTupleBroadcastReduceSequencePass() = default;
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;
};
} // namespace oneflow
#endif // #define ONEFLOW_CORE_JOB_COMPLETER_NCCL_TUPLE_BROADCAST_REDUCE_SEQUENCE_PASS_H_
......@@ -35,7 +35,9 @@ class NonDistributedOptimizerPass final : public OpGraphPass {
OF_DISALLOW_COPY_AND_MOVE(NonDistributedOptimizerPass);
NonDistributedOptimizerPass() = default;
~NonDistributedOptimizerPass() = default;
bool IsEnabled() const override { return GlobalJobDesc().enable_non_distributed_optimizer(); }
bool IsEnabled() const override {
return GlobalJobDesc().IsTrain() && GlobalJobDesc().enable_non_distributed_optimizer();
}
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
};
......
......@@ -9,6 +9,9 @@ namespace oneflow {
class OpGraphPass {
public:
OpGraphPass() = default;
virtual ~OpGraphPass() = default;
void operator()(Job* job) const {
if (IsEnabled() == false) { return; }
Apply(job);
......
#include "oneflow/core/job_completer/all_reduce_sequence_pass.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace oneflow {
......@@ -52,9 +52,18 @@ void ReOrderAllReduceGroups(std::vector<AllReduceGroup>* all_reduce_groups) {
all_reduce_groups->end() - lazy_count);
}
} // namespace
class SequentializeAllReduceGroupPass final : public OpGraphPass {
public:
SequentializeAllReduceGroupPass() = default;
~SequentializeAllReduceGroupPass() = default;
bool IsEnabled() const override {
return GlobalJobDesc().IsTrain() && !GlobalJobDesc().disable_all_reduce_sequence();
}
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
};
void AllReduceSequencePass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
void SequentializeAllReduceGroupPass::Apply(const OpGraph& op_graph,
JobBuilder* job_builder) const {
std::vector<AllReduceGroup> all_reduce_groups;
FindAllReduceGroups(op_graph, &all_reduce_groups);
ReOrderAllReduceGroups(&all_reduce_groups);
......@@ -68,4 +77,8 @@ void AllReduceSequencePass::Apply(const OpGraph& op_graph, JobBuilder* job_build
}
}
REGISTER_FUNCTION_PASS("SequentializeAllReduceGroupPass", SequentializeAllReduceGroupPass);
} // namespace
} // namespace oneflow
#include "oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/graph/op_graph.h"
namespace oneflow {
void NcclTupleBroadcastReduceSequencePass::Apply(const OpGraph& op_graph,
JobBuilder* builder) const {
class SequentializeNcclTupleBroadcastReducePass final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(SequentializeNcclTupleBroadcastReducePass);
SequentializeNcclTupleBroadcastReducePass() = default;
~SequentializeNcclTupleBroadcastReducePass() = default;
bool IsEnabled() const override { return GlobalJobDesc().IsTrain(); }
void Apply(const OpGraph& op_graph, JobBuilder* job_builder) const override;
};
void SequentializeNcclTupleBroadcastReducePass::Apply(const OpGraph& op_graph,
JobBuilder* builder) const {
std::vector<OperatorConf> broadcast_ops;
std::vector<OperatorConf> reduce_ops;
op_graph.ForEachNode([&](const OpNode* node) {
......@@ -41,4 +52,7 @@ void NcclTupleBroadcastReduceSequencePass::Apply(const OpGraph& op_graph,
builder->MutOpsOnlyOnce(reduce_ops);
}
REGISTER_FUNCTION_PASS("SequentializeNcclTupleBroadcastReducePass",
SequentializeNcclTupleBroadcastReducePass);
} // namespace oneflow
......@@ -94,7 +94,9 @@ std::function<bool(OpNode*)> MakePredicatorIsReachableFromAnyVariableOps(const O
REGISTER_FUNCTION_CONFIG_DEF().Bool("enable_pseudo_chain_merge", false,
"ties up chain headers unreachable from any variable ops");
class TieUpChainHeadersUnReachableFromAnyVariableOps final : public OpGraphPass {
bool IsEnabled() const override { return GlobalJobDesc().Bool("enable_pseudo_chain_merge"); }
bool IsEnabled() const override {
return GlobalJobDesc().IsTrain() && GlobalJobDesc().Bool("enable_pseudo_chain_merge");
}
void Apply(const OpGraph& op_graph, Job* job) const override {
auto IsReachableFromAnyVariableOps = MakePredicatorIsReachableFromAnyVariableOps(op_graph);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册