Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
bc5b5e0f
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 接近 3 年
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
bc5b5e0f
编写于
1月 07, 2020
作者:
L
lixinqi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor JobCompleter::Complete
上级
4c031f17
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
219 addition
and
234 deletion
+219
-234
oneflow/core/job_completer/add_all_reduce_group_pass.cpp
oneflow/core/job_completer/add_all_reduce_group_pass.cpp
+16
-3
oneflow/core/job_completer/add_lbi_diff_watcher.cpp
oneflow/core/job_completer/add_lbi_diff_watcher.cpp
+14
-2
oneflow/core/job_completer/add_lbi_diff_watcher.h
oneflow/core/job_completer/add_lbi_diff_watcher.h
+0
-11
oneflow/core/job_completer/all_reduce_add_pass.h
oneflow/core/job_completer/all_reduce_add_pass.h
+0
-23
oneflow/core/job_completer/all_reduce_sequence_pass.h
oneflow/core/job_completer/all_reduce_sequence_pass.h
+0
-21
oneflow/core/job_completer/auto_learning_rate.cpp
oneflow/core/job_completer/auto_learning_rate.cpp
+20
-2
oneflow/core/job_completer/auto_learning_rate.h
oneflow/core/job_completer/auto_learning_rate.h
+0
-13
oneflow/core/job_completer/auto_train_step.cpp
oneflow/core/job_completer/auto_train_step.cpp
+19
-2
oneflow/core/job_completer/auto_train_step.h
oneflow/core/job_completer/auto_train_step.h
+0
-13
oneflow/core/job_completer/generate_backward_and_optimizer_op_confs.cpp
...ob_completer/generate_backward_and_optimizer_op_confs.cpp
+98
-0
oneflow/core/job_completer/job_completer.cpp
oneflow/core/job_completer/job_completer.cpp
+10
-114
oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h
...job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h
+0
-22
oneflow/core/job_completer/non_distributed_optimizer_pass.cpp
...low/core/job_completer/non_distributed_optimizer_pass.cpp
+3
-1
oneflow/core/job_completer/op_graph_pass.h
oneflow/core/job_completer/op_graph_pass.h
+3
-0
oneflow/core/job_completer/sequentialize_all_reduce_group_pass.cpp
...ore/job_completer/sequentialize_all_reduce_group_pass.cpp
+16
-3
oneflow/core/job_completer/sequentialize_nccl_tuple_broadcast_reduce_pass.cpp
...pleter/sequentialize_nccl_tuple_broadcast_reduce_pass.cpp
+17
-3
oneflow/core/job_completer/tie_up_chain_headers.cpp
oneflow/core/job_completer/tie_up_chain_headers.cpp
+3
-1
未找到文件。
oneflow/core/job_completer/a
ll_reduce_add
_pass.cpp
→
oneflow/core/job_completer/a
dd_all_reduce_group
_pass.cpp
浏览文件 @
bc5b5e0f
#include "oneflow/core/job_completer/
all_reduce_add
_pass.h"
#include "oneflow/core/job_completer/
op_graph
_pass.h"
#include "oneflow/core/register/runtime_blob_desc.h"
#include "oneflow/core/register/runtime_blob_desc.h"
namespace
oneflow
{
namespace
oneflow
{
...
@@ -240,9 +240,18 @@ void BuildAllReduceStruct(
...
@@ -240,9 +240,18 @@ void BuildAllReduceStruct(
all_reduced_lbi
,
GetLastTouchedOpName
);
all_reduced_lbi
,
GetLastTouchedOpName
);
}
}
}
// namespace
class
AddAllReduceGroupPass
final
:
public
OpGraphPass
{
public:
AddAllReduceGroupPass
()
=
default
;
~
AddAllReduceGroupPass
()
=
default
;
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
IsTrain
()
&&
!
GlobalJobDesc
().
enable_non_distributed_optimizer
()
&&
GlobalJobDesc
().
enable_all_reduce_group
();
}
void
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
override
;
};
void
A
llReduceAdd
Pass
::
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
{
void
A
ddAllReduceGroup
Pass
::
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
{
auto
ProducerOpNode4Lbi
=
MakeGetterProducerOpNode4Lbi
(
op_graph
);
auto
ProducerOpNode4Lbi
=
MakeGetterProducerOpNode4Lbi
(
op_graph
);
std
::
vector
<
LogicalBlobId
>
lbis
;
std
::
vector
<
LogicalBlobId
>
lbis
;
FindAllReducedLbis
(
job_builder
->
job
(),
op_graph
,
ProducerOpNode4Lbi
,
&
lbis
);
FindAllReducedLbis
(
job_builder
->
job
(),
op_graph
,
ProducerOpNode4Lbi
,
&
lbis
);
...
@@ -286,4 +295,8 @@ void AllReduceAddPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) c
...
@@ -286,4 +295,8 @@ void AllReduceAddPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) c
});
});
}
}
}
// namespace
REGISTER_FUNCTION_PASS
(
"AddAllReduceGroupPass"
,
AddAllReduceGroupPass
);
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/job_completer/add_lbi_diff_watcher.cpp
浏览文件 @
bc5b5e0f
#include "oneflow/core/job_completer/
add_lbi_diff_watcher
.h"
#include "oneflow/core/job_completer/
op_graph_pass
.h"
#include "oneflow/core/job/lbi_diff_watcher_info.pb.h"
#include "oneflow/core/job/lbi_diff_watcher_info.pb.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/operator.h"
namespace
oneflow
{
namespace
oneflow
{
void
AddLbiDiffWatcherOpConfs
(
Job
*
job
)
{
namespace
{
class
AddLbiDiffWatcherOpConfs
final
:
public
OpGraphPass
{
public:
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
IsTrain
();
}
void
Apply
(
Job
*
job
)
const
override
;
};
void
AddLbiDiffWatcherOpConfs
::
Apply
(
Job
*
job
)
const
{
JobBuilder
job_builder
(
job
);
JobBuilder
job_builder
(
job
);
const
auto
&
map
=
Global
<
LbiDiffWatcherInfo
>::
Get
()
->
job_name2lbi_and_watcher_uuids
();
const
auto
&
map
=
Global
<
LbiDiffWatcherInfo
>::
Get
()
->
job_name2lbi_and_watcher_uuids
();
if
(
map
.
find
(
GlobalJobDesc
().
job_name
())
==
map
.
end
())
{
return
;
}
if
(
map
.
find
(
GlobalJobDesc
().
job_name
())
==
map
.
end
())
{
return
;
}
...
@@ -27,4 +35,8 @@ void AddLbiDiffWatcherOpConfs(Job* job) {
...
@@ -27,4 +35,8 @@ void AddLbiDiffWatcherOpConfs(Job* job) {
}
}
}
}
REGISTER_FUNCTION_PASS
(
"AddLbiDiffWatcherOpConfs"
,
AddLbiDiffWatcherOpConfs
);
}
// namespace
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/job_completer/add_lbi_diff_watcher.h
已删除
100644 → 0
浏览文件 @
4c031f17
#ifndef ONEFLOW_CORE_JOB_COMPLETER_ADD_LBI_DIFF_WATCHER_H_
#define ONEFLOW_CORE_JOB_COMPLETER_ADD_LBI_DIFF_WATCHER_H_
#include "oneflow/core/job/job_builder.h"
namespace
oneflow
{
void
AddLbiDiffWatcherOpConfs
(
Job
*
job
);
}
#endif // ONEFLOW_CORE_JOB_COMPLETER_ADD_LBI_DIFF_WATCHER_H_
oneflow/core/job_completer/all_reduce_add_pass.h
已删除
100644 → 0
浏览文件 @
4c031f17
#ifndef ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_ADD_PASS_H_
#define ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_ADD_PASS_H_
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace
oneflow
{
class
OpGraph
;
class
AllReduceAddPass
final
:
public
OpGraphPass
{
public:
AllReduceAddPass
()
=
default
;
~
AllReduceAddPass
()
=
default
;
bool
IsEnabled
()
const
override
{
return
!
GlobalJobDesc
().
enable_non_distributed_optimizer
()
&&
GlobalJobDesc
().
enable_all_reduce_group
();
}
void
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
override
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_ADD_PASS_H_
oneflow/core/job_completer/all_reduce_sequence_pass.h
已删除
100644 → 0
浏览文件 @
4c031f17
#ifndef ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_SEQUENCE_PASS_H_
#define ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_SEQUENCE_PASS_H_
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace
oneflow
{
class
OpGraph
;
class
AllReduceSequencePass
final
:
public
OpGraphPass
{
public:
AllReduceSequencePass
()
=
default
;
~
AllReduceSequencePass
()
=
default
;
bool
IsEnabled
()
const
override
{
return
!
GlobalJobDesc
().
disable_all_reduce_sequence
();
}
void
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
override
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_ALL_REDUCE_SEQUENCE_PASS_H_
oneflow/core/job_completer/auto_learning_rate.cpp
浏览文件 @
bc5b5e0f
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/job/job.pb.h"
namespace
oneflow
{
namespace
oneflow
{
void
AutoLearningRate
(
const
OpGraph
&
op_graph
,
Job
*
job
)
{
namespace
{
class
AutoLearningRate
final
:
public
OpGraphPass
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
AutoLearningRate
);
AutoLearningRate
()
=
default
;
~
AutoLearningRate
()
override
=
default
;
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
IsTrain
();
}
void
Apply
(
const
OpGraph
&
op_graph
,
Job
*
job
)
const
override
;
};
void
AutoLearningRate
::
Apply
(
const
OpGraph
&
op_graph
,
Job
*
job
)
const
{
JobBuilder
job_builder
(
job
);
JobBuilder
job_builder
(
job
);
const
TrainConf
&
train_conf
=
job
->
job_conf
().
train_conf
();
const
TrainConf
&
train_conf
=
job
->
job_conf
().
train_conf
();
auto
AddScheduleOp
=
[
&
](
const
std
::
string
&
op_name
,
const
float
learning_rate
)
->
std
::
string
{
auto
AddScheduleOp
=
[
&
](
const
std
::
string
&
op_name
,
const
float
learning_rate
)
->
std
::
string
{
...
@@ -58,4 +72,8 @@ void AutoLearningRate(const OpGraph& op_graph, Job* job) {
...
@@ -58,4 +72,8 @@ void AutoLearningRate(const OpGraph& op_graph, Job* job) {
}
}
}
}
REGISTER_FUNCTION_PASS
(
"AutoLearningRate"
,
AutoLearningRate
);
}
// namespace
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/job_completer/auto_learning_rate.h
已删除
100644 → 0
浏览文件 @
4c031f17
#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_
#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_
namespace
oneflow
{
class
OpGraph
;
class
Job
;
void
AutoLearningRate
(
const
OpGraph
&
op_graph
,
Job
*
job
);
}
// namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_LEARNING_RATE_H_
oneflow/core/job_completer/auto_train_step.cpp
浏览文件 @
bc5b5e0f
#include "oneflow/core/
graph/op_graph
.h"
#include "oneflow/core/
job_completer/op_graph_pass
.h"
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/job/job.pb.h"
namespace
oneflow
{
namespace
oneflow
{
void
AutoTrainStep
(
const
OpGraph
&
op_graph
,
Job
*
job
)
{
namespace
{
class
AutoTrainStep
final
:
public
OpGraphPass
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
AutoTrainStep
);
AutoTrainStep
()
=
default
;
~
AutoTrainStep
()
override
=
default
;
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
IsTrain
();
}
void
Apply
(
const
OpGraph
&
op_graph
,
Job
*
job
)
const
override
;
};
void
AutoTrainStep
::
Apply
(
const
OpGraph
&
op_graph
,
Job
*
job
)
const
{
if
(
job
->
job_conf
().
train_conf
().
has_train_step_lbn
())
{
return
;
}
if
(
job
->
job_conf
().
train_conf
().
has_train_step_lbn
())
{
return
;
}
OperatorConf
variable_op_conf
{};
OperatorConf
variable_op_conf
{};
const
std
::
string
train_step_name
=
"System-Train-TrainStep-"
+
job
->
job_conf
().
job_name
();
const
std
::
string
train_step_name
=
"System-Train-TrainStep-"
+
job
->
job_conf
().
job_name
();
...
@@ -42,4 +55,8 @@ void AutoTrainStep(const OpGraph& op_graph, Job* job) {
...
@@ -42,4 +55,8 @@ void AutoTrainStep(const OpGraph& op_graph, Job* job) {
job
->
mutable_job_conf
()
->
mutable_train_conf
()
->
set_train_step_lbn
(
train_step_lbn
);
job
->
mutable_job_conf
()
->
mutable_train_conf
()
->
set_train_step_lbn
(
train_step_lbn
);
}
}
REGISTER_FUNCTION_PASS
(
"AutoTrainStep"
,
AutoTrainStep
);
}
// namespace
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/job_completer/auto_train_step.h
已删除
100644 → 0
浏览文件 @
4c031f17
#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_TRAIN_STEP_H_
#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_TRAIN_STEP_H_
namespace
oneflow
{
class
OpGraph
;
class
Job
;
void
AutoTrainStep
(
const
OpGraph
&
op_graph
,
Job
*
job
);
}
// namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_TRAIN_STEP_H_
oneflow/core/job_completer/generate_backward_and_optimizer_op_confs.cpp
0 → 100644
浏览文件 @
bc5b5e0f
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job_completer/autograd.h"
#include "oneflow/core/job_completer/optimizer.h"
namespace
oneflow
{
namespace
{
void
UpdateJobHelperConfProducedLbi2ConsumedDiffLbi
(
const
HashMap
<
LogicalBlobId
,
LogicalBlobId
>&
lbi2diff_lbi
,
JobBuilder
*
job_builder
)
{
auto
&
mut_pairs
=
(
*
job_builder
->
mutable_helper
()
->
mutable_tag2lbi_relations
())[
kProducedLbi2ConsumedDiffLbi
];
for
(
const
auto
&
pair
:
lbi2diff_lbi
)
{
auto
*
mut_pair
=
mut_pairs
.
add_pair
();
*
mut_pair
->
mutable_first
()
=
pair
.
first
;
*
mut_pair
->
mutable_second
()
=
pair
.
second
;
}
}
void
BindIdenticalSbpObaPairsBetweenIbns
(
const
OpNode
&
op_node
,
JobBuilder
*
job_builder
)
{
HashMap
<
LogicalBlobId
,
std
::
vector
<
OpBlobArg
>>
in_lbi2obas
;
for
(
const
std
::
string
&
ibn
:
op_node
.
op
().
input_bns
())
{
in_lbi2obas
[
op_node
.
op
().
BnInOp2Lbi
(
ibn
)].
push_back
(
GenOpBlobArg
(
op_node
.
op
().
op_name
(),
ibn
));
}
for
(
const
auto
&
pair
:
in_lbi2obas
)
{
if
(
pair
.
second
.
size
()
>
1
)
{
FOR_RANGE
(
int32_t
,
i
,
1
,
pair
.
second
.
size
())
{
job_builder
->
BindIdenticalSbpOpBlobArgPair
(
pair
.
second
.
at
(
0
),
pair
.
second
.
at
(
i
));
}
}
}
}
void
SetSbpSignatureHintByIdenticalSbpObaPairs
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
{
HashMap
<
OpBlobArg
,
const
SbpParallel
*>
oba2sbp_parallel
;
op_graph
.
ForEachNode
([
&
](
OpNode
*
op_node
)
{
auto
ForEachBn
=
[
&
](
const
std
::
function
<
void
(
const
std
::
string
&
)
>&
Handler
)
{
for
(
const
auto
&
ibn
:
op_node
->
op
().
input_bns
())
{
Handler
(
ibn
);
}
for
(
const
auto
&
obn
:
op_node
->
op
().
output_bns
())
{
Handler
(
obn
);
}
};
ForEachBn
([
&
](
const
std
::
string
&
bn_in_op
)
{
const
auto
&
oba
=
GenOpBlobArg
(
op_node
->
op
().
op_name
(),
bn_in_op
);
oba2sbp_parallel
[
oba
]
=
&
op_node
->
SbpParallel4Lbi
(
op_node
->
op
().
BnInOp2Lbi
(
bn_in_op
));
});
});
auto
HasSbpParallel
=
[
&
](
const
OpBlobArg
&
oba
)
{
return
oba2sbp_parallel
.
find
(
oba
)
!=
oba2sbp_parallel
.
end
();
};
for
(
const
auto
&
pair
:
job_builder
->
job
().
helper
().
identical_sbp_oba_pairs
().
pair
())
{
const
SbpParallel
*
sbp_parallel
=
nullptr
;
if
(
HasSbpParallel
(
pair
.
first
())
&&
HasSbpParallel
(
pair
.
second
()))
{
CHECK
(
oba2sbp_parallel
.
at
(
pair
.
first
())
==
oba2sbp_parallel
.
at
(
pair
.
second
()));
sbp_parallel
=
oba2sbp_parallel
.
at
(
pair
.
first
());
}
else
if
(
HasSbpParallel
(
pair
.
first
()))
{
sbp_parallel
=
oba2sbp_parallel
.
at
(
pair
.
first
());
}
else
if
(
HasSbpParallel
(
pair
.
second
()))
{
sbp_parallel
=
oba2sbp_parallel
.
at
(
pair
.
second
());
}
else
{
UNIMPLEMENTED
();
}
*
job_builder
->
MutSbpParallel4Oba
(
pair
.
first
())
=
*
sbp_parallel
;
*
job_builder
->
MutSbpParallel4Oba
(
pair
.
second
())
=
*
sbp_parallel
;
}
}
void
UpdateOpSbpSignatureHint
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
{
op_graph
.
ForEachNode
(
[
&
](
OpNode
*
op_node
)
{
BindIdenticalSbpObaPairsBetweenIbns
(
*
op_node
,
job_builder
);
});
SetSbpSignatureHintByIdenticalSbpObaPairs
(
op_graph
,
job_builder
);
}
class
GenerateBackwardAndOptimizerOpConfs
final
:
public
OpGraphPass
{
public:
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
IsTrain
();
}
OF_DISALLOW_COPY_AND_MOVE
(
GenerateBackwardAndOptimizerOpConfs
);
GenerateBackwardAndOptimizerOpConfs
()
=
default
;
~
GenerateBackwardAndOptimizerOpConfs
()
override
=
default
;
void
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
override
;
};
void
GenerateBackwardAndOptimizerOpConfs
::
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
{
LogicalBlobId
total_loss_instance_num
;
HashMap
<
LogicalBlobId
,
LogicalBlobId
>
lbi2diff_lbi
;
AutoGrad
(
op_graph
,
job_builder
,
&
lbi2diff_lbi
);
std
::
function
<
const
LogicalBlobId
&
(
const
ParallelDesc
&
)
>
LossInstanceNum4ParallelDesc
;
AddTotalLossInstanceNumOpConf
(
op_graph
,
job_builder
,
lbi2diff_lbi
,
&
LossInstanceNum4ParallelDesc
);
AddOptimizerOpConf
(
op_graph
,
job_builder
,
lbi2diff_lbi
,
LossInstanceNum4ParallelDesc
);
UpdateJobHelperConfProducedLbi2ConsumedDiffLbi
(
lbi2diff_lbi
,
job_builder
);
UpdateOpSbpSignatureHint
(
op_graph
,
job_builder
);
}
REGISTER_FUNCTION_PASS
(
"GenerateBackwardAndOptimizerOpConfs"
,
GenerateBackwardAndOptimizerOpConfs
);
}
// namespace
}
// namespace oneflow
oneflow/core/job_completer/job_completer.cpp
浏览文件 @
bc5b5e0f
#include "oneflow/core/job_completer/job_completer.h"
#include "oneflow/core/job_completer/job_completer.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job_completer/autograd.h"
#include "oneflow/core/job_completer/autograd.h"
#include "oneflow/core/job_completer/autotick.h"
#include "oneflow/core/job_completer/autotick.h"
#include "oneflow/core/job_completer/add_keep_header_only_op_conf.h"
#include "oneflow/core/job_completer/add_keep_header_only_op_conf.h"
#include "oneflow/core/job_completer/optimizer.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job_completer/all_reduce_add_pass.h"
#include "oneflow/core/job_completer/all_reduce_sequence_pass.h"
#include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h"
#include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h"
#include "oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h"
#include "oneflow/core/job_completer/auto_train_step.h"
#include "oneflow/core/job_completer/auto_learning_rate.h"
#include "oneflow/core/job_completer/add_lbi_diff_watcher.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/job_completer/xrt_compilation.h"
#include "oneflow/core/job_completer/xrt_compilation.h"
namespace
oneflow
{
namespace
oneflow
{
...
@@ -42,95 +35,6 @@ void WithOpGraphAndMutJobBuilder(Job* job,
...
@@ -42,95 +35,6 @@ void WithOpGraphAndMutJobBuilder(Job* job,
Handler
(
op_graph
,
&
job_builder
);
Handler
(
op_graph
,
&
job_builder
);
}
}
void
UpdateJobHelperConfProducedLbi2ConsumedDiffLbi
(
const
HashMap
<
LogicalBlobId
,
LogicalBlobId
>&
lbi2diff_lbi
,
JobBuilder
*
job_builder
)
{
auto
&
mut_pairs
=
(
*
job_builder
->
mutable_helper
()
->
mutable_tag2lbi_relations
())[
kProducedLbi2ConsumedDiffLbi
];
for
(
const
auto
&
pair
:
lbi2diff_lbi
)
{
auto
*
mut_pair
=
mut_pairs
.
add_pair
();
*
mut_pair
->
mutable_first
()
=
pair
.
first
;
*
mut_pair
->
mutable_second
()
=
pair
.
second
;
}
}
void
BindIdenticalSbpObaPairsBetweenIbns
(
const
OpNode
&
op_node
,
JobBuilder
*
job_builder
)
{
HashMap
<
LogicalBlobId
,
std
::
vector
<
OpBlobArg
>>
in_lbi2obas
;
for
(
const
std
::
string
&
ibn
:
op_node
.
op
().
input_bns
())
{
in_lbi2obas
[
op_node
.
op
().
BnInOp2Lbi
(
ibn
)].
push_back
(
GenOpBlobArg
(
op_node
.
op
().
op_name
(),
ibn
));
}
for
(
const
auto
&
pair
:
in_lbi2obas
)
{
if
(
pair
.
second
.
size
()
>
1
)
{
FOR_RANGE
(
int32_t
,
i
,
1
,
pair
.
second
.
size
())
{
job_builder
->
BindIdenticalSbpOpBlobArgPair
(
pair
.
second
.
at
(
0
),
pair
.
second
.
at
(
i
));
}
}
}
}
void
SetSbpSignatureHintByIdenticalSbpObaPairs
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
{
HashMap
<
OpBlobArg
,
const
SbpParallel
*>
oba2sbp_parallel
;
op_graph
.
ForEachNode
([
&
](
OpNode
*
op_node
)
{
auto
ForEachBn
=
[
&
](
const
std
::
function
<
void
(
const
std
::
string
&
)
>&
Handler
)
{
for
(
const
auto
&
ibn
:
op_node
->
op
().
input_bns
())
{
Handler
(
ibn
);
}
for
(
const
auto
&
obn
:
op_node
->
op
().
output_bns
())
{
Handler
(
obn
);
}
};
ForEachBn
([
&
](
const
std
::
string
&
bn_in_op
)
{
const
auto
&
oba
=
GenOpBlobArg
(
op_node
->
op
().
op_name
(),
bn_in_op
);
oba2sbp_parallel
[
oba
]
=
&
op_node
->
SbpParallel4Lbi
(
op_node
->
op
().
BnInOp2Lbi
(
bn_in_op
));
});
});
auto
HasSbpParallel
=
[
&
](
const
OpBlobArg
&
oba
)
{
return
oba2sbp_parallel
.
find
(
oba
)
!=
oba2sbp_parallel
.
end
();
};
for
(
const
auto
&
pair
:
job_builder
->
job
().
helper
().
identical_sbp_oba_pairs
().
pair
())
{
const
SbpParallel
*
sbp_parallel
=
nullptr
;
if
(
HasSbpParallel
(
pair
.
first
())
&&
HasSbpParallel
(
pair
.
second
()))
{
CHECK
(
oba2sbp_parallel
.
at
(
pair
.
first
())
==
oba2sbp_parallel
.
at
(
pair
.
second
()));
sbp_parallel
=
oba2sbp_parallel
.
at
(
pair
.
first
());
}
else
if
(
HasSbpParallel
(
pair
.
first
()))
{
sbp_parallel
=
oba2sbp_parallel
.
at
(
pair
.
first
());
}
else
if
(
HasSbpParallel
(
pair
.
second
()))
{
sbp_parallel
=
oba2sbp_parallel
.
at
(
pair
.
second
());
}
else
{
UNIMPLEMENTED
();
}
*
job_builder
->
MutSbpParallel4Oba
(
pair
.
first
())
=
*
sbp_parallel
;
*
job_builder
->
MutSbpParallel4Oba
(
pair
.
second
())
=
*
sbp_parallel
;
}
}
void
UpdateOpSbpSignatureHint
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
{
op_graph
.
ForEachNode
(
[
&
](
OpNode
*
op_node
)
{
BindIdenticalSbpObaPairsBetweenIbns
(
*
op_node
,
job_builder
);
});
SetSbpSignatureHintByIdenticalSbpObaPairs
(
op_graph
,
job_builder
);
}
void
GenerateOpConf4Trainning
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
{
LogicalBlobId
total_loss_instance_num
;
HashMap
<
LogicalBlobId
,
LogicalBlobId
>
lbi2diff_lbi
;
AutoGrad
(
op_graph
,
job_builder
,
&
lbi2diff_lbi
);
std
::
function
<
const
LogicalBlobId
&
(
const
ParallelDesc
&
)
>
LossInstanceNum4ParallelDesc
;
AddTotalLossInstanceNumOpConf
(
op_graph
,
job_builder
,
lbi2diff_lbi
,
&
LossInstanceNum4ParallelDesc
);
AddOptimizerOpConf
(
op_graph
,
job_builder
,
lbi2diff_lbi
,
LossInstanceNum4ParallelDesc
);
UpdateJobHelperConfProducedLbi2ConsumedDiffLbi
(
lbi2diff_lbi
,
job_builder
);
UpdateOpSbpSignatureHint
(
op_graph
,
job_builder
);
}
std
::
function
<
ParallelConf
*
(
const
std
::
string
&
)
>
MakeGetterMutParallelConf4OpName
(
Placement
*
placement
)
{
auto
op_name2parallel_conf
=
std
::
make_shared
<
HashMap
<
std
::
string
,
ParallelConf
*>>
();
FOR_RANGE
(
int
,
idx
,
0
,
placement
->
placement_group_size
())
{
auto
*
placement_group
=
placement
->
mutable_placement_group
(
idx
);
for
(
const
std
::
string
&
op_name
:
placement_group
->
op_set
().
op_name
())
{
ParallelConf
*
parallel_conf
=
placement_group
->
mutable_parallel_conf
();
CHECK
(
op_name2parallel_conf
->
emplace
(
op_name
,
parallel_conf
).
second
);
}
}
return
[
op_name2parallel_conf
](
const
std
::
string
&
op_name
)
{
return
op_name2parallel_conf
->
at
(
op_name
);
};
}
void
SetCtrlInOpName4VariableOp
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
{
void
SetCtrlInOpName4VariableOp
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
{
auto
IsMutableConsumedLbi
=
[](
const
Operator
&
op
,
const
LogicalBlobId
&
lbi
)
->
bool
{
auto
IsMutableConsumedLbi
=
[](
const
Operator
&
op
,
const
LogicalBlobId
&
lbi
)
->
bool
{
for
(
const
std
::
string
&
bn
:
op
.
input_bns
())
{
for
(
const
std
::
string
&
bn
:
op
.
input_bns
())
{
...
@@ -179,28 +83,20 @@ void DumpLogicalBlobDescAndSbpSignature(const OpGraph& op_graph, JobBuilder* job
...
@@ -179,28 +83,20 @@ void DumpLogicalBlobDescAndSbpSignature(const OpGraph& op_graph, JobBuilder* job
op_graph
.
DumpSbpSignature
(
job_builder
);
op_graph
.
DumpSbpSignature
(
job_builder
);
}
}
void
MakeNcclTupleBroadcastReduceSequence
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
{
NcclTupleBroadcastReduceSequencePass
().
Apply
(
op_graph
,
job_builder
);
}
}
// namespace
}
// namespace
void
JobCompleter
::
Complete
(
Job
*
job
)
const
{
void
JobCompleter
::
Complete
(
Job
*
job
)
const
{
// complete variable ops
FunctionPass
(
"SetDefaultVariableConf"
)(
job
);
FunctionPass
(
"SetDefaultVariableConf"
)(
job
);
FunctionPass
(
"AutoMixedPrecision"
)(
job
);
FunctionPass
(
"AutoMixedPrecision"
)(
job
);
if
(
GlobalJobDesc
().
IsTrain
())
{
FunctionPass
(
"TieUpChainHeadersUnReachableFromAnyVariableOps"
)(
job
);
FunctionPass
(
"TieUpChainHeadersUnReachableFromAnyVariableOps"
)(
job
);
FunctionPass
(
"NonDistributedOptimizerPass"
)(
job
);
FunctionPass
(
"NonDistributedOptimizerPass"
)(
job
);
FunctionPass
(
"AutoTrainStep"
)(
job
);
WithOpGraphAndMutJob
(
job
,
&
AutoTrainStep
);
FunctionPass
(
"AutoLearningRate"
)(
job
);
WithOpGraphAndMutJob
(
job
,
&
AutoLearningRate
);
FunctionPass
(
"GenerateBackwardAndOptimizerOpConfs"
)(
job
);
// complete ops for trainning
FunctionPass
(
"SequentializeNcclTupleBroadcastReducePass"
)(
job
);
WithOpGraphAndMutJobBuilder
(
job
,
&
GenerateOpConf4Trainning
);
FunctionPass
(
"AddAllReduceGroupPass"
)(
job
);
WithOpGraphAndMutJobBuilder
(
job
,
&
MakeNcclTupleBroadcastReduceSequence
);
FunctionPass
(
"AddLbiDiffWatcherOpConfs"
)(
job
);
AllReduceAddPass
()(
job
);
FunctionPass
(
"SequentializeAllReduceGroupPass"
)(
job
);
AddLbiDiffWatcherOpConfs
(
job
);
AllReduceSequencePass
()(
job
);
}
WithOpGraphAndMutJobBuilder
(
job
,
&
DumpLogicalBlobDescAndSbpSignature
);
WithOpGraphAndMutJobBuilder
(
job
,
&
DumpLogicalBlobDescAndSbpSignature
);
WithOpGraphAndMutJobBuilder
(
job
,
&
GroupBoxingByDstParallel
);
WithOpGraphAndMutJobBuilder
(
job
,
&
GroupBoxingByDstParallel
);
WithOpGraphAndMutJobBuilder
(
job
,
&
AddKeepHeaderOnlyOp
);
WithOpGraphAndMutJobBuilder
(
job
,
&
AddKeepHeaderOnlyOp
);
...
...
oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h
已删除
100644 → 0
浏览文件 @
4c031f17
#ifndef ONEFLOW_CORE_JOB_COMPLETER_NCCL_TUPLE_BROADCAST_REDUCE_SEQUENCE_PASS_H_
#define ONEFLOW_CORE_JOB_COMPLETER_NCCL_TUPLE_BROADCAST_REDUCE_SEQUENCE_PASS_H_
#include "oneflow/core/common/util.h"
namespace
oneflow
{
class
OpGraph
;
class
JobBuilder
;
class
NcclTupleBroadcastReduceSequencePass
final
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
NcclTupleBroadcastReduceSequencePass
);
NcclTupleBroadcastReduceSequencePass
()
=
default
;
~
NcclTupleBroadcastReduceSequencePass
()
=
default
;
void
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
;
};
}
// namespace oneflow
#endif // #define ONEFLOW_CORE_JOB_COMPLETER_NCCL_TUPLE_BROADCAST_REDUCE_SEQUENCE_PASS_H_
oneflow/core/job_completer/non_distributed_optimizer_pass.cpp
浏览文件 @
bc5b5e0f
...
@@ -35,7 +35,9 @@ class NonDistributedOptimizerPass final : public OpGraphPass {
...
@@ -35,7 +35,9 @@ class NonDistributedOptimizerPass final : public OpGraphPass {
OF_DISALLOW_COPY_AND_MOVE
(
NonDistributedOptimizerPass
);
OF_DISALLOW_COPY_AND_MOVE
(
NonDistributedOptimizerPass
);
NonDistributedOptimizerPass
()
=
default
;
NonDistributedOptimizerPass
()
=
default
;
~
NonDistributedOptimizerPass
()
=
default
;
~
NonDistributedOptimizerPass
()
=
default
;
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
enable_non_distributed_optimizer
();
}
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
IsTrain
()
&&
GlobalJobDesc
().
enable_non_distributed_optimizer
();
}
void
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
override
;
void
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
override
;
};
};
...
...
oneflow/core/job_completer/op_graph_pass.h
浏览文件 @
bc5b5e0f
...
@@ -9,6 +9,9 @@ namespace oneflow {
...
@@ -9,6 +9,9 @@ namespace oneflow {
class
OpGraphPass
{
class
OpGraphPass
{
public:
public:
OpGraphPass
()
=
default
;
virtual
~
OpGraphPass
()
=
default
;
void
operator
()(
Job
*
job
)
const
{
void
operator
()(
Job
*
job
)
const
{
if
(
IsEnabled
()
==
false
)
{
return
;
}
if
(
IsEnabled
()
==
false
)
{
return
;
}
Apply
(
job
);
Apply
(
job
);
...
...
oneflow/core/job_completer/
all_reduce_sequence
_pass.cpp
→
oneflow/core/job_completer/
sequentialize_all_reduce_group
_pass.cpp
浏览文件 @
bc5b5e0f
#include "oneflow/core/job_completer/
all_reduce_sequence
_pass.h"
#include "oneflow/core/job_completer/
op_graph
_pass.h"
namespace
oneflow
{
namespace
oneflow
{
...
@@ -52,9 +52,18 @@ void ReOrderAllReduceGroups(std::vector<AllReduceGroup>* all_reduce_groups) {
...
@@ -52,9 +52,18 @@ void ReOrderAllReduceGroups(std::vector<AllReduceGroup>* all_reduce_groups) {
all_reduce_groups
->
end
()
-
lazy_count
);
all_reduce_groups
->
end
()
-
lazy_count
);
}
}
}
// namespace
class
SequentializeAllReduceGroupPass
final
:
public
OpGraphPass
{
public:
SequentializeAllReduceGroupPass
()
=
default
;
~
SequentializeAllReduceGroupPass
()
=
default
;
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
IsTrain
()
&&
!
GlobalJobDesc
().
disable_all_reduce_sequence
();
}
void
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
override
;
};
void
AllReduceSequencePass
::
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
{
void
SequentializeAllReduceGroupPass
::
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
{
std
::
vector
<
AllReduceGroup
>
all_reduce_groups
;
std
::
vector
<
AllReduceGroup
>
all_reduce_groups
;
FindAllReduceGroups
(
op_graph
,
&
all_reduce_groups
);
FindAllReduceGroups
(
op_graph
,
&
all_reduce_groups
);
ReOrderAllReduceGroups
(
&
all_reduce_groups
);
ReOrderAllReduceGroups
(
&
all_reduce_groups
);
...
@@ -68,4 +77,8 @@ void AllReduceSequencePass::Apply(const OpGraph& op_graph, JobBuilder* job_build
...
@@ -68,4 +77,8 @@ void AllReduceSequencePass::Apply(const OpGraph& op_graph, JobBuilder* job_build
}
}
}
}
REGISTER_FUNCTION_PASS
(
"SequentializeAllReduceGroupPass"
,
SequentializeAllReduceGroupPass
);
}
// namespace
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/job_completer/
nccl_tuple_broadcast_reduce_sequen
ce_pass.cpp
→
oneflow/core/job_completer/
sequentialize_nccl_tuple_broadcast_redu
ce_pass.cpp
浏览文件 @
bc5b5e0f
#include "oneflow/core/job_completer/
nccl_tuple_broadcast_reduce_sequence
_pass.h"
#include "oneflow/core/job_completer/
op_graph
_pass.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/graph/op_graph.h"
namespace
oneflow
{
namespace
oneflow
{
void
NcclTupleBroadcastReduceSequencePass
::
Apply
(
const
OpGraph
&
op_graph
,
class
SequentializeNcclTupleBroadcastReducePass
final
:
public
OpGraphPass
{
JobBuilder
*
builder
)
const
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
SequentializeNcclTupleBroadcastReducePass
);
SequentializeNcclTupleBroadcastReducePass
()
=
default
;
~
SequentializeNcclTupleBroadcastReducePass
()
=
default
;
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
IsTrain
();
}
void
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
override
;
};
void
SequentializeNcclTupleBroadcastReducePass
::
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
builder
)
const
{
std
::
vector
<
OperatorConf
>
broadcast_ops
;
std
::
vector
<
OperatorConf
>
broadcast_ops
;
std
::
vector
<
OperatorConf
>
reduce_ops
;
std
::
vector
<
OperatorConf
>
reduce_ops
;
op_graph
.
ForEachNode
([
&
](
const
OpNode
*
node
)
{
op_graph
.
ForEachNode
([
&
](
const
OpNode
*
node
)
{
...
@@ -41,4 +52,7 @@ void NcclTupleBroadcastReduceSequencePass::Apply(const OpGraph& op_graph,
...
@@ -41,4 +52,7 @@ void NcclTupleBroadcastReduceSequencePass::Apply(const OpGraph& op_graph,
builder
->
MutOpsOnlyOnce
(
reduce_ops
);
builder
->
MutOpsOnlyOnce
(
reduce_ops
);
}
}
REGISTER_FUNCTION_PASS
(
"SequentializeNcclTupleBroadcastReducePass"
,
SequentializeNcclTupleBroadcastReducePass
);
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/job_completer/tie_up_chain_headers.cpp
浏览文件 @
bc5b5e0f
...
@@ -94,7 +94,9 @@ std::function<bool(OpNode*)> MakePredicatorIsReachableFromAnyVariableOps(const O
...
@@ -94,7 +94,9 @@ std::function<bool(OpNode*)> MakePredicatorIsReachableFromAnyVariableOps(const O
REGISTER_FUNCTION_CONFIG_DEF
().
Bool
(
"enable_pseudo_chain_merge"
,
false
,
REGISTER_FUNCTION_CONFIG_DEF
().
Bool
(
"enable_pseudo_chain_merge"
,
false
,
"ties up chain headers unreachable from any variable ops"
);
"ties up chain headers unreachable from any variable ops"
);
class
TieUpChainHeadersUnReachableFromAnyVariableOps
final
:
public
OpGraphPass
{
class
TieUpChainHeadersUnReachableFromAnyVariableOps
final
:
public
OpGraphPass
{
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
Bool
(
"enable_pseudo_chain_merge"
);
}
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
IsTrain
()
&&
GlobalJobDesc
().
Bool
(
"enable_pseudo_chain_merge"
);
}
void
Apply
(
const
OpGraph
&
op_graph
,
Job
*
job
)
const
override
{
void
Apply
(
const
OpGraph
&
op_graph
,
Job
*
job
)
const
override
{
auto
IsReachableFromAnyVariableOps
=
MakePredicatorIsReachableFromAnyVariableOps
(
op_graph
);
auto
IsReachableFromAnyVariableOps
=
MakePredicatorIsReachableFromAnyVariableOps
(
op_graph
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录