Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f8638664
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f8638664
编写于
3月 20, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add an unitest
上级
02b7d8be
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
82 addition
and
9 deletion
+82
-9
paddle/fluid/operators/reader/open_files_op.cc
paddle/fluid/operators/reader/open_files_op.cc
+3
-1
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+4
-5
paddle/fluid/operators/reader/reader_op_registry.h
paddle/fluid/operators/reader/reader_op_registry.h
+1
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+3
-2
python/paddle/fluid/tests/unittests/test_multiple_reader.py
python/paddle/fluid/tests/unittests/test_multiple_reader.py
+71
-0
未找到文件。
paddle/fluid/operators/reader/open_files_op.cc
浏览文件 @
f8638664
...
...
@@ -94,7 +94,9 @@ void MultipleReader::EndScheduler() {
available_thread_idx_
->
Close
();
buffer_
->
Close
();
waiting_file_idx_
->
Close
();
scheduler_
.
join
();
if
(
scheduler_
.
joinable
())
{
scheduler_
.
join
();
}
delete
buffer_
;
delete
available_thread_idx_
;
delete
waiting_file_idx_
;
...
...
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
f8638664
...
...
@@ -38,17 +38,16 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
std
::
unique_ptr
<
framework
::
ReaderBase
>
CreateReaderByFileName
(
const
std
::
string
&
file_name
,
const
std
::
vector
<
framework
::
DDim
>&
dims
)
{
size_t
separator_pos
=
file_name
.
find
(
kFileFormatSeparator
);
size_t
separator_pos
=
file_name
.
find
_last_of
(
kFileFormatSeparator
);
PADDLE_ENFORCE_NE
(
separator_pos
,
std
::
string
::
npos
,
"File name illegal! A legal file name should be like: "
"[file_format]:[file_name] (e.g., 'recordio:data_file')."
);
std
::
string
filetype
=
file_name
.
substr
(
0
,
separator_pos
);
std
::
string
f_name
=
file_name
.
substr
(
separator_pos
+
1
);
"[file_name].[file_format] (e.g., 'data_file.recordio')."
);
std
::
string
filetype
=
file_name
.
substr
(
separator_pos
+
1
);
auto
itor
=
FileReaderRegistry
().
find
(
filetype
);
PADDLE_ENFORCE
(
itor
!=
FileReaderRegistry
().
end
(),
"No file reader registered for '%s' format."
,
filetype
);
framework
::
ReaderBase
*
reader
=
(
itor
->
second
)(
f_name
,
dims
);
framework
::
ReaderBase
*
reader
=
(
itor
->
second
)(
f
ile
_name
,
dims
);
return
std
::
unique_ptr
<
framework
::
ReaderBase
>
(
reader
);
}
...
...
paddle/fluid/operators/reader/reader_op_registry.h
浏览文件 @
f8638664
...
...
@@ -21,7 +21,7 @@ namespace paddle {
namespace
operators
{
namespace
reader
{
static
constexpr
char
kFileFormatSeparator
[]
=
"
:
"
;
static
constexpr
char
kFileFormatSeparator
[]
=
"
.
"
;
using
FileReaderCreator
=
std
::
function
<
framework
::
ReaderBase
*
(
const
std
::
string
&
,
const
std
::
vector
<
framework
::
DDim
>&
)
>
;
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
f8638664
...
...
@@ -21,7 +21,8 @@ from ..executor import global_scope
__all__
=
[
'data'
,
'BlockGuardServ'
,
'ListenAndServ'
,
'Send'
,
'open_recordio_file'
,
'read_file'
,
'create_shuffle_reader'
,
'create_double_buffer_reader'
'open_files'
,
'read_file'
,
'create_shuffle_reader'
,
'create_double_buffer_reader'
]
...
...
@@ -307,7 +308,7 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
'shape_concat'
:
shape_concat
,
'lod_levels'
:
lod_levels
,
'ranks'
:
ranks
,
'file
name
'
:
filenames
,
'file
_names
'
:
filenames
,
'thread_num'
:
thread_num
})
...
...
python/paddle/fluid/tests/unittests/test_multiple_reader.py
0 → 100644
浏览文件 @
f8638664
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle.fluid
as
fluid
import
paddle.v2
as
paddle
import
paddle.v2.dataset.mnist
as
mnist
from
shutil
import
copyfile
class
TestMultipleReader
(
unittest
.
TestCase
):
def
setUp
(
self
):
# Convert mnist to recordio file
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
reader
=
paddle
.
batch
(
mnist
.
train
(),
batch_size
=
32
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
# order is image and label
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
]),
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
),
],
place
=
fluid
.
CPUPlace
())
self
.
num_batch
=
fluid
.
recordio_writer
.
convert_reader_to_recordio_file
(
'./mnist_0.recordio'
,
reader
,
feeder
)
copyfile
(
'./mnist_0.recordio'
,
'./mnist_1.recordio'
)
copyfile
(
'./mnist_0.recordio'
,
'./mnist_2.recordio'
)
print
(
self
.
num_batch
)
def
test_multiple_reader
(
self
,
thread_num
=
3
):
file_list
=
[
'./mnist_0.recordio'
,
'./mnist_1.recordio'
,
'./mnist_2.recordio'
]
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
data_files
=
fluid
.
layers
.
open_files
(
filenames
=
file_list
,
thread_num
=
thread_num
,
shapes
=
[(
-
1
,
784
),
(
-
1
,
1
)],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
])
img
,
label
=
fluid
.
layers
.
read_file
(
data_files
)
if
fluid
.
core
.
is_compiled_with_cuda
():
place
=
fluid
.
CUDAPlace
(
0
)
else
:
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
batch_count
=
0
while
not
data_files
.
eof
():
img_val
,
=
exe
.
run
(
fetch_list
=
[
img
])
batch_count
+=
1
print
(
batch_count
)
# data_files.reset()
print
(
"FUCK"
)
self
.
assertEqual
(
batch_count
,
self
.
num_batch
*
3
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录