diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 49cdf5365c996485ea31b9c5db3adbf3e1ba18d1..1ab4111efe80f573d12552d7de4e11707c23ff33 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -94,7 +94,9 @@ void MultipleReader::EndScheduler() { available_thread_idx_->Close(); buffer_->Close(); waiting_file_idx_->Close(); - scheduler_.join(); + if (scheduler_.joinable()) { + scheduler_.join(); + } delete buffer_; delete available_thread_idx_; delete waiting_file_idx_; diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index 05d79c76d5ab0e48f441a7cc8a470bd99eb80ca8..fc8dc747ff0c2286f4516d8350f75d9887361924 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -38,17 +38,16 @@ std::unordered_map& FileReaderRegistry() { std::unique_ptr CreateReaderByFileName( const std::string& file_name, const std::vector& dims) { - size_t separator_pos = file_name.find(kFileFormatSeparator); + size_t separator_pos = file_name.find_last_of(kFileFormatSeparator); PADDLE_ENFORCE_NE(separator_pos, std::string::npos, "File name illegal! A legal file name should be like: " - "[file_format]:[file_name] (e.g., 'recordio:data_file')."); - std::string filetype = file_name.substr(0, separator_pos); - std::string f_name = file_name.substr(separator_pos + 1); + "[file_name].[file_format] (e.g., 'data_file.recordio')."); + std::string filetype = file_name.substr(separator_pos + 1); auto itor = FileReaderRegistry().find(filetype); PADDLE_ENFORCE(itor != FileReaderRegistry().end(), "No file reader registered for '%s' format.", filetype); - framework::ReaderBase* reader = (itor->second)(f_name, dims); + framework::ReaderBase* reader = (itor->second)(file_name, dims); return std::unique_ptr(reader); } diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index dd19b982dad8622c2c9cfd3395e6812acba26982..929d32ad8b367865e33530f8517343c513ee9878 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -21,7 +21,7 @@ namespace paddle { namespace operators { namespace reader { -static constexpr char kFileFormatSeparator[] = ":"; +static constexpr char kFileFormatSeparator[] = "."; using FileReaderCreator = std::function&)>; diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 89153f325bed5b46fed6d196b4d0fb520513be26..f169642eaa44eadcef8ff0bc6745183a9fec20e8 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -21,7 +21,8 @@ from ..executor import global_scope __all__ = [ 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', - 'read_file', 'create_shuffle_reader', 'create_double_buffer_reader' + 'open_files', 'read_file', 'create_shuffle_reader', + 'create_double_buffer_reader' ] @@ -307,7 +308,7 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes): 'shape_concat': shape_concat, 'lod_levels': lod_levels, 'ranks': ranks, - 'filename': filenames, + 'file_names': filenames, 'thread_num': thread_num }) diff --git a/python/paddle/fluid/tests/unittests/test_multiple_reader.py b/python/paddle/fluid/tests/unittests/test_multiple_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1aaaae5a7a459ae7a545cc65209dcfd7e80d89 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multiple_reader.py @@ -0,0 +1,71 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle.fluid as fluid +import paddle.v2 as paddle +import paddle.v2.dataset.mnist as mnist +from shutil import copyfile + + +class TestMultipleReader(unittest.TestCase): + def setUp(self): + # Convert mnist to recordio file + with fluid.program_guard(fluid.Program(), fluid.Program()): + reader = paddle.batch(mnist.train(), batch_size=32) + feeder = fluid.DataFeeder( + feed_list=[ # order is image and label + fluid.layers.data( + name='image', shape=[784]), + fluid.layers.data( + name='label', shape=[1], dtype='int64'), + ], + place=fluid.CPUPlace()) + self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file( + './mnist_0.recordio', reader, feeder) + copyfile('./mnist_0.recordio', './mnist_1.recordio') + copyfile('./mnist_0.recordio', './mnist_2.recordio') + print(self.num_batch) + + def test_multiple_reader(self, thread_num=3): + file_list = [ + './mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio' + ] + with fluid.program_guard(fluid.Program(), fluid.Program()): + data_files = fluid.layers.open_files( + filenames=file_list, + thread_num=thread_num, + shapes=[(-1, 784), (-1, 1)], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + img, label = fluid.layers.read_file(data_files) + + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + batch_count = 0 + while not data_files.eof(): + img_val, = exe.run(fetch_list=[img]) + batch_count += 1 + print(batch_count) + # data_files.reset() + print("FUCK") + + self.assertEqual(batch_count, self.num_batch * 3)