Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
aaea8a39
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
aaea8a39
编写于
9月 09, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
差异文件
merge master
上级
6efc30f5
5beaf463
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
84 addition
and
11 deletion
+84
-11
paddle/fluid/train/custom_trainer/feed/accessor/dense_input_accessor.cc
...rain/custom_trainer/feed/accessor/dense_input_accessor.cc
+20
-5
paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h
.../train/custom_trainer/feed/accessor/input_data_accessor.h
+11
-0
paddle/fluid/train/custom_trainer/feed/conf/trainer.yaml
paddle/fluid/train/custom_trainer/feed/conf/trainer.yaml
+2
-0
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc
...id/train/custom_trainer/feed/dataset/dataset_container.cc
+1
-1
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc
...ain/custom_trainer/feed/executor/multi_thread_executor.cc
+28
-1
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h
...rain/custom_trainer/feed/executor/multi_thread_executor.h
+3
-0
paddle/fluid/train/custom_trainer/feed/io/file_system.h
paddle/fluid/train/custom_trainer/feed/io/file_system.h
+4
-0
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+5
-1
paddle/fluid/train/custom_trainer/feed/scripts/create_programs.py
...luid/train/custom_trainer/feed/scripts/create_programs.py
+8
-0
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
+2
-3
未找到文件。
paddle/fluid/train/custom_trainer/feed/accessor/dense_input_accessor.cc
浏览文件 @
aaea8a39
...
...
@@ -69,6 +69,15 @@ int32_t DenseInputAccessor::pull_dense(size_t table_id) {
int32_t
DenseInputAccessor
::
forward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
)
{
collect_persistables
(
scope
);
if
(
_need_async_pull
)
{
++
_pull_request_num
;
}
return
0
;
}
int32_t
DenseInputAccessor
::
collect_persistables
(
paddle
::
framework
::
Scope
*
scope
)
{
// 首次同步pull,之后异步pull
if
(
_data_buffer
==
nullptr
)
{
_pull_mutex
.
lock
();
...
...
@@ -94,7 +103,9 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num,
paddle
::
framework
::
DDim
ddim
(
shape_ptr
,
variable
.
shape
.
size
());
auto
*
tensor
=
ScopeHelper
::
resize_lod_tensor
(
scope
,
variable
.
name
,
ddim
);
auto
*
grad_tensor
=
ScopeHelper
::
resize_lod_tensor
(
scope
,
variable
.
gradient_name
,
ddim
);
VLOG
(
5
)
<<
"fill scope variable:"
<<
variable
.
name
<<
", "
<<
variable
.
gradient_name
;
VLOG
(
5
)
<<
"fill scope variable:"
<<
variable
.
name
<<
", "
<<
variable
.
gradient_name
<<
", data_buffer: "
<<
_data_buffer
+
data_buffer_idx
<<
", dim: "
<<
variable
.
dim
*
sizeof
(
float
);
auto
*
var_data
=
tensor
->
mutable_data
<
float
>
(
_trainer_context
->
cpu_place
);
memcpy
(
var_data
,
_data_buffer
+
data_buffer_idx
,
variable
.
dim
*
sizeof
(
float
));
data_buffer_idx
+=
variable
.
dim
;
...
...
@@ -107,8 +118,12 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num,
VLOG
(
2
)
<<
"[Debug][PullDense]"
<<
ScopeHelper
::
to_string
(
scope
,
variable
.
name
);
}
}
if
(
_need_async_pull
)
{
++
_pull_request_num
;
return
0
;
}
int32_t
DenseInputAccessor
::
collect_persistables_name
(
std
::
vector
<
std
::
string
>&
persistables
)
{
for
(
auto
&
variable
:
_x_variables
)
{
persistables
.
push_back
(
variable
.
name
);
}
return
0
;
}
...
...
paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h
浏览文件 @
aaea8a39
...
...
@@ -38,6 +38,12 @@ public:
// 后向,一般用于更新梯度,在训练网络执行后调用
virtual
int32_t
backward
(
SampleInstance
*
samples
,
size_t
num
,
::
paddle
::
framework
::
Scope
*
scope
)
=
0
;
// 收集持久化变量的名称, 并将值拷贝到Scope
virtual
int32_t
collect_persistables_name
(
std
::
vector
<
std
::
string
>&
persistables
)
{
return
0
;}
// 填充持久化变量的值,用于保存
virtual
int32_t
collect_persistables
(
paddle
::
framework
::
Scope
*
scope
)
{
return
0
;}
protected:
size_t
_table_id
=
0
;
bool
_need_gradient
=
false
;
...
...
@@ -144,6 +150,11 @@ public:
virtual
int32_t
backward
(
SampleInstance
*
samples
,
size_t
num
,
paddle
::
framework
::
Scope
*
scope
);
virtual
int32_t
collect_persistables_name
(
std
::
vector
<
std
::
string
>&
persistables
);
virtual
int32_t
collect_persistables
(
paddle
::
framework
::
Scope
*
scope
);
protected:
virtual
int32_t
pull_dense
(
size_t
table_id
);
...
...
paddle/fluid/train/custom_trainer/feed/conf/trainer.yaml
浏览文件 @
aaea8a39
...
...
@@ -30,6 +30,8 @@ dataset:
pipeline_cmd
:
'
./tool/ins_weight.py
|
awk
-f
./tool/format_newcate_hotnews.awk'
parser
:
class
:
AbacusTextDataParser
shuffler
:
class
:
LocalShuffler
epoch
:
epoch_class
:
TimelyEpochAccessor
...
...
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc
浏览文件 @
aaea8a39
...
...
@@ -31,7 +31,7 @@ int DatasetContainer::initialize(
_data_root_paths
=
config
[
"root_path"
].
as
<
std
::
vector
<
std
::
string
>>
();
_data_split_interval
=
config
[
"data_spit_interval"
].
as
<
int
>
();
_data_path_formater
=
config
[
"data_path_formater"
].
as
<
std
::
string
>
();
std
::
string
shuffler
=
config
[
"shuffler"
][
"
name
"
].
as
<
std
::
string
>
();
std
::
string
shuffler
=
config
[
"shuffler"
][
"
class
"
].
as
<
std
::
string
>
();
_shuffler
.
reset
(
CREATE_INSTANCE
(
Shuffler
,
shuffler
));
_shuffler
->
initialize
(
config
,
context
);
std
::
string
data_reader_class
=
config
[
"data_reader"
].
as
<
std
::
string
>
();
...
...
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc
浏览文件 @
aaea8a39
...
...
@@ -2,6 +2,8 @@
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/monitor/monitor.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
custom_trainer
{
...
...
@@ -55,6 +57,7 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
CHECK
(
_trainer_context
->
file_system
->
exists
(
model_config_path
))
<<
"miss model config file:"
<<
model_config_path
;
_model_config
=
YAML
::
LoadFile
(
model_config_path
);
_persistables
.
clear
();
for
(
const
auto
&
accessor_config
:
_model_config
[
"input_accessor"
])
{
auto
accessor_class
=
accessor_config
[
"class"
].
as
<
std
::
string
>
();
auto
*
accessor_ptr
=
CREATE_INSTANCE
(
DataInputAccessor
,
accessor_class
);
...
...
@@ -69,7 +72,10 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
_table_to_accessors
[
table_id
]
=
{
accessor_ptr
};
}
}
CHECK
(
accessor_ptr
->
collect_persistables_name
(
_persistables
)
==
0
)
<<
"collect_persistables Failed, class:"
<<
accessor_class
;
}
// std::sort(_persistables.begin(), _persistables.end()); // 持久化变量名一定要排序
// Monitor组件
for
(
const
auto
&
monitor_config
:
_model_config
[
"monitor"
])
{
...
...
@@ -82,6 +88,27 @@ int MultiThreadExecutor::initialize(YAML::Node exe_config,
return
ret
;
}
int32_t
MultiThreadExecutor
::
save_persistables
(
const
std
::
string
&
filename
)
{
// auto fs = _trainer_context->file_system;
// fs->mkdir(fs->path_split(filename).first);
auto
scope_obj
=
_scope_obj_pool
->
get
();
for
(
size_t
i
=
0
;
i
<
_input_accessors
.
size
();
++
i
)
{
_input_accessors
[
i
]
->
collect_persistables
(
scope_obj
.
get
());
}
framework
::
ProgramDesc
prog
;
auto
*
block
=
prog
.
MutableBlock
(
0
);
auto
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"save_combine"
);
op
->
SetInput
(
"X"
,
_persistables
);
op
->
SetAttr
(
"file_path"
,
filename
);
op
->
CheckAttrs
();
platform
::
CPUPlace
place
;
framework
::
Executor
exe
(
place
);
exe
.
Run
(
prog
,
scope_obj
.
get
(),
0
,
true
,
true
);
return
0
;
}
paddle
::
framework
::
Channel
<
DataItem
>
MultiThreadExecutor
::
run
(
paddle
::
framework
::
Channel
<
DataItem
>
input
,
const
DataParser
*
parser
)
{
...
...
paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h
浏览文件 @
aaea8a39
...
...
@@ -47,6 +47,8 @@ public:
virtual
paddle
::
framework
::
Channel
<
DataItem
>
run
(
paddle
::
framework
::
Channel
<
DataItem
>
input
,
const
DataParser
*
parser
);
virtual
int32_t
save_persistables
(
const
std
::
string
&
filename
);
virtual
bool
is_dump_all_model
()
{
return
_need_dump_all_model
;
}
...
...
@@ -80,6 +82,7 @@ protected:
std
::
vector
<
std
::
shared_ptr
<
DataInputAccessor
>>
_input_accessors
;
std
::
map
<
uint32_t
,
std
::
vector
<
DataInputAccessor
*>>
_table_to_accessors
;
std
::
shared_ptr
<
paddle
::
ps
::
ObjectPool
<::
paddle
::
framework
::
Scope
>>
_scope_obj_pool
;
std
::
vector
<
std
::
string
>
_persistables
;
};
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/io/file_system.h
浏览文件 @
aaea8a39
...
...
@@ -25,6 +25,10 @@ public:
virtual
bool
exists
(
const
std
::
string
&
path
)
=
0
;
virtual
void
mkdir
(
const
std
::
string
&
path
)
=
0
;
virtual
std
::
string
path_join
(
const
std
::
string
&
dir
,
const
std
::
string
&
path
);
template
<
class
...
STRS
>
std
::
string
path_join
(
const
std
::
string
&
dir
,
const
std
::
string
&
path
,
const
STRS
&
...
paths
)
{
return
path_join
(
path_join
(
dir
,
path
),
paths
...);
}
virtual
std
::
pair
<
std
::
string
,
std
::
string
>
path_split
(
const
std
::
string
&
path
);
protected:
};
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
浏览文件 @
aaea8a39
...
...
@@ -27,6 +27,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
}
int
LearnerProcess
::
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
)
{
auto
fs
=
_context_ptr
->
file_system
;
auto
*
ps_client
=
_context_ptr
->
pslib
->
ps_client
();
auto
*
environment
=
_context_ptr
->
environment
.
get
();
auto
*
epoch_accessor
=
_context_ptr
->
epoch_accessor
.
get
();
...
...
@@ -39,18 +40,21 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
paddle
::
platform
::
Timer
timer
;
timer
.
Start
();
std
::
set
<
uint32_t
>
table_set
;
auto
model_dir
=
epoch_accessor
->
model_save_path
(
epoch_id
,
way
);
for
(
auto
&
executor
:
_executors
)
{
const
auto
&
table_accessors
=
executor
->
table_accessors
();
for
(
auto
&
itr
:
table_accessors
)
{
table_set
.
insert
(
itr
.
first
);
}
auto
save_path
=
fs
->
path_join
(
model_dir
,
executor
->
train_exe_name
()
+
"_param"
);
VLOG
(
2
)
<<
"Start save model, save_path:"
<<
save_path
;
executor
->
save_persistables
(
save_path
);
}
int
ret_size
=
0
;
auto
table_num
=
table_set
.
size
();
std
::
future
<
int
>
rets
[
table_num
];
for
(
auto
table_id
:
table_set
)
{
VLOG
(
2
)
<<
"Start save model, table_id:"
<<
table_id
;
auto
model_dir
=
epoch_accessor
->
model_save_path
(
epoch_id
,
way
);
rets
[
ret_size
++
]
=
ps_client
->
save
(
table_id
,
model_dir
,
std
::
to_string
((
int
)
way
));
}
int
all_ret
=
0
;
...
...
paddle/fluid/train/custom_trainer/feed/scripts/create_programs.py
浏览文件 @
aaea8a39
...
...
@@ -124,6 +124,14 @@ class ModelBuilder:
with
open
(
os
.
path
.
join
(
self
.
_save_path
,
name
+
'.pbtxt'
),
'w'
)
as
fout
:
fout
.
write
(
str
(
program
))
fluid
.
io
.
save_inference_model
(
self
.
_save_path
,
[
var
.
name
for
var
in
inputs
],
outputs
,
executor
=
None
,
main_program
=
test_program
,
model_filename
=
'inference_program'
,
program_only
=
True
)
params
=
filter
(
fluid
.
io
.
is_parameter
,
main_program
.
list_vars
())
vars
=
[]
sums
=
[]
...
...
paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.cc
浏览文件 @
aaea8a39
#pragma once
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.h"
...
...
@@ -30,7 +29,7 @@ public:
return
0
;
}
};
REGIST_CLASS
(
DataPars
er
,
LocalShuffler
);
REGIST_CLASS
(
Shuffl
er
,
LocalShuffler
);
class
GlobalShuffler
:
public
Shuffler
{
public:
...
...
@@ -109,7 +108,7 @@ private:
uint32_t
_max_concurrent_num
=
0
;
};
REGIST_CLASS
(
DataPars
er
,
GlobalShuffler
);
REGIST_CLASS
(
Shuffl
er
,
GlobalShuffler
);
}
// namespace feed
}
// namespace custom_trainer
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录