Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
ee85ce06
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ee85ce06
编写于
1月 07, 2020
作者:
L
lixinqi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
more cases of REGISTER_FUNCTION_PASS
上级
c117d4ed
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
128 addition
and
162 deletion
+128
-162
oneflow/core/job/oneflow.cpp
oneflow/core/job/oneflow.cpp
+2
-2
oneflow/core/job_completer/auto_mixed_precision.cpp
oneflow/core/job_completer/auto_mixed_precision.cpp
+36
-2
oneflow/core/job_completer/auto_mixed_precision.h
oneflow/core/job_completer/auto_mixed_precision.h
+0
-45
oneflow/core/job_completer/complete_ofrecord_decoder.cpp
oneflow/core/job_completer/complete_ofrecord_decoder.cpp
+11
-5
oneflow/core/job_completer/job_completer.cpp
oneflow/core/job_completer/job_completer.cpp
+4
-7
oneflow/core/job_completer/non_distributed_optimizer_pass.cpp
...low/core/job_completer/non_distributed_optimizer_pass.cpp
+14
-2
oneflow/core/job_completer/non_distributed_optimizer_pass.h
oneflow/core/job_completer/non_distributed_optimizer_pass.h
+0
-23
oneflow/core/job_completer/op_graph_pass.cpp
oneflow/core/job_completer/op_graph_pass.cpp
+5
-1
oneflow/core/job_completer/op_graph_pass.h
oneflow/core/job_completer/op_graph_pass.h
+6
-2
oneflow/core/job_completer/set_default_variable_conf.cpp
oneflow/core/job_completer/set_default_variable_conf.cpp
+49
-39
oneflow/core/job_completer/set_default_variable_conf.h
oneflow/core/job_completer/set_default_variable_conf.h
+0
-14
oneflow/core/job_completer/tie_up_chain_headers.cpp
oneflow/core/job_completer/tie_up_chain_headers.cpp
+1
-0
oneflow/core/job_completer/user_job_completer.h
oneflow/core/job_completer/user_job_completer.h
+0
-20
未找到文件。
oneflow/core/job/oneflow.cpp
浏览文件 @
ee85ce06
...
...
@@ -6,7 +6,7 @@
#include "oneflow/core/job/improver.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/job_builder.h"
#include "oneflow/core/job_completer/
user_job_completer
.h"
#include "oneflow/core/job_completer/
op_graph_pass
.h"
#include "oneflow/core/job/job_set.pb.h"
#include "oneflow/core/job/machine_context.h"
#include "oneflow/core/job/profiler.h"
...
...
@@ -743,7 +743,7 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
AddJobName2JobId
(
jobs
.
at
(
job_id
).
job_conf
().
job_name
(),
job_id
);
{
auto
scope
=
std
::
make_unique
<
GlobalJobDescScope
>
(
jobs
.
at
(
job_id
).
job_conf
(),
job_id
);
UserJobCompleter
().
Complete
(
&
jobs
.
at
(
job_id
));
FunctionPass
(
"CompleteOfrecordDecoder"
)
(
&
jobs
.
at
(
job_id
));
CompileCurJobOnMaster
(
&
jobs
.
at
(
job_id
),
&
sub_plans
.
at
(
job_id
),
true
);
}
}
...
...
oneflow/core/job_completer/auto_mixed_precision.cpp
浏览文件 @
ee85ce06
#include <algorithm>
#include "oneflow/core/job_completer/auto_mixed_precision.h"
#include "oneflow/core/job_completer/auto_mixed_precision_lists.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/device/cuda_util.h"
...
...
@@ -175,7 +176,36 @@ void InsertCastOpImpl(bool f2h, const OpGraph& op_graph, const HashSet<OpNode*>&
job_builder
->
MutOpsOnlyOnce
(
dst_op_confs
);
}
}
// namespace
class
AutoMixedPrecision
final
:
public
OpGraphPass
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
AutoMixedPrecision
);
AutoMixedPrecision
()
:
white_list_
(
AutoMixedPrecisionLists
::
WhiteList
()),
black_list_
(
AutoMixedPrecisionLists
::
BlackList
()),
gray_list_
(
AutoMixedPrecisionLists
::
GrayList
()),
clear_list_
(
AutoMixedPrecisionLists
::
ClearList
())
{}
~
AutoMixedPrecision
()
=
default
;
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
enable_auto_mixed_precision
();
}
void
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
override
;
private:
void
FillBlackSet
(
const
OpGraph
&
op_graph
,
HashSet
<
OpNode
*>*
black_set
)
const
;
void
FillWhiteSet
(
const
OpGraph
&
op_graph
,
std
::
function
<
bool
(
OpNode
*
)
>
IsAllowedToRunWithHalf
,
const
HashSet
<
OpNode
*>&
black_set
,
HashSet
<
OpNode
*>*
white_set
)
const
;
void
PropagateWhiteThroughClearNodes
(
const
OpGraph
&
op_graph
,
std
::
function
<
bool
(
OpNode
*
)
>
IsAllowedToRunWithHalf
,
const
HashSet
<
OpNode
*>&
black_set
,
HashSet
<
OpNode
*>*
white_set
)
const
;
void
InsertCastOp
(
const
OpGraph
&
op_graph
,
const
HashSet
<
OpNode
*>&
white_set
,
JobBuilder
*
job_builder
)
const
;
const
AMPList
&
white_list_
;
const
AMPList
&
black_list_
;
const
AMPList
&
gray_list_
;
const
AMPList
&
clear_list_
;
};
void
AutoMixedPrecision
::
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
{
CHECK_GE
(
CUDA_VERSION
,
10000
);
...
...
@@ -286,4 +316,8 @@ void AutoMixedPrecision::InsertCastOp(const OpGraph& op_graph, const HashSet<OpN
InsertCastOpImpl
(
false
,
op_graph
,
white_set
,
job_builder
);
}
REGISTER_FUNCTION_PASS
(
"AutoMixedPrecision"
,
AutoMixedPrecision
);
}
// namespace
}
// namespace oneflow
oneflow/core/job_completer/auto_mixed_precision.h
已删除
100644 → 0
浏览文件 @
c117d4ed
#ifndef ONEFLOW_CORE_JOB_COMPLETER_AUTO_MIXED_PRECISION_H_
#define ONEFLOW_CORE_JOB_COMPLETER_AUTO_MIXED_PRECISION_H_
#include "oneflow/core/job_completer/auto_mixed_precision_lists.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace
oneflow
{
class
OpGraph
;
class
OpNode
;
class
Job
;
class
AutoMixedPrecision
final
:
public
OpGraphPass
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
AutoMixedPrecision
);
AutoMixedPrecision
()
:
white_list_
(
AutoMixedPrecisionLists
::
WhiteList
()),
black_list_
(
AutoMixedPrecisionLists
::
BlackList
()),
gray_list_
(
AutoMixedPrecisionLists
::
GrayList
()),
clear_list_
(
AutoMixedPrecisionLists
::
ClearList
())
{}
~
AutoMixedPrecision
()
=
default
;
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
enable_auto_mixed_precision
();
}
void
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
override
;
private:
void
FillBlackSet
(
const
OpGraph
&
op_graph
,
HashSet
<
OpNode
*>*
black_set
)
const
;
void
FillWhiteSet
(
const
OpGraph
&
op_graph
,
std
::
function
<
bool
(
OpNode
*
)
>
IsAllowedToRunWithHalf
,
const
HashSet
<
OpNode
*>&
black_set
,
HashSet
<
OpNode
*>*
white_set
)
const
;
void
PropagateWhiteThroughClearNodes
(
const
OpGraph
&
op_graph
,
std
::
function
<
bool
(
OpNode
*
)
>
IsAllowedToRunWithHalf
,
const
HashSet
<
OpNode
*>&
black_set
,
HashSet
<
OpNode
*>*
white_set
)
const
;
void
InsertCastOp
(
const
OpGraph
&
op_graph
,
const
HashSet
<
OpNode
*>&
white_set
,
JobBuilder
*
job_builder
)
const
;
const
AMPList
&
white_list_
;
const
AMPList
&
black_list_
;
const
AMPList
&
gray_list_
;
const
AMPList
&
clear_list_
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_AUTO_MIXED_PRECISION_H_
oneflow/core/job_completer/
user_job_complet
er.cpp
→
oneflow/core/job_completer/
complete_ofrecord_decod
er.cpp
浏览文件 @
ee85ce06
#include "oneflow/core/job_completer/
user_job_completer
.h"
#include "oneflow/core/job_completer/
op_graph_pass
.h"
#include "oneflow/core/job/job_builder.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/job/parallel_desc.h"
...
...
@@ -132,9 +132,15 @@ void AddRecordLoadOps(Job* job) {
}
// namespace
void
UserJobCompleter
::
Complete
(
Job
*
job
)
const
{
SplitDecodeOps
(
job
);
AddRecordLoadOps
(
job
);
}
class
CompleteOfrecordDecoder
final
:
public
OpGraphPass
{
public:
bool
IsEnabled
()
const
override
{
return
true
;
}
void
Apply
(
Job
*
job
)
const
override
{
SplitDecodeOps
(
job
);
AddRecordLoadOps
(
job
);
}
};
REGISTER_FUNCTION_PASS
(
"CompleteOfrecordDecoder"
,
CompleteOfrecordDecoder
);
}
// namespace oneflow
oneflow/core/job_completer/job_completer.cpp
浏览文件 @
ee85ce06
...
...
@@ -5,11 +5,8 @@
#include "oneflow/core/job_completer/optimizer.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job_completer/all_reduce_add_pass.h"
#include "oneflow/core/job_completer/set_default_variable_conf.h"
#include "oneflow/core/job_completer/all_reduce_sequence_pass.h"
#include "oneflow/core/job_completer/group_boxing_by_dst_parallel.h"
#include "oneflow/core/job_completer/auto_mixed_precision.h"
#include "oneflow/core/job_completer/non_distributed_optimizer_pass.h"
#include "oneflow/core/job_completer/nccl_tuple_broadcast_reduce_sequence_pass.h"
#include "oneflow/core/job_completer/auto_train_step.h"
#include "oneflow/core/job_completer/auto_learning_rate.h"
...
...
@@ -190,11 +187,11 @@ void MakeNcclTupleBroadcastReduceSequence(const OpGraph& op_graph, JobBuilder* j
void
JobCompleter
::
Complete
(
Job
*
job
)
const
{
// complete variable ops
WithOpGraphAndMutJobBuilder
(
job
,
&
SetDefaultVariableConf
);
AutoMixedPrecision
(
)(
job
);
FunctionPass
(
"SetDefaultVariableConf"
)(
job
);
FunctionPass
(
"AutoMixedPrecision"
)(
job
);
if
(
GlobalJobDesc
().
IsTrain
())
{
F
indF
unctionPass
(
"TieUpChainHeadersUnReachableFromAnyVariableOps"
)(
job
);
NonDistributedOptimizerPass
(
)(
job
);
FunctionPass
(
"TieUpChainHeadersUnReachableFromAnyVariableOps"
)(
job
);
FunctionPass
(
"NonDistributedOptimizerPass"
)(
job
);
WithOpGraphAndMutJob
(
job
,
&
AutoTrainStep
);
WithOpGraphAndMutJob
(
job
,
&
AutoLearningRate
);
// complete ops for trainning
...
...
oneflow/core/job_completer/non_distributed_optimizer_pass.cpp
浏览文件 @
ee85ce06
#include "oneflow/core/job_completer/non_distributed_optimizer_pass.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/job_desc.h"
...
...
@@ -29,7 +30,14 @@ ParallelConf NonDistributedParallelConf4ParallelId(const ParallelDesc& pd,
return
parallel_conf
;
}
}
// namespace
class
NonDistributedOptimizerPass
final
:
public
OpGraphPass
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
NonDistributedOptimizerPass
);
NonDistributedOptimizerPass
()
=
default
;
~
NonDistributedOptimizerPass
()
=
default
;
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
enable_non_distributed_optimizer
();
}
void
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
override
;
};
void
NonDistributedOptimizerPass
::
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
builder
)
const
{
HashMap
<
ParallelDesc
,
HashMap
<
const
OpNode
*
,
std
::
vector
<
const
OpNode
*>>>
pd2last_node2node_seqs
;
...
...
@@ -176,4 +184,8 @@ void NonDistributedOptimizerPass::Apply(const OpGraph& op_graph, JobBuilder* bui
}
}
REGISTER_FUNCTION_PASS
(
"NonDistributedOptimizerPass"
,
NonDistributedOptimizerPass
);
}
// namespace
}
// namespace oneflow
oneflow/core/job_completer/non_distributed_optimizer_pass.h
已删除
100644 → 0
浏览文件 @
c117d4ed
#ifndef ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_
#define ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
namespace
oneflow
{
class
OpGraph
;
class
JobBuilder
;
class
NonDistributedOptimizerPass
final
:
public
OpGraphPass
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
NonDistributedOptimizerPass
);
NonDistributedOptimizerPass
()
=
default
;
~
NonDistributedOptimizerPass
()
=
default
;
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
enable_non_distributed_optimizer
();
}
void
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
override
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_NON_DISTRIBUTED_OPTIMIZER_PASS_H_
oneflow/core/job_completer/op_graph_pass.cpp
浏览文件 @
ee85ce06
...
...
@@ -15,7 +15,11 @@ void RegisterFunctionPass(const std::string& pass_name, const OpGraphPass* pass)
CHECK
(
PassName2FunctionPass
()
->
emplace
(
pass_name
,
pass
).
second
);
}
const
OpGraphPass
&
FindFunctionPass
(
const
std
::
string
&
pass_name
)
{
bool
HasFunctionPass
(
const
std
::
string
&
pass_name
)
{
return
PassName2FunctionPass
()
->
find
(
pass_name
)
!=
PassName2FunctionPass
()
->
end
();
}
const
OpGraphPass
&
FunctionPass
(
const
std
::
string
&
pass_name
)
{
const
auto
&
iter
=
PassName2FunctionPass
()
->
find
(
pass_name
);
CHECK
(
iter
!=
PassName2FunctionPass
()
->
end
());
return
*
iter
->
second
;
...
...
oneflow/core/job_completer/op_graph_pass.h
浏览文件 @
ee85ce06
...
...
@@ -11,10 +11,13 @@ class OpGraphPass {
public:
void
operator
()(
Job
*
job
)
const
{
if
(
IsEnabled
()
==
false
)
{
return
;
}
Apply
(
job
);
}
virtual
bool
IsEnabled
()
const
{
return
true
;
}
virtual
void
Apply
(
Job
*
job
)
const
{
const
OpGraph
op_graph
(
*
job
);
Apply
(
op_graph
,
job
);
}
virtual
bool
IsEnabled
()
const
{
return
true
;
}
virtual
void
Apply
(
const
OpGraph
&
op_graph
,
Job
*
job
)
const
{
JobBuilder
job_builder
(
job
);
Apply
(
op_graph
,
&
job_builder
);
...
...
@@ -26,7 +29,8 @@ class OpGraphPass {
COMMAND(RegisterFunctionPass(pass_name, new pass_type))
void
RegisterFunctionPass
(
const
std
::
string
&
pass_name
,
const
OpGraphPass
*
pass
);
const
OpGraphPass
&
FindFunctionPass
(
const
std
::
string
&
pass_name
);
bool
HasFunctionPass
(
const
std
::
string
&
pass_name
);
const
OpGraphPass
&
FunctionPass
(
const
std
::
string
&
pass_name
);
}
// namespace oneflow
...
...
oneflow/core/job_completer/set_default_variable_conf.cpp
浏览文件 @
ee85ce06
#include "oneflow/core/job_completer/
set_default_variable_conf
.h"
#include "oneflow/core/job_completer/
op_graph_pass
.h"
#include "oneflow/core/job/job_builder.h"
#include "oneflow/core/job/job_set_compile_ctx.h"
namespace
oneflow
{
void
SetDefaultVariableConf
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
{
auto
BlobDesc4ModelLbi
=
op_graph
.
MakeGetterBlobDesc4ModelLbi
();
op_graph
.
ForEachNode
([
&
](
OpNode
*
op_node
)
{
if
(
op_node
->
op
().
op_conf
().
has_variable_conf
())
{
OperatorConf
variable_op_conf
(
op_node
->
op
().
op_conf
());
VariableOpConf
*
variable_conf
=
variable_op_conf
.
mutable_variable_conf
();
if
(
!
variable_conf
->
has_data_type
())
{
variable_conf
->
set_data_type
(
job_builder
->
job
().
job_conf
().
default_data_type
());
}
if
(
!
variable_conf
->
has_initializer
()
&&
!
variable_conf
->
has_initialize_with_snapshot
())
{
if
(
job_builder
->
job
().
job_conf
().
has_default_initializer_conf
())
{
*
variable_conf
->
mutable_initializer
()
=
job_builder
->
job
().
job_conf
().
default_initializer_conf
();
}
else
if
(
job_builder
->
job
().
job_conf
().
has_default_initialize_with_snapshot_path
())
{
variable_conf
->
mutable_initialize_with_snapshot
()
->
set_path
(
job_builder
->
job
().
job_conf
().
default_initialize_with_snapshot_path
());
variable_conf
->
mutable_initialize_with_snapshot
()
->
set_key
(
GenLogicalBlobName
(
op_node
->
op
().
BnInOp2Lbi
(
"out"
)));
namespace
{
class
SetDefaultVariableConf
final
:
public
OpGraphPass
{
bool
IsEnabled
()
const
override
{
return
true
;
}
void
Apply
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
)
const
override
{
auto
BlobDesc4ModelLbi
=
op_graph
.
MakeGetterBlobDesc4ModelLbi
();
op_graph
.
ForEachNode
([
&
](
OpNode
*
op_node
)
{
if
(
op_node
->
op
().
op_conf
().
has_variable_conf
())
{
OperatorConf
variable_op_conf
(
op_node
->
op
().
op_conf
());
VariableOpConf
*
variable_conf
=
variable_op_conf
.
mutable_variable_conf
();
if
(
!
variable_conf
->
has_data_type
())
{
variable_conf
->
set_data_type
(
job_builder
->
job
().
job_conf
().
default_data_type
());
}
if
(
!
variable_conf
->
has_initializer
()
&&
!
variable_conf
->
has_initialize_with_snapshot
())
{
if
(
job_builder
->
job
().
job_conf
().
has_default_initializer_conf
())
{
*
variable_conf
->
mutable_initializer
()
=
job_builder
->
job
().
job_conf
().
default_initializer_conf
();
}
else
if
(
job_builder
->
job
().
job_conf
().
has_default_initialize_with_snapshot_path
())
{
variable_conf
->
mutable_initialize_with_snapshot
()
->
set_path
(
job_builder
->
job
().
job_conf
().
default_initialize_with_snapshot_path
());
variable_conf
->
mutable_initialize_with_snapshot
()
->
set_key
(
GenLogicalBlobName
(
op_node
->
op
().
BnInOp2Lbi
(
"out"
)));
}
else
{
UNIMPLEMENTED
();
}
}
int64_t
random_seed
;
auto
*
var_op_name2random
=
Global
<
JobSetCompileCtx
>::
Get
()
->
GetVarOpName2randomSeed
();
const
std
::
string
&
var_op_name
=
variable_op_conf
.
name
();
if
(
variable_conf
->
has_random_seed
())
{
random_seed
=
variable_conf
->
random_seed
();
}
else
{
UNIMPLEMENTED
();
random_seed
=
NewRandomSeed
();
}
const
auto
&
pair
=
var_op_name2random
->
insert
({
var_op_name
,
random_seed
});
if
(
variable_conf
->
has_random_seed
())
{
CHECK_EQ
(
variable_conf
->
random_seed
(),
pair
.
first
->
second
);
}
else
{
variable_conf
->
set_random_seed
(
pair
.
first
->
second
);
}
job_builder
->
AddOrMutOpsOnlyOnce
(
op_node
->
parallel_desc
().
parallel_conf
(),
{
variable_op_conf
});
}
int64_t
random_seed
;
auto
*
var_op_name2random
=
Global
<
JobSetCompileCtx
>::
Get
()
->
GetVarOpName2randomSeed
();
const
std
::
string
&
var_op_name
=
variable_op_conf
.
name
();
if
(
variable_conf
->
has_random_seed
())
{
random_seed
=
variable_conf
->
random_seed
();
}
else
{
random_seed
=
NewRandomSeed
();
}
const
auto
&
pair
=
var_op_name2random
->
insert
({
var_op_name
,
random_seed
});
if
(
variable_conf
->
has_random_seed
())
{
CHECK_EQ
(
variable_conf
->
random_seed
(),
pair
.
first
->
second
);
}
else
{
variable_conf
->
set_random_seed
(
pair
.
first
->
second
);
}
job_builder
->
AddOrMutOpsOnlyOnce
(
op_node
->
parallel_desc
().
parallel_conf
(),
{
variable_op_conf
});
}
});
}
});
}
};
REGISTER_FUNCTION_PASS
(
"SetDefaultVariableConf"
,
SetDefaultVariableConf
);
}
// namespace
}
// namespace oneflow
oneflow/core/job_completer/set_default_variable_conf.h
已删除
100644 → 0
浏览文件 @
c117d4ed
#ifndef ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_
#define ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/graph/op_graph.h"
namespace
oneflow
{
void
SetDefaultVariableConf
(
const
OpGraph
&
op_graph
,
JobBuilder
*
job_builder
);
}
// namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_FILL_VARIABLE_CONF_H_
oneflow/core/job_completer/tie_up_chain_headers.cpp
浏览文件 @
ee85ce06
...
...
@@ -95,6 +95,7 @@ REGISTER_FUNCTION_CONFIG_DEF().Bool("enable_pseudo_chain_merge", false,
"ties up chain headers unreachable from any variable ops"
);
class
TieUpChainHeadersUnReachableFromAnyVariableOps
final
:
public
OpGraphPass
{
bool
IsEnabled
()
const
override
{
return
GlobalJobDesc
().
Bool
(
"enable_pseudo_chain_merge"
);
}
void
Apply
(
const
OpGraph
&
op_graph
,
Job
*
job
)
const
override
{
auto
IsReachableFromAnyVariableOps
=
MakePredicatorIsReachableFromAnyVariableOps
(
op_graph
);
auto
GetSourceNodesAndEdges
=
[
&
](
const
HashSet
<
OpNode
*>&
chain_nodes
,
...
...
oneflow/core/job_completer/user_job_completer.h
已删除
100644 → 0
浏览文件 @
c117d4ed
#ifndef ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_
#define ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/job_desc.h"
namespace
oneflow
{
class
UserJobCompleter
final
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
UserJobCompleter
);
UserJobCompleter
()
=
default
;
~
UserJobCompleter
()
=
default
;
void
Complete
(
Job
*
job
)
const
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_JOB_COMPLETER_USER_JOB_COMPLETER_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录