Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleRec
提交
379235f4
P
PaddleRec
项目概览
PaddlePaddle
/
PaddleRec
通知
68
Star
12
Fork
5
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
27
列表
看板
标记
里程碑
合并请求
10
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
27
Issue
27
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
379235f4
编写于
8月 06, 2019
作者:
R
rensilin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
file_system_ut
Change-Id: I96c0bc535a0f49a92e8b987df5cc06c5eca4758e
上级
501c9f25
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
238 addition
and
88 deletion
+238
-88
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
...le/fluid/train/custom_trainer/feed/dataset/data_reader.cc
+64
-29
paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc
...le/fluid/train/custom_trainer/feed/io/auto_file_system.cc
+2
-2
paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
.../fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
+53
-23
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc
...id/train/custom_trainer/feed/unit_test/test_datareader.cc
+112
-31
paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc
...luid/train/custom_trainer/feed/unit_test/test_executor.cc
+7
-3
未找到文件。
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
浏览文件 @
379235f4
...
...
@@ -4,13 +4,13 @@
#include <glog/logging.h>
#include "paddle/fluid/
framework/io/fs
.h"
#include "paddle/fluid/
train/custom_trainer/feed/io/file_system
.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
LineDataParser
:
public
DataParser
{
class
LineDataParser
:
public
DataParser
{
public:
LineDataParser
()
{}
...
...
@@ -29,7 +29,7 @@ public:
VLOG
(
2
)
<<
"fail to parse line: "
<<
std
::
string
(
str
,
len
)
<<
", strlen: "
<<
len
;
return
-
1
;
}
VLOG
(
5
)
<<
"getline: "
<<
str
<<
" , pos: "
<<
pos
<<
", len: "
<<
len
;
VLOG
(
5
)
<<
"getline: "
<<
str
<<
" , pos: "
<<
pos
<<
", len: "
<<
len
;
data
.
id
.
assign
(
str
,
pos
);
data
.
data
.
assign
(
str
+
pos
+
1
,
len
-
pos
-
1
);
if
(
!
data
.
data
.
empty
()
&&
data
.
data
.
back
()
==
'\n'
)
{
...
...
@@ -47,7 +47,7 @@ public:
VLOG
(
2
)
<<
"fail to parse line: "
<<
str
<<
", get '
\\
0' at pos: "
<<
pos
;
return
-
1
;
}
VLOG
(
5
)
<<
"getline: "
<<
str
<<
" , pos: "
<<
pos
;
VLOG
(
5
)
<<
"getline: "
<<
str
<<
" , pos: "
<<
pos
;
data
.
id
.
assign
(
str
,
pos
);
data
.
data
.
assign
(
str
+
pos
+
1
);
if
(
!
data
.
data
.
empty
()
&&
data
.
data
.
back
()
==
'\n'
)
{
...
...
@@ -88,13 +88,30 @@ public:
_buffer_size
=
config
[
"buffer_size"
].
as
<
int
>
(
1024
);
_filename_prefix
=
config
[
"filename_prefix"
].
as
<
std
::
string
>
(
""
);
_buffer
.
reset
(
new
char
[
_buffer_size
]);
if
(
config
[
"file_system"
]
&&
config
[
"file_system"
][
"class"
])
{
_file_system
.
reset
(
CREATE_CLASS
(
FileSystem
,
config
[
"file_system"
][
"class"
].
as
<
std
::
string
>
()));
if
(
_file_system
==
nullptr
||
_file_system
->
initialize
(
config
[
"file_system"
],
context
)
!=
0
)
{
VLOG
(
2
)
<<
"fail to create class: "
<<
config
[
"file_system"
][
"class"
].
as
<
std
::
string
>
();
return
-
1
;
}
}
else
{
_file_system
.
reset
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
if
(
_file_system
==
nullptr
||
_file_system
->
initialize
(
YAML
::
Load
(
""
),
context
)
!=
0
)
{
VLOG
(
2
)
<<
"fail to init file system"
;
return
-
1
;
}
}
return
0
;
}
//判断样本数据是否已就绪,就绪表明可以开始download
virtual
bool
is_data_ready
(
const
std
::
string
&
data_dir
)
{
auto
done_file_path
=
framework
::
fs_
path_join
(
data_dir
,
_done_file_name
);
if
(
framework
::
fs_
exists
(
done_file_path
))
{
auto
done_file_path
=
_file_system
->
path_join
(
data_dir
,
_done_file_name
);
if
(
_file_system
->
exists
(
done_file_path
))
{
return
true
;
}
return
false
;
...
...
@@ -102,12 +119,13 @@ public:
virtual
std
::
vector
<
std
::
string
>
data_file_list
(
const
std
::
string
&
data_dir
)
{
if
(
_filename_prefix
.
empty
())
{
return
framework
::
fs_
list
(
data_dir
);
return
_file_system
->
list
(
data_dir
);
}
std
::
vector
<
std
::
string
>
data_files
;
for
(
auto
&
filepath
:
framework
::
fs_list
(
data_dir
))
{
auto
filename
=
framework
::
fs_path_split
(
filepath
).
second
;
if
(
filename
.
size
()
>=
_filename_prefix
.
size
()
&&
filename
.
substr
(
0
,
_filename_prefix
.
size
())
==
_filename_prefix
)
{
for
(
auto
&
filepath
:
_file_system
->
list
(
data_dir
))
{
auto
filename
=
_file_system
->
path_split
(
filepath
).
second
;
if
(
filename
.
size
()
>=
_filename_prefix
.
size
()
&&
filename
.
substr
(
0
,
_filename_prefix
.
size
())
==
_filename_prefix
)
{
data_files
.
push_back
(
std
::
move
(
filepath
));
}
}
...
...
@@ -116,35 +134,50 @@ public:
//读取数据样本流中
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
framework
::
Channel
<
DataItem
>
data_channel
)
{
framework
::
ChannelWriter
<
DataItem
>
writer
(
data_channel
.
get
());
auto
deleter
=
[](
framework
::
ChannelWriter
<
DataItem
>
*
writer
)
{
if
(
writer
)
{
writer
->
Flush
();
VLOG
(
3
)
<<
"writer auto flush"
;
}
delete
writer
;
};
std
::
unique_ptr
<
framework
::
ChannelWriter
<
DataItem
>
,
decltype
(
deleter
)
>
writer
(
new
framework
::
ChannelWriter
<
DataItem
>
(
data_channel
.
get
()),
deleter
);
DataItem
data_item
;
if
(
_buffer_size
<=
0
||
_buffer
==
nullptr
)
{
VLOG
(
2
)
<<
"no buffer"
;
return
-
1
;
}
for
(
const
auto
&
filepath
:
data_file_list
(
data_dir
))
{
if
(
framework
::
fs_
path_split
(
filepath
).
second
==
_done_file_name
)
{
if
(
_file_system
->
path_split
(
filepath
).
second
==
_done_file_name
)
{
continue
;
}
int
err_no
=
0
;
std
::
shared_ptr
<
FILE
>
fin
=
framework
::
fs_open_read
(
filepath
,
&
err_no
,
_pipeline_cmd
);
if
(
err_no
!=
0
)
{
VLOG
(
2
)
<<
"fail to open file: "
<<
filepath
<<
", with cmd: "
<<
_pipeline_cmd
;
return
-
1
;
}
while
(
fgets
(
_buffer
.
get
(),
_buffer_size
,
fin
.
get
()))
{
if
(
_parser
->
parse
(
_buffer
.
get
(),
data_item
)
!=
0
)
{
{
std
::
shared_ptr
<
FILE
>
fin
=
_file_system
->
open_read
(
filepath
,
_pipeline_cmd
);
if
(
fin
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to open file: "
<<
filepath
<<
", with cmd: "
<<
_pipeline_cmd
;
return
-
1
;
}
while
(
fgets
(
_buffer
.
get
(),
_buffer_size
,
fin
.
get
()))
{
if
(
_buffer
[
0
]
==
'\n'
)
{
continue
;
}
if
(
_parser
->
parse
(
_buffer
.
get
(),
data_item
)
!=
0
)
{
return
-
1
;
}
(
*
writer
)
<<
std
::
move
(
data_item
);
}
if
(
ferror
(
fin
.
get
())
!=
0
)
{
VLOG
(
2
)
<<
"fail to read file: "
<<
filepath
;
return
-
1
;
}
writer
<<
std
::
move
(
data_item
);
}
if
(
ferror
(
fin
.
get
())
!=
0
)
{
VLOG
(
2
)
<<
"fail to read file: "
<<
filepath
;
if
(
!
_file_system
)
{
_file_system
->
reset_err_no
()
;
return
-
1
;
}
}
writer
.
Flush
();
if
(
!
writer
)
{
writer
->
Flush
();
if
(
!
(
*
writer
)
)
{
VLOG
(
2
)
<<
"fail when write to channel"
;
return
-
1
;
}
...
...
@@ -155,14 +188,16 @@ public:
virtual
const
DataParser
*
get_parser
()
{
return
_parser
.
get
();
}
private:
std
::
string
_done_file_name
;
// without data_dir
std
::
string
_done_file_name
;
// without data_dir
int
_buffer_size
=
0
;
std
::
unique_ptr
<
char
[]
>
_buffer
;
std
::
string
_filename_prefix
;
std
::
unique_ptr
<
FileSystem
>
_file_system
;
};
REGISTER_CLASS
(
DataReader
,
LineDataReader
);
}
//
namespace feed
}
//
namespace custom_trainer
}
//
namespace paddle
}
//
namespace feed
}
//
namespace custom_trainer
}
//
namespace paddle
paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc
浏览文件 @
379235f4
...
...
@@ -29,8 +29,8 @@ class AutoFileSystem : public FileSystem {
public:
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
override
{
_file_system
.
clear
();
if
(
config
)
{
for
(
auto
&
prefix_fs
:
config
)
{
if
(
config
&&
config
[
"file_systems"
]
&&
config
[
"file_systems"
].
Type
()
==
YAML
::
NodeType
::
Map
)
{
for
(
auto
&
prefix_fs
:
config
[
"file_systems"
]
)
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
)));
if
(
fs
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to create class: "
<<
prefix_fs
.
second
[
"class"
].
as
<
std
::
string
>
(
""
);
...
...
paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc
浏览文件 @
379235f4
...
...
@@ -16,9 +16,11 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <tuple>
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/piece.h"
#include "glog/logging.h"
namespace
paddle
{
...
...
@@ -31,8 +33,10 @@ public:
_buffer_size
=
config
[
"buffer_size"
].
as
<
size_t
>
(
0
);
_hdfs_command
=
config
[
"hdfs_command"
].
as
<
std
::
string
>
(
"hadoop fs"
);
_ugi
.
clear
();
for
(
const
auto
&
prefix_ugi
:
config
[
"ugi"
])
{
_ugi
.
emplace
(
prefix_ugi
.
first
.
as
<
std
::
string
>
(),
prefix_ugi
.
second
.
as
<
std
::
string
>
());
if
(
config
[
"ugis"
]
&&
config
[
"ugis"
].
Type
()
==
YAML
::
NodeType
::
Map
)
{
for
(
const
auto
&
prefix_ugi
:
config
[
"ugis"
])
{
_ugi
.
emplace
(
prefix_ugi
.
first
.
as
<
std
::
string
>
(),
prefix_ugi
.
second
.
as
<
std
::
string
>
());
}
}
if
(
_ugi
.
find
(
"default"
)
==
_ugi
.
end
())
{
VLOG
(
2
)
<<
"fail to load default ugi"
;
...
...
@@ -48,8 +52,7 @@ public:
cmd
=
string
::
format_string
(
"%s -text
\"
%s
\"
"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
());
}
else
{
cmd
=
string
::
format_string
(
"%s -cat
\"
%s
\"
"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
());
cmd
=
string
::
format_string
(
"%s -cat
\"
%s
\"
"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
());
}
bool
is_pipe
=
true
;
...
...
@@ -59,7 +62,8 @@ public:
std
::
shared_ptr
<
FILE
>
open_write
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
override
{
std
::
string
cmd
=
string
::
format_string
(
"%s -put -
\"
%s
\"
"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
());
std
::
string
cmd
=
string
::
format_string
(
"%s -put -
\"
%s
\"
"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
());
bool
is_pipe
=
true
;
if
(
string
::
end_with
(
path
,
".gz
\"
"
))
{
...
...
@@ -89,12 +93,8 @@ public:
if
(
path
==
""
)
{
return
{};
}
auto
paths
=
_split_path
(
path
);
std
::
string
prefix
=
"hdfs:"
;
if
(
string
::
begin_with
(
path
,
"afs:"
))
{
prefix
=
"afs:"
;
}
int
err_no
=
0
;
std
::
vector
<
std
::
string
>
list
;
do
{
...
...
@@ -115,7 +115,7 @@ public:
if
(
line
.
size
()
!=
8
)
{
continue
;
}
list
.
push_back
(
prefix
+
line
[
7
]);
list
.
push_back
(
_get_prefix
(
paths
)
+
line
[
7
]);
}
}
while
(
err_no
==
-
1
);
return
list
;
...
...
@@ -146,30 +146,60 @@ public:
return
;
}
shell_execute
(
string
::
format_string
(
"%s -mkdir %s; true"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
()));
shell_execute
(
string
::
format_string
(
"%s -mkdir %s; true"
,
hdfs_command
(
path
).
c_str
(),
path
.
c_str
()));
}
std
::
string
hdfs_command
(
const
std
::
string
&
path
)
{
auto
start_pos
=
path
.
find_first_of
(
':'
);
auto
end_pos
=
path
.
find_first_of
(
'/'
);
if
(
start_pos
!=
std
::
string
::
npos
&&
end_pos
!=
std
::
string
::
npos
&&
start_pos
<
end_pos
)
{
auto
fs_path
=
path
.
substr
(
start_pos
+
1
,
end_pos
-
start_pos
-
1
);
auto
ugi_it
=
_ugi
.
find
(
fs_path
);
if
(
ugi_it
!=
_ugi
.
end
())
{
return
hdfs_command_with_ugi
(
ugi_it
->
second
);
}
auto
paths
=
_split_path
(
path
);
auto
it
=
_ugi
.
find
(
std
::
get
<
1
>
(
paths
).
ToString
());
if
(
it
!=
_ugi
.
end
())
{
return
hdfs_command_with_ugi
(
it
->
second
);
}
VLOG
(
5
)
<<
"path: "
<<
path
<<
", select default ugi"
;
return
hdfs_command_with_ugi
(
_ugi
[
"default"
]);
}
std
::
string
hdfs_command_with_ugi
(
std
::
string
ugi
)
{
return
string
::
format_string
(
"%s -Dhadoop.job.ugi=
\"
%s
\"
"
,
_hdfs_command
.
c_str
(),
ugi
.
c_str
());
return
string
::
format_string
(
"%s -Dhadoop.job.ugi=
\"
%s
\"
"
,
_hdfs_command
.
c_str
(),
ugi
.
c_str
());
}
private:
std
::
string
_get_prefix
(
const
std
::
tuple
<
string
::
Piece
,
string
::
Piece
,
string
::
Piece
>&
paths
)
{
if
(
std
::
get
<
1
>
(
paths
).
len
()
==
0
)
{
return
std
::
get
<
0
>
(
paths
).
ToString
();
}
return
std
::
get
<
0
>
(
paths
).
ToString
()
+
"//"
+
std
::
get
<
1
>
(
paths
).
ToString
();
}
std
::
tuple
<
string
::
Piece
,
string
::
Piece
,
string
::
Piece
>
_split_path
(
string
::
Piece
path
)
{
// parse "xxx://abc.def:8756/user" as "xxx:", "abc.def:8756", "/user"
// parse "xxx:/user" as "xxx:", "", "/user"
// parse "xxx://abc.def:8756" as "xxx:", "abc.def:8756", ""
// parse "other" as "", "", "other"
std
::
tuple
<
string
::
Piece
,
string
::
Piece
,
string
::
Piece
>
result
{
string
::
SubStr
(
path
,
0
,
0
),
string
::
SubStr
(
path
,
0
,
0
),
path
};
auto
fs_pos
=
string
::
Find
(
path
,
':'
,
0
)
+
1
;
if
(
path
.
len
()
>
fs_pos
)
{
std
::
get
<
0
>
(
result
)
=
string
::
SubStr
(
path
,
0
,
fs_pos
);
path
=
string
::
SkipPrefix
(
path
,
fs_pos
);
if
(
string
::
HasPrefix
(
path
,
"//"
))
{
path
=
string
::
SkipPrefix
(
path
,
2
);
auto
end_pos
=
string
::
Find
(
path
,
'/'
,
0
);
if
(
end_pos
!=
string
::
Piece
::
npos
)
{
std
::
get
<
1
>
(
result
)
=
string
::
SubStr
(
path
,
0
,
end_pos
);
std
::
get
<
2
>
(
result
)
=
string
::
SkipPrefix
(
path
,
end_pos
);
}
else
{
std
::
get
<
1
>
(
result
)
=
path
;
}
}
else
{
std
::
get
<
2
>
(
result
)
=
path
;
}
}
return
result
;
}
size_t
_buffer_size
=
0
;
std
::
string
_hdfs_command
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
_ugi
;
...
...
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc
浏览文件 @
379235f4
...
...
@@ -19,7 +19,8 @@ limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
...
...
@@ -31,50 +32,50 @@ namespace {
const
char
test_data_dir
[]
=
"test_data"
;
}
class
DataReaderTest
:
public
testing
::
Test
{
class
DataReaderTest
:
public
testing
::
Test
{
public:
static
void
SetUpTestCase
()
{
f
ramework
::
shell_set_verbose
(
true
);
framework
::
localfs_mkdir
(
test_data_dir
);
static
void
SetUpTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
f
s
->
mkdir
(
test_data_dir
);
shell_set_verbose
(
true
);
{
std
::
ofstream
fout
(
f
ramework
::
fs_
path_join
(
test_data_dir
,
"a.txt"
));
std
::
ofstream
fout
(
f
s
->
path_join
(
test_data_dir
,
"a.txt"
));
fout
<<
"abc 123456"
<<
std
::
endl
;
fout
<<
"def 234567"
<<
std
::
endl
;
fout
.
close
();
}
{
std
::
ofstream
fout
(
f
ramework
::
fs_
path_join
(
test_data_dir
,
"b.txt"
));
std
::
ofstream
fout
(
f
s
->
path_join
(
test_data_dir
,
"b.txt"
));
fout
<<
"ghi 345678"
<<
std
::
endl
;
fout
<<
"jkl 456789"
<<
std
::
endl
;
fout
.
close
();
}
}
static
void
TearDownTestCase
()
{
f
ramework
::
localfs_
remove
(
test_data_dir
);
static
void
TearDownTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
f
s
->
remove
(
test_data_dir
);
}
virtual
void
SetUp
()
{
virtual
void
SetUp
()
{
fs
.
reset
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
context_ptr
.
reset
(
new
TrainerContext
());
}
virtual
void
TearDown
()
{
virtual
void
TearDown
()
{
fs
=
nullptr
;
context_ptr
=
nullptr
;
}
std
::
shared_ptr
<
TrainerContext
>
context_ptr
;
std
::
unique_ptr
<
FileSystem
>
fs
;
};
TEST_F
(
DataReaderTest
,
LineDataParser
)
{
std
::
unique_ptr
<
DataParser
>
data_parser
(
CREATE_CLASS
(
DataParser
,
"LineDataParser"
));
ASSERT_NE
(
nullptr
,
data_parser
);
auto
config
=
YAML
::
Load
(
""
);
...
...
@@ -105,11 +106,12 @@ TEST_F(DataReaderTest, LineDataReader) {
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
" class: LineDataParser
\n
"
"pipeline_cmd: cat
\n
"
"done_file: done_file
\n
"
"buffer_size: 128"
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
" class: LineDataParser
\n
"
"pipeline_cmd: cat
\n
"
"done_file: done_file
\n
"
"buffer_size: 128"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
ASSERT_EQ
(
2
,
data_file_list
.
size
());
...
...
@@ -117,7 +119,7 @@ TEST_F(DataReaderTest, LineDataReader) {
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"b.txt"
),
data_file_list
[
1
]);
ASSERT_FALSE
(
data_reader
->
is_data_ready
(
test_data_dir
));
std
::
ofstream
fout
(
f
ramework
::
fs_
path_join
(
test_data_dir
,
"done_file"
));
std
::
ofstream
fout
(
f
s
->
path_join
(
test_data_dir
,
"done_file"
));
fout
<<
"done"
;
fout
.
close
();
ASSERT_TRUE
(
data_reader
->
is_data_ready
(
test_data_dir
));
...
...
@@ -128,7 +130,7 @@ TEST_F(DataReaderTest, LineDataReader) {
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
DataItem
data_item
;
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"abc"
,
data_item
.
id
.
c_str
());
...
...
@@ -156,23 +158,24 @@ TEST_F(DataReaderTest, LineDataReader) {
TEST_F
(
DataReaderTest
,
LineDataReader_filename_prefix
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
" class: LineDataParser
\n
"
"pipeline_cmd: cat
\n
"
"done_file: done_file
\n
"
"filename_prefix: a"
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
" class: LineDataParser
\n
"
"pipeline_cmd: cat
\n
"
"done_file: done_file
\n
"
"filename_prefix: a"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
ASSERT_EQ
(
1
,
data_file_list
.
size
());
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"a.txt"
),
data_file_list
[
0
]);
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
ASSERT_NE
(
nullptr
,
channel
);
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_data_dir
,
channel
));
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
DataItem
data_item
;
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"abc"
,
data_item
.
id
.
c_str
());
...
...
@@ -187,6 +190,84 @@ TEST_F(DataReaderTest, LineDataReader_filename_prefix) {
ASSERT_FALSE
(
reader
);
}
TEST_F
(
DataReaderTest
,
LineDataReader_FileSystem
)
{
std
::
unique_ptr
<
DataReader
>
data_reader
(
CREATE_CLASS
(
DataReader
,
"LineDataReader"
));
ASSERT_NE
(
nullptr
,
data_reader
);
auto
config
=
YAML
::
Load
(
"parser:
\n
"
" class: LineDataParser
\n
"
"pipeline_cmd: cat
\n
"
"done_file: done_file
\n
"
"filename_prefix: a
\n
"
"file_system:
\n
"
" class: AutoFileSystem
\n
"
" file_systems:
\n
"
" 'afs:': &HDFS
\n
"
" class: HadoopFileSystem
\n
"
" hdfs_command: 'hadoop fs'
\n
"
" ugis:
\n
"
" 'default': 'feed_video,D3a0z8'
\n
"
" 'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8'
\n
"
"
\n
"
" 'hdfs:': *HDFS
\n
"
);
ASSERT_EQ
(
0
,
data_reader
->
initialize
(
config
,
context_ptr
));
{
auto
data_file_list
=
data_reader
->
data_file_list
(
test_data_dir
);
ASSERT_EQ
(
1
,
data_file_list
.
size
());
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_data_dir
,
"a.txt"
),
data_file_list
[
0
]);
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
ASSERT_NE
(
nullptr
,
channel
);
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_data_dir
,
channel
));
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
DataItem
data_item
;
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"abc"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"123456"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"def"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"234567"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_FALSE
(
reader
);
}
{
char
test_hadoop_dir
[]
=
"afs://xingtian.afs.baidu.com:9902/user/feed_video/user/rensilin/paddle_trainer_test_dir"
;
ASSERT_TRUE
(
data_reader
->
is_data_ready
(
test_hadoop_dir
));
auto
data_file_list
=
data_reader
->
data_file_list
(
test_hadoop_dir
);
ASSERT_EQ
(
1
,
data_file_list
.
size
());
ASSERT_EQ
(
string
::
format_string
(
"%s/%s"
,
test_hadoop_dir
,
"a.txt"
),
data_file_list
[
0
]);
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
ASSERT_NE
(
nullptr
,
channel
);
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_hadoop_dir
,
channel
));
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
DataItem
data_item
;
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"hello"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"world"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_TRUE
(
reader
);
ASSERT_STREQ
(
"hello"
,
data_item
.
id
.
c_str
());
ASSERT_STREQ
(
"hadoop"
,
data_item
.
data
.
c_str
());
reader
>>
data_item
;
ASSERT_FALSE
(
reader
);
}
}
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc
浏览文件 @
379235f4
...
...
@@ -19,7 +19,8 @@ limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
namespace
paddle
{
...
...
@@ -37,7 +38,9 @@ class SimpleExecutorTest : public testing::Test
public:
static
void
SetUpTestCase
()
{
::
paddle
::
framework
::
localfs_mkdir
(
test_data_dir
);
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
fs
->
mkdir
(
test_data_dir
);
shell_set_verbose
(
true
);
{
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
startup_program
(
...
...
@@ -67,7 +70,8 @@ public:
static
void
TearDownTestCase
()
{
::
paddle
::
framework
::
localfs_remove
(
test_data_dir
);
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
fs
->
remove
(
test_data_dir
);
}
virtual
void
SetUp
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录