Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
beffe2d5
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 接近 3 年
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
beffe2d5
编写于
8月 17, 2017
作者:
W
willzhang4a58
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine job conf
上级
fa4085b1
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
56 addition
and
49 deletion
+56
-49
oneflow/core/graph/loss_accumulate_task_graph.cpp
oneflow/core/graph/loss_accumulate_task_graph.cpp
+1
-1
oneflow/core/graph/loss_record_task_graph.cpp
oneflow/core/graph/loss_record_task_graph.cpp
+2
-3
oneflow/core/graph/model_save_task_graph.cpp
oneflow/core/graph/model_save_task_graph.cpp
+2
-4
oneflow/core/graph/model_update_task_graph.cpp
oneflow/core/graph/model_update_task_graph.cpp
+1
-1
oneflow/core/job/dlnet_conf.proto
oneflow/core/job/dlnet_conf.proto
+0
-1
oneflow/core/job/job_conf.proto
oneflow/core/job/job_conf.proto
+25
-16
oneflow/core/job/job_desc.cpp
oneflow/core/job/job_desc.cpp
+1
-1
oneflow/core/job/job_desc.h
oneflow/core/job/job_desc.h
+9
-7
oneflow/core/job/keyword.cpp
oneflow/core/job/keyword.cpp
+4
-1
oneflow/core/job/keyword.h
oneflow/core/job/keyword.h
+1
-0
oneflow/core/job/parallel_desc.cpp
oneflow/core/job/parallel_desc.cpp
+2
-2
oneflow/core/job/placement.proto
oneflow/core/job/placement.proto
+1
-5
oneflow/core/operator/operator_manager.cpp
oneflow/core/operator/operator_manager.cpp
+7
-7
未找到文件。
oneflow/core/graph/loss_accumulate_task_graph.cpp
浏览文件 @
beffe2d5
...
...
@@ -20,7 +20,7 @@ void LossAccTaskGraph::BuildTaskGraph() {
// parallel_desc
ParallelConf
pr_conf
;
pr_conf
.
set_policy
(
kDataParallel
);
pr_conf
.
mutable_device_set
()
->
add_device_name
(
loss_task_
->
device_name
());
pr_conf
.
add_device_name
(
loss_task_
->
device_name
());
auto
pr_desc
=
std
::
make_shared
<
ParallelDesc
>
(
pr_conf
);
// faker chain
auto
chain_gph
=
of_make_unique
<
ChainGraph
>
();
...
...
oneflow/core/graph/loss_record_task_graph.cpp
浏览文件 @
beffe2d5
...
...
@@ -18,8 +18,7 @@ void LossRecordTaskGraph::BuildTaskGraph(
faker_pr_conf
.
set_policy
(
kFakerLossRecord
);
for
(
TaskNode
*
task
:
sorted_loss_acc_task
)
{
auto
loss_acc_task
=
static_cast
<
CompTaskNode
*>
(
task
);
faker_pr_conf
.
mutable_device_set
()
->
add_device_name
(
loss_acc_task
->
device_name
());
faker_pr_conf
.
add_device_name
(
loss_acc_task
->
device_name
());
sorted_loss_acc_tasks_
.
push_back
(
loss_acc_task
);
}
// faker chain
...
...
@@ -31,7 +30,7 @@ void LossRecordTaskGraph::BuildTaskGraph(
// loss_record_pr_conf
ParallelConf
loss_record_pr_conf
;
loss_record_pr_conf
.
set_policy
(
kDataParallel
);
loss_record_pr_conf
.
mutable_device_set
()
->
add_device_name
(
loss_record_pr_conf
.
add_device_name
(
IDMgr
::
Singleton
()
->
MachineName4MachineId
(
0
)
+
":persistence"
);
// loss record op
OperatorConf
op_conf
;
...
...
oneflow/core/graph/model_save_task_graph.cpp
浏览文件 @
beffe2d5
...
...
@@ -18,8 +18,7 @@ void MdSaveTaskGraph::BuildTaskGraph() {
ChainNode
*
faker_chain
=
chain_gph
->
NewNode
();
ParallelConf
faker_pr_conf
;
faker_pr_conf
.
set_policy
(
kDataParallel
);
faker_pr_conf
.
mutable_device_set
()
->
add_device_name
(
update_task_
->
device_name
());
faker_pr_conf
.
add_device_name
(
update_task_
->
device_name
());
faker_chain
->
mut_parallel_desc
().
reset
(
new
ParallelDesc
(
faker_pr_conf
));
faker_chain
->
mut_output_lbns
()
=
{
kPackedBlobName
};
// save
...
...
@@ -28,8 +27,7 @@ void MdSaveTaskGraph::BuildTaskGraph() {
GetMachineNameFromDeviceName
(
update_task_
->
device_name
());
ParallelConf
save_pr_conf
;
save_pr_conf
.
set_policy
(
kDataParallel
);
save_pr_conf
.
mutable_device_set
()
->
add_device_name
(
machine_name
+
":persistence"
);
save_pr_conf
.
add_device_name
(
machine_name
+
":persistence"
);
save_chain
->
mut_parallel_desc
().
reset
(
new
ParallelDesc
(
save_pr_conf
));
save_chain
->
mut_input_lbns
()
=
{
kPackedBlobName
};
//
...
...
oneflow/core/graph/model_update_task_graph.cpp
浏览文件 @
beffe2d5
...
...
@@ -19,7 +19,7 @@ void MdUpdtTaskGraph::BuildTaskGraph(uint32_t random_seed) {
ChainNode
*
updt_chain
=
chain_gph
->
NewNode
();
ParallelConf
updt_pr_conf
;
updt_pr_conf
.
set_policy
(
kDataParallel
);
updt_pr_conf
.
mutable_device_set
()
->
add_device_name
(
fw_task_
->
device_name
());
updt_pr_conf
.
add_device_name
(
fw_task_
->
device_name
());
updt_chain
->
mut_parallel_desc
().
reset
(
new
ParallelDesc
(
updt_pr_conf
));
updt_chain
->
mut_input_lbns
()
=
{
kPackedBlobName
};
updt_chain
->
mut_op_vec
()
=
{
OpMgr
::
Singleton
()
->
ModelUpdateOp
()};
...
...
oneflow/core/job/dlnet_conf.proto
浏览文件 @
beffe2d5
...
...
@@ -4,6 +4,5 @@ package oneflow;
import
"oneflow/core/operator/op_conf.proto"
;
message
DLNetConf
{
string
name
=
1
;
repeated
OperatorConf
op
=
100
;
}
oneflow/core/job/job_conf.proto
浏览文件 @
beffe2d5
...
...
@@ -8,25 +8,34 @@ enum FloatingPointTypeProto {
kDouble
=
1
;
}
message
JobConf
{
string
train_dlnet_conf_filepath
=
1
;
string
resource_filepath
=
2
;
string
placement_filepath
=
3
;
string
model_load_snapshot_path
=
4
;
string
model_save_snapshots_path
=
5
;
int32
piece_size
=
7
;
int32
num_of_pieces_in_batch
=
8
;
bool
is_train
=
9
;
FloatingPointTypeProto
floating_point_type
=
10
;
int32
num_of_batches_in_snapshot
=
11
;
int32
staleness
=
12
;
// at least 0. If set as 0, then it's BSP
int64
total_batch_num
=
13
;
FillConf
default_fill_conf
=
14
;
bool
use_async_cpu_stream
=
15
;
int32
piece_num_of_record_loss
=
16
;
message
TrainConf
{
string
model_save_snapshots_path
=
1
;
int32
num_of_batches_in_snapshot
=
2
;
int32
staleness
=
3
;
// at least 0. If set as 0, then it's BSP
int64
total_batch_num
=
4
;
FillConf
default_fill_conf
=
5
;
int32
piece_num_of_record_loss
=
6
;
oneof
ModelUpdateCase
{
NormalModelUpdateOpConf
normal_mdupdt_conf
=
1000
;
MomentumModelUpdateOpConf
momentum_mdupdt_conf
=
1001
;
RMSPropModelUpdateOpConf
rmsprop_mdupdt_conf
=
1002
;
}
}
message
PredictConf
{
}
message
JobConf
{
string
dlnet_filepath
=
1
;
string
resource_filepath
=
2
;
string
placement_filepath
=
3
;
string
model_load_snapshot_path
=
4
;
int32
piece_size
=
5
;
int32
num_of_pieces_in_batch
=
6
;
FloatingPointTypeProto
floating_point_type
=
7
;
bool
use_async_cpu_stream
=
8
;
oneof
JobType
{
TrainConf
train_conf
=
1000
;
PredictConf
predict_conf
=
1001
;
}
}
oneflow/core/job/job_desc.cpp
浏览文件 @
beffe2d5
...
...
@@ -6,7 +6,7 @@ namespace oneflow {
void
JobDesc
::
InitFromJobConf
(
const
JobConf
&
conf
)
{
LOG
(
INFO
)
<<
"Read JobConf"
;
job_conf_
=
conf
;
ParseProtoFromTextFile
(
conf
.
train_dlnet_conf
_filepath
(),
&
train_dlnet_conf_
);
ParseProtoFromTextFile
(
conf
.
dlnet
_filepath
(),
&
train_dlnet_conf_
);
ParseProtoFromTextFile
(
conf
.
resource_filepath
(),
&
resource_
);
ParseProtoFromTextFile
(
conf
.
placement_filepath
(),
&
placement_
);
}
...
...
oneflow/core/job/job_desc.h
浏览文件 @
beffe2d5
...
...
@@ -27,14 +27,14 @@ class JobDesc final {
return
job_conf_
.
model_load_snapshot_path
();
}
const
std
::
string
&
md_save_snapshots_path
()
{
return
job_conf_
.
model_save_snapshots_path
();
return
job_conf_
.
train_conf
().
model_save_snapshots_path
();
}
int32_t
piece_size
()
const
{
return
job_conf_
.
piece_size
();
}
int32_t
num_of_pieces_in_batch
()
const
{
return
job_conf_
.
num_of_pieces_in_batch
();
}
int32_t
batch_size
()
const
{
return
piece_size
()
*
num_of_pieces_in_batch
();
}
bool
is_train
()
const
{
return
job_conf_
.
is_train
();
}
bool
is_train
()
const
{
return
job_conf_
.
has_train_conf
();
}
FloatingPointTypeProto
floating_point_type
()
const
{
return
job_conf_
.
floating_point_type
();
}
...
...
@@ -48,19 +48,21 @@ class JobDesc final {
}
}
int32_t
num_of_batches_in_snapshot
()
const
{
return
job_conf_
.
num_of_batches_in_snapshot
();
return
job_conf_
.
train_conf
().
num_of_batches_in_snapshot
();
}
int32_t
staleness
()
const
{
return
job_conf_
.
train_conf
().
staleness
();
}
int64_t
total_batch_num
()
const
{
return
job_conf_
.
train_conf
().
total_batch_num
();
}
int32_t
staleness
()
const
{
return
job_conf_
.
staleness
();
}
int64_t
total_batch_num
()
const
{
return
job_conf_
.
total_batch_num
();
}
int64_t
total_piece_num
()
const
{
return
total_batch_num
()
*
num_of_pieces_in_batch
();
}
const
FillConf
*
default_fill_conf
()
const
{
return
OF_PB_POINTER_GET
(
job_conf_
,
default_fill_conf
);
return
OF_PB_POINTER_GET
(
job_conf_
.
train_conf
()
,
default_fill_conf
);
}
bool
use_async_cpu_stream
()
const
{
return
job_conf_
.
use_async_cpu_stream
();
}
int32_t
piece_num_of_record_loss
()
const
{
return
job_conf_
.
piece_num_of_record_loss
();
return
job_conf_
.
train_conf
().
piece_num_of_record_loss
();
}
private:
...
...
oneflow/core/job/keyword.cpp
浏览文件 @
beffe2d5
...
...
@@ -2,6 +2,9 @@
namespace
oneflow
{
const
char
*
kPackedBlobName
=
"_oneflow_PackedBlobName"
;
#define ONEFLOW_INTERNAL_PREFIX "OneFlowInternal"
const
char
*
kPackedBlobName
=
ONEFLOW_INTERNAL_PREFIX
"PackedBlobName"
;
const
char
*
kNullDataId
=
ONEFLOW_INTERNAL_PREFIX
"NullDataId"
;
}
// namespace oneflow
oneflow/core/job/keyword.h
浏览文件 @
beffe2d5
...
...
@@ -4,6 +4,7 @@
namespace
oneflow
{
extern
const
char
*
kPackedBlobName
;
extern
const
char
*
kNullDataId
;
}
// namespace oneflow
...
...
oneflow/core/job/parallel_desc.cpp
浏览文件 @
beffe2d5
...
...
@@ -13,8 +13,8 @@ std::pair<std::string, std::string> ParseDeviceNameConf(
ParallelDesc
::
ParallelDesc
(
const
ParallelConf
&
user_conf
)
{
policy_
=
user_conf
.
policy
();
device_type_
=
JobDesc
::
Singleton
()
->
resource
().
device_type
();
for
(
int64_t
i
=
0
;
i
<
user_conf
.
device_
set
().
device_
name_size
();
++
i
)
{
const
std
::
string
&
device_name
=
user_conf
.
device_
set
().
device_
name
(
i
);
for
(
int64_t
i
=
0
;
i
<
user_conf
.
device_name_size
();
++
i
)
{
const
std
::
string
&
device_name
=
user_conf
.
device_name
(
i
);
std
::
pair
<
std
::
string
,
std
::
string
>
machine_name_device_id
=
ParseDeviceNameConf
(
device_name
);
std
::
string
machine_name
=
machine_name_device_id
.
first
;
...
...
oneflow/core/job/placement.proto
浏览文件 @
beffe2d5
...
...
@@ -8,13 +8,9 @@ enum ParallelPolicy {
kFakerLossRecord
=
3
;
}
message
DeviceSet
{
repeated
string
device_name
=
1
;
}
message
ParallelConf
{
ParallelPolicy
policy
=
1
;
DeviceSet
device_set
=
2
;
repeated
string
device_name
=
2
;
}
message
OpNameSet
{
...
...
oneflow/core/operator/operator_manager.cpp
浏览文件 @
beffe2d5
...
...
@@ -34,16 +34,16 @@ std::shared_ptr<const Operator> OpMgr::ModelUpdateOp() {
if
(
!
model_update_op_
)
{
OperatorConf
mdupdt_conf
;
mdupdt_conf
.
set_name
(
"model_update"
);
const
JobConf
&
job_conf
=
JobDesc
::
Singleton
()
->
job
_conf
();
if
(
job
_conf
.
has_normal_mdupdt_conf
())
{
const
TrainConf
&
train_conf
=
JobDesc
::
Singleton
()
->
job_conf
().
train
_conf
();
if
(
train
_conf
.
has_normal_mdupdt_conf
())
{
*
(
mdupdt_conf
.
mutable_normal_mdupdt_conf
())
=
job
_conf
.
normal_mdupdt_conf
();
}
else
if
(
job
_conf
.
has_momentum_mdupdt_conf
())
{
train
_conf
.
normal_mdupdt_conf
();
}
else
if
(
train
_conf
.
has_momentum_mdupdt_conf
())
{
*
(
mdupdt_conf
.
mutable_momentum_mdupdt_conf
())
=
job
_conf
.
momentum_mdupdt_conf
();
}
else
if
(
job
_conf
.
has_rmsprop_mdupdt_conf
())
{
train
_conf
.
momentum_mdupdt_conf
();
}
else
if
(
train
_conf
.
has_rmsprop_mdupdt_conf
())
{
*
(
mdupdt_conf
.
mutable_rmsprop_mdupdt_conf
())
=
job
_conf
.
rmsprop_mdupdt_conf
();
train
_conf
.
rmsprop_mdupdt_conf
();
}
else
{
UNEXPECTED_RUN
();
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录