Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
17e0cb7c
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
17e0cb7c
编写于
8月 02, 2019
作者:
R
rensilin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
data_reader
Change-Id: Id829354b599d9d98824785be2a883480c94b9ffe
上级
1fcde8e9
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
170 addition
and
9 deletion
+170
-9
paddle/fluid/framework/io/fs.cc
paddle/fluid/framework/io/fs.cc
+20
-1
paddle/fluid/framework/io/fs.h
paddle/fluid/framework/io/fs.h
+4
-0
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
...le/fluid/train/custom_trainer/feed/dataset/data_reader.cc
+142
-0
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
+4
-8
paddle/fluid/train/custom_trainer/feed/so/libpaddle_fluid_avx_mklml.so
...train/custom_trainer/feed/so/libpaddle_fluid_avx_mklml.so
+0
-0
未找到文件。
paddle/fluid/framework/io/fs.cc
浏览文件 @
17e0cb7c
...
...
@@ -149,7 +149,7 @@ std::vector<std::string> localfs_list(const std::string& path) {
std
::
shared_ptr
<
FILE
>
pipe
;
int
err_no
=
0
;
pipe
=
shell_popen
(
string
::
format_string
(
"find %s -
type f -maxdepth 1
"
,
path
.
c_str
()),
"r"
,
string
::
format_string
(
"find %s -
maxdepth 1 -type f
"
,
path
.
c_str
()),
"r"
,
&
err_no
);
string
::
LineFileReader
reader
;
std
::
vector
<
std
::
string
>
list
;
...
...
@@ -452,5 +452,24 @@ void fs_mkdir(const std::string& path) {
LOG
(
FATAL
)
<<
"Not supported"
;
}
}
std
::
string
fs_path_join
(
const
std
::
string
&
dir
,
const
std
::
string
&
path
)
{
if
(
dir
.
empty
())
{
return
path
;
}
if
(
dir
.
back
()
==
'/'
)
{
return
dir
+
path
;
}
return
dir
+
'/'
+
path
;
}
std
::
pair
<
std
::
string
,
std
::
string
>
fs_path_split
(
const
std
::
string
&
path
)
{
size_t
pos
=
path
.
find_last_of
(
'/'
);
if
(
pos
==
std
::
string
::
npos
)
{
return
{
"."
,
path
};
}
return
{
path
.
substr
(
0
,
pos
),
path
.
substr
(
pos
+
1
)};
}
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/io/fs.h
浏览文件 @
17e0cb7c
...
...
@@ -97,5 +97,9 @@ extern std::string fs_tail(const std::string& path);
extern
bool
fs_exists
(
const
std
::
string
&
path
);
extern
void
fs_mkdir
(
const
std
::
string
&
path
);
extern
std
::
string
fs_path_join
(
const
std
::
string
&
dir
,
const
std
::
string
&
path
);
extern
std
::
pair
<
std
::
string
,
std
::
string
>
fs_path_split
(
const
std
::
string
&
path
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
0 → 100644
浏览文件 @
17e0cb7c
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
#include <cstdio>
#include <glog/logging.h>
#include "paddle/fluid/framework/io/fs.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
LineDataParser
:
public
DataParser
{
public:
LineDataParser
()
{}
virtual
~
LineDataParser
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
{
return
0
;
}
virtual
int
parse
(
const
char
*
str
,
size_t
len
,
DataItem
&
data
)
const
{
size_t
pos
=
0
;
while
(
str
[
pos
]
!=
' '
)
{
if
(
pos
>=
len
)
{
VLOG
(
2
)
<<
"fail to parse line, strlen: "
<<
len
;
return
-
1
;
}
++
pos
;
}
data
.
id
.
assign
(
str
,
pos
);
data
.
data
.
assign
(
str
+
pos
+
1
,
len
-
pos
-
1
);
return
0
;
}
virtual
int
parse
(
const
char
*
str
,
DataItem
&
data
)
const
{
size_t
pos
=
0
;
while
(
str
[
pos
]
!=
' '
)
{
if
(
str
[
pos
]
==
'\0'
)
{
VLOG
(
2
)
<<
"fail to parse line, get '
\\
0' at pos: "
<<
pos
;
return
-
1
;
}
++
pos
;
}
data
.
id
.
assign
(
str
,
pos
);
data
.
data
.
assign
(
str
+
pos
+
1
);
return
0
;
}
virtual
int
parse_to_sample
(
const
DataItem
&
data
,
SampleInstance
&
instance
)
const
{
return
0
;
}
};
REGISTER_CLASS
(
DataParser
,
LineDataParser
);
int
DataReader
::
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
{
_parser
.
reset
(
CREATE_CLASS
(
DataParser
,
config
[
"parser"
][
"class"
].
as
<
std
::
string
>
()));
if
(
_parser
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to get parser: "
<<
config
[
"parser"
][
"class"
].
as
<
std
::
string
>
();
return
-
1
;
}
if
(
_parser
->
initialize
(
config
[
"parser"
],
context
)
!=
0
)
{
VLOG
(
2
)
<<
"fail to initialize parser"
<<
config
[
"parser"
][
"class"
].
as
<
std
::
string
>
();
return
-
1
;
}
_pipeline_cmd
=
config
[
"pipeline_cmd"
].
as
<
std
::
string
>
();
return
0
;
}
class
LineDataReader
:
public
DataReader
{
public:
LineDataReader
()
{}
virtual
~
LineDataReader
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
{
if
(
DataReader
::
initialize
(
config
,
context
)
!=
0
)
{
return
-
1
;
}
_done_file_name
=
config
[
"done_file"
].
as
<
std
::
string
>
();
_buffer_size
=
config
[
"buffer_size"
].
as
<
int
>
(
1024
);
_buffer
.
reset
(
new
char
[
_buffer_size
]);
return
0
;
}
//判断样本数据是否已就绪,就绪表明可以开始download
virtual
bool
is_data_ready
(
const
std
::
string
&
data_dir
)
{
auto
done_file_path
=
::
paddle
::
framework
::
fs_path_join
(
data_dir
,
_done_file_name
);
if
(
::
paddle
::
framework
::
fs_exists
(
done_file_path
))
{
return
true
;
}
return
false
;
}
//读取数据样本流中
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
::
paddle
::
framework
::
Channel
<
DataItem
>
data_channel
)
{
::
paddle
::
framework
::
ChannelWriter
<
DataItem
>
writer
(
data_channel
.
get
());
DataItem
data_item
;
if
(
_buffer_size
<=
0
||
_buffer
==
nullptr
)
{
VLOG
(
2
)
<<
"no buffer"
;
return
-
1
;
}
for
(
const
auto
&
filename
:
::
paddle
::
framework
::
fs_list
(
data_dir
))
{
if
(
::
paddle
::
framework
::
fs_path_split
(
filename
).
second
==
_done_file_name
)
{
continue
;
}
int
err_no
;
std
::
shared_ptr
<
FILE
>
fin
=
::
paddle
::
framework
::
fs_open_read
(
filename
,
&
err_no
,
_pipeline_cmd
);
if
(
err_no
!=
0
)
{
return
-
1
;
}
while
(
fgets
(
_buffer
.
get
(),
_buffer_size
,
fin
.
get
()))
{
if
(
_parser
->
parse
(
_buffer
.
get
(),
data_item
)
!=
0
)
{
return
-
1
;
}
writer
<<
std
::
move
(
data_item
);
}
if
(
ferror
(
fin
.
get
())
!=
0
)
{
VLOG
(
2
)
<<
"fail to read file: "
<<
filename
;
return
-
1
;
}
}
writer
.
Flush
();
if
(
!
writer
)
{
VLOG
(
2
)
<<
"fail when write to channel"
;
return
-
1
;
}
return
0
;
}
virtual
const
DataParser
*
get_parser
()
{
return
_parser
.
get
();
}
private:
std
::
string
_done_file_name
;
// without data_dir
int
_buffer_size
=
0
;
std
::
unique_ptr
<
char
[]
>
_buffer
;
};
REGISTER_CLASS
(
DataReader
,
LineDataReader
);
}
//namespace feed
}
//namespace custom_trainer
}
//namespace paddle
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
浏览文件 @
17e0cb7c
...
...
@@ -42,9 +42,10 @@ public:
virtual
~
DataParser
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
=
0
;
virtual
int
parse
(
const
std
::
string
&
str
,
DataItem
&
data
)
const
{
return
parse
(
str
.
c_str
(),
str
.
size
(),
data
);
return
parse
(
str
.
c_str
(),
data
);
}
virtual
int
parse
(
const
char
*
str
,
size_t
len
,
DataItem
&
data
)
const
=
0
;
virtual
int
parse
(
const
char
*
str
,
DataItem
&
data
)
const
=
0
;
virtual
int
parse_to_sample
(
const
DataItem
&
data
,
SampleInstance
&
instance
)
const
=
0
;
};
REGISTER_REGISTERER
(
DataParser
);
...
...
@@ -53,7 +54,7 @@ class DataReader {
public:
DataReader
()
{}
virtual
~
DataReader
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
=
0
;
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
);
//判断样本数据是否已就绪,就绪表明可以开始download
virtual
bool
is_data_ready
(
const
std
::
string
&
data_dir
)
=
0
;
//读取数据样本流中
...
...
@@ -61,17 +62,12 @@ public:
virtual
const
DataParser
*
get_parser
()
{
return
_parser
.
get
();
}
pr
ivate
:
pr
otected
:
std
::
shared_ptr
<
DataParser
>
_parser
;
//数据格式转换
std
::
string
_pipeline_cmd
;
//将文件流,重定向到pipeline_cmd,再读入
};
REGISTER_REGISTERER
(
DataReader
);
//TODO
//可读取HDFS/DISK上数据的Reader,数据按行分隔
//HDFS/DISK - FileLineReader
}
//namespace feed
}
//namespace custom_trainer
}
//namespace paddle
paddle/fluid/train/custom_trainer/feed/so/libpaddle_fluid_avx_mklml.so
浏览文件 @
17e0cb7c
无法预览此类型文件
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录