Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5ed713d5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5ed713d5
编写于
7月 18, 2019
作者:
G
guru4elephant
提交者:
GitHub
7月 18, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove ctr reader, all functions are satisfied in dataset (#18672)
* remove ctr reader, all functions are satisfied in dataset
上级
898237c1
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
0 addition
and
989 deletion
+0
-989
paddle/fluid/API.spec
paddle/fluid/API.spec
+0
-1
paddle/fluid/operators/reader/CMakeLists.txt
paddle/fluid/operators/reader/CMakeLists.txt
+0
-6
paddle/fluid/operators/reader/ctr_reader.cc
paddle/fluid/operators/reader/ctr_reader.cc
+0
-398
paddle/fluid/operators/reader/ctr_reader.h
paddle/fluid/operators/reader/ctr_reader.h
+0
-189
paddle/fluid/operators/reader/ctr_reader_test.cc
paddle/fluid/operators/reader/ctr_reader_test.cc
+0
-229
python/paddle/fluid/contrib/reader/__init__.py
python/paddle/fluid/contrib/reader/__init__.py
+0
-2
python/paddle/fluid/contrib/reader/ctr_reader.py
python/paddle/fluid/contrib/reader/ctr_reader.py
+0
-164
未找到文件。
paddle/fluid/API.spec
浏览文件 @
5ed713d5
...
...
@@ -451,7 +451,6 @@ paddle.fluid.contrib.QuantizeTranspiler.convert_to_int8 (ArgSpec(args=['self', '
paddle.fluid.contrib.QuantizeTranspiler.freeze_program (ArgSpec(args=['self', 'program', 'place', 'fuse_bn', 'scope'], varargs=None, keywords=None, defaults=(False, None)), ('document', '909675a1ab055c69b436a7893fcae4fd'))
paddle.fluid.contrib.QuantizeTranspiler.training_transpile (ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)), ('document', '6dd9909f10b283ba2892a99058a72884'))
paddle.fluid.contrib.distributed_batch_reader (ArgSpec(args=['batch_reader'], varargs=None, keywords=None, defaults=None), ('document', 'b60796eb0a481484dd34e345f0eaa4d5'))
paddle.fluid.contrib.reader.ctr_reader.ctr_reader (ArgSpec(args=['feed_dict', 'file_type', 'file_format', 'dense_slot_index', 'sparse_slot_index', 'capacity', 'thread_num', 'batch_size', 'file_list', 'slots', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b2ebf3de2a6ef1af2c3b88d2db7591ab'))
paddle.fluid.contrib.Compressor ('paddle.fluid.contrib.slim.core.compressor.Compressor', ('document', 'a5417774a94aa9ae5560a42b96527e7d'))
paddle.fluid.contrib.Compressor.__init__ (ArgSpec(args=['self', 'place', 'scope', 'train_program', 'train_reader', 'train_feed_list', 'train_fetch_list', 'eval_program', 'eval_reader', 'eval_feed_list', 'eval_fetch_list', 'teacher_programs', 'checkpoint_path', 'train_optimizer', 'distiller_optimizer', 'search_space'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, [], None, None, None, None)), ('document', 'c195b3bba26169cff9439e8c467557c0'))
paddle.fluid.contrib.Compressor.config (ArgSpec(args=['self', 'config_file'], varargs=None, keywords=None, defaults=None), ('document', '780d9c007276ccbb95b292400d7807b0'))
...
...
paddle/fluid/operators/reader/CMakeLists.txt
浏览文件 @
5ed713d5
...
...
@@ -30,12 +30,6 @@ reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library
(
create_custom_reader_op SRCS create_custom_reader_op.cc
)
reader_library
(
create_py_reader_op SRCS create_py_reader_op.cc DEPS py_reader
)
if
(
NOT WIN32 AND NOT ON_INFER
)
cc_library
(
ctr_reader SRCS ctr_reader.cc DEPS gzstream reader zlib
)
cc_test
(
ctr_reader_test SRCS ctr_reader_test.cc DEPS ctr_reader
)
reader_library
(
create_ctr_reader_op SRCS create_ctr_reader_op.cc DEPS ctr_reader
)
endif
()
cc_test
(
reader_blocking_queue_test SRCS reader_blocking_queue_test.cc
)
# Export local libraries to parent
# set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
...
...
paddle/fluid/operators/reader/ctr_reader.cc
已删除
100644 → 0
浏览文件 @
898237c1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/ctr_reader.h"
#include <gzstream.h>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_map>
#include <algorithm>
#include <random>
namespace
paddle
{
namespace
operators
{
namespace
reader
{
static
inline
void
string_split
(
const
std
::
string
&
s
,
const
char
delimiter
,
std
::
vector
<
std
::
string
>*
output
)
{
if
(
s
.
empty
())
return
;
size_t
start
=
0
;
size_t
end
=
s
.
find
(
delimiter
);
while
(
end
!=
std
::
string
::
npos
)
{
if
(
end
>
start
)
output
->
emplace_back
(
s
.
substr
(
start
,
end
-
start
));
start
=
end
+
1
;
end
=
s
.
find
(
delimiter
,
start
);
}
auto
term
=
s
.
substr
(
start
);
if
(
!
term
.
empty
())
output
->
emplace_back
(
term
);
}
static
inline
void
parse_line
(
const
std
::
string
&
line
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
slot_to_index
,
int64_t
*
label
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int64_t
>>*
slot_to_data
)
{
std
::
vector
<
std
::
string
>
ret
;
string_split
(
line
,
' '
,
&
ret
);
*
label
=
std
::
stoi
(
ret
[
0
])
>
0
;
for
(
size_t
i
=
1
;
i
<
ret
.
size
();
++
i
)
{
const
std
::
string
&
item
=
ret
[
i
];
std
::
vector
<
std
::
string
>
feasign_and_slot
;
string_split
(
item
,
':'
,
&
feasign_and_slot
);
if
(
feasign_and_slot
.
size
()
==
2
&&
slot_to_index
.
find
(
feasign_and_slot
[
1
])
!=
slot_to_index
.
end
())
{
int64_t
feasign
=
std
::
strtoll
(
feasign_and_slot
[
0
].
c_str
(),
NULL
,
10
);
(
*
slot_to_data
)[
feasign_and_slot
[
1
]].
push_back
(
feasign
);
}
}
// NOTE:: if the slot has no value, then fill [0] as it's data.
for
(
auto
&
item
:
slot_to_index
)
{
if
(
slot_to_data
->
find
(
item
.
first
)
==
slot_to_data
->
end
())
{
(
*
slot_to_data
)[
item
.
first
].
push_back
(
0
);
}
}
}
// label slot1:fea_sign slot2:fea_sign slot1:fea_sign
static
inline
void
parse_svm_line
(
const
std
::
string
&
line
)
{}
class
Reader
{
public:
virtual
~
Reader
()
{}
virtual
bool
HasNext
()
=
0
;
virtual
void
NextLine
(
std
::
string
*
line
)
=
0
;
};
class
GzipReader
:
public
Reader
{
public:
explicit
GzipReader
(
const
std
::
string
&
file_name
)
:
gzstream_
(
file_name
.
c_str
())
{}
~
GzipReader
()
{}
bool
HasNext
()
override
{
return
gzstream_
.
peek
()
!=
EOF
;
}
void
NextLine
(
std
::
string
*
line
)
override
{
std
::
getline
(
gzstream_
,
*
line
);
}
private:
igzstream
gzstream_
;
};
class
PlainFileReader
:
public
Reader
{
public:
explicit
PlainFileReader
(
const
std
::
string
&
file_name
)
:
stream_
(
file_name
.
c_str
())
{}
~
PlainFileReader
()
{}
bool
HasNext
()
override
{
return
stream_
.
peek
()
!=
EOF
;
}
void
NextLine
(
std
::
string
*
line
)
override
{
std
::
getline
(
stream_
,
*
line
);
}
private:
std
::
ifstream
stream_
;
};
template
<
typename
SingleFileReader
>
class
MultiFileReader
:
public
Reader
{
public:
explicit
MultiFileReader
(
const
std
::
vector
<
std
::
string
>&
file_list
)
{
for
(
auto
&
file
:
file_list
)
{
readers_
.
emplace_back
(
std
::
make_shared
<
SingleFileReader
>
(
file
));
}
}
bool
HasNext
()
override
{
if
(
current_reader_index_
>=
readers_
.
size
())
{
return
false
;
}
if
(
!
readers_
[
current_reader_index_
]
->
HasNext
())
{
current_reader_index_
++
;
return
HasNext
();
}
return
true
;
}
void
NextLine
(
std
::
string
*
line
)
override
{
readers_
[
current_reader_index_
]
->
NextLine
(
line
);
}
private:
std
::
vector
<
std
::
shared_ptr
<
SingleFileReader
>>
readers_
;
size_t
current_reader_index_
=
0
;
};
void
MonitorThread
(
std
::
vector
<
ReaderThreadStatus
>*
thread_status
,
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
)
{
VLOG
(
3
)
<<
"monitor thread in"
;
bool
reader_thread_is_running
=
true
;
while
(
reader_thread_is_running
)
{
VLOG
(
3
)
<<
"reader_thread_is_running"
;
reader_thread_is_running
=
false
;
for
(
size_t
i
=
0
;
i
<
(
*
thread_status
).
size
();
++
i
)
{
if
((
*
thread_status
)[
i
]
==
Running
)
{
VLOG
(
3
)
<<
"reader is running!"
;
reader_thread_is_running
=
true
;
}
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
VLOG
(
3
)
<<
"all reader thread is stopped, close the queue"
;
queue
->
Close
();
VLOG
(
3
)
<<
"monitor thread exited"
;
}
void
ReadSvmData
(
const
DataDesc
&
data_desc
,
std
::
shared_ptr
<
Reader
>
reader
,
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
)
{
std
::
unordered_map
<
std
::
string
,
size_t
>
slot_to_index
;
for
(
size_t
i
=
0
;
i
<
data_desc
.
sparse_slot_ids_
.
size
();
++
i
)
{
slot_to_index
[
data_desc
.
sparse_slot_ids_
[
i
]]
=
i
;
}
std
::
string
line
;
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int64_t
>>>
batch_data
;
std
::
vector
<
int64_t
>
batch_label
;
while
(
reader
->
HasNext
())
{
batch_data
.
clear
();
batch_data
.
reserve
(
data_desc
.
batch_size_
);
batch_label
.
clear
();
batch_label
.
reserve
(
data_desc
.
batch_size_
);
// read batch_size data
for
(
int
i
=
0
;
i
<
data_desc
.
batch_size_
;
++
i
)
{
if
(
reader
->
HasNext
())
{
reader
->
NextLine
(
&
line
);
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int64_t
>>
slot_to_data
;
int64_t
label
;
parse_line
(
line
,
slot_to_index
,
&
label
,
&
slot_to_data
);
batch_data
.
push_back
(
slot_to_data
);
batch_label
.
push_back
(
label
);
}
else
{
break
;
}
}
std
::
vector
<
framework
::
LoDTensor
>
lod_datas
;
// first insert tensor for each sparse_slots
for
(
auto
&
slot
:
data_desc
.
sparse_slot_ids_
)
{
std
::
vector
<
size_t
>
lod_data
{
0
};
std
::
vector
<
int64_t
>
batch_feasign
;
for
(
size_t
i
=
0
;
i
<
batch_data
.
size
();
++
i
)
{
auto
&
feasign
=
batch_data
[
i
][
slot
];
lod_data
.
push_back
(
lod_data
.
back
()
+
feasign
.
size
());
batch_feasign
.
insert
(
batch_feasign
.
end
(),
feasign
.
begin
(),
feasign
.
end
());
}
framework
::
LoDTensor
lod_tensor
;
framework
::
LoD
lod
{
lod_data
};
lod_tensor
.
set_lod
(
lod
);
int64_t
*
tensor_data
=
lod_tensor
.
mutable_data
<
int64_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
batch_feasign
.
size
()),
1
}),
platform
::
CPUPlace
());
memcpy
(
tensor_data
,
batch_feasign
.
data
(),
batch_feasign
.
size
()
*
sizeof
(
int64_t
));
lod_datas
.
push_back
(
lod_tensor
);
}
// insert label tensor
framework
::
LoDTensor
label_tensor
;
auto
*
label_tensor_data
=
label_tensor
.
mutable_data
<
int64_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
batch_label
.
size
()),
1
}),
platform
::
CPUPlace
());
memcpy
(
label_tensor_data
,
batch_label
.
data
(),
batch_label
.
size
()
*
sizeof
(
int64_t
));
lod_datas
.
push_back
(
label_tensor
);
queue
->
Push
(
lod_datas
);
VLOG
(
4
)
<<
"push one data, queue_size="
<<
queue
->
Size
();
}
}
// label dense_fea,dense_fea sparse_fea,sparse_fea
static
inline
void
parse_csv_line
(
const
std
::
string
&
line
,
const
DataDesc
&
data_desc
,
int64_t
*
label
,
std
::
vector
<
std
::
vector
<
float
>>*
dense_datas
,
std
::
vector
<
std
::
vector
<
int64_t
>>*
sparse_datas
)
{
std
::
vector
<
std
::
string
>
ret
;
string_split
(
line
,
' '
,
&
ret
);
*
label
=
std
::
stol
(
ret
[
0
]);
dense_datas
->
resize
(
data_desc
.
dense_slot_index_
.
size
());
for
(
size_t
i
=
0
;
i
<
data_desc
.
dense_slot_index_
.
size
();
++
i
)
{
int
slot_idx
=
data_desc
.
dense_slot_index_
[
i
];
auto
&
slot_data
=
ret
[
slot_idx
];
std
::
vector
<
std
::
string
>
data_in_slot_str
;
string_split
(
slot_data
,
','
,
&
data_in_slot_str
);
std
::
vector
<
float
>
data_in_slot
;
for
(
auto
&
data_str
:
data_in_slot_str
)
{
(
*
dense_datas
)[
i
].
push_back
(
std
::
stof
(
data_str
));
}
}
sparse_datas
->
resize
(
data_desc
.
sparse_slot_index_
.
size
());
for
(
size_t
i
=
0
;
i
<
data_desc
.
sparse_slot_index_
.
size
();
++
i
)
{
int
slot_idx
=
data_desc
.
sparse_slot_index_
[
i
];
auto
&
slot_data
=
ret
[
slot_idx
];
std
::
vector
<
std
::
string
>
data_in_slot_str
;
string_split
(
slot_data
,
','
,
&
data_in_slot_str
);
std
::
vector
<
int64_t
>
data_in_slot
;
for
(
auto
&
data_str
:
data_in_slot_str
)
{
auto
id
=
std
::
stol
(
data_str
);
(
*
sparse_datas
)[
i
].
push_back
(
id
);
}
}
}
void
ReadCsvData
(
const
DataDesc
&
data_desc
,
std
::
shared_ptr
<
Reader
>
reader
,
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
)
{
std
::
string
line
;
while
(
reader
->
HasNext
())
{
std
::
vector
<
int64_t
>
batch_label
;
batch_label
.
reserve
(
data_desc
.
batch_size_
);
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>>>
batch_dense_data
;
batch_dense_data
.
reserve
(
data_desc
.
batch_size_
);
std
::
vector
<
std
::
vector
<
std
::
vector
<
int64_t
>>>
batch_sparse_data
;
batch_sparse_data
.
reserve
(
data_desc
.
batch_size_
);
// read batch_size data
for
(
int
i
=
0
;
i
<
data_desc
.
batch_size_
;
++
i
)
{
if
(
reader
->
HasNext
())
{
reader
->
NextLine
(
&
line
);
int64_t
label
;
std
::
vector
<
std
::
vector
<
float
>>
dense_datas
;
std
::
vector
<
std
::
vector
<
int64_t
>>
sparse_datas
;
parse_csv_line
(
line
,
data_desc
,
&
label
,
&
dense_datas
,
&
sparse_datas
);
batch_label
.
push_back
(
label
);
if
(
!
batch_dense_data
.
empty
())
{
PADDLE_ENFORCE_EQ
(
batch_dense_data
[
0
].
size
(),
dense_datas
.
size
(),
"dense data should have the same shape"
);
}
batch_dense_data
.
push_back
(
dense_datas
);
batch_sparse_data
.
push_back
(
sparse_datas
);
}
else
{
break
;
}
}
// the order of output data is label, dense_datas, sparse_datas
std
::
vector
<
framework
::
LoDTensor
>
lod_datas
;
// insert label tensor
framework
::
LoDTensor
label_tensor
;
auto
*
label_tensor_data
=
label_tensor
.
mutable_data
<
int64_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
batch_label
.
size
()),
1
}),
platform
::
CPUPlace
());
memcpy
(
label_tensor_data
,
batch_label
.
data
(),
batch_label
.
size
()
*
sizeof
(
int64_t
));
lod_datas
.
push_back
(
label_tensor
);
// insert tensor for each dense_slots
for
(
size_t
i
=
0
;
i
<
data_desc
.
dense_slot_index_
.
size
();
++
i
)
{
framework
::
LoDTensor
lod_tensor
;
size_t
width
=
batch_dense_data
[
0
][
i
].
size
();
auto
*
tensor_data
=
lod_tensor
.
mutable_data
<
float
>
(
framework
::
make_ddim
(
{
static_cast
<
int64_t
>
(
batch_dense_data
.
size
()),
// batch_size
static_cast
<
int64_t
>
(
width
)}),
platform
::
CPUPlace
());
for
(
size_t
j
=
0
;
j
<
batch_dense_data
.
size
();
++
j
)
{
auto
&
dense_data_row
=
batch_dense_data
[
j
][
i
];
memcpy
(
tensor_data
+
j
*
width
,
dense_data_row
.
data
(),
width
*
sizeof
(
float
));
}
lod_datas
.
push_back
(
lod_tensor
);
}
// insert tensor for each sparse_slots
for
(
size_t
i
=
0
;
i
<
data_desc
.
sparse_slot_index_
.
size
();
++
i
)
{
std
::
vector
<
size_t
>
lod_data
{
0
};
std
::
vector
<
int64_t
>
batch_feasign
;
for
(
size_t
row_idx
=
0
;
row_idx
<
batch_sparse_data
.
size
();
++
row_idx
)
{
auto
&
sparse_ids
=
batch_sparse_data
[
row_idx
][
i
];
lod_data
.
push_back
(
lod_data
.
back
()
+
sparse_ids
.
size
());
batch_feasign
.
insert
(
batch_feasign
.
end
(),
sparse_ids
.
begin
(),
sparse_ids
.
end
());
}
framework
::
LoDTensor
lod_tensor
;
framework
::
LoD
lod
{
lod_data
};
lod_tensor
.
set_lod
(
lod
);
int64_t
*
tensor_data
=
lod_tensor
.
mutable_data
<
int64_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
batch_feasign
.
size
()),
1
}),
platform
::
CPUPlace
());
memcpy
(
tensor_data
,
batch_feasign
.
data
(),
batch_feasign
.
size
()
*
sizeof
(
int64_t
));
lod_datas
.
push_back
(
lod_tensor
);
}
queue
->
Push
(
lod_datas
);
VLOG
(
4
)
<<
"push one data, queue_size="
<<
queue
->
Size
();
}
}
void
ReadThread
(
const
std
::
vector
<
std
::
string
>&
file_list
,
const
DataDesc
&
data_desc
,
int
thread_id
,
std
::
vector
<
ReaderThreadStatus
>*
thread_status
,
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
)
{
VLOG
(
3
)
<<
"["
<<
thread_id
<<
"]"
<<
" reader thread start! thread_id = "
<<
thread_id
;
for
(
auto
&
file
:
file_list
)
{
VLOG
(
3
)
<<
"["
<<
thread_id
<<
"]"
<<
" file "
<<
file
;
}
(
*
thread_status
)[
thread_id
]
=
Running
;
VLOG
(
3
)
<<
"set status to running"
;
std
::
shared_ptr
<
Reader
>
reader
;
if
(
data_desc
.
file_type_
==
"gzip"
)
{
reader
.
reset
(
new
MultiFileReader
<
GzipReader
>
(
file_list
));
}
else
if
(
data_desc
.
file_type_
==
"plain"
)
{
reader
.
reset
(
new
MultiFileReader
<
PlainFileReader
>
(
file_list
));
}
else
{
PADDLE_THROW
(
"do not support file format %s"
,
data_desc
.
file_type_
);
}
VLOG
(
3
)
<<
"reader inited"
;
if
(
data_desc
.
file_format_
==
"svm"
)
{
ReadSvmData
(
data_desc
,
reader
,
queue
);
}
else
if
(
data_desc
.
file_format_
==
"csv"
)
{
ReadCsvData
(
data_desc
,
reader
,
queue
);
}
(
*
thread_status
)[
thread_id
]
=
Stopped
;
VLOG
(
3
)
<<
"set status to stopped, thread "
<<
thread_id
<<
" exited"
;
}
}
// namespace reader
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/reader/ctr_reader.h
已删除
100644 → 0
浏览文件 @
898237c1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <sys/time.h>
#include <algorithm>
#include <chrono> // NOLINT
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
enum
ReaderThreadStatus
{
Running
,
Stopped
};
struct
DataDesc
{
DataDesc
(
int
batch_size
,
const
std
::
vector
<
std
::
string
>&
file_names
,
const
std
::
string
&
file_type
,
const
std
::
string
&
file_format
,
const
std
::
vector
<
int
>&
dense_slot_index
,
const
std
::
vector
<
int
>&
sparse_slot_index
,
const
std
::
vector
<
std
::
string
>&
sparse_slot_ids
)
:
batch_size_
(
batch_size
),
file_names_
(
file_names
),
file_type_
(
file_type
),
file_format_
(
file_format
),
dense_slot_index_
(
dense_slot_index
),
sparse_slot_index_
(
sparse_slot_index
),
sparse_slot_ids_
(
sparse_slot_ids
)
{}
const
int
batch_size_
;
const
std
::
vector
<
std
::
string
>
file_names_
;
const
std
::
string
file_type_
;
// gzip or plain
const
std
::
string
file_format_
;
// csv or svm
// used for csv data format
const
std
::
vector
<
int
>
dense_slot_index_
;
const
std
::
vector
<
int
>
sparse_slot_index_
;
// used for svm data format
const
std
::
vector
<
std
::
string
>
sparse_slot_ids_
;
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
DataDesc
&
data_desc
)
{
os
<<
"data_desc:
\n
"
;
os
<<
"
\t
batch_size -> "
<<
data_desc
.
batch_size_
<<
"
\n
"
;
os
<<
"
\t
file_type -> "
<<
data_desc
.
file_type_
<<
"
\n
"
;
os
<<
"
\t
file_format -> "
<<
data_desc
.
file_format_
<<
"
\n
"
;
os
<<
"
\t
file_names -> {"
;
for
(
auto
&
file_name
:
data_desc
.
file_names_
)
{
os
<<
file_name
<<
","
;
}
os
<<
"}
\n
"
;
os
<<
"
\t
dense_slot_index -> {"
;
for
(
auto
&
slot
:
data_desc
.
dense_slot_index_
)
{
os
<<
slot
<<
","
;
}
os
<<
"}
\n
"
;
os
<<
"
\t
sparse_slot_index_ -> {"
;
for
(
auto
&
slot
:
data_desc
.
sparse_slot_index_
)
{
os
<<
slot
<<
","
;
}
os
<<
"}
\n
"
;
os
<<
"
\t
sparse_slot_ids_ -> {"
;
for
(
auto
&
slot
:
data_desc
.
sparse_slot_ids_
)
{
os
<<
slot
<<
","
;
}
os
<<
"}
\n
"
;
return
os
;
}
void
ReadThread
(
const
std
::
vector
<
std
::
string
>&
file_list
,
const
DataDesc
&
data_desc
,
int
thread_id
,
std
::
vector
<
ReaderThreadStatus
>*
thread_status
,
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
);
// monitor all running thread, if they are all stopped,
// then push an empty data into LoDTensorBlockingQueue
void
MonitorThread
(
std
::
vector
<
ReaderThreadStatus
>*
thread_status
,
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
);
class
CTRReader
:
public
framework
::
FileReader
{
public:
CTRReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
,
int
thread_num
,
const
DataDesc
&
data_desc
)
:
data_desc_
(
data_desc
)
{
PADDLE_ENFORCE_GT
(
thread_num
,
0
,
"thread num should be larger then 0!"
);
PADDLE_ENFORCE
(
queue
!=
nullptr
,
"LoDTensorBlockingQueue must not be null"
);
PADDLE_ENFORCE_GT
(
data_desc_
.
file_names_
.
size
(),
0
,
"file list should not be empty"
);
thread_num_
=
std
::
min
<
size_t
>
(
data_desc_
.
file_names_
.
size
(),
thread_num
);
queue_
=
queue
;
SplitFiles
();
for
(
size_t
i
=
0
;
i
<
thread_num_
;
++
i
)
{
read_thread_status_
.
push_back
(
Stopped
);
}
}
~
CTRReader
()
{
Shutdown
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
bool
success
;
*
out
=
queue_
->
Pop
(
&
success
);
if
(
!
success
)
out
->
clear
();
}
void
Shutdown
()
override
{
VLOG
(
3
)
<<
"Shutdown reader"
;
if
(
status_
==
ReaderStatus
::
kStopped
)
{
return
;
}
// shutdown should stop all the reader thread
for
(
auto
&
read_thread
:
read_threads_
)
{
read_thread
->
join
();
}
if
(
monitor_thread_
)
{
monitor_thread_
->
join
();
}
read_threads_
.
clear
();
monitor_thread_
.
reset
(
nullptr
);
queue_
->
Close
();
status_
=
ReaderStatus
::
kStopped
;
}
void
Start
()
override
{
VLOG
(
3
)
<<
"Start reader"
;
PADDLE_ENFORCE_EQ
(
read_threads_
.
size
(),
0
,
"read thread should be empty!"
);
queue_
->
ReOpen
();
VLOG
(
3
)
<<
"reopen success"
;
VLOG
(
3
)
<<
"thread_num "
<<
thread_num_
;
for
(
size_t
thread_id
=
0
;
thread_id
<
thread_num_
;
thread_id
++
)
{
read_threads_
.
emplace_back
(
new
std
::
thread
(
std
::
bind
(
&
ReadThread
,
file_groups_
[
thread_id
],
data_desc_
,
static_cast
<
int
>
(
thread_id
),
&
read_thread_status_
,
queue_
)));
}
monitor_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
MonitorThread
,
&
read_thread_status_
,
queue_
)));
status_
=
ReaderStatus
::
kRunning
;
}
private:
void
SplitFiles
()
{
file_groups_
.
resize
(
thread_num_
);
for
(
size_t
i
=
0
;
i
<
data_desc_
.
file_names_
.
size
();
++
i
)
{
auto
&
file_name
=
data_desc_
.
file_names_
[
i
];
std
::
ifstream
f
(
file_name
.
c_str
());
PADDLE_ENFORCE
(
f
.
good
(),
"file %s not exist!"
,
file_name
);
file_groups_
[
i
%
thread_num_
].
push_back
(
file_name
);
}
}
private:
size_t
thread_num_
;
const
DataDesc
data_desc_
;
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
read_threads_
;
std
::
unique_ptr
<
std
::
thread
>
monitor_thread_
;
std
::
vector
<
ReaderThreadStatus
>
read_thread_status_
;
std
::
vector
<
std
::
vector
<
std
::
string
>>
file_groups_
;
};
}
// namespace reader
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/reader/ctr_reader_test.cc
已删除
100644 → 0
浏览文件 @
898237c1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/ctr_reader.h"
#include <gzstream.h>
#include <time.h>
#include <math.h>
#include <stdio.h>
#include <cstring>
#include <fstream>
#include <tuple>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
using
paddle
::
operators
::
reader
::
LoDTensorBlockingQueue
;
using
paddle
::
operators
::
reader
::
LoDTensorBlockingQueueHolder
;
using
paddle
::
operators
::
reader
::
CTRReader
;
using
paddle
::
framework
::
LoDTensor
;
using
paddle
::
framework
::
LoD
;
using
paddle
::
framework
::
DDim
;
using
paddle
::
platform
::
CPUPlace
;
using
paddle
::
framework
::
make_ddim
;
using
paddle
::
operators
::
reader
::
DataDesc
;
static
void
generatedata
(
const
std
::
vector
<
std
::
string
>&
data
,
const
std
::
string
&
file_name
)
{
std
::
ifstream
in
(
file_name
.
c_str
());
if
(
in
.
good
())
{
VLOG
(
3
)
<<
"file "
<<
file_name
<<
" exist, delete it first!"
;
remove
(
file_name
.
c_str
());
}
else
{
in
.
close
();
}
ogzstream
out
(
file_name
.
c_str
());
PADDLE_ENFORCE
(
out
.
good
(),
"open file %s failed!"
,
file_name
);
for
(
auto
&
c
:
data
)
{
out
<<
c
;
}
out
.
close
();
PADDLE_ENFORCE
(
out
.
good
(),
"save file %s failed!"
,
file_name
);
}
static
inline
void
check_all_data
(
const
std
::
vector
<
std
::
string
>&
ctr_data
,
const
std
::
vector
<
std
::
string
>&
slots
,
const
std
::
vector
<
DDim
>&
label_dims
,
const
std
::
vector
<
int64_t
>&
label_value
,
const
std
::
vector
<
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>>&
data_slot_6002
,
const
std
::
vector
<
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>>&
data_slot_6003
,
size_t
batch_num
,
size_t
batch_size
,
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
,
CTRReader
*
reader
)
{
std
::
vector
<
LoDTensor
>
out
;
for
(
size_t
i
=
0
;
i
<
batch_num
;
++
i
)
{
reader
->
ReadNext
(
&
out
);
ASSERT_EQ
(
out
.
size
(),
slots
.
size
()
+
1
);
auto
&
label_tensor
=
out
.
back
();
ASSERT_EQ
(
label_tensor
.
dims
(),
label_dims
[
i
]);
for
(
size_t
j
=
0
;
j
<
batch_size
&&
i
*
batch_num
+
j
<
ctr_data
.
size
();
++
j
)
{
auto
&
label
=
label_tensor
.
data
<
int64_t
>
()[
j
];
ASSERT_TRUE
(
label
==
0
||
label
==
1
);
ASSERT_EQ
(
label
,
label_value
[
i
*
batch_size
+
j
]);
}
auto
&
tensor_6002
=
out
[
0
];
ASSERT_EQ
(
std
::
get
<
0
>
(
data_slot_6002
[
i
]),
tensor_6002
.
lod
());
ASSERT_EQ
(
std
::
memcmp
(
std
::
get
<
1
>
(
data_slot_6002
[
i
]).
data
(),
tensor_6002
.
data
<
int64_t
>
(),
tensor_6002
.
dims
()[
1
]
*
sizeof
(
int64_t
)),
0
);
}
reader
->
ReadNext
(
&
out
);
ASSERT_EQ
(
out
.
size
(),
0
);
ASSERT_EQ
(
queue
->
Size
(),
0
);
}
TEST
(
CTR_READER
,
read_data
)
{
const
std
::
vector
<
std
::
string
>
ctr_data
=
{
"0 0:6002 1:6003 2:6004 3:6005 4:6006
\n
"
,
"0 5:6003 6:6003 7:6003 8:6004 9:6004
\n
"
,
"1 10:6002 11:6002 12:6002 13:6002 14:6002
\n
"
,
"0 15:6003 16:6003 17:6003 18:6003 19:6004
\n
"
,
"1 20:6001 21:6001 22:6001 23:6001 24:6001
\n
"
,
"1 25:6004 26:6004 27:6004 28:6005 29:6005
\n
"
,
"0 30:6002 31:6003 32:6004 33:6004 34:6005
\n
"
,
"1 35:6003 36:6003 37:6005 38:6005 39:6005
\n
"
,
"1 40:6002 41:6003 42:6004 43:6004 44:6005
\n
"
,
"1 46:6006 45:6006 47:6003 48:6003 49:6003
\n
"
,
};
std
::
string
gz_file_name
=
"test_ctr_reader_data.gz"
;
generatedata
(
ctr_data
,
gz_file_name
);
std
::
vector
<
int64_t
>
label_value
=
{
0
,
0
,
1
,
0
,
1
,
1
,
0
,
1
,
1
,
1
};
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
a1
({{
0
,
1
,
2
,
7
}},
{
0
,
0
,
10
,
11
,
12
,
13
,
14
});
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
a2
({{
0
,
1
,
2
,
3
}},
{
0
,
0
,
0
});
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
a3
({{
0
,
1
,
2
,
3
}},
{
30
,
0
,
40
});
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
a4
({{
0
,
1
}},
{
0
});
std
::
vector
<
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>>
data_slot_6002
{
a1
,
a2
,
a3
,
a4
};
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
b1
({{
0
,
1
,
4
,
5
}},
{
1
,
5
,
6
,
7
,
0
});
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
b2
({{
0
,
4
,
5
,
6
}},
{
15
,
16
,
17
,
18
,
0
,
0
});
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
b3
({{
0
,
1
,
3
,
4
}},
{
31
,
35
,
36
,
41
});
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>
b4
({{
0
,
3
}},
{
47
,
48
,
49
});
std
::
vector
<
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>>
data_slot_6003
{
b1
,
b2
,
b3
,
b4
};
std
::
vector
<
DDim
>
label_dims
=
{{
3
,
1
},
{
3
,
1
},
{
3
,
1
},
{
1
,
1
}};
LoDTensorBlockingQueueHolder
queue_holder
;
int
capacity
=
64
;
queue_holder
.
InitOnce
(
capacity
,
false
);
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
=
queue_holder
.
GetQueue
();
int
batch_size
=
3
;
int
thread_num
=
1
;
std
::
vector
<
std
::
string
>
sparse_slots
=
{
"6002"
,
"6003"
};
std
::
vector
<
std
::
string
>
file_list
;
for
(
int
i
=
0
;
i
<
thread_num
;
++
i
)
{
file_list
.
push_back
(
gz_file_name
);
}
DataDesc
data_desc
(
batch_size
,
file_list
,
"gzip"
,
"svm"
,
{},
{},
sparse_slots
);
CTRReader
reader
(
queue
,
thread_num
,
data_desc
);
reader
.
Start
();
size_t
batch_num
=
std
::
ceil
(
static_cast
<
float
>
(
ctr_data
.
size
())
/
batch_size
)
*
thread_num
;
check_all_data
(
ctr_data
,
sparse_slots
,
label_dims
,
label_value
,
data_slot_6002
,
data_slot_6003
,
batch_num
,
batch_size
,
queue
,
&
reader
);
reader
.
Shutdown
();
reader
.
Start
();
check_all_data
(
ctr_data
,
sparse_slots
,
label_dims
,
label_value
,
data_slot_6002
,
data_slot_6003
,
batch_num
,
batch_size
,
queue
,
&
reader
);
reader
.
Shutdown
();
}
static
void
GenereteCsvData
(
const
std
::
string
&
file_name
,
const
std
::
vector
<
std
::
string
>&
data
)
{
std
::
ofstream
out
(
file_name
.
c_str
());
PADDLE_ENFORCE
(
out
.
good
(),
"open file %s failed!"
,
file_name
);
for
(
auto
&
c
:
data
)
{
out
<<
c
;
}
out
.
close
();
PADDLE_ENFORCE
(
out
.
good
(),
"save file %s failed!"
,
file_name
);
}
static
void
CheckReadCsvOut
(
const
std
::
vector
<
LoDTensor
>&
out
)
{
ASSERT_EQ
(
out
.
size
(),
3
);
ASSERT_EQ
(
out
[
0
].
dims
()[
1
],
1
);
ASSERT_EQ
(
out
[
1
].
dims
()[
1
],
2
);
ASSERT_EQ
(
out
[
2
].
dims
()[
1
],
1
);
for
(
size_t
i
=
0
;
i
<
out
[
0
].
numel
();
++
i
)
{
int64_t
label
=
out
[
0
].
data
<
int64_t
>
()[
i
];
auto
&
dense_dim
=
out
[
1
].
dims
();
for
(
size_t
j
=
0
;
j
<
dense_dim
[
1
];
++
j
)
{
ASSERT_EQ
(
out
[
1
].
data
<
float
>
()[
i
*
dense_dim
[
1
]
+
j
],
static_cast
<
float
>
(
label
+
0.1
));
}
auto
&
sparse_lod
=
out
[
2
].
lod
();
for
(
size_t
j
=
sparse_lod
[
0
][
i
];
j
<
sparse_lod
[
0
][
i
+
1
];
++
j
)
{
ASSERT_EQ
(
out
[
2
].
data
<
int64_t
>
()[
j
],
label
);
}
}
}
TEST
(
CTR_READER
,
read_csv_data
)
{
std
::
string
file_name
=
"test_ctr_reader_data.csv"
;
const
std
::
vector
<
std
::
string
>
csv_data
=
{
"0 0.1,0.1 0,0,0,0
\n
"
,
"1 1.1,1.1 1,1,1,1
\n
"
,
"2 2.1,2.1 2,2,2,2
\n
"
,
"3 3.1,3.1 3,3,3,3
\n
"
,
};
GenereteCsvData
(
file_name
,
csv_data
);
LoDTensorBlockingQueueHolder
queue_holder
;
int
capacity
=
64
;
queue_holder
.
InitOnce
(
capacity
,
false
);
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
=
queue_holder
.
GetQueue
();
int
batch_size
=
3
;
int
thread_num
=
1
;
std
::
vector
<
std
::
string
>
file_list
;
for
(
int
i
=
0
;
i
<
thread_num
;
++
i
)
{
file_list
.
push_back
(
file_name
);
}
DataDesc
data_desc
(
batch_size
,
file_list
,
"plain"
,
"csv"
,
{
1
},
{
2
},
{});
CTRReader
reader
(
queue
,
thread_num
,
data_desc
);
for
(
size_t
i
=
0
;
i
<
2
;
++
i
)
{
reader
.
Start
();
std
::
vector
<
LoDTensor
>
out
;
while
(
true
)
{
reader
.
ReadNext
(
&
out
);
if
(
out
.
empty
())
{
break
;
}
CheckReadCsvOut
(
out
);
}
reader
.
Shutdown
();
}
}
python/paddle/fluid/contrib/reader/__init__.py
浏览文件 @
5ed713d5
...
...
@@ -14,9 +14,7 @@
from
__future__
import
print_function
from
.
import
ctr_reader
from
.distributed_reader
import
*
__all__
=
[]
__all__
+=
distributed_reader
.
__all__
__all__
+=
ctr_reader
.
__all__
python/paddle/fluid/contrib/reader/ctr_reader.py
已删除
100644 → 0
浏览文件 @
898237c1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
paddle.fluid
import
core
from
paddle.fluid.executor
import
global_scope
from
paddle.fluid.framework
import
default_main_program
,
\
default_startup_program
,
Variable
from
paddle.fluid.unique_name
import
generate
as
unique_name
__all__
=
[
'ctr_reader'
]
def
monkey_patch_reader_methods
(
reader
):
def
__get_reader__
():
scope
=
global_scope
()
var
=
scope
.
find_var
(
reader
.
name
)
return
var
.
get_reader
()
def
reset
():
return
__get_reader__
().
reset
()
def
start
():
return
__get_reader__
().
start
()
reader
.
reset
=
reset
reader
.
start
=
start
reader
.
stop_gradient
=
True
reader
.
persistable
=
True
return
reader
def
_copy_reader_var_
(
block
,
var
):
new_var
=
block
.
create_var
(
name
=
var
.
name
,
type
=
core
.
VarDesc
.
VarType
.
READER
)
new_var
.
desc
.
set_shapes
(
var
.
desc
.
shapes
())
new_var
.
desc
.
set_dtypes
(
var
.
desc
.
dtypes
())
new_var
.
persistable
=
True
return
new_var
def
ctr_reader
(
feed_dict
,
file_type
,
# gzip or plain
file_format
,
# csv or svm
dense_slot_index
,
sparse_slot_index
,
capacity
,
thread_num
,
batch_size
,
file_list
,
slots
,
name
=
None
):
"""
Create a CTR reader for data feeding in Python
This layer returns a Reader Variable.
The Reader provides :code:`decorate_paddle_reader()` and
:code:`decorate_tensor_provider()` to set a Python generator as the data
source in Python side. When :code:`Executor::Run()` is invoked in C++
side, the data from the generator would be read automatically. Unlike
:code:`DataFeeder.feed()`, the data reading process and
:code:`Executor::Run()` process can run in parallel using
:code:`py_reader`. The :code:`start()` method of the Reader should be
called when each pass begins, while the :code:`reset()` method should be
called when the pass ends and :code:`fluid.core.EOFException` raises.
Note that :code:`Program.clone()` method cannot clone :code:`py_reader`.
Args:
feed_dict(list(variable)): a list of data variable.
file_type('gzip'|'plain'): the type of the data file
file_format('csv'|'svm'): csv data or svm data format.
cvs data format is :
label dense_fea,dense_fea sparse_fea,sparse_fea
the svm data format is :
label slot1:fea_sign slot2:fea_sign slot1:fea_sign
dense_slot_index(list(int)): the index of dense slots
sparse_slot_index(list(int)): the index of sparse slots
capacity(int): The buffer capacity maintained by :code:`py_reader`.
thread_num(int): the thread num to read files by cpp reader.
batch_size(int): batch size of data.
file_list(list(str)): List of file names that need to read.
slots(list(int64)): list of slot id.
name(string): The prefix Python queue name and Reader name. None will
be generated automatically.
Returns:
Variable: A Reader from which we can get feeding data.
Examples:
1. The basic usage of :code:`ctr_reader` is as follows:
.. code-block:: python
py_reader = fluid.contrib.ctr_reader.ctr_reader(
feed_dict=datas, file_type='plain', file_format='csv',
file_list=file_list, dense_slot_indexs=[1, 2, 3, 4], sparse_slot_indexs=[],
capacity=64, thread_num=20, batch_size=1000, slots=[], name='ctr_reader')
"""
if
name
is
None
:
queue_name
=
unique_name
(
'lod_tensor_blocking_queue'
)
reader_name
=
unique_name
(
'create_ctr_reader'
)
else
:
queue_name
=
"_"
.
join
([
name
,
"queue"
])
reader_name
=
"_"
.
join
([
name
,
"reader"
])
var
=
global_scope
().
var
(
queue_name
)
feed_queue
=
core
.
init_lod_tensor_blocking_queue
(
var
,
capacity
)
startup_blk
=
default_startup_program
().
current_block
()
reader_var
=
startup_blk
.
create_var
(
name
=
reader_name
)
startup_blk
.
append_op
(
type
=
'create_ctr_reader'
,
inputs
=
{
'blocking_queue'
:
[
queue_name
]},
outputs
=
{
'Out'
:
[
reader_var
]},
attrs
=
{
'use_data_config'
:
False
,
'thread_num'
:
thread_num
,
'batch_size'
:
batch_size
,
'file_list'
:
file_list
,
'file_type'
:
file_type
,
'file_format'
:
file_format
,
'dense_slot_index'
:
dense_slot_index
,
'sparse_slot_index'
:
sparse_slot_index
,
'sparse_slots'
:
slots
,
'ranks'
:
[],
'lod_levels'
:
[],
'shape_concat'
:
[]
})
dtypes
=
[
data
.
dtype
for
data
in
feed_dict
]
reader_var
.
desc
.
set_dtypes
(
dtypes
)
reader_var
.
persistable
=
True
main_prog_reader_var
=
_copy_reader_var_
(
default_main_program
().
current_block
(),
reader_var
)
reader
=
monkey_patch_reader_methods
(
main_prog_reader_var
)
# monkey patch py_reader special methods
reader
.
queue
=
feed_queue
reader
.
exited
=
False
main_blk
=
default_main_program
().
current_block
()
main_blk
.
append_op
(
type
=
'read'
,
inputs
=
{
'Reader'
:
[
reader
]},
attrs
=
{
'infer_out'
:
False
},
outputs
=
{
'Out'
:
feed_dict
})
return
reader
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录