Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
68d7bf3d
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
68d7bf3d
编写于
3月 20, 2019
作者:
D
dongdaxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fetch var function
test=develop
上级
767bf0c8
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
42 addition
and
27 deletion
+42
-27
paddle/fluid/framework/device_worker.h
paddle/fluid/framework/device_worker.h
+5
-6
paddle/fluid/framework/downpour_worker.cc
paddle/fluid/framework/downpour_worker.cc
+3
-7
paddle/fluid/framework/hogwild_worker.cc
paddle/fluid/framework/hogwild_worker.cc
+11
-9
paddle/fluid/framework/trainer_desc.proto
paddle/fluid/framework/trainer_desc.proto
+10
-3
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+7
-2
python/paddle/fluid/trainer_desc.py
python/paddle/fluid/trainer_desc.py
+6
-0
未找到文件。
paddle/fluid/framework/device_worker.h
浏览文件 @
68d7bf3d
...
...
@@ -96,7 +96,7 @@ class DeviceWorker {
virtual
void
Initialize
(
const
TrainerDesc
&
desc
)
=
0
;
virtual
void
SetDeviceIndex
(
int
tid
)
=
0
;
virtual
void
TrainFiles
()
=
0
;
virtual
void
PrintFetchVars
(
int
batch_cnt
)
=
0
;
virtual
void
PrintFetchVars
()
=
0
;
virtual
void
TrainFilesWithProfiler
()
=
0
;
virtual
void
CreateDeviceResource
(
const
ProgramDesc
&
main_prog
)
=
0
;
// will make this zero copy in the future
...
...
@@ -111,6 +111,8 @@ class DeviceWorker {
Scope
*
root_scope_
;
paddle
::
platform
::
Place
place_
;
std
::
shared_ptr
<
DataFeed
>
device_reader_
;
int64_t
batch_num_
;
FetchConfig
fetch_config_
;
};
class
CPUWorkerBase
:
public
DeviceWorker
{
...
...
@@ -120,7 +122,7 @@ class CPUWorkerBase : public DeviceWorker {
virtual
void
SetDeviceIndex
(
int
tid
)
{
thread_id_
=
tid
;
}
virtual
void
TrainFiles
()
=
0
;
virtual
void
TrainFilesWithProfiler
()
{}
virtual
void
PrintFetchVars
(
int
batch_cnt
)
{}
virtual
void
PrintFetchVars
()
{}
virtual
void
CreateDeviceResource
(
const
ProgramDesc
&
main_prog
)
{}
protected:
...
...
@@ -134,7 +136,7 @@ class HogwildWorker : public CPUWorkerBase {
virtual
void
Initialize
(
const
TrainerDesc
&
desc
);
virtual
void
TrainFiles
();
virtual
void
TrainFilesWithProfiler
();
virtual
void
PrintFetchVars
(
int
batch_cnt
);
virtual
void
PrintFetchVars
();
virtual
void
CreateDeviceResource
(
const
ProgramDesc
&
main_prog
);
virtual
void
BindingDataFeedMemory
();
...
...
@@ -144,9 +146,6 @@ class HogwildWorker : public CPUWorkerBase {
std
::
vector
<
std
::
string
>
op_names_
;
std
::
vector
<
OperatorBase
*>
ops_
;
Scope
*
thread_scope_
;
std
::
vector
<
std
::
string
>
fetch_var_names_
;
std
::
vector
<
std
::
vector
<
float
>>
fetch_values_
;
int
batch_cnt_per_print_
;
};
class
DownpourWorker
:
public
HogwildWorker
{
...
...
paddle/fluid/framework/downpour_worker.cc
浏览文件 @
68d7bf3d
...
...
@@ -58,14 +58,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
skip_ops_
[
i
]
=
param_
.
skip_ops
(
i
);
}
fetch_var_names_
.
resize
(
desc
.
fetch_var_names_size
());
for
(
size_t
i
=
0
;
i
<
desc
.
fetch_var_names_size
();
++
i
)
{
fetch_var_names_
[
i
]
=
desc
.
fetch_var_names
(
i
);
}
batch_cnt_per_print_
=
static_cast
<
int
>
(
desc
.
batch_per_print
());
skip_ops_
.
resize
(
param_
.
skip_ops_size
());
fleet_ptr_
=
FleetWrapper
::
GetInstance
();
fetch_config_
=
desc
.
fetch_config
();
}
void
DownpourWorker
::
CollectLabelInfo
(
size_t
table_idx
)
{
...
...
@@ -334,6 +328,7 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
}
timeline
.
Start
();
PrintFetchVars
();
}
}
...
...
@@ -445,6 +440,7 @@ void DownpourWorker::TrainFiles() {
thread_scope_
->
DropKids
();
++
batch_cnt
;
PrintFetchVars
();
}
}
...
...
paddle/fluid/framework/hogwild_worker.cc
浏览文件 @
68d7bf3d
...
...
@@ -21,11 +21,7 @@ namespace paddle {
namespace
framework
{
void
HogwildWorker
::
Initialize
(
const
TrainerDesc
&
desc
)
{
fetch_var_names_
.
resize
(
desc
.
fetch_var_names_size
());
for
(
size_t
i
=
0
;
i
<
desc
.
fetch_var_names_size
();
++
i
)
{
fetch_var_names_
[
i
]
=
desc
.
fetch_var_names
(
i
);
}
batch_cnt_per_print_
=
static_cast
<
int
>
(
desc
.
batch_per_print
());
fetch_config_
=
desc
.
fetch_config
();
}
void
HogwildWorker
::
CreateThreadOperators
(
const
ProgramDesc
&
program
)
{
...
...
@@ -119,6 +115,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
}
}
timeline
.
Start
();
PrintFetchVars
();
}
}
...
...
@@ -136,15 +133,20 @@ void HogwildWorker::TrainFiles() {
++
batch_cnt
;
thread_scope_
->
DropKids
();
PrintFetchVars
();
}
}
void
HogwildWorker
::
PrintFetchVars
(
int
batch_cnt
)
{
void
HogwildWorker
::
PrintFetchVars
()
{
// call count
batch_num_
++
;
int
batch_per_print
=
fetch_config_
.
print_period
();
if
(
thread_id_
==
0
)
{
if
(
batch_
cnt
>
0
&&
batch_cnt
%
batch_cnt_per_print_
==
0
)
{
int
fetch_var_num
=
fetch_
var_names_
.
size
();
if
(
batch_
num_
%
batch_per_print
==
0
)
{
int
fetch_var_num
=
fetch_
config_
.
fetch_var_names_
size
();
for
(
int
i
=
0
;
i
<
fetch_var_num
;
++
i
)
{
platform
::
PrintVar
(
thread_scope_
,
fetch_var_names_
[
i
],
"None"
);
platform
::
PrintVar
(
thread_scope_
,
fetch_config_
.
fetch_var_names
(
i
),
"None"
);
}
}
}
...
...
paddle/fluid/framework/trainer_desc.proto
浏览文件 @
68d7bf3d
...
...
@@ -28,9 +28,8 @@ message TrainerDesc {
// if we need to binding cpu
optional
bool
binding_cpu
=
4
[
default
=
false
];
repeated
string
filelist
=
5
;
repeated
string
fetch_var_names
=
6
;
optional
int32
batch_per_print
=
7
[
default
=
100
];
optional
bool
debug
=
8
[
default
=
false
];
optional
bool
debug
=
6
[
default
=
false
];
optional
FetchConfig
fetch_config
=
7
;
// device worker parameters
optional
HogwildWorkerParameter
hogwild_param
=
101
;
...
...
@@ -49,6 +48,14 @@ message DownpourWorkerParameter {
repeated
ProgramConfig
program_config
=
4
;
}
message
FetchConfig
{
enum
Method
{
PRINT
=
0
;
}
repeated
string
fetch_var_names
=
1
;
optional
string
fetch_var_str_format
=
2
;
optional
int32
print_period
=
3
[
default
=
100
];
optional
Method
method
=
4
[
default
=
PRINT
];
}
message
ProgramConfig
{
required
string
program_id
=
1
;
repeated
int32
push_sparse_table_id
=
2
;
...
...
python/paddle/fluid/executor.py
浏览文件 @
68d7bf3d
...
...
@@ -621,13 +621,17 @@ class Executor(object):
opt_info
=
None
):
pass
fluid
.
Logger
(
"Loss: {0}"
,
loss
)
def
train_from_dataset
(
self
,
program
=
None
,
dataset
=
None
,
fetch_list
=
None
,
scope
=
None
,
thread
=
0
,
debug
=
False
):
debug
=
False
,
fetch_list
=
None
,
fetch_info
=
None
,
print_period
=
100
):
if
scope
is
None
:
scope
=
global_scope
()
if
fetch_list
is
None
:
...
...
@@ -650,6 +654,7 @@ class Executor(object):
else
:
trainer
.
set_thread
(
thread
)
trainer
.
set_debug
(
debug
)
trainer
.
set_fetch_var_and_info
(
fetch_list
,
fetch_info
,
print_period
)
trainer
.
gen_trainer_desc
()
dataset
.
_prepare_to_run
()
if
debug
:
...
...
python/paddle/fluid/trainer_desc.py
浏览文件 @
68d7bf3d
...
...
@@ -36,6 +36,12 @@ class TrainerDesc(object):
self
.
device_worker_
=
None
self
.
program_
=
None
def
set_fetch_var_and_info
(
self
,
fetch_vars
,
fetch_info
,
print_period
):
for
v
in
fetch_vars
:
self
.
proto_desc
.
fetch_config
.
fetch_var_names
.
extend
(
v
.
name
)
self
.
proto_desc
.
fetch_config
.
fetch_var_str_format
=
fetch_info
self
.
proto_desc
.
print_period
=
print_period
def
set_debug
(
self
,
debug
):
self
.
proto_desc
.
debug
=
debug
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录