Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
379235f4
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
379235f4
编写于
8月 06, 2019
作者:
R
rensilin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
file_system_ut
Change-Id: I96c0bc535a0f49a92e8b987df5cc06c5eca4758e
上级
501c9f25
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
238 addition
and
88 deletion
+238
-88
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
...le/fluid/train/custom_trainer/feed/dataset/data_reader.cc
+64
-29
paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc
...le/fluid/train/custom_trainer/feed/io/auto_file_system.cc
+2
-2
paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
.../fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
+53
-23
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc
...id/train/custom_trainer/feed/unit_test/test_datareader.cc
+112
-31
paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc
...luid/train/custom_trainer/feed/unit_test/test_executor.cc
+7
-3
未找到文件。
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
浏览文件 @
379235f4
...
@@ -4,13 +4,13 @@
...
@@ -4,13 +4,13 @@
#include <glog/logging.h>
#include <glog/logging.h>
#include "paddle/fluid/
framework/io/fs
.h"
#include "paddle/fluid/
train/custom_trainer/feed/io/file_system
.h"
namespace
paddle
{
namespace
paddle
{
namespace
custom_trainer
{
namespace
custom_trainer
{
namespace
feed
{
namespace
feed
{
class
LineDataParser
:
public
DataParser
{
class
LineDataParser
:
public
DataParser
{
public:
public:
LineDataParser
()
{}
LineDataParser
()
{}
...
@@ -29,7 +29,7 @@ public:
...
@@ -29,7 +29,7 @@ public:
VLOG
(
2
)
<<
"fail to parse line: "
<<
std
::
string
(
str
,
len
)
<<
", strlen: "
<<
len
;
VLOG
(
2
)
<<
"fail to parse line: "
<<
std
::
string
(
str
,
len
)
<<
", strlen: "
<<
len
;
return
-
1
;
return
-
1
;
}
}
VLOG
(
5
)
<<
"getline: "
<<
str
<<
" , pos: "
<<
pos
<<
", len: "
<<
len
;
VLOG
(
5
)
<<
"getline: "
<<
str
<<
" , pos: "
<<
pos
<<
", len: "
<<
len
;
data
.
id
.
assign
(
str
,
pos
);
data
.
id
.
assign
(
str
,
pos
);
data
.
data
.
assign
(
str
+
pos
+
1
,
len
-
pos
-
1
);
data
.
data
.
assign
(
str
+
pos
+
1
,
len
-
pos
-
1
);
if
(
!
data
.
data
.
empty
()
&&
data
.
data
.
back
()
==
'\n'
)
{
if
(
!
data
.
data
.
empty
()
&&
data
.
data
.
back
()
==
'\n'
)
{
...
@@ -47,7 +47,7 @@ public:
...
@@ -47,7 +47,7 @@ public:
VLOG
(
2
)
<<
"fail to parse line: "
<<
str
<<
", get '
\\
0' at pos: "
<<
pos
;
VLOG
(
2
)
<<
"fail to parse line: "
<<
str
<<
", get '
\\
0' at pos: "
<<
pos
;
return
-
1
;
return
-
1
;
}
}
VLOG
(
5
)
<<
"getline: "
<<
str
<<
" , pos: "
<<
pos
;
VLOG
(
5
)
<<
"getline: "
<<
str
<<
" , pos: "
<<
pos
;
data
.
id
.
assign
(
str
,
pos
);
data
.
id
.
assign
(
str
,
pos
);
data
.
data
.
assign
(
str
+
pos
+
1
);
data
.
data
.
assign
(
str
+
pos
+
1
);
if
(
!
data
.
data
.
empty
()
&&
data
.
data
.
back
()
==
'\n'
)
{
if
(
!
data
.
data
.
empty
()
&&
data
.
data
.
back
()
==
'\n'
)
{
...
@@ -88,13 +88,30 @@ public:
...
@@ -88,13 +88,30 @@ public:
_buffer_size
=
config
[
"buffer_size"
].
as
<
int
>
(
1024
);
_buffer_size
=
config
[
"buffer_size"
].
as
<
int
>
(
1024
);
_filename_prefix
=
config
[
"filename_prefix"
].
as
<
std
::
string
>
(
""
);
_filename_prefix
=
config
[
"filename_prefix"
].
as
<
std
::
string
>
(
""
);
_buffer
.
reset
(
new
char
[
_buffer_size
]);
_buffer
.
reset
(
new
char
[
_buffer_size
]);
if
(
config
[
"file_system"
]
&&
config
[
"file_system"
][
"class"
])
{
_file_system
.
reset
(
CREATE_CLASS
(
FileSystem
,
config
[
"file_system"
][
"class"
].
as
<
std
::
string
>
()));
if
(
_file_system
==
nullptr
||
_file_system
->
initialize
(
config
[
"file_system"
],
context
)
!=
0
)
{
VLOG
(
2
)
<<
"fail to create class: "
<<
config
[
"file_system"
][
"class"
].
as
<
std
::
string
>
();
return
-
1
;
}
}
else
{
_file_system
.
reset
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
if
(
_file_system
==
nullptr
||
_file_system
->
initialize
(
YAML
::
Load
(
""
),
context
)
!=
0
)
{
VLOG
(
2
)
<<
"fail to init file system"
;
return
-
1
;
}
}
return
0
;
return
0
;
}
}
//判断样本数据是否已就绪,就绪表明可以开始download
//判断样本数据是否已就绪,就绪表明可以开始download
virtual
bool
is_data_ready
(
const
std
::
string
&
data_dir
)
{
virtual
bool
is_data_ready
(
const
std
::
string
&
data_dir
)
{
auto
done_file_path
=
framework
::
fs_
path_join
(
data_dir
,
_done_file_name
);
auto
done_file_path
=
_file_system
->
path_join
(
data_dir
,
_done_file_name
);
if
(
framework
::
fs_
exists
(
done_file_path
))
{
if
(
_file_system
->
exists
(
done_file_path
))
{
return
true
;
return
true
;
}
}
return
false
;
return
false
;
...
@@ -102,12 +119,13 @@ public:
...
@@ -102,12 +119,13 @@ public:
virtual
std
::
vector
<
std
::
string
>
data_file_list
(
const
std
::
string
&
data_dir
)
{
virtual
std
::
vector
<
std
::
string
>
data_file_list
(
const
std
::
string
&
data_dir
)
{
if
(
_filename_prefix
.
empty
())
{
if
(
_filename_prefix
.
empty
())
{
return
framework
::
fs_
list
(
data_dir
);
return
_file_system
->
list
(
data_dir
);
}
}
std
::
vector
<
std
::
string
>
data_files
;
std
::
vector
<
std
::
string
>
data_files
;
for
(
auto
&
filepath
:
framework
::
fs_list
(
data_dir
))
{
for
(
auto
&
filepath
:
_file_system
->
list
(
data_dir
))
{
auto
filename
=
framework
::
fs_path_split
(
filepath
).
second
;
auto
filename
=
_file_system
->
path_split
(
filepath
).
second
;
if
(
filename
.
size
()
>=
_filename_prefix
.
size
()
&&
filename
.
substr
(
0
,
_filename_prefix
.
size
())
==
_filename_prefix
)
{
if
(
filename
.
size
()
>=
_filename_prefix
.
size
()
&&
filename
.
substr
(
0
,
_filename_prefix
.
size
())
==
_filename_prefix
)
{
data_files
.
push_back
(
std
::
move
(
filepath
));
data_files
.
push_back
(
std
::
move
(
filepath
));
}
}
}
}
...
@@ -116,35 +134,50 @@ public:
...
@@ -116,35 +134,50 @@ public:
//读取数据样本流中
//读取数据样本流中
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
framework
::
Channel
<
DataItem
>
data_channel
)
{
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
framework
::
Channel
<
DataItem
>
data_channel
)
{
framework
::
ChannelWriter
<
DataItem
>
writer
(
data_channel
.
get
());
auto
deleter
=
[](
framework
::
ChannelWriter
<
DataItem
>
*
writer
)
{
if
(
writer
)
{
writer
->
Flush
();
VLOG
(
3
)
<<
"writer auto flush"
;
}
delete
writer
;
};
std
::
unique_ptr
<
framework
::
ChannelWriter
<
DataItem
>
,
decltype
(
deleter
)
>
writer
(
new
framework
::
ChannelWriter
<
DataItem
>
(
data_channel
.
get
()),
deleter
);
DataItem
data_item
;
DataItem
data_item
;
if
(
_buffer_size
<=
0
||
_buffer
==
nullptr
)
{
if
(
_buffer_size
<=
0
||
_buffer
==
nullptr
)
{
VLOG
(
2
)
<<
"no buffer"
;
VLOG
(
2
)
<<
"no buffer"
;
return
-
1
;
return
-
1
;
}
}
for
(
const
auto
&
filepath
:
data_file_list
(
data_dir
))
{
for
(
const
auto
&
filepath
:
data_file_list
(
data_dir
))
{
if
(
framework
::
fs_
path_split
(
filepath
).
second
==
_done_file_name
)
{
if
(
_file_system
->
path_split
(
filepath
).
second
==
_done_file_name
)
{
continue
;
continue
;
}
}
int
err_no
=
0
;
{
std
::
shared_ptr
<
FILE
>
fin
=
framework
::
fs_open_read
(
filepath
,
&
err_no
,
_pipeline_cmd
);
std
::
shared_ptr
<
FILE
>
fin
=
_file_system
->
open_read
(
filepath
,
_pipeline_cmd
);
if
(
err_no
!=
0
)
{
if
(
fin
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to open file: "
<<
filepath
<<
", with cmd: "
<<
_pipeline_cmd
;
VLOG
(
2
)
<<
"fail to open file: "
<<
filepath
<<
", with cmd: "
<<
_pipeline_cmd
;
return
-
1
;
return
-
1
;
}
}
while
(
fgets
(
_buffer
.
get
(),
_buffer_size
,
fin
.
get
()))
{
while
(
fgets
(
_buffer
.
get
(),
_buffer_size
,
fin
.
get
()))
{
if
(
_parser
->
parse
(
_buffer
.
get
(),
data_item
)
!=
0
)
{
if
(
_buffer
[
0
]
==
'\n'
)
{
continue
;
}
if
(
_parser
->
parse
(
_buffer
.
get
(),
data_item
)
!=
0
)
{
return
-
1
;
}
(
*
writer
)
<<
std
::
move
(
data_item
);
}
if
(
ferror
(
fin
.
get
())
!=
0
)
{
VLOG
(
2
)
<<
"fail to read file: "
<<
filepath
;
return
-
1
;
return
-
1
;
}
}
writer
<<
std
::
move
(
data_item
);
}
}
if
(
ferror
(
fin
.
get
())
!=
0
)
{
if
(
!
_file_system
)
{
VLOG
(
2
)
<<
"fail to read file: "
<<
filepath
;
_file_system
->
reset_err_no
()
;
return
-
1
;
return
-
1
;
}
}
}
}
writer
.
Flush
();
writer
->
Flush
();
if
(
!
writer
)
{
if
(
!
(
*
writer
)
)
{
VLOG
(
2
)
<<
"fail when write to channel"
;
VLOG
(
2
)
<<
"fail when write to channel"
;
return
-
1
;
return
-
1
;
}
}
...
@@ -155,14 +188,16 @@ public:
...
@@ -155,14 +188,16 @@ public:
virtual
const
DataParser
*
get_parser
()
{
virtual
const
DataParser
*
get_parser
()
{
return
_parser
.
get
();
return
_parser
.
get
();
}
}
private:
private:
std
::
string
_done_file_name
;
// without data_dir
std
::
string
_done_file_name
;
// without data_dir
int
_buffer_size
=
0
;
int
_buffer_size
=
0
;
std
::
unique_ptr
<
char
[]
>
_buffer
;
std
::
unique_ptr
<
char
[]
>
_buffer
;
std
::
string
_filename_prefix
;
std
::
string
_filename_prefix
;
std
::
unique_ptr
<
FileSystem
>
_file_system
;
};
};
REGISTER_CLASS
(
DataReader
,
LineDataReader
);
REGISTER_CLASS
(
DataReader
,
LineDataReader
);
}
//
namespace feed
}
//
namespace feed
}
//
namespace custom_trainer
}
//
namespace custom_trainer
}
//
namespace paddle
}
//
namespace paddle
paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc
浏览文件 @
379235f4
...
@@ -29,8 +29,8 @@ class AutoFileSystem : public FileSystem {
...
@@ -29,8 +29,8 @@ class AutoFileSystem : public FileSystem {
public:
public:
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
override
{
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
override
{
_file_system
.
clear
();
_file_system
.
clear
();
if
(
config
)
{
if
(
config
&&
config
[
"file_systems"
]
&&
config
[
"file_systems"
].
Type
()
==
YAML
::
NodeType
::
Map
)
{
for
(
auto
&
prefix_fs
:
config
)
{
for
(
auto
&
prefix_fs
:
config
[
"file_systems"
]
)
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
)));
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
)));
if
(
fs
==
nullptr
)
{
if
(
fs
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to create class: "
<<
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
);
VLOG
(
2
)
<<
"fail to create class: "
<<
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
);
...
...
paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
浏览文件 @
379235f4
...
@@ -16,9 +16,11 @@ limitations under the License. */
...
@@ -16,9 +16,11 @@ limitations under the License. */
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <tuple>
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/piece.h"
#include "glog/logging.h"
#include "glog/logging.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -31,8 +33,10 @@ public:
...
@@ -31,8 +33,10 @@ public:
_buffer_size
=
config
[
"buffer_size"
].
as
<
size_t
>
(
0
);
_buffer_size
=
config
[
"buffer_size"
].
as
<
size_t
>
(
0
);
_hdfs_command
=
config
[
"hdfs_command"
].
as
<
std
::
string
>
(
"hadoop fs"
);
_hdfs_command
=
config
[
"hdfs_command"
].
as
<
std
::
string
>
(
"hadoop fs"
);
_ugi
.
clear
();
_ugi
.
clear
();
for
(
const
auto
&
prefix_ugi
:
config
[
"ugi"
])
{
if
(
config
[
"ugis"
]
&&
config
[
"ugis"
].
Type
()
==
YAML
::
NodeType
::
Map
)
{
_ugi
.
emplace
(
prefix_ugi
.
first
.
as
<
std
::
string
>
(),
prefix_ugi
.
second
.
as
<
std
::
string
>
());
for
(
const
auto
&
prefix_ugi
:
config
[
"ugis"
])
{
_ugi
.
emplace
(
prefix_ugi
.
first
.
as
<
std
::
string
>
(),
prefix_ugi
.
second
.
as
<
std
::
string
>
());
}
}
}
if
(
_ugi
.
find
(
"default"
)
==
_ugi
.
end
())
{
if
(
_ugi
.
find
(
"default"
)
==
_ugi
.
end
())
{
VLOG
(
2
)
<<
"fail to load default ugi"
;
VLOG
(
2
)
<<
"fail to load default ugi"
;
...
@@ -48,8 +52,7 @@ public:
...
@@ -48,8 +52,7 @@ public:
cmd
=
string
::
format_string
(
cmd
=
string
::
format_string
(
"%s -text
\"
%s
\"
"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
());
"%s -text
\"
%s
\"
"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
());
}
else
{
}
else
{
cmd
=
string
::
format_string
(
cmd
=
string
::
format_string
(
"%s -cat
\"
%s
\"
"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
());
"%s -cat
\"
%s
\"
"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
());
}
}
bool
is_pipe
=
true
;
bool
is_pipe
=
true
;
...
@@ -59,7 +62,8 @@ public:
...
@@ -59,7 +62,8 @@ public:
std
::
shared_ptr
<
FILE
>
open_write
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
std
::
shared_ptr
<
FILE
>
open_write
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
override
{
override
{
std
::
string
cmd
=
string
::
format_string
(
"%s -put -
\"
%s
\"
"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
());
std
::
string
cmd
=
string
::
format_string
(
"%s -put -
\"
%s
\"
"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
());
bool
is_pipe
=
true
;
bool
is_pipe
=
true
;
if
(
string
::
end_with
(
path
,
".gz
\"
"
))
{
if
(
string
::
end_with
(
path
,
".gz
\"
"
))
{
...
@@ -89,12 +93,8 @@ public:
...
@@ -89,12 +93,8 @@ public:
if
(
path
==
""
)
{
if
(
path
==
""
)
{
return
{};
return
{};
}
}
auto
paths
=
_split_path
(
path
);
std
::
string
prefix
=
"hdfs:"
;
if
(
string
::
begin_with
(
path
,
"afs:"
))
{
prefix
=
"afs:"
;
}
int
err_no
=
0
;
int
err_no
=
0
;
std
::
vector
<
std
::
string
>
list
;
std
::
vector
<
std
::
string
>
list
;
do
{
do
{
...
@@ -115,7 +115,7 @@ public:
...
@@ -115,7 +115,7 @@ public:
if
(
line
.
size
()
!=
8
)
{
if
(
line
.
size
()
!=
8
)
{
continue
;
continue
;
}
}
list
.
push_back
(
prefix
+
line
[
7
]);
list
.
push_back
(
_get_prefix
(
paths
)
+
line
[
7
]);
}
}
}
while
(
err_no
==
-
1
);
}
while
(
err_no
==
-
1
);
return
list
;
return
list
;
...
@@ -146,30 +146,60 @@ public:
...
@@ -146,30 +146,60 @@ public:
return
;
return
;
}
}
shell_execute
(
shell_execute
(
string
::
format_string
(
string
::
format_string
(
"%s -mkdir %s; true"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
()));
"%s -mkdir %s; true"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
()));
}
}
std
::
string
hdfs_command
(
const
std
::
string
&
path
)
{
std
::
string
hdfs_command
(
const
std
::
string
&
path
)
{
auto
start_pos
=
path
.
find_first_of
(
':'
);
auto
paths
=
_split_path
(
path
);
auto
end_pos
=
path
.
find_first_of
(
'/'
);
auto
it
=
_ugi
.
find
(
std
::
get
<
1
>
(
paths
).
ToString
());
if
(
start_pos
!=
std
::
string
::
npos
&&
end_pos
!=
std
::
string
::
npos
&&
start_pos
<
end_pos
)
{
if
(
it
!=
_ugi
.
end
())
{
auto
fs_path
=
path
.
substr
(
start_pos
+
1
,
end_pos
-
start_pos
-
1
);
return
hdfs_command_with_ugi
(
it
->
second
);
auto
ugi_it
=
_ugi
.
find
(
fs_path
);
if
(
ugi_it
!=
_ugi
.
end
())
{
return
hdfs_command_with_ugi
(
ugi_it
->
second
);
}
}
}
VLOG
(
5
)
<<
"path: "
<<
path
<<
", select default ugi"
;
VLOG
(
5
)
<<
"path: "
<<
path
<<
", select default ugi"
;
return
hdfs_command_with_ugi
(
_ugi
[
"default"
]);
return
hdfs_command_with_ugi
(
_ugi
[
"default"
]);
}
}
std
::
string
hdfs_command_with_ugi
(
std
::
string
ugi
)
{
std
::
string
hdfs_command_with_ugi
(
std
::
string
ugi
)
{
return
string
::
format_string
(
"%s -Dhadoop.job.ugi=
\"
%s
\"
"
,
_hdfs_command
.
c_str
(),
ugi
.
c_str
());
return
string
::
format_string
(
"%s -Dhadoop.job.ugi=
\"
%s
\"
"
,
_hdfs_command
.
c_str
(),
ugi
.
c_str
());
}
}
private:
private:
std
::
string
_get_prefix
(
const
std
::
tuple
<
string
::
Piece
,
string
::
Piece
,
string
::
Piece
>&
paths
)
{
if
(
std
::
get
<
1
>
(
paths
).
len
()
==
0
)
{
return
std
::
get
<
0
>
(
paths
).
ToString
();
}
return
std
::
get
<
0
>
(
paths
).
ToString
()
+
"//"
+
std
::
get
<
1
>
(
paths
).
ToString
();
}
std
::
tuple
<
string
::
Piece
,
string
::
Piece
,
string
::
Piece
>
_split_path
(
string
::
Piece
path
)
{
// parse "xxx://abc.def:8756/user" as "xxx:", "abc.def:8756", "/user"
// parse "xxx:/user" as "xxx:", "", "/user"
// parse "xxx://abc.def:8756" as "xxx:", "abc.def:8756", ""
// parse "other" as "", "", "other"
std
::
tuple
<
string
::
Piece
,
string
::
Piece
,
string
::
Piece
>
result
{
string
::
SubStr
(
path
,
0
,
0
),
string
::
SubStr
(
path
,
0
,
0
),
path
};
auto
fs_pos
=
string
::
Find
(
path
,
':'
,
0
)
+
1
;
if
(
path
.
len
()
>
fs_pos
)
{
std
::
get
<
0
>
(
result
)
=
string
::
SubStr
(
path
,
0
,
fs_pos
);
path
=
string
::
SkipPrefix
(
path
,
fs_pos
);
if
(
string
::
HasPrefix
(
path
,
"//"
))
{
path
=
string
::
SkipPrefix
(
path
,
2
);
auto
end_pos
=
string
::
Find
(
path
,
'/'
,
0
);
if
(
end_pos
!=
string
::
Piece
::
npos
)
{
std
::
get
<
1
>
(
result
)
=
string
::
SubStr
(
path
,
0
,
end_pos
);
std
::
get
<
2
>
(
result
)
=
string
::
SkipPrefix
(
path
,
end_pos
);
}
else
{
std
::
get
<
1
>
(
result
)
=
path
;
}
}
else
{
std
::
get
<
2
>
(
result
)
=
path
;
}
}
return
result
;
}
size_t
_buffer_size
=
0
;
size_t
_buffer_size
=
0
;
std
::
string
_hdfs_command
;
std
::
string
_hdfs_command
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
_ugi
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
_ugi
;
...
...
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc
浏览文件 @
379235f4
...
@@ -19,7 +19,8 @@ limitations under the License. */
...
@@ -19,7 +19,8 @@ limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
...
@@ -31,50 +32,50 @@ namespace {
...
@@ -31,50 +32,50 @@ namespace {
const
char
test_data_dir
[]
=
"test_data"
;
const
char
test_data_dir
[]
=
"test_data"
;
}
}
class
DataReaderTest
:
public
testing
::
Test
class
DataReaderTest
:
public
testing
::
Test
{
{
public:
public:
static
void
SetUpTestCase
()
static
void
SetUpTestCase
()
{
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
f
ramework
::
shell_set_verbose
(
true
);
f
s
->
mkdir
(
test_data_dir
);
framework
::
localfs_mkdir
(
test_data_dir
);
shell_set_verbose
(
true
);
{
{
std
::
ofstream
fout
(
f
ramework
::
fs_
path_join
(
test_data_dir
,
"a.txt"
));
std
::
ofstream
fout
(
f
s
->
path_join
(
test_data_dir
,
"a.txt"
));
fout
<<
"abc 123456"
<<
std
::
endl
;
fout
<<
"abc 123456"
<<
std
::
endl
;
fout
<<
"def 234567"
<<
std
::
endl
;
fout
<<
"def 234567"
<<
std
::
endl
;
fout
.
close
();
fout
.
close
();
}
}
{
{
std
::
ofstream
fout
(
f
ramework
::
fs_
path_join
(
test_data_dir
,
"b.txt"
));
std
::
ofstream
fout
(
f
s
->
path_join
(
test_data_dir
,
"b.txt"
));
fout
<<
"ghi 345678"
<<
std
::
endl
;
fout
<<
"ghi 345678"
<<
std
::
endl
;
fout
<<
"jkl 456789"
<<
std
::
endl
;
fout
<<
"jkl 456789"
<<
std
::
endl
;
fout
.
close
();
fout
.
close
();
}
}
}
}
static
void
TearDownTestCase
()
static
void
TearDownTestCase
()
{
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
f
ramework
::
localfs_
remove
(
test_data_dir
);
f
s
->
remove
(
test_data_dir
);
}
}
virtual
void
SetUp
()
virtual
void
SetUp
()
{
{
fs
.
reset
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
context_ptr
.
reset
(
new
TrainerContext
());
context_ptr
.
reset
(
new
TrainerContext
());
}
}
virtual
void
TearDown
()
virtual
void
TearDown
()
{
{
fs
=
nullptr
;
context_ptr
=
nullptr
;
context_ptr
=
nullptr
;
}
}
std
::
shared_ptr
<
TrainerContext
>
context_ptr
;
std
::
shared_ptr
<
TrainerContext
>
context_ptr
;
std
::
unique_ptr
<
FileSystem
>
fs
;
};
};
TEST_F
(
DataReaderTest
,
LineDataParser
)
{
TEST_F
(
DataReaderTest
,
LineDataParser
)
{
std
::
unique_ptr
<
DataParser
>
data_parser
(
CREATE_CLASS
(
DataParser
,
"LineDataParser"
));
std
::
unique_ptr
<
DataParser
>
data_parser
(
CREATE_CLASS
(
DataParser
,
"LineDataParser"
));
ASSERT_NE
(
nullptr
,
data_parser
);
ASSERT_NE
(
nullptr
,
data_parser
);
auto
config
=
YAML
::
Load
(
""
);
auto
config
=
YAML
::
Load
(
""
);
...
@@ -105,11 +106,12 @@ TEST_F(DataReaderTest, LineDataReader) {
...
@@ -105,11 +106,12 @@ TEST_F(DataReaderTest, LineDataReader) {
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
auto
config
=
YAML
::
Load
(
" class: LineDataParser
\n
"
"parser:
\n
"
"pipeline_cmd: cat
\n
"
" class: LineDataParser
\n
"
"done_file: done_file
\n
"
"pipeline_cmd: cat
\n
"
"buffer_size: 128"
);
"done_file: done_file
\n
"
"buffer_size: 128"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
ASSERT_EQ
(
2
,
data_file_list
.
size
());
ASSERT_EQ
(
2
,
data_file_list
.
size
());
...
@@ -117,7 +119,7 @@ TEST_F(DataReaderTest, LineDataReader) {
...
@@ -117,7 +119,7 @@ TEST_F(DataReaderTest, LineDataReader) {
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"b.txt"
),
data_file_list
[
1
]);
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"b.txt"
),
data_file_list
[
1
]);
ASSERT_FALSE
(
data_reader
->
is_data_ready
(
test_data_dir
));
ASSERT_FALSE
(
data_reader
->
is_data_ready
(
test_data_dir
));
std
::
ofstream
fout
(
f
ramework
::
fs_
path_join
(
test_data_dir
,
"done_file"
));
std
::
ofstream
fout
(
f
s
->
path_join
(
test_data_dir
,
"done_file"
));
fout
<<
"done"
;
fout
<<
"done"
;
fout
.
close
();
fout
.
close
();
ASSERT_TRUE
(
data_reader
->
is_data_ready
(
test_data_dir
));
ASSERT_TRUE
(
data_reader
->
is_data_ready
(
test_data_dir
));
...
@@ -128,7 +130,7 @@ TEST_F(DataReaderTest, LineDataReader) {
...
@@ -128,7 +130,7 @@ TEST_F(DataReaderTest, LineDataReader) {
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
DataItem
data_item
;
DataItem
data_item
;
reader
>>
data_item
;
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"abc"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"abc"
,
data_item
.
id
.
c_str
());
...
@@ -156,23 +158,24 @@ TEST_F(DataReaderTest, LineDataReader) {
...
@@ -156,23 +158,24 @@ TEST_F(DataReaderTest, LineDataReader) {
TEST_F
(
DataReaderTest
,
LineDataReader_filename_prefix
)
{
TEST_F
(
DataReaderTest
,
LineDataReader_filename_prefix
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
auto
config
=
YAML
::
Load
(
" class: LineDataParser
\n
"
"parser:
\n
"
"pipeline_cmd: cat
\n
"
" class: LineDataParser
\n
"
"done_file: done_file
\n
"
"pipeline_cmd: cat
\n
"
"filename_prefix: a"
);
"done_file: done_file
\n
"
"filename_prefix: a"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
ASSERT_EQ
(
1
,
data_file_list
.
size
());
ASSERT_EQ
(
1
,
data_file_list
.
size
());
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"a.txt"
),
data_file_list
[
0
]);
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"a.txt"
),
data_file_list
[
0
]);
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
ASSERT_NE
(
nullptr
,
channel
);
ASSERT_NE
(
nullptr
,
channel
);
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_data_dir
,
channel
));
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_data_dir
,
channel
));
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
DataItem
data_item
;
DataItem
data_item
;
reader
>>
data_item
;
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"abc"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"abc"
,
data_item
.
id
.
c_str
());
...
@@ -187,6 +190,84 @@ TEST_F(DataReaderTest, LineDataReader_filename_prefix) {
...
@@ -187,6 +190,84 @@ TEST_F(DataReaderTest, LineDataReader_filename_prefix) {
ASSERT_FALSE
(
reader
);
ASSERT_FALSE
(
reader
);
}
}
TEST_F
(
DataReaderTest
,
LineDataReader_FileSystem
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
" class: LineDataParser
\n
"
"pipeline_cmd: cat
\n
"
"done_file: done_file
\n
"
"filename_prefix: a
\n
"
"file_system:
\n
"
" class: AutoFileSystem
\n
"
" file_systems:
\n
"
" 'afs:': &HDFS
\n
"
" class: HadoopFileSystem
\n
"
" hdfs_command: 'hadoop fs'
\n
"
" ugis:
\n
"
" 'default': 'feed_video,D3a0z8'
\n
"
" 'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8'
\n
"
"
\n
"
" 'hdfs:': *HDFS
\n
"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
{
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
ASSERT_EQ
(
1
,
data_file_list
.
size
());
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"a.txt"
),
data_file_list
[
0
]);
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
ASSERT_NE
(
nullptr
,
channel
);
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_data_dir
,
channel
));
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
DataItem
data_item
;
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"abc"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"123456"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"def"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"234567"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_FALSE
(
reader
);
}
{
char
test_hadoop_dir
[]
=
"afs://xingtian.afs.baidu.com:9902/user/feed_video/user/rensilin/paddle_trainer_test_dir"
;
ASSERT_TRUE
(
data_reader
->
is_data_ready
(
test_hadoop_dir
));
auto
data_file_list
=
data_reader
->
data_file_list
(
test_hadoop_dir
);
ASSERT_EQ
(
1
,
data_file_list
.
size
());
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_hadoop_dir
,
"a.txt"
),
data_file_list
[
0
]);
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
ASSERT_NE
(
nullptr
,
channel
);
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_hadoop_dir
,
channel
));
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
DataItem
data_item
;
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"hello"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"world"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"hello"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"hadoop"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_FALSE
(
reader
);
}
}
}
// namespace feed
}
// namespace feed
}
// namespace custom_trainer
}
// namespace custom_trainer
}
// namespace paddle
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc
浏览文件 @
379235f4
...
@@ -19,7 +19,8 @@ limitations under the License. */
...
@@ -19,7 +19,8 @@ limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/string_helper.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -37,7 +38,9 @@ class SimpleExecutorTest : public testing::Test
...
@@ -37,7 +38,9 @@ class SimpleExecutorTest : public testing::Test
public:
public:
static
void
SetUpTestCase
()
static
void
SetUpTestCase
()
{
{
::
paddle
::
framework
::
localfs_mkdir
(
test_data_dir
);
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
fs
->
mkdir
(
test_data_dir
);
shell_set_verbose
(
true
);
{
{
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
startup_program
(
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
startup_program
(
...
@@ -67,7 +70,8 @@ public:
...
@@ -67,7 +70,8 @@ public:
static
void
TearDownTestCase
()
static
void
TearDownTestCase
()
{
{
::
paddle
::
framework
::
localfs_remove
(
test_data_dir
);
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
fs
->
remove
(
test_data_dir
);
}
}
virtual
void
SetUp
()
virtual
void
SetUp
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录