Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
db4a4db0
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
db4a4db0
编写于
8月 16, 2019
作者:
R
rensilin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fs_bug
Change-Id: I7e92af98dc56e18b79640f070b13a26f9c94ab52
上级
6a15698f
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
84 addition
and
65 deletion
+84
-65
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
...le/fluid/train/custom_trainer/feed/dataset/data_reader.cc
+57
-40
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc
...rain/custom_trainer/feed/unit_test/test_datareader_omp.cc
+27
-25
未找到文件。
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc
浏览文件 @
db4a4db0
...
...
@@ -131,23 +131,28 @@ public:
return
read_all
(
file_list
,
data_channel
);
}
virtual
int
read_all
(
const
std
::
vector
<
std
::
string
>&
file_list
,
::
paddle
::
framework
::
Channel
<
DataItem
>
data_channel
)
{
auto
deleter
=
[](
framework
::
ChannelWriter
<
DataItem
>
*
writer
)
{
if
(
writer
)
{
writer
->
Flush
();
VLOG
(
3
)
<<
"writer auto flush"
;
}
delete
writer
;
};
std
::
unique_ptr
<
framework
::
ChannelWriter
<
DataItem
>
,
decltype
(
deleter
)
>
writer
(
new
framework
::
ChannelWriter
<
DataItem
>
(
data_channel
.
get
()),
deleter
);
DataItem
data_item
;
int
file_list_size
=
file_list
.
size
();
const
int
file_list_size
=
file_list
.
size
();
std
::
atomic
<
bool
>
is_failed
(
false
);
const
int
max_threads
=
omp_get_max_threads
();
std
::
vector
<
framework
::
ChannelWriter
<
DataItem
>>
writers
;
// writer is not thread safe
writers
.
reserve
(
max_threads
);
for
(
int
i
=
0
;
i
<
max_threads
;
++
i
)
{
writers
.
emplace_back
(
data_channel
.
get
());
}
VLOG
(
5
)
<<
"file_list: "
<<
string
::
join_strings
(
file_list
,
' '
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
file_list_size
;
++
i
)
{
if
(
is_failed
)
{
continue
;
}
const
int
thread_num
=
omp_get_thread_num
();
framework
::
ChannelWriter
<
DataItem
>
*
writer
=
nullptr
;
if
(
thread_num
<
max_threads
)
{
writer
=
&
writers
[
thread_num
];
}
const
auto
&
filepath
=
file_list
[
i
];
if
(
!
is_failed
)
{
std
::
shared_ptr
<
FILE
>
fin
=
_file_system
->
open_read
(
filepath
,
_pipeline_cmd
);
if
(
fin
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to open file: "
<<
filepath
<<
", with cmd: "
<<
_pipeline_cmd
;
...
...
@@ -158,16 +163,25 @@ public:
size_t
buffer_size
=
0
;
ssize_t
line_len
=
0
;
while
((
line_len
=
getline
(
&
buffer
,
&
buffer_size
,
fin
.
get
()))
!=
-
1
)
{
// 去掉行位回车
if
(
line_len
>
0
&&
buffer
[
line_len
-
1
]
==
'\n'
)
{
buffer
[
--
line_len
]
=
'\0'
;
}
// 忽略空行
if
(
line_len
<=
0
)
{
continue
;
}
if
(
_parser
->
parse
(
buffer
,
line_len
,
data_item
)
==
0
)
{
VLOG
(
5
)
<<
"parse data: "
<<
data_item
.
id
<<
" "
<<
data_item
.
data
<<
", filename: "
<<
filepath
<<
", thread_num: "
<<
thread_num
<<
", max_threads: "
<<
max_threads
;
if
(
writer
==
nullptr
)
{
if
(
!
data_channel
->
Put
(
std
::
move
(
data_item
)))
{
VLOG
(
2
)
<<
"fail to put data, thread_num: "
<<
thread_num
;
}
}
else
{
(
*
writer
)
<<
std
::
move
(
data_item
);
}
}
}
if
(
buffer
!=
nullptr
)
{
free
(
buffer
);
buffer
=
nullptr
;
...
...
@@ -178,18 +192,21 @@ public:
is_failed
=
true
;
continue
;
}
}
if
(
_file_system
->
err_no
()
!=
0
)
{
_file_system
->
reset_err_no
();
is_failed
=
true
;
continue
;
}
}
writer
->
Flush
();
if
(
!
(
*
writer
))
{
VLOG
(
2
)
<<
"fail when write to channel"
;
// omp end
for
(
int
i
=
0
;
i
<
max_threads
;
++
i
)
{
writers
[
i
].
Flush
();
if
(
!
writers
[
i
])
{
VLOG
(
2
)
<<
"writer "
<<
i
<<
" is failed"
;
is_failed
=
true
;
}
}
data_channel
->
Close
();
return
is_failed
?
-
1
:
0
;
}
...
...
paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc
浏览文件 @
db4a4db0
...
...
@@ -38,6 +38,9 @@ class DataReaderOmpTest : public testing::Test {
public:
static
void
SetUpTestCase
()
{
std
::
unique_ptr
<
FileSystem
>
fs
(
CREATE_CLASS
(
FileSystem
,
"LocalFileSystem"
));
if
(
fs
->
exists
(
test_data_dir
))
{
fs
->
remove
(
test_data_dir
);
}
fs
->
mkdir
(
test_data_dir
);
shell_set_verbose
(
true
);
std_items
.
clear
();
...
...
@@ -92,11 +95,12 @@ public:
}
static
void
read_all
(
framework
::
Channel
<
DataItem
>&
channel
,
std
::
vector
<
DataItem
>&
items
)
{
framework
::
ChannelReader
<
DataItem
>
reader
(
channel
.
get
());
DataItem
data_item
;
while
(
reader
>>
data_item
)
{
items
.
push_back
(
std
::
move
(
data_item
));
}
channel
->
ReadAll
(
items
);
// framework::ChannelReader<DataItem> reader(channel.get());
// DataItem data_item;
// while (reader >> data_item) {
// items.push_back(std::move(data_item));
// }
}
static
bool
is_same_with_std_items
(
const
std
::
vector
<
DataItem
>&
items
)
{
...
...
@@ -107,6 +111,14 @@ public:
return
is_same
(
items
,
sorted_std_items
);
}
static
std
::
string
to_string
(
const
std
::
vector
<
DataItem
>&
items
)
{
std
::
string
items_str
=
""
;
for
(
const
auto
&
item
:
items
)
{
items_str
.
append
(
item
.
id
);
}
return
items_str
;
}
static
std
::
vector
<
DataItem
>
std_items
;
static
std
::
vector
<
DataItem
>
sorted_std_items
;
std
::
shared_ptr
<
TrainerContext
>
context_ptr
;
...
...
@@ -137,7 +149,6 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) {
ASSERT_EQ
(
string
::
format_string
(
"%s/%s.txt"
,
test_data_dir
,
std_items
[
i
].
id
.
c_str
()),
data_file_list
[
i
]);
}
int
same_count
=
0
;
for
(
int
i
=
0
;
i
<
n_run
;
++
i
)
{
auto
channel
=
framework
::
MakeChannel
<
DataItem
>
(
128
);
ASSERT_NE
(
nullptr
,
channel
);
...
...
@@ -146,13 +157,8 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) {
std
::
vector
<
DataItem
>
items
;
read_all
(
channel
,
items
);
if
(
is_same_with_std_items
(
items
))
{
++
same_count
;
}
ASSERT_TRUE
(
is_same_with_std_items
(
items
));
}
// n_run 次都相同
ASSERT_EQ
(
n_run
,
same_count
);
}
TEST_F
(
DataReaderOmpTest
,
LineDataReaderMuiltThread
)
{
...
...
@@ -188,36 +194,32 @@ TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) {
omp_set_num_threads
(
4
);
channel
->
SetBlockSize
(
1
);
ASSERT_EQ
(
0
,
data_reader
->
read_all
(
test_data_dir
,
channel
));
std
::
vector
<
DataItem
>
items
;
read_all
(
channel
,
items
);
ASSERT_EQ
(
std_items_size
,
items
.
size
());
if
(
is_same_with_std_items
(
items
))
{
++
same_count
;
}
VLOG
(
5
)
<<
"before sort items: "
<<
to_string
(
items
);
std
::
sort
(
items
.
begin
(),
items
.
end
(),
[]
(
const
DataItem
&
a
,
const
DataItem
&
b
)
{
return
a
.
id
<
b
.
id
;
});
if
(
is_same_with_sorted_std_items
(
items
))
{
++
sort_same_count
;
}
else
{
std
::
string
items_str
=
""
;
for
(
const
auto
&
item
:
items
)
{
items_str
.
append
(
item
.
id
);
bool
is_same_with_std
=
is_same_with_sorted_std_items
(
items
);
if
(
!
is_same_with_std
)
{
VLOG
(
5
)
<<
"after sort items: "
<<
to_string
(
items
);
}
VLOG
(
2
)
<<
"items: "
<<
items_str
;
}
// 排序后都是相同的
ASSERT_TRUE
(
is_same_with_std
);
}
// n_run次有不同的(证明是多线程)
ASSERT_EQ
(
4
,
omp_get_max_threads
());
ASSERT_GT
(
n_run
,
same_count
);
// 但排序后都是相同的
ASSERT_EQ
(
n_run
,
sort_same_count
);
}
}
// namespace feed
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录