提交 f8638664 编写于 作者: F fengjiayi

Add an unitest

上级 02b7d8be
......@@ -94,7 +94,9 @@ void MultipleReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_file_idx_->Close();
scheduler_.join();
if (scheduler_.joinable()) {
scheduler_.join();
}
delete buffer_;
delete available_thread_idx_;
delete waiting_file_idx_;
......
......@@ -38,17 +38,16 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims) {
size_t separator_pos = file_name.find(kFileFormatSeparator);
size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
"File name illegal! A legal file name should be like: "
"[file_format]:[file_name] (e.g., 'recordio:data_file').");
std::string filetype = file_name.substr(0, separator_pos);
std::string f_name = file_name.substr(separator_pos + 1);
"[file_name].[file_format] (e.g., 'data_file.recordio').");
std::string filetype = file_name.substr(separator_pos + 1);
auto itor = FileReaderRegistry().find(filetype);
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
"No file reader registered for '%s' format.", filetype);
framework::ReaderBase* reader = (itor->second)(f_name, dims);
framework::ReaderBase* reader = (itor->second)(file_name, dims);
return std::unique_ptr<framework::ReaderBase>(reader);
}
......
......@@ -21,7 +21,7 @@ namespace paddle {
namespace operators {
namespace reader {
static constexpr char kFileFormatSeparator[] = ":";
static constexpr char kFileFormatSeparator[] = ".";
using FileReaderCreator = std::function<framework::ReaderBase*(
const std::string&, const std::vector<framework::DDim>&)>;
......
......@@ -21,7 +21,8 @@ from ..executor import global_scope
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'read_file', 'create_shuffle_reader', 'create_double_buffer_reader'
'open_files', 'read_file', 'create_shuffle_reader',
'create_double_buffer_reader'
]
......@@ -307,7 +308,7 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks,
'filename': filenames,
'file_names': filenames,
'thread_num': thread_num
})
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
from shutil import copyfile
class TestMultipleReader(unittest.TestCase):
def setUp(self):
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=32)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist_0.recordio', reader, feeder)
copyfile('./mnist_0.recordio', './mnist_1.recordio')
copyfile('./mnist_0.recordio', './mnist_2.recordio')
print(self.num_batch)
def test_multiple_reader(self, thread_num=3):
file_list = [
'./mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
]
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_files = fluid.layers.open_files(
filenames=file_list,
thread_num=thread_num,
shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
img, label = fluid.layers.read_file(data_files)
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
batch_count = 0
while not data_files.eof():
img_val, = exe.run(fetch_list=[img])
batch_count += 1
print(batch_count)
# data_files.reset()
print("FUCK")
self.assertEqual(batch_count, self.num_batch * 3)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册