Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
35b79ab8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
35b79ab8
编写于
11月 28, 2018
作者:
Q
Qiao Longfei
提交者:
GitHub
11月 28, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13983 from jacquesqiao/add-ctr-reader
Add ctr reader
上级
b1dbbb7f
da387720
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
782 addition
and
0 deletion
+782
-0
CMakeLists.txt
CMakeLists.txt
+1
-0
cmake/external/gzstream.cmake
cmake/external/gzstream.cmake
+47
-0
paddle/fluid/operators/reader/CMakeLists.txt
paddle/fluid/operators/reader/CMakeLists.txt
+6
-0
paddle/fluid/operators/reader/create_ctr_reader_op.cc
paddle/fluid/operators/reader/create_ctr_reader_op.cc
+79
-0
paddle/fluid/operators/reader/ctr_reader.cc
paddle/fluid/operators/reader/ctr_reader.cc
+238
-0
paddle/fluid/operators/reader/ctr_reader.h
paddle/fluid/operators/reader/ctr_reader.h
+133
-0
paddle/fluid/operators/reader/ctr_reader_test.cc
paddle/fluid/operators/reader/ctr_reader_test.cc
+155
-0
python/paddle/fluid/contrib/reader/ctr_reader.py
python/paddle/fluid/contrib/reader/ctr_reader.py
+123
-0
未找到文件。
CMakeLists.txt
浏览文件 @
35b79ab8
...
...
@@ -214,6 +214,7 @@ if (NOT WIN32)
# there is no official support of warpctc, nccl, cupti in windows
include
(
external/warpctc
)
# download, build, install warpctc
include
(
cupti
)
include
(
external/gzstream
)
endif
(
NOT WIN32
)
if
(
WITH_DISTRIBUTE
)
...
...
cmake/external/gzstream.cmake
0 → 100644
浏览文件 @
35b79ab8
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
IF
(
MOBILE_INFERENCE
)
return
()
ENDIF
()
include
(
ExternalProject
)
# NOTE: gzstream is needed when linking with ctr reader.
SET
(
GZSTREAM_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/gzstream
)
SET
(
GZSTREAM_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/gzstream
)
SET
(
GZSTREAM_INCLUDE_DIR
"
${
GZSTREAM_INSTALL_DIR
}
/include/"
CACHE PATH
"gzstream include directory."
FORCE
)
ExternalProject_Add
(
extern_gzstream
GIT_REPOSITORY
"https://github.com/jacquesqiao/gzstream.git"
GIT_TAG
""
PREFIX
${
GZSTREAM_SOURCES_DIR
}
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
BUILD_IN_SOURCE 1
BUILD_COMMAND make -j8
INSTALL_COMMAND mkdir -p
${
GZSTREAM_INSTALL_DIR
}
/lib/ && mkdir -p
${
GZSTREAM_INSTALL_DIR
}
/include/
&& cp
${
GZSTREAM_SOURCES_DIR
}
/src/extern_gzstream/libgzstream.a
${
GZSTREAM_INSTALL_DIR
}
/lib
&& cp -r
${
GZSTREAM_SOURCES_DIR
}
/src/extern_gzstream/gzstream.h
${
GZSTREAM_INSTALL_DIR
}
/include
)
ADD_LIBRARY
(
gzstream STATIC IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET gzstream PROPERTY IMPORTED_LOCATION
"
${
GZSTREAM_INSTALL_DIR
}
/lib/libgzstream.a"
)
include_directories
(
${
GZSTREAM_INCLUDE_DIR
}
)
ADD_DEPENDENCIES
(
gzstream extern_gzstream zlib
)
paddle/fluid/operators/reader/CMakeLists.txt
浏览文件 @
35b79ab8
...
...
@@ -28,6 +28,12 @@ reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library
(
create_custom_reader_op SRCS create_custom_reader_op.cc
)
reader_library
(
create_py_reader_op SRCS create_py_reader_op.cc
)
if
(
NOT WIN32 AND NOT ON_INFER
)
cc_library
(
ctr_reader SRCS ctr_reader.cc DEPS gzstream reader zlib
)
cc_test
(
ctr_reader_test SRCS ctr_reader_test.cc DEPS ctr_reader
)
reader_library
(
create_ctr_reader_op SRCS create_ctr_reader_op.cc DEPS ctr_reader
)
endif
()
cc_test
(
reader_blocking_queue_test SRCS reader_blocking_queue_test.cc
)
# Export local libraries to parent
# set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
...
...
paddle/fluid/operators/reader/create_ctr_reader_op.cc
0 → 100644
浏览文件 @
35b79ab8
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/ctr_reader.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
class
CreateCTRReaderOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
if
(
out
->
Get
()
!=
nullptr
)
return
;
const
std
::
string
&
queue_name
=
Input
(
"blocking_queue"
);
auto
*
queue_holder_var
=
scope
.
FindVar
(
queue_name
);
PADDLE_ENFORCE_NOT_NULL
(
queue_holder_var
,
"No LoDTensorBlockingQueueHolder variable with name %s found"
,
queue_name
);
auto
*
queue_holder
=
queue_holder_var
->
template
GetMutable
<
LoDTensorBlockingQueueHolder
>();
int
thread_num
=
Attr
<
int
>
(
"thread_num"
);
std
::
vector
<
std
::
string
>
slots
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"slots"
);
int
batch_size
=
Attr
<
int
>
(
"batch_size"
);
std
::
vector
<
std
::
string
>
file_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"file_list"
);
out
->
Reset
(
std
::
make_shared
<
CTRReader
>
(
queue_holder
->
GetQueue
(),
batch_size
,
thread_num
,
slots
,
file_list
));
}
};
class
CreateCTRReaderOpMaker
:
public
FileReaderMakerBase
{
protected:
void
Apply
()
override
{
AddInput
(
"blocking_queue"
,
"Name of the `LoDTensorBlockingQueueHolder` variable"
);
AddAttr
<
int
>
(
"thread_num"
,
"the thread num to read data"
);
AddAttr
<
int
>
(
"batch_size"
,
"the batch size of read data"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"file_list"
,
"The list of files that need to read"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"slots"
,
"the slots that should be extract from file"
);
AddComment
(
R"DOC(
Create CTRReader to support read ctr data with cpp.
)DOC"
);
}
};
}
// namespace reader
}
// namespace operators
}
// namespace paddle
namespace
reader
=
::
paddle
::
operators
::
reader
;
REGISTER_FILE_READER_OPERATOR
(
create_ctr_reader
,
reader
::
CreateCTRReaderOp
,
reader
::
CreateCTRReaderOpMaker
);
paddle/fluid/operators/reader/ctr_reader.cc
0 → 100644
浏览文件 @
35b79ab8
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/ctr_reader.h"
#include <gzstream.h>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_map>
#include <algorithm>
#include <random>
namespace
paddle
{
namespace
operators
{
namespace
reader
{
static
inline
void
string_split
(
const
std
::
string
&
s
,
const
char
delimiter
,
std
::
vector
<
std
::
string
>*
output
)
{
size_t
start
=
0
;
size_t
end
=
s
.
find_first_of
(
delimiter
);
while
(
end
<=
std
::
string
::
npos
)
{
output
->
emplace_back
(
s
.
substr
(
start
,
end
-
start
));
if
(
end
==
std
::
string
::
npos
)
{
break
;
}
start
=
end
+
1
;
end
=
s
.
find_first_of
(
delimiter
,
start
);
}
}
static
inline
void
parse_line
(
const
std
::
string
&
line
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
slot_to_index
,
int64_t
*
label
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int64_t
>>*
slot_to_data
)
{
std
::
vector
<
std
::
string
>
ret
;
string_split
(
line
,
' '
,
&
ret
);
*
label
=
std
::
stoi
(
ret
[
2
])
>
0
;
for
(
size_t
i
=
3
;
i
<
ret
.
size
();
++
i
)
{
const
std
::
string
&
item
=
ret
[
i
];
std
::
vector
<
std
::
string
>
feasign_and_slot
;
string_split
(
item
,
':'
,
&
feasign_and_slot
);
if
(
feasign_and_slot
.
size
()
==
2
&&
slot_to_index
.
find
(
feasign_and_slot
[
1
])
!=
slot_to_index
.
end
())
{
int64_t
feasign
=
std
::
strtoll
(
feasign_and_slot
[
0
].
c_str
(),
NULL
,
10
);
(
*
slot_to_data
)[
feasign_and_slot
[
1
]].
push_back
(
feasign
);
}
}
// NOTE:: if the slot has no value, then fill [0] as it's data.
for
(
auto
&
item
:
slot_to_index
)
{
if
(
slot_to_data
->
find
(
item
.
first
)
==
slot_to_data
->
end
())
{
(
*
slot_to_data
)[
item
.
first
].
push_back
(
0
);
}
}
}
class
Reader
{
public:
virtual
~
Reader
()
{}
virtual
bool
HasNext
()
=
0
;
virtual
void
NextLine
(
std
::
string
*
line
)
=
0
;
};
class
GzipReader
:
public
Reader
{
public:
explicit
GzipReader
(
const
std
::
string
&
file_name
)
:
gzstream_
(
file_name
.
c_str
())
{}
~
GzipReader
()
{}
bool
HasNext
()
override
{
return
gzstream_
.
peek
()
!=
EOF
;
}
void
NextLine
(
std
::
string
*
line
)
override
{
std
::
getline
(
gzstream_
,
*
line
);
}
private:
igzstream
gzstream_
;
};
class
MultiGzipReader
:
public
Reader
{
public:
explicit
MultiGzipReader
(
const
std
::
vector
<
std
::
string
>&
file_list
)
{
for
(
auto
&
file
:
file_list
)
{
readers_
.
emplace_back
(
std
::
make_shared
<
GzipReader
>
(
file
));
}
}
bool
HasNext
()
override
{
if
(
current_reader_index_
>=
readers_
.
size
())
{
return
false
;
}
if
(
!
readers_
[
current_reader_index_
]
->
HasNext
())
{
current_reader_index_
++
;
return
HasNext
();
}
return
true
;
}
void
NextLine
(
std
::
string
*
line
)
override
{
readers_
[
current_reader_index_
]
->
NextLine
(
line
);
}
private:
std
::
vector
<
std
::
shared_ptr
<
GzipReader
>>
readers_
;
size_t
current_reader_index_
=
0
;
};
void
MonitorThread
(
std
::
vector
<
ReaderThreadStatus
>*
thread_status
,
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
)
{
VLOG
(
30
)
<<
"monitor thread in"
;
bool
reader_thread_is_running
=
true
;
while
(
reader_thread_is_running
)
{
VLOG
(
30
)
<<
"reader_thread_is_running"
;
reader_thread_is_running
=
false
;
for
(
size_t
i
=
0
;
i
<
(
*
thread_status
).
size
();
++
i
)
{
if
((
*
thread_status
)[
i
]
==
Running
)
{
VLOG
(
30
)
<<
"reader is running!"
;
reader_thread_is_running
=
true
;
}
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
VLOG
(
30
)
<<
"all reader thread is stopped, push empty data into queue"
;
queue
->
Push
({});
VLOG
(
30
)
<<
"monitor thread exited"
;
}
void
ReadThread
(
const
std
::
vector
<
std
::
string
>&
file_list
,
const
std
::
vector
<
std
::
string
>&
slots
,
int
batch_size
,
int
thread_id
,
std
::
vector
<
ReaderThreadStatus
>*
thread_status
,
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
)
{
VLOG
(
30
)
<<
"["
<<
thread_id
<<
"]"
<<
" reader thread start! thread_id = "
<<
thread_id
;
for
(
auto
&
file
:
file_list
)
{
VLOG
(
30
)
<<
"["
<<
thread_id
<<
"]"
<<
" file "
<<
file
;
}
(
*
thread_status
)[
thread_id
]
=
Running
;
VLOG
(
30
)
<<
"set status to running"
;
std
::
unordered_map
<
std
::
string
,
size_t
>
slot_to_index
;
for
(
size_t
i
=
0
;
i
<
slots
.
size
();
++
i
)
{
slot_to_index
[
slots
[
i
]]
=
i
;
}
std
::
string
line
;
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int64_t
>>>
batch_data
;
std
::
vector
<
int64_t
>
batch_label
;
MultiGzipReader
reader
(
file_list
);
VLOG
(
30
)
<<
"reader inited"
;
while
(
reader
.
HasNext
())
{
batch_data
.
clear
();
batch_data
.
reserve
(
batch_size
);
batch_label
.
clear
();
batch_label
.
reserve
(
batch_size
);
// read batch_size data
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
if
(
reader
.
HasNext
())
{
reader
.
NextLine
(
&
line
);
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int64_t
>>
slot_to_data
;
int64_t
label
;
parse_line
(
line
,
slot_to_index
,
&
label
,
&
slot_to_data
);
batch_data
.
push_back
(
slot_to_data
);
batch_label
.
push_back
(
label
);
}
else
{
break
;
}
}
std
::
vector
<
framework
::
LoDTensor
>
lod_datas
;
// first insert tensor for each slots
for
(
auto
&
slot
:
slots
)
{
std
::
vector
<
size_t
>
lod_data
{
0
};
std
::
vector
<
int64_t
>
batch_feasign
;
for
(
size_t
i
=
0
;
i
<
batch_data
.
size
();
++
i
)
{
auto
&
feasign
=
batch_data
[
i
][
slot
];
lod_data
.
push_back
(
lod_data
.
back
()
+
feasign
.
size
());
batch_feasign
.
insert
(
batch_feasign
.
end
(),
feasign
.
begin
(),
feasign
.
end
());
}
framework
::
LoDTensor
lod_tensor
;
framework
::
LoD
lod
{
lod_data
};
lod_tensor
.
set_lod
(
lod
);
int64_t
*
tensor_data
=
lod_tensor
.
mutable_data
<
int64_t
>
(
framework
::
make_ddim
({
1
,
static_cast
<
int64_t
>
(
batch_feasign
.
size
())}),
platform
::
CPUPlace
());
memcpy
(
tensor_data
,
batch_feasign
.
data
(),
batch_feasign
.
size
()
*
sizeof
(
int64_t
));
lod_datas
.
push_back
(
lod_tensor
);
}
// insert label tensor
framework
::
LoDTensor
label_tensor
;
auto
*
label_tensor_data
=
label_tensor
.
mutable_data
<
int64_t
>
(
framework
::
make_ddim
({
1
,
static_cast
<
int64_t
>
(
batch_label
.
size
())}),
platform
::
CPUPlace
());
memcpy
(
label_tensor_data
,
batch_label
.
data
(),
batch_label
.
size
()
*
sizeof
(
int64_t
));
lod_datas
.
push_back
(
label_tensor
);
queue
->
Push
(
lod_datas
);
VLOG
(
40
)
<<
"push one data, queue_size="
<<
queue
->
Size
();
}
(
*
thread_status
)[
thread_id
]
=
Stopped
;
VLOG
(
30
)
<<
"set status to stopped, thread "
<<
thread_id
<<
" exited"
;
}
}
// namespace reader
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/reader/ctr_reader.h
0 → 100644
浏览文件 @
35b79ab8
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <sys/time.h>
#include <chrono> // NOLINT
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
enum
ReaderThreadStatus
{
Running
,
Stopped
};
void
ReadThread
(
const
std
::
vector
<
std
::
string
>&
file_list
,
const
std
::
vector
<
std
::
string
>&
slots
,
int
batch_size
,
int
thread_id
,
std
::
vector
<
ReaderThreadStatus
>*
thread_status
,
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
);
// monitor all running thread, if they are all stopped,
// then push an empty data into LoDTensorBlockingQueue
void
MonitorThread
(
std
::
vector
<
ReaderThreadStatus
>*
thread_status
,
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
);
class
CTRReader
:
public
framework
::
FileReader
{
public:
explicit
CTRReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
,
int
batch_size
,
int
thread_num
,
const
std
::
vector
<
std
::
string
>&
slots
,
const
std
::
vector
<
std
::
string
>&
file_list
)
:
batch_size_
(
batch_size
),
slots_
(
slots
),
file_list_
(
file_list
)
{
PADDLE_ENFORCE_GT
(
thread_num
,
0
,
"thread num should be larger then 0!"
);
PADDLE_ENFORCE
(
queue
!=
nullptr
,
"LoDTensorBlockingQueue must not be null"
);
PADDLE_ENFORCE_GT
(
file_list
.
size
(),
0
,
"file list should not be empty"
);
thread_num_
=
file_list_
.
size
()
>
thread_num
?
thread_num
:
file_list_
.
size
();
queue_
=
queue
;
SplitFiles
();
for
(
size_t
i
=
0
;
i
<
thread_num_
;
++
i
)
{
read_thread_status_
.
push_back
(
Stopped
);
}
}
~
CTRReader
()
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
bool
success
;
*
out
=
queue_
->
Pop
(
&
success
);
if
(
!
success
)
out
->
clear
();
}
void
Shutdown
()
override
{
VLOG
(
3
)
<<
"Shutdown reader"
;
if
(
status_
==
ReaderStatus
::
kStopped
)
{
return
;
}
// shutdown should stop all the reader thread
for
(
auto
&
read_thread
:
read_threads_
)
{
read_thread
->
join
();
}
monitor_thread_
->
join
();
read_threads_
.
clear
();
monitor_thread_
.
reset
(
nullptr
);
queue_
->
Close
();
status_
=
ReaderStatus
::
kStopped
;
}
void
Start
()
override
{
VLOG
(
3
)
<<
"Start reader"
;
PADDLE_ENFORCE_EQ
(
read_threads_
.
size
(),
0
,
"read thread should be empty!"
);
queue_
->
ReOpen
();
VLOG
(
3
)
<<
"reopen success"
;
VLOG
(
3
)
<<
"thread_num "
<<
thread_num_
;
for
(
int
thread_id
=
0
;
thread_id
<
thread_num_
;
thread_id
++
)
{
read_threads_
.
emplace_back
(
new
std
::
thread
(
std
::
bind
(
&
ReadThread
,
file_groups_
[
thread_id
],
slots_
,
batch_size_
,
thread_id
,
&
read_thread_status_
,
queue_
)));
}
monitor_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
MonitorThread
,
&
read_thread_status_
,
queue_
)));
status_
=
ReaderStatus
::
kRunning
;
}
private:
void
SplitFiles
()
{
file_groups_
.
resize
(
thread_num_
);
for
(
size_t
i
=
0
;
i
<
file_list_
.
size
();
++
i
)
{
auto
&
file_name
=
file_list_
[
i
];
std
::
ifstream
f
(
file_name
.
c_str
());
PADDLE_ENFORCE
(
f
.
good
(),
"file %s not exist!"
,
file_name
);
file_groups_
[
i
%
thread_num_
].
push_back
(
file_name
);
}
}
private:
size_t
thread_num_
;
const
int
batch_size_
;
const
std
::
vector
<
std
::
string
>
slots_
;
const
std
::
vector
<
std
::
string
>
file_list_
;
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
read_threads_
;
std
::
unique_ptr
<
std
::
thread
>
monitor_thread_
;
std
::
vector
<
ReaderThreadStatus
>
read_thread_status_
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
file_groups_
;
};
}
// namespace reader
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/reader/ctr_reader_test.cc
0 → 100644
浏览文件 @
35b79ab8
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/ctr_reader.h"
#include <gzstream.h>
#include <time.h>
#include <math.h>
#include <stdio.h>
#include <cstring>
#include <fstream>
#include <tuple>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
using
paddle
::
operators
::
reader
::
LoDTensorBlockingQueue
;
using
paddle
::
operators
::
reader
::
LoDTensorBlockingQueueHolder
;
using
paddle
::
operators
::
reader
::
CTRReader
;
using
paddle
::
framework
::
LoDTensor
;
using
paddle
::
framework
::
LoD
;
using
paddle
::
framework
::
DDim
;
using
paddle
::
platform
::
CPUPlace
;
using
paddle
::
framework
::
make_ddim
;
static
void
generatedata
(
const
std
::
vector
<
std
::
string
>&
data
,
const
std
::
string
&
file_name
)
{
std
::
ifstream
in
(
file_name
.
c_str
());
if
(
in
.
good
())
{
VLOG
(
3
)
<<
"file "
<<
file_name
<<
" exist, delete it first!"
;
remove
(
file_name
.
c_str
());
}
else
{
in
.
close
();
}
ogzstream
out
(
file_name
.
c_str
());
PADDLE_ENFORCE
(
out
.
good
(),
"open file %s failed!"
,
file_name
);
for
(
auto
&
c
:
data
)
{
out
<<
c
;
}
out
.
close
();
PADDLE_ENFORCE
(
out
.
good
(),
"save file %s failed!"
,
file_name
);
}
static
inline
void
check_all_data
(
const
std
::
vector
<
std
::
string
>&
ctr_data
,
const
std
::
vector
<
std
::
string
>&
slots
,
const
std
::
vector
<
DDim
>&
label_dims
,
const
std
::
vector
<
int64_t
>&
label_value
,
const
std
::
vector
<
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>>&
data_slot_6002
,
const
std
::
vector
<
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>>&
data_slot_6003
,
size_t
batch_num
,
size_t
batch_size
,
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
,
CTRReader
*
reader
)
{
std
::
vector
<
LoDTensor
>
out
;
for
(
size_t
i
=
0
;
i
<
batch_num
;
++
i
)
{
reader
->
ReadNext
(
&
out
);
ASSERT_EQ
(
out
.
size
(),
slots
.
size
()
+
1
);
auto
&
label_tensor
=
out
.
back
();
ASSERT_EQ
(
label_tensor
.
dims
(),
label_dims
[
i
]);
for
(
size_t
j
=
0
;
j
<
batch_size
&&
i
*
batch_num
+
j
<
ctr_data
.
size
();
++
j
)
{
auto
&
label
=
label_tensor
.
data
<
int64_t
>
()[
j
];
ASSERT_TRUE
(
label
==
0
||
label
==
1
);
ASSERT_EQ
(
label
,
label_value
[
i
*
batch_size
+
j
]);
}
auto
&
tensor_6002
=
out
[
0
];
ASSERT_EQ
(
std
::
get
<
0
>
(
data_slot_6002
[
i
]),
tensor_6002
.
lod
());
ASSERT_EQ
(
std
::
memcmp
(
std
::
get
<
1
>
(
data_slot_6002
[
i
]).
data
(),
tensor_6002
.
data
<
int64_t
>
(),
tensor_6002
.
dims
()[
1
]
*
sizeof
(
int64_t
)),
0
);
}
reader
->
ReadNext
(
&
out
);
ASSERT_EQ
(
out
.
size
(),
0
);
ASSERT_EQ
(
queue
->
Size
(),
0
);
}
TEST
(
CTR_READER
,
read_data
)
{
const
std
::
vector
<
std
::
string
>
ctr_data
=
{
"aaaa 1 0 0:6002 1:6003 2:6004 3:6005 4:6006 -1
\n
"
,
"bbbb 1 0 5:6003 6:6003 7:6003 8:6004 9:6004 -1
\n
"
,
"cccc 1 1 10:6002 11:6002 12:6002 13:6002 14:6002 -2
\n
"
,
"dddd 1 0 15:6003 16:6003 17:6003 18:6003 19:6004 -3
\n
"
,
"1111 1 1 20:6001 21:6001 22:6001 23:6001 24:6001 12
\n
"
,
"2222 1 1 25:6004 26:6004 27:6004 28:6005 29:6005 aa
\n
"
,
"3333 1 0 30:6002 31:6003 32:6004 33:6004 34:6005 er
\n
"
,
"eeee 1 1 35:6003 36:6003 37:6005 38:6005 39:6005 dd
\n
"
,
"ffff 1 1 40:6002 41:6003 42:6004 43:6004 44:6005 66
\n
"
,
"gggg 1 1 46:6006 45:6006 47:6003 48:6003 49:6003 ba
\n
"
,
};
std
::
string
gz_file_name
=
"test_ctr_reader_data.gz"
;
generatedata
(
ctr_data
,
gz_file_name
);
std
::
vector
<
int64_t
>
label_value
=
{
0
,
0
,
1
,
0
,
1
,
1
,
0
,
1
,
1
,
1
};
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
a1
({{
0
,
1
,
2
,
7
}},
{
0
,
0
,
10
,
11
,
12
,
13
,
14
});
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
a2
({{
0
,
1
,
2
,
3
}},
{
0
,
0
,
0
});
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
a3
({{
0
,
1
,
2
,
3
}},
{
30
,
0
,
40
});
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
a4
({{
0
,
1
}},
{
0
});
std
::
vector
<
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>>
data_slot_6002
{
a1
,
a2
,
a3
,
a4
};
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
b1
({{
0
,
1
,
4
,
5
}},
{
1
,
5
,
6
,
7
,
0
});
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
b2
({{
0
,
4
,
5
,
6
}},
{
15
,
16
,
17
,
18
,
0
,
0
});
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
b3
({{
0
,
1
,
3
,
4
}},
{
31
,
35
,
36
,
41
});
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
b4
({{
0
,
3
}},
{
47
,
48
,
49
});
std
::
vector
<
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>>
data_slot_6003
{
b1
,
b2
,
b3
,
b4
};
std
::
vector
<
DDim
>
label_dims
=
{{
1
,
3
},
{
1
,
3
},
{
1
,
3
},
{
1
,
1
}};
LoDTensorBlockingQueueHolder
queue_holder
;
int
capacity
=
64
;
queue_holder
.
InitOnce
(
capacity
,
{},
false
);
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
=
queue_holder
.
GetQueue
();
int
batch_size
=
3
;
int
thread_num
=
1
;
std
::
vector
<
std
::
string
>
slots
=
{
"6002"
,
"6003"
};
std
::
vector
<
std
::
string
>
file_list
;
for
(
int
i
=
0
;
i
<
thread_num
;
++
i
)
{
file_list
.
push_back
(
gz_file_name
);
}
CTRReader
reader
(
queue
,
batch_size
,
thread_num
,
slots
,
file_list
);
reader
.
Start
();
size_t
batch_num
=
std
::
ceil
(
static_cast
<
float
>
(
ctr_data
.
size
())
/
batch_size
)
*
thread_num
;
check_all_data
(
ctr_data
,
slots
,
label_dims
,
label_value
,
data_slot_6002
,
data_slot_6003
,
batch_num
,
batch_size
,
queue
,
&
reader
);
reader
.
Shutdown
();
reader
.
Start
();
check_all_data
(
ctr_data
,
slots
,
label_dims
,
label_value
,
data_slot_6002
,
data_slot_6003
,
batch_num
,
batch_size
,
queue
,
&
reader
);
reader
.
Shutdown
();
}
python/paddle/fluid/contrib/reader/ctr_reader.py
0 → 100644
浏览文件 @
35b79ab8
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
paddle.fluid
import
core
from
paddle.fluid.executor
import
global_scope
from
paddle.fluid.framework
import
default_main_program
,
\
default_startup_program
,
Variable
from
paddle.fluid.unique_name
import
generate
as
unique_name
def
monkey_patch_reader_methods
(
reader
):
def
__get_reader__
():
scope
=
global_scope
()
var
=
scope
.
find_var
(
reader
.
name
)
return
var
.
get_reader
()
def
reset
():
return
__get_reader__
().
reset
()
reader
.
reset
=
reset
reader
.
stop_gradient
=
True
reader
.
persistable
=
True
return
reader
def
_copy_reader_var_
(
block
,
var
):
new_var
=
block
.
create_var
(
name
=
var
.
name
,
type
=
core
.
VarDesc
.
VarType
.
READER
)
new_var
.
desc
.
set_shapes
(
var
.
desc
.
shapes
())
new_var
.
desc
.
set_dtypes
(
var
.
desc
.
dtypes
())
new_var
.
persistable
=
True
return
new_var
def
ctr_reader
(
feed_data
,
capacity
,
thread_num
,
batch_size
,
file_list
,
slots
,
name
=
None
):
"""
Create a CTR reader for data feeding in Python
This layer returns a Reader Variable.
The Reader provides :code:`decorate_paddle_reader()` and
:code:`decorate_tensor_provider()` to set a Python generator as the data
source in Python side. When :code:`Executor::Run()` is invoked in C++
side, the data from the generator would be read automatically. Unlike
:code:`DataFeeder.feed()`, the data reading process and
:code:`Executor::Run()` process can run in parallel using
:code:`py_reader`. The :code:`start()` method of the Reader should be
called when each pass begins, while the :code:`reset()` method should be
called when the pass ends and :code:`fluid.core.EOFException` raises.
Note that :code:`Program.clone()` method cannot clone :code:`py_reader`.
Args:
capacity(int): The buffer capacity maintained by :code:`py_reader`.
thread_num(list|tuple): List of tuples which declaring data shapes.
batch_size(list|tuple): List of strs which declaring data type.
file_list(list|tuple): List of ints which declaring data lod_level.
slots(bool): Whether use double buffer or not.
name(basestring): The prefix Python queue name and Reader name. None will
be generated automatically.
Returns:
Variable: A Reader from which we can get feeding data.
Examples:
1. The basic usage of :code:`py_reader` is as follows:
"""
if
name
is
None
:
queue_name
=
unique_name
(
'lod_tensor_blocking_queue'
)
reader_name
=
unique_name
(
'create_ctr_reader'
)
else
:
queue_name
=
"_"
.
join
([
name
,
"queue"
])
reader_name
=
"_"
.
join
([
name
,
"reader"
])
var
=
global_scope
().
var
(
queue_name
)
feed_queue
=
core
.
init_lod_tensor_blocking_queue
(
var
,
capacity
,
shapes
)
startup_blk
=
default_startup_program
().
current_block
()
reader_var
=
startup_blk
.
create_var
(
name
=
reader_name
)
startup_blk
.
append_op
(
type
=
'create_ctr_reader'
,
inputs
=
{
'blocking_queue'
:
[
queue_name
]},
outputs
=
{
'Out'
:
[
reader_var
]},
attrs
=
{
'thread_num'
:
thread_num
,
'batch_size'
:
batch_size
,
'file_list'
:
file_list
,
'slots'
:
slots
,
})
reader_var
.
persistable
=
True
main_prog_reader_var
=
_copy_reader_var_
(
default_main_program
().
current_block
(),
reader_var
)
reader
=
monkey_patch_reader_methods
(
main_prog_reader_var
)
# monkey patch py_reader special methods
reader
.
queue
=
feed_queue
reader
.
exited
=
False
main_blk
=
default_main_program
().
current_block
()
main_blk
.
append_op
(
type
=
'read'
,
inputs
=
{
'Reader'
:
[
reader
]},
outputs
=
{
'Out'
:
feed_data
})
return
reader
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录