Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
79133eae
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
79133eae
编写于
8月 19, 2019
作者:
R
rensilin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean code
Change-Id: I402094ec45a96482f0d57d1dbe37ac909e01246c
上级
d21f279a
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
51 addition
and
42 deletion
+51
-42
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
...luid/train/custom_trainer/feed/accessor/epoch_accessor.cc
+25
-18
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
...fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
+1
-0
paddle/fluid/train/custom_trainer/feed/conf/gflags.conf
paddle/fluid/train/custom_trainer/feed/conf/gflags.conf
+1
-0
paddle/fluid/train/custom_trainer/feed/conf/trainer.yaml
paddle/fluid/train/custom_trainer/feed/conf/trainer.yaml
+2
-2
paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc
paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc
+2
-2
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc
...id/train/custom_trainer/feed/dataset/dataset_container.cc
+3
-4
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
+2
-2
paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc
...le/fluid/train/custom_trainer/feed/io/auto_file_system.cc
+4
-4
paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
.../fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
+2
-2
paddle/fluid/train/custom_trainer/feed/main.cc
paddle/fluid/train/custom_trainer/feed/main.cc
+1
-0
paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc
...uid/train/custom_trainer/feed/process/init_env_process.cc
+1
-1
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+1
-1
paddle/fluid/train/custom_trainer/feed/scripts/start_feed_trainer.sh
...d/train/custom_trainer/feed/scripts/start_feed_trainer.sh
+2
-2
paddle/fluid/train/custom_trainer/feed/temp/feed_trainer.cpp
paddle/fluid/train/custom_trainer/feed/temp/feed_trainer.cpp
+1
-1
paddle/fluid/train/custom_trainer/feed/unit_test/test_create_programs.cc
...ain/custom_trainer/feed/unit_test/test_create_programs.cc
+3
-3
未找到文件。
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
浏览文件 @
79133eae
...
...
@@ -7,20 +7,25 @@ namespace custom_trainer {
namespace
feed
{
int
EpochAccessor
::
initialize
(
YAML
::
Node
config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
_model_root_path
=
config
[
"model_root_path"
].
as
<
std
::
string
>
()
+
"/"
;
_model_root_path
=
config
[
"model_root_path"
].
as
<
std
::
string
>
();
_trainer_context
=
context_ptr
.
get
();
if
(
context_ptr
->
file_system
==
nullptr
)
{
VLOG
(
0
)
<<
"file_system is not initialized"
;
return
-
1
;
}
_done_file_path
=
_model_root_path
;
if
(
config
[
"donefile"
])
{
_done_file_path
.
append
(
config
[
"donefile"
].
as
<
std
::
string
>
());
_done_file_path
=
_trainer_context
->
file_system
->
path_join
(
_model_root_path
,
config
[
"donefile"
].
as
<
std
::
string
>
());
}
else
{
_done_file_path
.
append
(
"epoch_donefile.txt"
);
_done_file_path
=
_trainer_context
->
file_system
->
path_join
(
_model_root_path
,
"epoch_donefile.txt"
);
}
if
(
!
context_ptr
->
file_system
->
exists
(
_done_file_path
))
{
if
(
!
_trainer_context
->
file_system
->
exists
(
_done_file_path
))
{
VLOG
(
0
)
<<
"missing done file, path:"
<<
_done_file_path
;
}
std
::
string
done_text
=
context_ptr
->
file_system
->
tail
(
_done_file_path
);
std
::
string
done_text
=
_trainer_context
->
file_system
->
tail
(
_done_file_path
);
_done_status
=
paddle
::
string
::
split_string
(
done_text
,
std
::
string
(
"
\t
"
));
_current_epoch_id
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
EpochIdField
);
_last_checkpoint_epoch_id
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
CheckpointIdField
);
...
...
@@ -67,23 +72,25 @@ namespace feed {
if
(
epoch_id
==
0
)
{
return
false
;
}
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceDelta
)
{
switch
(
save_way
)
{
case
ModelSaveWay
::
ModelSaveInferenceDelta
:
return
true
;
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceBase
)
{
case
ModelSaveWay
::
ModelSaveInferenceBase
:
return
is_last_epoch
(
epoch_id
);
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveTrainCheckpoint
)
{
case
ModelSaveWay
::
ModelSaveTrainCheckpoint
:
return
((
epoch_id
/
3600
)
%
8
)
==
0
;
}
return
false
;
}
std
::
string
HourlyEpochAccessor
::
model_save_path
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
)
{
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceDelta
)
{
return
_model_root_path
+
"/xbox/delta-"
+
std
::
to_string
(
epoch_id
);
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveInferenceBase
)
{
return
_model_root_path
+
"/xbox/base"
;
}
else
if
(
save_way
==
ModelSaveWay
::
ModelSaveTrainCheckpoint
)
{
return
_model_root_path
+
"/xbox/checkpoint"
;
switch
(
save_way
)
{
case
ModelSaveWay
::
ModelSaveInferenceDelta
:
return
_trainer_context
->
file_system
->
path_join
(
_model_root_path
,
"/xbox/delta-"
+
std
::
to_string
(
epoch_id
));
case
ModelSaveWay
::
ModelSaveInferenceBase
:
return
_trainer_context
->
file_system
->
path_join
(
_model_root_path
,
"/xbox/base"
);
case
ModelSaveWay
::
ModelSaveTrainCheckpoint
:
return
_trainer_context
->
file_system
->
path_join
(
_model_root_path
,
"/xbox/checkpoint"
);
}
return
""
;
}
...
...
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h
浏览文件 @
79133eae
...
...
@@ -52,6 +52,7 @@ public:
virtual
bool
need_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
virtual
std
::
string
model_save_path
(
uint64_t
epoch_id
,
ModelSaveWay
save_way
)
=
0
;
protected:
TrainerContext
*
_trainer_context
;
std
::
string
_done_file_path
;
std
::
string
_model_root_path
;
uint64_t
_current_epoch_id
=
0
;
...
...
paddle/fluid/train/custom_trainer/feed/conf/gflags.conf
浏览文件 @
79133eae
-
log_dir
=
log
-
v
=
10
paddle/fluid/train/custom_trainer/feed/conf/trainer.yaml
浏览文件 @
79133eae
...
...
@@ -11,14 +11,14 @@ io :
ugis
:
'
default'
:
'
feed_video,D3a0z8'
'
xingtian.afs.baidu.com:9902'
:
'
feed_video,D3a0z8'
local
:
default
:
class
:
LocalFileSystem
buffer_size
:
1024000
dataset
:
data_list
:
train_sample
:
prefetch_num
:
2
root_path
:
./sample
root_path
:
[
./sample
]
data_spit_interval
:
300
data_path_formater
:
'
%Y%m%d/%H%M'
data_reader
:
LineDataReader
...
...
paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc
浏览文件 @
79133eae
...
...
@@ -7,14 +7,14 @@ namespace feed {
int
Dataset
::
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
{
if
(
config
[
"data_list"
].
Type
()
!=
YAML
::
NodeType
::
Map
)
{
VLOG
(
0
)
<<
"miss data_list config in dataset, or type error please check"
;
LOG
(
FATAL
)
<<
"miss data_list config in dataset, or type error please check"
;
return
-
1
;
}
for
(
auto
&
data_config
:
config
[
"data_list"
])
{
std
::
string
name
=
data_config
.
first
.
as
<
std
::
string
>
();
auto
data_ptr
=
std
::
make_shared
<
DatasetContainer
>
();
if
(
data_ptr
->
initialize
(
data_config
.
second
,
context
)
!=
0
)
{
VLOG
(
0
)
<<
"dataset initialize failed, name:"
<<
name
;
LOG
(
FATAL
)
<<
"dataset initialize failed, name:"
<<
name
;
return
-
1
;
}
_data_containers
[
name
]
=
data_ptr
;
...
...
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc
浏览文件 @
79133eae
...
...
@@ -6,10 +6,10 @@
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h"
namespace
paddle
{
...
...
@@ -27,8 +27,7 @@ int DatasetContainer::initialize(
_dataset_list
[
i
].
reset
(
new
DatasetInfo
);
}
_data_root_paths
=
paddle
::
string
::
split_string
(
config
[
"root_path"
].
as
<
std
::
string
>
(),
" "
);
_data_root_paths
=
config
[
"root_path"
].
as
<
std
::
vector
<
std
::
string
>>
();
_data_split_interval
=
config
[
"data_spit_interval"
].
as
<
int
>
();
_data_path_formater
=
config
[
"data_path_formater"
].
as
<
std
::
string
>
();
std
::
string
data_reader_class
=
config
[
"data_reader"
].
as
<
std
::
string
>
();
...
...
@@ -66,7 +65,7 @@ void DatasetContainer::pre_detect_data(uint64_t epoch_id) {
for
(
int
i
=
0
;
i
<
_data_root_paths
.
size
()
&&
status
==
0
;
++
i
)
{
for
(
int
j
=
0
;
j
<
data_num
&&
status
==
0
;
++
j
)
{
std
::
string
path_suffix
=
format_timestamp
(
data_timestamp
+
j
*
_data_split_interval
,
_data_path_formater
);
std
::
string
data_dir
=
_
data_root_paths
[
i
]
+
"/"
+
path_suffix
;
std
::
string
data_dir
=
_
trainer_context
->
file_system
->
path_join
(
_data_root_paths
[
i
],
path_suffix
)
;
status
=
read_data_list
(
data_dir
,
data_path_list
);
}
}
...
...
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
浏览文件 @
79133eae
...
...
@@ -17,7 +17,7 @@ namespace {
int
ReadBinaryFile
(
const
std
::
string
&
filename
,
std
::
string
*
contents
)
{
std
::
ifstream
fin
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
if
(
!
fin
)
{
VLOG
(
2
)
<<
"Cannot open file "
<<
filename
;
LOG
(
FATAL
)
<<
"Cannot open file "
<<
filename
;
return
-
1
;
}
fin
.
seekg
(
0
,
std
::
ios
::
end
);
...
...
@@ -31,7 +31,7 @@ int ReadBinaryFile(const std::string& filename, std::string* contents) {
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
Load
(
paddle
::
framework
::
Executor
*
/*executor*/
,
const
std
::
string
&
model_filename
)
{
VLOG
(
3
)
<<
"loading model from "
<<
model_filename
;
LOG
(
INFO
)
<<
"loading model from "
<<
model_filename
;
std
::
string
program_desc_str
;
if
(
ReadBinaryFile
(
model_filename
,
&
program_desc_str
)
!=
0
)
{
return
nullptr
;
...
...
paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc
浏览文件 @
79133eae
...
...
@@ -19,17 +19,18 @@ public:
for
(
auto
&
prefix_fs
:
config
[
"file_systems"
])
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
)));
if
(
fs
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to create class: "
<<
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
);
LOG
(
FATAL
)
<<
"fail to create class: "
<<
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
);
return
-
1
;
}
if
(
fs
->
initialize
(
prefix_fs
.
second
,
context
)
!=
0
)
{
VLOG
(
2
)
<<
"fail to initialize class: "
<<
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
);
return
0
;
LOG
(
FATAL
)
<<
"fail to initialize class: "
<<
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
);
return
-
1
;
}
_file_system
.
emplace
(
prefix_fs
.
first
.
as
<
std
::
string
>
(
""
),
std
::
move
(
fs
));
}
}
if
(
_file_system
.
find
(
"default"
)
==
_file_system
.
end
())
{
LOG
(
WARNING
)
<<
"miss default file_system, use LocalFileSystem as default"
;
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
if
(
fs
==
nullptr
||
fs
->
initialize
(
YAML
::
Load
(
""
),
context
)
!=
0
)
{
return
-
1
;
...
...
@@ -82,7 +83,6 @@ public:
return
fs_it
->
second
.
get
();
}
}
VLOG
(
5
)
<<
"path: "
<<
path
<<
", select default file system"
;
return
_file_system
[
"default"
].
get
();
}
...
...
paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
浏览文件 @
79133eae
...
...
@@ -25,7 +25,7 @@ public:
}
}
if
(
_ugi
.
find
(
"default"
)
==
_ugi
.
end
())
{
VLOG
(
2
)
<<
"fail to load default ugi"
;
LOG
(
FATAL
)
<<
"fail to load default ugi"
;
return
-
1
;
}
return
0
;
...
...
@@ -62,7 +62,7 @@ public:
int64_t
file_size
(
const
std
::
string
&
path
)
override
{
_err_no
=
-
1
;
VLOG
(
2
)
<<
"not support"
;
LOG
(
FATAL
)
<<
"not support"
;
return
0
;
}
...
...
paddle/fluid/train/custom_trainer/feed/main.cc
浏览文件 @
79133eae
...
...
@@ -13,6 +13,7 @@ using namespace paddle::custom_trainer::feed;
DEFINE_string
(
feed_trainer_conf_path
,
"./conf/trainer.yaml"
,
"path of trainer conf"
);
int
main
(
int
argc
,
char
*
argv
[])
{
google
::
InitGoogleLogging
(
argv
[
0
]);
//gflags
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
std
::
string
gflag_conf
=
"./conf/gflags.conf"
;
...
...
paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc
浏览文件 @
79133eae
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
浏览文件 @
79133eae
...
...
@@ -76,7 +76,7 @@ int LearnerProcess::run() {
uint64_t
epoch_id
=
epoch_accessor
->
current_epoch_id
();
environment
->
log
(
EnvironmentRole
::
WORKER
,
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"Resume traine with epoch_id:%d label:%s"
,
epoch_id
,
_context_ptr
->
epoch_accessor
->
text
(
epoch_id
).
c_str
());
"Resume traine
r
with epoch_id:%d label:%s"
,
epoch_id
,
_context_ptr
->
epoch_accessor
->
text
(
epoch_id
).
c_str
());
//判断是否先dump出base
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
...
...
paddle/fluid/train/custom_trainer/feed/scripts/start_feed_trainer.sh
100644 → 100755
浏览文件 @
79133eae
#!bash
#!
/bin/
bash
export
LD_LIBRARY_PATH
=
LD_LIBRARY_PATH:./so
./bin/feed_trainer
./bin/feed_trainer
"
$@
"
paddle/fluid/train/custom_trainer/feed/temp/feed_trainer.cpp
浏览文件 @
79133eae
...
...
@@ -40,7 +40,7 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) {
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
Load
(
paddle
::
framework
::
Executor
*
executor
,
const
std
::
string
&
model_filename
)
{
VLOG
(
3
)
<<
"loading model from "
<<
model_filename
;
LOG
(
DEBUG
)
<<
"loading model from "
<<
model_filename
;
std
::
string
program_desc_str
;
ReadBinaryFile
(
model_filename
,
&
program_desc_str
);
...
...
paddle/fluid/train/custom_trainer/feed/unit_test/test_create_programs.cc
浏览文件 @
79133eae
...
...
@@ -133,9 +133,9 @@ TEST_F(CreateProgramsTest, example_network) {
auto
output_var
=
executor
->
var
<::
paddle
::
framework
::
LoDTensor
>
(
output_name
);
auto
output
=
output_var
.
data
<
float
>
()[
0
];
VLOG
(
3
)
<<
"loss: "
<<
loss
<<
std
::
endl
;
VLOG
(
3
)
<<
"label: "
<<
label_data
[
0
]
<<
std
::
endl
;
VLOG
(
3
)
<<
"output: "
<<
output
<<
std
::
endl
;
LOG
(
INFO
)
<<
"loss: "
<<
loss
<<
std
::
endl
;
LOG
(
INFO
)
<<
"label: "
<<
label_data
[
0
]
<<
std
::
endl
;
LOG
(
INFO
)
<<
"output: "
<<
output
<<
std
::
endl
;
ASSERT_NEAR
(
loss
,
pow
(
output
-
label_data
[
0
],
2
),
1e-8
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录