Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
1f9f8ebb
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1f9f8ebb
编写于
8月 05, 2019
作者:
R
rensilin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
LineDataReader_filename_prefix
Change-Id: I47c64499e664cd2d4f403a9fb181333fe291acd6
上级
fae7e884
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
65 addition
and
16 deletion
+65
-16
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
...le/fluid/train/custom_trainer/feed/dataset/data_reader.cc
+25
-9
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
+1
-0
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc
...id/train/custom_trainer/feed/unit_test/test_datareader.cc
+39
-7
未找到文件。
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
浏览文件 @
1f9f8ebb
...
...
@@ -86,35 +86,50 @@ public:
}
_done_file_name
=
config
[
"done_file"
].
as
<
std
::
string
>
();
_buffer_size
=
config
[
"buffer_size"
].
as
<
int
>
(
1024
);
_filename_prefix
=
config
[
"filename_prefix"
].
as
<
std
::
string
>
(
""
);
_buffer
.
reset
(
new
char
[
_buffer_size
]);
return
0
;
}
//判断样本数据是否已就绪,就绪表明可以开始download
virtual
bool
is_data_ready
(
const
std
::
string
&
data_dir
)
{
auto
done_file_path
=
::
paddle
::
framework
::
fs_path_join
(
data_dir
,
_done_file_name
);
if
(
::
paddle
::
framework
::
fs_exists
(
done_file_path
))
{
auto
done_file_path
=
framework
::
fs_path_join
(
data_dir
,
_done_file_name
);
if
(
framework
::
fs_exists
(
done_file_path
))
{
return
true
;
}
return
false
;
}
virtual
std
::
vector
<
std
::
string
>
data_file_list
(
const
std
::
string
&
data_dir
)
{
if
(
_filename_prefix
.
empty
())
{
return
framework
::
fs_list
(
data_dir
);
}
std
::
vector
<
std
::
string
>
data_files
;
for
(
auto
&
filepath
:
framework
::
fs_list
(
data_dir
))
{
auto
filename
=
framework
::
fs_path_split
(
filepath
).
second
;
if
(
filename
.
size
()
>=
_filename_prefix
.
size
()
&&
filename
.
substr
(
0
,
_filename_prefix
.
size
())
==
_filename_prefix
)
{
data_files
.
push_back
(
std
::
move
(
filepath
));
}
}
return
data_files
;
}
//读取数据样本流中
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
::
paddle
::
framework
::
Channel
<
DataItem
>
data_channel
)
{
::
paddle
::
framework
::
ChannelWriter
<
DataItem
>
writer
(
data_channel
.
get
());
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
framework
::
Channel
<
DataItem
>
data_channel
)
{
framework
::
ChannelWriter
<
DataItem
>
writer
(
data_channel
.
get
());
DataItem
data_item
;
if
(
_buffer_size
<=
0
||
_buffer
==
nullptr
)
{
VLOG
(
2
)
<<
"no buffer"
;
return
-
1
;
}
for
(
const
auto
&
file
name
:
::
paddle
::
framework
::
fs
_list
(
data_dir
))
{
if
(
::
paddle
::
framework
::
fs_path_split
(
filename
).
second
==
_done_file_name
)
{
for
(
const
auto
&
file
path
:
data_file
_list
(
data_dir
))
{
if
(
framework
::
fs_path_split
(
filepath
).
second
==
_done_file_name
)
{
continue
;
}
int
err_no
=
0
;
std
::
shared_ptr
<
FILE
>
fin
=
::
paddle
::
framework
::
fs_open_read
(
filename
,
&
err_no
,
_pipeline_cmd
);
std
::
shared_ptr
<
FILE
>
fin
=
framework
::
fs_open_read
(
filepath
,
&
err_no
,
_pipeline_cmd
);
if
(
err_no
!=
0
)
{
VLOG
(
2
)
<<
"fail to open file: "
<<
file
name
<<
", with cmd: "
<<
_pipeline_cmd
;
VLOG
(
2
)
<<
"fail to open file: "
<<
file
path
<<
", with cmd: "
<<
_pipeline_cmd
;
return
-
1
;
}
while
(
fgets
(
_buffer
.
get
(),
_buffer_size
,
fin
.
get
()))
{
...
...
@@ -124,7 +139,7 @@ public:
writer
<<
std
::
move
(
data_item
);
}
if
(
ferror
(
fin
.
get
())
!=
0
)
{
VLOG
(
2
)
<<
"fail to read file: "
<<
file
name
;
VLOG
(
2
)
<<
"fail to read file: "
<<
file
path
;
return
-
1
;
}
}
...
...
@@ -144,6 +159,7 @@ private:
std
::
string
_done_file_name
;
// without data_dir
int
_buffer_size
=
0
;
std
::
unique_ptr
<
char
[]
>
_buffer
;
std
::
string
_filename_prefix
;
};
REGISTER_CLASS
(
DataReader
,
LineDataReader
);
...
...
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
浏览文件 @
1f9f8ebb
...
...
@@ -59,6 +59,7 @@ public:
virtual
bool
is_data_ready
(
const
std
::
string
&
data_dir
)
=
0
;
//读取数据样本流中
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
::
paddle
::
framework
::
Channel
<
DataItem
>
data_channel
)
=
0
;
virtual
std
::
vector
<
std
::
string
>
data_file_list
(
const
std
::
string
&
data_dir
)
=
0
;
virtual
const
DataParser
*
get_parser
()
{
return
_parser
.
get
();
}
...
...
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc
浏览文件 @
1f9f8ebb
...
...
@@ -105,18 +105,16 @@ TEST_F(DataReaderTest, LineDataReader) {
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
YAML
::
Node
config
=
YAML
::
Load
(
"parser:
\n
"
" class: LineDataParser
\n
"
"pipeline_cmd: cat
\n
"
"done_file: done_file"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
config
=
YAML
::
Load
(
"parser:
\n
"
auto
config
=
YAML
::
Load
(
"parser:
\n
"
" class: LineDataParser
\n
"
"pipeline_cmd: cat
\n
"
"done_file: done_file
\n
"
"buffer_size: 128"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
ASSERT_EQ
(
2
,
data_file_list
.
size
());
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"a.txt"
),
data_file_list
[
0
]);
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"b.txt"
),
data_file_list
[
1
]);
ASSERT_FALSE
(
data_reader
->
is_data_ready
(
test_data_dir
));
std
::
ofstream
fout
(
framework
::
fs_path_join
(
test_data_dir
,
"done_file"
));
...
...
@@ -155,6 +153,40 @@ TEST_F(DataReaderTest, LineDataReader) {
ASSERT_FALSE
(
reader
);
}
TEST_F
(
DataReaderTest
,
LineDataReader_filename_prefix
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
" class: LineDataParser
\n
"
"pipeline_cmd: cat
\n
"
"done_file: done_file
\n
"
"filename_prefix: a"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
ASSERT_EQ
(
1
,
data_file_list
.
size
());
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"a.txt"
),
data_file_list
[
0
]);
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
ASSERT_NE
(
nullptr
,
channel
);
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_data_dir
,
channel
));
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
DataItem
data_item
;
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"abc"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"123456"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"def"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"234567"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_FALSE
(
reader
);
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录