Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
825e947c
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
825e947c
编写于
8月 20, 2019
作者:
R
rensilin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dump params
Change-Id: I073a955ac9aa13e4afdfb869121b83f95062cdac
上级
79133eae
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
16 addition
and
14 deletion
+16
-14
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
...luid/train/custom_trainer/feed/accessor/epoch_accessor.cc
+5
-5
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
...le/fluid/train/custom_trainer/feed/dataset/data_reader.cc
+3
-2
paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc
...le/fluid/train/custom_trainer/feed/io/auto_file_system.cc
+1
-1
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+2
-2
paddle/fluid/train/custom_trainer/feed/scripts/create_programs.py
...luid/train/custom_trainer/feed/scripts/create_programs.py
+2
-1
paddle/fluid/train/custom_trainer/feed/temp/feed_trainer.cpp
paddle/fluid/train/custom_trainer/feed/temp/feed_trainer.cpp
+1
-1
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc
...id/train/custom_trainer/feed/unit_test/test_datareader.cc
+2
-2
未找到文件。
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
浏览文件 @
825e947c
...
@@ -14,18 +14,18 @@ namespace feed {
...
@@ -14,18 +14,18 @@ namespace feed {
VLOG
(
0
)
<<
"file_system is not initialized"
;
VLOG
(
0
)
<<
"file_system is not initialized"
;
return
-
1
;
return
-
1
;
}
}
auto
fs
=
_trainer_context
->
file_system
.
get
();
if
(
config
[
"donefile"
])
{
if
(
config
[
"donefile"
])
{
_done_file_path
=
_trainer_context
->
file_system
->
path_join
(
_model_root_path
,
config
[
"donefile"
].
as
<
std
::
string
>
());
_done_file_path
=
fs
->
path_join
(
_model_root_path
,
config
[
"donefile"
].
as
<
std
::
string
>
());
}
else
{
}
else
{
_done_file_path
=
_trainer_context
->
file_system
->
path_join
(
_model_root_path
,
"epoch_donefile.txt"
);
_done_file_path
=
fs
->
path_join
(
_model_root_path
,
"epoch_donefile.txt"
);
}
}
if
(
!
_trainer_context
->
file_system
->
exists
(
_done_file_path
))
{
if
(
!
fs
->
exists
(
_done_file_path
))
{
VLOG
(
0
)
<<
"missing done file, path:"
<<
_done_file_path
;
VLOG
(
0
)
<<
"missing done file, path:"
<<
_done_file_path
;
}
}
std
::
string
done_text
=
_trainer_context
->
file_system
->
tail
(
_done_file_path
);
std
::
string
done_text
=
fs
->
tail
(
_done_file_path
);
_done_status
=
paddle
::
string
::
split_string
(
done_text
,
std
::
string
(
"
\t
"
));
_done_status
=
paddle
::
string
::
split_string
(
done_text
,
std
::
string
(
"
\t
"
));
_current_epoch_id
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
EpochIdField
);
_current_epoch_id
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
EpochIdField
);
_last_checkpoint_epoch_id
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
CheckpointIdField
);
_last_checkpoint_epoch_id
=
get_status
<
uint64_t
>
(
EpochStatusFiled
::
CheckpointIdField
);
...
...
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
浏览文件 @
825e947c
...
@@ -534,7 +534,7 @@ public:
...
@@ -534,7 +534,7 @@ public:
size_t
buffer_size
=
0
;
size_t
buffer_size
=
0
;
ssize_t
line_len
=
0
;
ssize_t
line_len
=
0
;
while
((
line_len
=
getline
(
&
buffer
,
&
buffer_size
,
fin
.
get
()))
!=
-
1
)
{
while
((
line_len
=
getline
(
&
buffer
,
&
buffer_size
,
fin
.
get
()))
!=
-
1
)
{
// 去掉行
位
回车
// 去掉行
尾
回车
if
(
line_len
>
0
&&
buffer
[
line_len
-
1
]
==
'\n'
)
{
if
(
line_len
>
0
&&
buffer
[
line_len
-
1
]
==
'\n'
)
{
buffer
[
--
line_len
]
=
'\0'
;
buffer
[
--
line_len
]
=
'\0'
;
}
}
...
@@ -547,7 +547,8 @@ public:
...
@@ -547,7 +547,8 @@ public:
VLOG
(
5
)
<<
"parse data: "
<<
data_item
.
id
<<
" "
<<
data_item
.
data
<<
", filename: "
<<
filepath
<<
", thread_num: "
<<
thread_num
<<
", max_threads: "
<<
max_threads
;
VLOG
(
5
)
<<
"parse data: "
<<
data_item
.
id
<<
" "
<<
data_item
.
data
<<
", filename: "
<<
filepath
<<
", thread_num: "
<<
thread_num
<<
", max_threads: "
<<
max_threads
;
if
(
writer
==
nullptr
)
{
if
(
writer
==
nullptr
)
{
if
(
!
data_channel
->
Put
(
std
::
move
(
data_item
)))
{
if
(
!
data_channel
->
Put
(
std
::
move
(
data_item
)))
{
VLOG
(
2
)
<<
"fail to put data, thread_num: "
<<
thread_num
;
LOG
(
WARNING
)
<<
"fail to put data, thread_num: "
<<
thread_num
;
is_failed
=
true
;
}
}
}
else
{
}
else
{
(
*
writer
)
<<
std
::
move
(
data_item
);
(
*
writer
)
<<
std
::
move
(
data_item
);
...
...
paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc
浏览文件 @
825e947c
...
@@ -77,7 +77,7 @@ public:
...
@@ -77,7 +77,7 @@ public:
FileSystem
*
get_file_system
(
const
std
::
string
&
path
)
{
FileSystem
*
get_file_system
(
const
std
::
string
&
path
)
{
auto
pos
=
path
.
find_first_of
(
":"
);
auto
pos
=
path
.
find_first_of
(
":"
);
if
(
pos
!=
std
::
string
::
npos
)
{
if
(
pos
!=
std
::
string
::
npos
)
{
auto
substr
=
path
.
substr
(
0
,
pos
+
1
);
auto
substr
=
path
.
substr
(
0
,
pos
);
// example: afs:/xxx -> afs
auto
fs_it
=
_file_system
.
find
(
substr
);
auto
fs_it
=
_file_system
.
find
(
substr
);
if
(
fs_it
!=
_file_system
.
end
())
{
if
(
fs_it
!=
_file_system
.
end
())
{
return
fs_it
->
second
.
get
();
return
fs_it
->
second
.
get
();
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
浏览文件 @
825e947c
...
@@ -76,7 +76,7 @@ int LearnerProcess::run() {
...
@@ -76,7 +76,7 @@ int LearnerProcess::run() {
uint64_t
epoch_id
=
epoch_accessor
->
current_epoch_id
();
uint64_t
epoch_id
=
epoch_accessor
->
current_epoch_id
();
environment
->
log
(
EnvironmentRole
::
WORKER
,
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
environment
->
log
(
EnvironmentRole
::
WORKER
,
EnvironmentLogType
::
MASTER_LOG
,
EnvironmentLogLevel
::
NOTICE
,
"Resume train
er
with epoch_id:%d label:%s"
,
epoch_id
,
_context_ptr
->
epoch_accessor
->
text
(
epoch_id
).
c_str
());
"Resume train
ing
with epoch_id:%d label:%s"
,
epoch_id
,
_context_ptr
->
epoch_accessor
->
text
(
epoch_id
).
c_str
());
//判断是否先dump出base
//判断是否先dump出base
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
...
@@ -108,7 +108,7 @@ int LearnerProcess::run() {
...
@@ -108,7 +108,7 @@ int LearnerProcess::run() {
for
(
int
thread_id
=
0
;
thread_id
<
_train_thread_num
;
++
thread_id
)
{
for
(
int
thread_id
=
0
;
thread_id
<
_train_thread_num
;
++
thread_id
)
{
train_threads
[
i
].
reset
(
new
std
::
thread
([
this
](
int
exe_idx
,
int
thread_idx
)
{
train_threads
[
i
].
reset
(
new
std
::
thread
([
this
](
int
exe_idx
,
int
thread_idx
)
{
auto
*
executor
=
_threads_executor
[
thread_idx
][
exe_idx
].
get
();
auto
*
executor
=
_threads_executor
[
thread_idx
][
exe_idx
].
get
();
run_executor
(
executor
);
run_executor
(
executor
);
},
i
,
thread_id
));
},
i
,
thread_id
));
}
}
for
(
int
i
=
0
;
i
<
_train_thread_num
;
++
i
)
{
for
(
int
i
=
0
;
i
<
_train_thread_num
;
++
i
)
{
...
...
paddle/fluid/train/custom_trainer/feed/scripts/create_programs.py
浏览文件 @
825e947c
...
@@ -119,11 +119,12 @@ class ModelBuilder:
...
@@ -119,11 +119,12 @@ class ModelBuilder:
'inputs'
:
[{
"name"
:
var
.
name
,
"shape"
:
var
.
shape
}
for
var
in
inputs
],
'inputs'
:
[{
"name"
:
var
.
name
,
"shape"
:
var
.
shape
}
for
var
in
inputs
],
'outputs'
:
[{
"name"
:
var
.
name
,
"shape"
:
var
.
shape
}
for
var
in
outputs
],
'outputs'
:
[{
"name"
:
var
.
name
,
"shape"
:
var
.
shape
}
for
var
in
outputs
],
'labels'
:
[{
"name"
:
var
.
name
,
"shape"
:
var
.
shape
}
for
var
in
labels
],
'labels'
:
[{
"name"
:
var
.
name
,
"shape"
:
var
.
shape
}
for
var
in
labels
],
'vars'
:
[{
"name"
:
var
.
name
,
"shape"
:
var
.
shape
}
for
var
in
main_program
.
list_vars
()
if
fluid
.
io
.
is_parameter
(
var
)],
'loss'
:
loss
.
name
,
'loss'
:
loss
.
name
,
}
}
with
open
(
model_desc_path
,
'w'
)
as
f
:
with
open
(
model_desc_path
,
'w'
)
as
f
:
yaml
.
safe_dump
(
model_desc
,
f
,
encoding
=
'utf-8'
,
allow_unicode
=
True
)
yaml
.
safe_dump
(
model_desc
,
f
,
encoding
=
'utf-8'
,
allow_unicode
=
True
,
default_flow_style
=
None
)
def
main
(
argv
):
def
main
(
argv
):
...
...
paddle/fluid/train/custom_trainer/feed/temp/feed_trainer.cpp
浏览文件 @
825e947c
...
@@ -40,7 +40,7 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) {
...
@@ -40,7 +40,7 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) {
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
Load
(
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
Load
(
paddle
::
framework
::
Executor
*
executor
,
const
std
::
string
&
model_filename
)
{
paddle
::
framework
::
Executor
*
executor
,
const
std
::
string
&
model_filename
)
{
LOG
(
DEBUG
)
<<
"loading model from "
<<
model_filename
;
VLOG
(
3
)
<<
"loading model from "
<<
model_filename
;
std
::
string
program_desc_str
;
std
::
string
program_desc_str
;
ReadBinaryFile
(
model_filename
,
&
program_desc_str
);
ReadBinaryFile
(
model_filename
,
&
program_desc_str
);
...
...
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc
浏览文件 @
825e947c
...
@@ -193,14 +193,14 @@ TEST_F(DataReaderTest, LineDataReader_FileSystem) {
...
@@ -193,14 +193,14 @@ TEST_F(DataReaderTest, LineDataReader_FileSystem) {
"file_system:
\n
"
"file_system:
\n
"
" class: AutoFileSystem
\n
"
" class: AutoFileSystem
\n
"
" file_systems:
\n
"
" file_systems:
\n
"
" 'afs
:
': &HDFS
\n
"
" 'afs': &HDFS
\n
"
" class: HadoopFileSystem
\n
"
" class: HadoopFileSystem
\n
"
" hdfs_command: 'hadoop fs'
\n
"
" hdfs_command: 'hadoop fs'
\n
"
" ugis:
\n
"
" ugis:
\n
"
" 'default': 'feed_video,D3a0z8'
\n
"
" 'default': 'feed_video,D3a0z8'
\n
"
" 'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8'
\n
"
" 'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8'
\n
"
"
\n
"
"
\n
"
" 'hdfs
:
': *HDFS
\n
"
);
" 'hdfs': *HDFS
\n
"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
{
{
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录