Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
68d7bf3d
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
68d7bf3d
编写于
3月 20, 2019
作者:
D
dongdaxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fetch var function
test=develop
上级
767bf0c8
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
42 addition
and
27 deletion
+42
-27
paddle/fluid/framework/device_worker.h
paddle/fluid/framework/device_worker.h
+5
-6
paddle/fluid/framework/downpour_worker.cc
paddle/fluid/framework/downpour_worker.cc
+3
-7
paddle/fluid/framework/hogwild_worker.cc
paddle/fluid/framework/hogwild_worker.cc
+11
-9
paddle/fluid/framework/trainer_desc.proto
paddle/fluid/framework/trainer_desc.proto
+10
-3
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+7
-2
python/paddle/fluid/trainer_desc.py
python/paddle/fluid/trainer_desc.py
+6
-0
未找到文件。
paddle/fluid/framework/device_worker.h
浏览文件 @
68d7bf3d
...
@@ -96,7 +96,7 @@ class DeviceWorker {
...
@@ -96,7 +96,7 @@ class DeviceWorker {
virtual
void
Initialize
(
const
TrainerDesc
&
desc
)
=
0
;
virtual
void
Initialize
(
const
TrainerDesc
&
desc
)
=
0
;
virtual
void
SetDeviceIndex
(
int
tid
)
=
0
;
virtual
void
SetDeviceIndex
(
int
tid
)
=
0
;
virtual
void
TrainFiles
()
=
0
;
virtual
void
TrainFiles
()
=
0
;
virtual
void
PrintFetchVars
(
int
batch_cnt
)
=
0
;
virtual
void
PrintFetchVars
()
=
0
;
virtual
void
TrainFilesWithProfiler
()
=
0
;
virtual
void
TrainFilesWithProfiler
()
=
0
;
virtual
void
CreateDeviceResource
(
const
ProgramDesc
&
main_prog
)
=
0
;
virtual
void
CreateDeviceResource
(
const
ProgramDesc
&
main_prog
)
=
0
;
// will make this zero copy in the future
// will make this zero copy in the future
...
@@ -111,6 +111,8 @@ class DeviceWorker {
...
@@ -111,6 +111,8 @@ class DeviceWorker {
Scope
*
root_scope_
;
Scope
*
root_scope_
;
paddle
::
platform
::
Place
place_
;
paddle
::
platform
::
Place
place_
;
std
::
shared_ptr
<
DataFeed
>
device_reader_
;
std
::
shared_ptr
<
DataFeed
>
device_reader_
;
int64_t
batch_num_
;
FetchConfig
fetch_config_
;
};
};
class
CPUWorkerBase
:
public
DeviceWorker
{
class
CPUWorkerBase
:
public
DeviceWorker
{
...
@@ -120,7 +122,7 @@ class CPUWorkerBase : public DeviceWorker {
...
@@ -120,7 +122,7 @@ class CPUWorkerBase : public DeviceWorker {
virtual
void
SetDeviceIndex
(
int
tid
)
{
thread_id_
=
tid
;
}
virtual
void
SetDeviceIndex
(
int
tid
)
{
thread_id_
=
tid
;
}
virtual
void
TrainFiles
()
=
0
;
virtual
void
TrainFiles
()
=
0
;
virtual
void
TrainFilesWithProfiler
()
{}
virtual
void
TrainFilesWithProfiler
()
{}
virtual
void
PrintFetchVars
(
int
batch_cnt
)
{}
virtual
void
PrintFetchVars
()
{}
virtual
void
CreateDeviceResource
(
const
ProgramDesc
&
main_prog
)
{}
virtual
void
CreateDeviceResource
(
const
ProgramDesc
&
main_prog
)
{}
protected:
protected:
...
@@ -134,7 +136,7 @@ class HogwildWorker : public CPUWorkerBase {
...
@@ -134,7 +136,7 @@ class HogwildWorker : public CPUWorkerBase {
virtual
void
Initialize
(
const
TrainerDesc
&
desc
);
virtual
void
Initialize
(
const
TrainerDesc
&
desc
);
virtual
void
TrainFiles
();
virtual
void
TrainFiles
();
virtual
void
TrainFilesWithProfiler
();
virtual
void
TrainFilesWithProfiler
();
virtual
void
PrintFetchVars
(
int
batch_cnt
);
virtual
void
PrintFetchVars
();
virtual
void
CreateDeviceResource
(
const
ProgramDesc
&
main_prog
);
virtual
void
CreateDeviceResource
(
const
ProgramDesc
&
main_prog
);
virtual
void
BindingDataFeedMemory
();
virtual
void
BindingDataFeedMemory
();
...
@@ -144,9 +146,6 @@ class HogwildWorker : public CPUWorkerBase {
...
@@ -144,9 +146,6 @@ class HogwildWorker : public CPUWorkerBase {
std
::
vector
<
std
::
string
>
op_names_
;
std
::
vector
<
std
::
string
>
op_names_
;
std
::
vector
<
OperatorBase
*>
ops_
;
std
::
vector
<
OperatorBase
*>
ops_
;
Scope
*
thread_scope_
;
Scope
*
thread_scope_
;
std
::
vector
<
std
::
string
>
fetch_var_names_
;
std
::
vector
<
std
::
vector
<
float
>>
fetch_values_
;
int
batch_cnt_per_print_
;
};
};
class
DownpourWorker
:
public
HogwildWorker
{
class
DownpourWorker
:
public
HogwildWorker
{
...
...
paddle/fluid/framework/downpour_worker.cc
浏览文件 @
68d7bf3d
...
@@ -58,14 +58,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
...
@@ -58,14 +58,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
skip_ops_
[
i
]
=
param_
.
skip_ops
(
i
);
skip_ops_
[
i
]
=
param_
.
skip_ops
(
i
);
}
}
fetch_var_names_
.
resize
(
desc
.
fetch_var_names_size
());
for
(
size_t
i
=
0
;
i
<
desc
.
fetch_var_names_size
();
++
i
)
{
fetch_var_names_
[
i
]
=
desc
.
fetch_var_names
(
i
);
}
batch_cnt_per_print_
=
static_cast
<
int
>
(
desc
.
batch_per_print
());
skip_ops_
.
resize
(
param_
.
skip_ops_size
());
fleet_ptr_
=
FleetWrapper
::
GetInstance
();
fleet_ptr_
=
FleetWrapper
::
GetInstance
();
fetch_config_
=
desc
.
fetch_config
();
}
}
void
DownpourWorker
::
CollectLabelInfo
(
size_t
table_idx
)
{
void
DownpourWorker
::
CollectLabelInfo
(
size_t
table_idx
)
{
...
@@ -334,6 +328,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
...
@@ -334,6 +328,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
}
}
}
timeline
.
Start
();
timeline
.
Start
();
PrintFetchVars
();
}
}
}
}
...
@@ -445,6 +440,7 @@ void DownpourWorker::TrainFiles() {
...
@@ -445,6 +440,7 @@ void DownpourWorker::TrainFiles() {
thread_scope_
->
DropKids
();
thread_scope_
->
DropKids
();
++
batch_cnt
;
++
batch_cnt
;
PrintFetchVars
();
}
}
}
}
...
...
paddle/fluid/framework/hogwild_worker.cc
浏览文件 @
68d7bf3d
...
@@ -21,11 +21,7 @@ namespace paddle {
...
@@ -21,11 +21,7 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
void
HogwildWorker
::
Initialize
(
const
TrainerDesc
&
desc
)
{
void
HogwildWorker
::
Initialize
(
const
TrainerDesc
&
desc
)
{
fetch_var_names_
.
resize
(
desc
.
fetch_var_names_size
());
fetch_config_
=
desc
.
fetch_config
();
for
(
size_t
i
=
0
;
i
<
desc
.
fetch_var_names_size
();
++
i
)
{
fetch_var_names_
[
i
]
=
desc
.
fetch_var_names
(
i
);
}
batch_cnt_per_print_
=
static_cast
<
int
>
(
desc
.
batch_per_print
());
}
}
void
HogwildWorker
::
CreateThreadOperators
(
const
ProgramDesc
&
program
)
{
void
HogwildWorker
::
CreateThreadOperators
(
const
ProgramDesc
&
program
)
{
...
@@ -119,6 +115,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
...
@@ -119,6 +115,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
}
}
}
}
timeline
.
Start
();
timeline
.
Start
();
PrintFetchVars
();
}
}
}
}
...
@@ -136,15 +133,20 @@ void HogwildWorker::TrainFiles() {
...
@@ -136,15 +133,20 @@ void HogwildWorker::TrainFiles() {
++
batch_cnt
;
++
batch_cnt
;
thread_scope_
->
DropKids
();
thread_scope_
->
DropKids
();
PrintFetchVars
();
}
}
}
}
void
HogwildWorker
::
PrintFetchVars
(
int
batch_cnt
)
{
void
HogwildWorker
::
PrintFetchVars
()
{
// call count
batch_num_
++
;
int
batch_per_print
=
fetch_config_
.
print_period
();
if
(
thread_id_
==
0
)
{
if
(
thread_id_
==
0
)
{
if
(
batch_
cnt
>
0
&&
batch_cnt
%
batch_cnt_per_print_
==
0
)
{
if
(
batch_
num_
%
batch_per_print
==
0
)
{
int
fetch_var_num
=
fetch_
var_names_
.
size
();
int
fetch_var_num
=
fetch_
config_
.
fetch_var_names_
size
();
for
(
int
i
=
0
;
i
<
fetch_var_num
;
++
i
)
{
for
(
int
i
=
0
;
i
<
fetch_var_num
;
++
i
)
{
platform
::
PrintVar
(
thread_scope_
,
fetch_var_names_
[
i
],
"None"
);
platform
::
PrintVar
(
thread_scope_
,
fetch_config_
.
fetch_var_names
(
i
),
"None"
);
}
}
}
}
}
}
...
...
paddle/fluid/framework/trainer_desc.proto
浏览文件 @
68d7bf3d
...
@@ -28,9 +28,8 @@ message TrainerDesc {
...
@@ -28,9 +28,8 @@ message TrainerDesc {
// if we need to binding cpu
// if we need to binding cpu
optional
bool
binding_cpu
=
4
[
default
=
false
];
optional
bool
binding_cpu
=
4
[
default
=
false
];
repeated
string
filelist
=
5
;
repeated
string
filelist
=
5
;
repeated
string
fetch_var_names
=
6
;
optional
bool
debug
=
6
[
default
=
false
];
optional
int32
batch_per_print
=
7
[
default
=
100
];
optional
FetchConfig
fetch_config
=
7
;
optional
bool
debug
=
8
[
default
=
false
];
// device worker parameters
// device worker parameters
optional
HogwildWorkerParameter
hogwild_param
=
101
;
optional
HogwildWorkerParameter
hogwild_param
=
101
;
...
@@ -49,6 +48,14 @@ message DownpourWorkerParameter {
...
@@ -49,6 +48,14 @@ message DownpourWorkerParameter {
repeated
ProgramConfig
program_config
=
4
;
repeated
ProgramConfig
program_config
=
4
;
}
}
message
FetchConfig
{
enum
Method
{
PRINT
=
0
;
}
repeated
string
fetch_var_names
=
1
;
optional
string
fetch_var_str_format
=
2
;
optional
int32
print_period
=
3
[
default
=
100
];
optional
Method
method
=
4
[
default
=
PRINT
];
}
message
ProgramConfig
{
message
ProgramConfig
{
required
string
program_id
=
1
;
required
string
program_id
=
1
;
repeated
int32
push_sparse_table_id
=
2
;
repeated
int32
push_sparse_table_id
=
2
;
...
...
python/paddle/fluid/executor.py
浏览文件 @
68d7bf3d
...
@@ -621,13 +621,17 @@ class Executor(object):
...
@@ -621,13 +621,17 @@ class Executor(object):
opt_info
=
None
):
opt_info
=
None
):
pass
pass
fluid
.
Logger
(
"Loss: {0}"
,
loss
)
def
train_from_dataset
(
self
,
def
train_from_dataset
(
self
,
program
=
None
,
program
=
None
,
dataset
=
None
,
dataset
=
None
,
fetch_list
=
None
,
scope
=
None
,
scope
=
None
,
thread
=
0
,
thread
=
0
,
debug
=
False
):
debug
=
False
,
fetch_list
=
None
,
fetch_info
=
None
,
print_period
=
100
):
if
scope
is
None
:
if
scope
is
None
:
scope
=
global_scope
()
scope
=
global_scope
()
if
fetch_list
is
None
:
if
fetch_list
is
None
:
...
@@ -650,6 +654,7 @@ class Executor(object):
...
@@ -650,6 +654,7 @@ class Executor(object):
else
:
else
:
trainer
.
set_thread
(
thread
)
trainer
.
set_thread
(
thread
)
trainer
.
set_debug
(
debug
)
trainer
.
set_debug
(
debug
)
trainer
.
set_fetch_var_and_info
(
fetch_list
,
fetch_info
,
print_period
)
trainer
.
gen_trainer_desc
()
trainer
.
gen_trainer_desc
()
dataset
.
_prepare_to_run
()
dataset
.
_prepare_to_run
()
if
debug
:
if
debug
:
...
...
python/paddle/fluid/trainer_desc.py
浏览文件 @
68d7bf3d
...
@@ -36,6 +36,12 @@ class TrainerDesc(object):
...
@@ -36,6 +36,12 @@ class TrainerDesc(object):
self
.
device_worker_
=
None
self
.
device_worker_
=
None
self
.
program_
=
None
self
.
program_
=
None
def
set_fetch_var_and_info
(
self
,
fetch_vars
,
fetch_info
,
print_period
):
for
v
in
fetch_vars
:
self
.
proto_desc
.
fetch_config
.
fetch_var_names
.
extend
(
v
.
name
)
self
.
proto_desc
.
fetch_config
.
fetch_var_str_format
=
fetch_info
self
.
proto_desc
.
print_period
=
print_period
def
set_debug
(
self
,
debug
):
def
set_debug
(
self
,
debug
):
self
.
proto_desc
.
debug
=
debug
self
.
proto_desc
.
debug
=
debug
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录