Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
803e2ed9
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
803e2ed9
编写于
10月 19, 2018
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ctr_reader_test and fix bug
上级
c8bd5210
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
108 addition
and
22 deletion
+108
-22
paddle/fluid/operators/reader/CMakeLists.txt
paddle/fluid/operators/reader/CMakeLists.txt
+1
-0
paddle/fluid/operators/reader/ctr_reader.cc
paddle/fluid/operators/reader/ctr_reader.cc
+51
-17
paddle/fluid/operators/reader/ctr_reader.h
paddle/fluid/operators/reader/ctr_reader.h
+11
-5
paddle/fluid/operators/reader/ctr_reader_test.cc
paddle/fluid/operators/reader/ctr_reader_test.cc
+45
-0
未找到文件。
paddle/fluid/operators/reader/CMakeLists.txt
浏览文件 @
803e2ed9
...
@@ -17,6 +17,7 @@ endfunction()
...
@@ -17,6 +17,7 @@ endfunction()
cc_library
(
buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool
)
cc_library
(
buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool
)
cc_library
(
ctr_reader SRCS ctr_reader.cc DEPS reader simple_threadpool boost gzstream
)
cc_library
(
ctr_reader SRCS ctr_reader.cc DEPS reader simple_threadpool boost gzstream
)
cc_test
(
ctr_reader_test SRCS ctr_reader_test.cc DEPS ctr_reader
)
reader_library
(
open_files_op SRCS open_files_op.cc DEPS buffered_reader
)
reader_library
(
open_files_op SRCS open_files_op.cc DEPS buffered_reader
)
reader_library
(
create_ctr_reader_op SRCS create_ctr_reader_op.cc DEPS ctr_reader
)
reader_library
(
create_ctr_reader_op SRCS create_ctr_reader_op.cc DEPS ctr_reader
)
reader_library
(
create_random_data_generator_op SRCS create_random_data_generator_op.cc
)
reader_library
(
create_random_data_generator_op SRCS create_random_data_generator_op.cc
)
...
...
paddle/fluid/operators/reader/ctr_reader.cc
浏览文件 @
803e2ed9
...
@@ -46,32 +46,47 @@ static inline void string_split(const std::string& s, const char delimiter,
...
@@ -46,32 +46,47 @@ static inline void string_split(const std::string& s, const char delimiter,
}
}
static
inline
void
parse_line
(
static
inline
void
parse_line
(
const
std
::
string
&
line
,
const
std
::
vector
<
std
::
string
>&
slots
,
const
std
::
string
&
line
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
slot_to_index
,
int64_t
*
label
,
int64_t
*
label
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int64_t
>>*
slot
s
_to_data
)
{
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int64_t
>>*
slot_to_data
)
{
std
::
vector
<
std
::
string
>
ret
;
std
::
vector
<
std
::
string
>
ret
;
string_split
(
line
,
' '
,
&
ret
);
string_split
(
line
,
' '
,
&
ret
);
*
label
=
std
::
stoi
(
ret
[
2
])
>
0
;
*
label
=
std
::
stoi
(
ret
[
2
])
>
0
;
for
(
size_t
i
=
3
;
i
<
ret
.
size
();
++
i
)
{
for
(
size_t
i
=
3
;
i
<
ret
.
size
();
++
i
)
{
const
std
::
string
&
item
=
ret
[
i
];
const
std
::
string
&
item
=
ret
[
i
];
std
::
vector
<
std
::
string
>
slot_and_feasign
;
std
::
vector
<
std
::
string
>
feasign_and_slot
;
string_split
(
item
,
':'
,
&
slot_and_feasign
);
string_split
(
item
,
':'
,
&
feasign_and_slot
);
if
(
slot_and_feasign
.
size
()
==
2
)
{
auto
&
slot
=
feasign_and_slot
[
1
];
const
std
::
string
&
slot
=
slot_and_feasign
[
1
];
if
(
feasign_and_slot
.
size
()
==
2
&&
int64_t
feasign
=
std
::
strtoll
(
slot_and_feasign
[
0
].
c_str
(),
NULL
,
10
);
slot_to_index
.
find
(
slot
)
!=
slot_to_index
.
end
())
{
(
*
slots_to_data
)[
slot_and_feasign
[
1
]].
push_back
(
feasign
);
const
std
::
string
&
slot
=
feasign_and_slot
[
1
];
int64_t
feasign
=
std
::
strtoll
(
feasign_and_slot
[
0
].
c_str
(),
NULL
,
10
);
(
*
slot_to_data
)[
feasign_and_slot
[
1
]].
push_back
(
feasign
);
}
}
}
}
// NOTE:: if the slot has no value, then fill [0] as it's data.
// NOTE:: if the slot has no value, then fill [0] as it's data.
for
(
auto
&
slot
:
slots
)
{
for
(
auto
&
item
:
slot_to_index
)
{
if
(
slot
s_to_data
->
find
(
slot
)
==
slots
_to_data
->
end
())
{
if
(
slot
_to_data
->
find
(
item
.
first
)
==
slot
_to_data
->
end
())
{
(
*
slot
s_to_data
)[
slo
t
].
push_back
(
0
);
(
*
slot
_to_data
)[
item
.
firs
t
].
push_back
(
0
);
}
}
}
}
}
}
static
void
print_map
(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int64_t
>>*
map
)
{
for
(
auto
it
=
map
->
begin
();
it
!=
map
->
end
();
++
it
)
{
std
::
cout
<<
it
->
first
<<
" -> "
;
std
::
cout
<<
"["
;
for
(
auto
&
i
:
it
->
second
)
{
std
::
cout
<<
i
<<
" "
;
}
std
::
cout
<<
"]
\n
"
;
}
}
class
Reader
{
class
Reader
{
public:
public:
virtual
~
Reader
()
{}
virtual
~
Reader
()
{}
...
@@ -126,7 +141,14 @@ void ReadThread(const std::vector<std::string>& file_list,
...
@@ -126,7 +141,14 @@ void ReadThread(const std::vector<std::string>& file_list,
const
std
::
vector
<
std
::
string
>&
slots
,
int
batch_size
,
const
std
::
vector
<
std
::
string
>&
slots
,
int
batch_size
,
int
thread_id
,
std
::
vector
<
ReaderThreadStatus
>*
thread_status
,
int
thread_id
,
std
::
vector
<
ReaderThreadStatus
>*
thread_status
,
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
)
{
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
)
{
VLOG
(
3
)
<<
"reader thread start! thread_id = "
<<
thread_id
;
(
*
thread_status
)[
thread_id
]
=
Running
;
(
*
thread_status
)[
thread_id
]
=
Running
;
VLOG
(
3
)
<<
"set status to running"
;
std
::
unordered_map
<
std
::
string
,
size_t
>
slot_to_index
;
for
(
size_t
i
=
0
;
i
<
slots
.
size
();
++
i
)
{
slot_to_index
[
slots
[
i
]]
=
i
;
}
std
::
string
line
;
std
::
string
line
;
...
@@ -135,21 +157,29 @@ void ReadThread(const std::vector<std::string>& file_list,
...
@@ -135,21 +157,29 @@ void ReadThread(const std::vector<std::string>& file_list,
MultiGzipReader
reader
(
file_list
);
MultiGzipReader
reader
(
file_list
);
VLOG
(
3
)
<<
"reader inited"
;
while
(
reader
.
HasNext
())
{
while
(
reader
.
HasNext
())
{
// read all files
batch_data
.
clear
();
batch_label
.
clear
();
// read batch_size data
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
if
(
reader
.
HasNext
())
{
if
(
reader
.
HasNext
())
{
reader
.
NextLine
(
&
line
);
reader
.
NextLine
(
&
line
);
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int64_t
>>
slot
s
_to_data
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int64_t
>>
slot_to_data
;
int64_t
label
;
int64_t
label
;
parse_line
(
line
,
slot
s
,
&
label
,
&
slots
_to_data
);
parse_line
(
line
,
slot
_to_index
,
&
label
,
&
slot
_to_data
);
batch_data
.
push_back
(
slot
s
_to_data
);
batch_data
.
push_back
(
slot_to_data
);
batch_label
.
push_back
(
label
);
batch_label
.
push_back
(
label
);
}
else
{
}
else
{
break
;
break
;
}
}
}
}
VLOG
(
3
)
<<
"read one batch, batch_size = "
<<
batch_data
.
size
();
print_map
(
&
batch_data
[
0
]);
std
::
vector
<
framework
::
LoDTensor
>
lod_datas
;
std
::
vector
<
framework
::
LoDTensor
>
lod_datas
;
// first insert tensor for each slots
// first insert tensor for each slots
...
@@ -159,9 +189,9 @@ void ReadThread(const std::vector<std::string>& file_list,
...
@@ -159,9 +189,9 @@ void ReadThread(const std::vector<std::string>& file_list,
for
(
size_t
i
=
0
;
i
<
batch_data
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
batch_data
.
size
();
++
i
)
{
auto
&
feasign
=
batch_data
[
i
][
slot
];
auto
&
feasign
=
batch_data
[
i
][
slot
];
lod_data
.
push_back
(
lod_data
.
back
()
+
feasign
.
size
());
lod_data
.
push_back
(
lod_data
.
back
()
+
feasign
.
size
());
batch_feasign
.
insert
(
feasign
.
end
(),
feasign
.
begin
(),
feasign
.
end
());
batch_feasign
.
insert
(
batch_feasign
.
end
(),
feasign
.
begin
(),
feasign
.
end
());
}
}
framework
::
LoDTensor
lod_tensor
;
framework
::
LoDTensor
lod_tensor
;
...
@@ -174,6 +204,8 @@ void ReadThread(const std::vector<std::string>& file_list,
...
@@ -174,6 +204,8 @@ void ReadThread(const std::vector<std::string>& file_list,
lod_datas
.
push_back
(
lod_tensor
);
lod_datas
.
push_back
(
lod_tensor
);
}
}
VLOG
(
3
)
<<
"convert data to tensor"
;
// insert label tensor
// insert label tensor
framework
::
LoDTensor
label_tensor
;
framework
::
LoDTensor
label_tensor
;
int64_t
*
label_tensor_data
=
label_tensor
.
mutable_data
<
int64_t
>
(
int64_t
*
label_tensor_data
=
label_tensor
.
mutable_data
<
int64_t
>
(
...
@@ -182,10 +214,12 @@ void ReadThread(const std::vector<std::string>& file_list,
...
@@ -182,10 +214,12 @@ void ReadThread(const std::vector<std::string>& file_list,
memcpy
(
label_tensor_data
,
batch_label
.
data
(),
batch_label
.
size
());
memcpy
(
label_tensor_data
,
batch_label
.
data
(),
batch_label
.
size
());
lod_datas
.
push_back
(
label_tensor
);
lod_datas
.
push_back
(
label_tensor
);
VLOG
(
3
)
<<
"push one data"
;
queue
->
Push
(
lod_datas
);
queue
->
Push
(
lod_datas
);
}
}
(
*
thread_status
)[
thread_id
]
=
Stopped
;
(
*
thread_status
)[
thread_id
]
=
Stopped
;
VLOG
(
3
)
<<
"thread "
<<
thread_id
<<
" exited"
;
}
}
}
// namespace reader
}
// namespace reader
...
...
paddle/fluid/operators/reader/ctr_reader.h
浏览文件 @
803e2ed9
...
@@ -47,15 +47,15 @@ class CTRReader : public framework::FileReader {
...
@@ -47,15 +47,15 @@ class CTRReader : public framework::FileReader {
PADDLE_ENFORCE
(
queue
!=
nullptr
,
"LoDTensorBlockingQueue must not be null"
);
PADDLE_ENFORCE
(
queue
!=
nullptr
,
"LoDTensorBlockingQueue must not be null"
);
PADDLE_ENFORCE_GT
(
file_list
.
size
(),
0
,
"file list should not be empty"
);
PADDLE_ENFORCE_GT
(
file_list
.
size
(),
0
,
"file list should not be empty"
);
thread_num_
=
thread_num_
=
file_list_
.
size
()
>
thread_num
_
?
thread_num_
:
file_list_
.
size
();
file_list_
.
size
()
>
thread_num
?
thread_num
:
file_list_
.
size
();
queue_
=
queue
;
queue_
=
queue
;
SplitFiles
();
SplitFiles
();
for
(
int
i
=
0
;
i
<
thread_num
;
++
i
)
{
for
(
int
i
=
0
;
i
<
thread_num
_
;
++
i
)
{
read_thread_status_
.
push_back
(
Stopped
);
read_thread_status_
.
push_back
(
Stopped
);
}
}
}
}
~
CTRReader
()
{
queue_
->
Close
();
}
~
CTRReader
()
{
Shutdown
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
bool
success
;
bool
success
;
...
@@ -74,8 +74,11 @@ class CTRReader : public framework::FileReader {
...
@@ -74,8 +74,11 @@ class CTRReader : public framework::FileReader {
void
Start
()
override
{
void
Start
()
override
{
VLOG
(
3
)
<<
"Start reader"
;
VLOG
(
3
)
<<
"Start reader"
;
PADDLE_ENFORCE_EQ
(
read_threads_
.
size
(),
0
,
"read thread should be empty!"
);
queue_
->
ReOpen
();
queue_
->
ReOpen
();
for
(
int
thread_id
=
0
;
thread_id
<
file_groups_
.
size
();
thread_id
++
)
{
VLOG
(
3
)
<<
"reopen success"
;
VLOG
(
3
)
<<
"thread_num "
<<
thread_num_
;
for
(
int
thread_id
=
0
;
thread_id
<
thread_num_
;
thread_id
++
)
{
read_threads_
.
emplace_back
(
new
std
::
thread
(
read_threads_
.
emplace_back
(
new
std
::
thread
(
std
::
bind
(
&
ReadThread
,
file_groups_
[
thread_id
],
slots_
,
batch_size_
,
std
::
bind
(
&
ReadThread
,
file_groups_
[
thread_id
],
slots_
,
batch_size_
,
thread_id
,
&
read_thread_status_
,
queue_
)));
thread_id
,
&
read_thread_status_
,
queue_
)));
...
@@ -86,7 +89,10 @@ class CTRReader : public framework::FileReader {
...
@@ -86,7 +89,10 @@ class CTRReader : public framework::FileReader {
void
SplitFiles
()
{
void
SplitFiles
()
{
file_groups_
.
resize
(
thread_num_
);
file_groups_
.
resize
(
thread_num_
);
for
(
int
i
=
0
;
i
<
file_list_
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
file_list_
.
size
();
++
i
)
{
file_groups_
[
i
%
thread_num_
].
push_back
(
file_list_
[
i
]);
auto
&
file_name
=
file_list_
[
i
];
std
::
ifstream
f
(
file_name
.
c_str
());
PADDLE_ENFORCE
(
f
.
good
(),
"file %s not exist!"
,
file_name
);
file_groups_
[
i
%
thread_num_
].
push_back
(
file_name
);
}
}
}
}
...
...
paddle/fluid/operators/reader/ctr_reader_test.cc
0 → 100644
浏览文件 @
803e2ed9
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/ctr_reader.h"
using
paddle
::
operators
::
reader
::
LoDTensorBlockingQueue
;
using
paddle
::
operators
::
reader
::
LoDTensorBlockingQueueHolder
;
using
paddle
::
operators
::
reader
::
CTRReader
;
TEST
(
CTR_READER
,
read_data
)
{
LoDTensorBlockingQueueHolder
queue_holder
;
int
capacity
=
64
;
queue_holder
.
InitOnce
(
capacity
,
{},
false
);
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
=
queue_holder
.
GetQueue
();
int
batch_size
=
10
;
int
thread_num
=
1
;
std
::
vector
<
std
::
string
>
slots
=
{
"6003"
,
"6004"
};
std
::
vector
<
std
::
string
>
file_list
=
{
"/Users/qiaolongfei/project/gzip_test/part-00000-A.gz"
,
"/Users/qiaolongfei/project/gzip_test/part-00000-A.gz"
};
CTRReader
reader
(
queue
,
batch_size
,
thread_num
,
slots
,
file_list
);
reader
.
Start
();
//
// std::vector<LoDTensor> out;
// reader.ReadNext(&out);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录