diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 4dac3831109beeed660d32f08fb27c7adf62ac2b..70e2f587dc414a850ddc341b98f26ae54636755c 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include "glog/logging.h" +#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/reader/reader_op_registry.h" namespace paddle { @@ -20,43 +23,53 @@ namespace reader { class ShuffleReader : public framework::DecoratedReader { public: - ShuffleReader(ReaderBase* reader, int buffer_size) - : DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) { - buffer_.reserve(buffer_size); + ShuffleReader(ReaderBase* reader, size_t buffer_size, size_t seed = 0) + : DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) { + VLOG(10) << "Create shuffle reader of " << reader_; + if (seed_ == 0) { + std::random_device device; + seed_ = device(); + } + ReadIntoBuffers(); } - void ReadNext(std::vector* out) override; + void ReadNext(std::vector* out) override { + if (iteration_pos_ >= buffer_.size()) { + VLOG(10) << "Resetting shuffle buffer"; + ReadIntoBuffers(); + } + *out = buffer_[iteration_pos_++]; + } - private: - int buffer_size_; - std::vector> buffer_; - size_t iteration_pos_; -}; + bool HasNext() const override { + return iteration_pos_ < buffer_.size() || reader_->HasNext(); + } -void ShuffleReader::ReadNext(std::vector* out) { - if (iteration_pos_ >= buffer_.size()) { - // Reload buffer with new data + private: + void ReadIntoBuffers() { buffer_.clear(); buffer_.reserve(buffer_size_); - for (int i = 0; i < buffer_size_; ++i) { - buffer_.push_back(std::vector()); - reader_->ReadNext(&buffer_.back()); - if (buffer_.back().empty()) { - buffer_.pop_back(); + iteration_pos_ = 0; + PADDLE_ENFORCE(reader_->HasNext()); + for (size_t i = 0; i < buffer_size_; ++i) { + if (!reader_->HasNext()) { break; } + buffer_.emplace_back(); + reader_->ReadNext(&buffer_.back()); } - // TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be - // optimize. - std::random_shuffle(buffer_.begin(), buffer_.end()); - iteration_pos_ = 0; + std::mt19937 g(seed_); + std::shuffle(buffer_.begin(), buffer_.end(), g); + seed_ = g(); // update seed_; + VLOG(10) << "random buffer size = " << buffer_.size(); } - out->clear(); - if (!buffer_.empty()) { - std::swap(*out, buffer_[iteration_pos_++]); - } - // if buffer_ is empty, the 'out' will return as an empty vector. -} + + size_t buffer_size_; + std::vector> buffer_; + + size_t iteration_pos_; + size_t seed_; +}; class CreateShuffleReaderOp : public framework::OperatorBase { public: @@ -67,10 +80,10 @@ class CreateShuffleReaderOp : public framework::OperatorBase { const platform::Place& dev_place) const override { const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - auto* out = scope.FindVar(Output("Out")) - ->template GetMutable(); - out->Reset( - new ShuffleReader(underlying_reader.Get(), Attr("buffer_size"))); + auto& var = detail::Ref(scope.FindVar(Output("Out"))); + var.GetMutable()->Reset( + new ShuffleReader(underlying_reader.Get(), + static_cast(Attr("buffer_size")))); } }; diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index f1b2af70205ab40f08c11061a683b567f5bcbb7b..81dd9789495a685012f6848106a72556ed7df339 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -21,7 +21,7 @@ from ..executor import global_scope __all__ = [ 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', - 'read_file' + 'read_file', 'create_shuffle_reader' ] @@ -245,6 +245,8 @@ def monkey_patch_reader_methods(reader): reader.eof = eof reader.reset = reset + reader.stop_gradient = True + reader.persistable = True return reader @@ -285,6 +287,25 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): startup_var) +def __create_decorated_reader__(op_type, reader, attrs): + var_name = unique_name(op_type) + startup_blk = default_startup_program().current_block() + startup_var = startup_blk.create_var(name=var_name) + startup_blk.append_op( + type=op_type, + inputs={'UnderlyingReader': reader}, + outputs={'Out': [startup_var]}, + attrs=attrs) + startup_var.persistable = True + return _copy_reader_var_(default_main_program().current_block(), + startup_var) + + +def create_shuffle_reader(reader, buffer_size): + return __create_decorated_reader__('create_shuffle_reader', reader, + {'buffer_size': int(buffer_size)}) + + def read_file(file_obj): helper = LayerHelper('read_file') out = [ diff --git a/python/paddle/fluid/recordio_writer.py b/python/paddle/fluid/recordio_writer.py index 9735df8c06113230af9695f76a7589ea9f50e527..5accaacd5361165d30b92c71ae4fd62e23e44e07 100644 --- a/python/paddle/fluid/recordio_writer.py +++ b/python/paddle/fluid/recordio_writer.py @@ -36,6 +36,7 @@ def convert_reader_to_recordio_file( feed_order=None): if feed_order is None: feed_order = feeder.feed_names + counter = 0 with create_recordio_writer(filename, compressor, max_num_records) as writer: for batch in reader_creator(): @@ -43,3 +44,5 @@ def convert_reader_to_recordio_file( for each in feed_order: writer.append_tensor(res[each]) writer.complete_append_tensor() + counter += 1 + return counter diff --git a/python/paddle/fluid/tests/unittests/test_recordio_reader.py b/python/paddle/fluid/tests/unittests/test_recordio_reader.py index d249742bd30ec41749f16beaa7076f7c6e8f063c..cdebda5b7df1c443f9ed7b58e64c6ce63d2adc31 100644 --- a/python/paddle/fluid/tests/unittests/test_recordio_reader.py +++ b/python/paddle/fluid/tests/unittests/test_recordio_reader.py @@ -31,10 +31,10 @@ class TestRecordIO(unittest.TestCase): name='label', shape=[1], dtype='int64'), ], place=fluid.CPUPlace()) - fluid.recordio_writer.convert_reader_to_recordio_file( + self.num_batches = fluid.recordio_writer.convert_reader_to_recordio_file( './mnist.recordio', reader, feeder) - def test_main(self): + def test_main(self, decorator_callback=None): # use new program with fluid.program_guard(fluid.Program(), fluid.Program()): data_file = fluid.layers.open_recordio_file( @@ -42,6 +42,8 @@ class TestRecordIO(unittest.TestCase): shapes=[[-1, 784], [-1, 1]], lod_levels=[0, 0], dtypes=['float32', 'int64']) + if decorator_callback is not None: + data_file = decorator_callback(data_file) img, label = fluid.layers.read_file(data_file) hidden = fluid.layers.fc(input=img, size=100, act='tanh') @@ -56,9 +58,14 @@ class TestRecordIO(unittest.TestCase): avg_loss_np = [] # train a pass + batch_id = 0 while not data_file.eof(): tmp, = exe.run(fetch_list=[avg_loss]) avg_loss_np.append(tmp) + batch_id += 1 data_file.reset() - + self.assertEqual(batch_id, self.num_batches) self.assertLess(avg_loss_np[-1], avg_loss_np[0]) + + def test_shuffle_reader(self): + self.test_main(decorator_callback=lambda reader: fluid.layers.create_shuffle_reader(reader, buffer_size=200))