提交 46ae4075 编写于 作者: Y Yu Yang

Polish ShuffleReader and test

上级 7eedced8
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <random>
#include "glog/logging.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle { namespace paddle {
...@@ -20,43 +23,53 @@ namespace reader { ...@@ -20,43 +23,53 @@ namespace reader {
class ShuffleReader : public framework::DecoratedReader { class ShuffleReader : public framework::DecoratedReader {
public: public:
ShuffleReader(ReaderBase* reader, int buffer_size) ShuffleReader(ReaderBase* reader, size_t buffer_size, size_t seed = 0)
: DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) { : DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) {
buffer_.reserve(buffer_size); VLOG(10) << "Create shuffle reader of " << reader_;
if (seed_ == 0) {
std::random_device device;
seed_ = device();
}
ReadIntoBuffers();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer";
ReadIntoBuffers();
}
*out = buffer_[iteration_pos_++];
}
private: bool HasNext() const override {
int buffer_size_; return iteration_pos_ < buffer_.size() || reader_->HasNext();
std::vector<std::vector<framework::LoDTensor>> buffer_; }
size_t iteration_pos_;
};
void ShuffleReader::ReadNext(std::vector<framework::LoDTensor>* out) { private:
if (iteration_pos_ >= buffer_.size()) { void ReadIntoBuffers() {
// Reload buffer with new data
buffer_.clear(); buffer_.clear();
buffer_.reserve(buffer_size_); buffer_.reserve(buffer_size_);
for (int i = 0; i < buffer_size_; ++i) { iteration_pos_ = 0;
buffer_.push_back(std::vector<framework::LoDTensor>()); PADDLE_ENFORCE(reader_->HasNext());
reader_->ReadNext(&buffer_.back()); for (size_t i = 0; i < buffer_size_; ++i) {
if (buffer_.back().empty()) { if (!reader_->HasNext()) {
buffer_.pop_back();
break; break;
} }
buffer_.emplace_back();
reader_->ReadNext(&buffer_.back());
} }
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be std::mt19937 g(seed_);
// optimize. std::shuffle(buffer_.begin(), buffer_.end(), g);
std::random_shuffle(buffer_.begin(), buffer_.end()); seed_ = g(); // update seed_;
iteration_pos_ = 0; VLOG(10) << "random buffer size = " << buffer_.size();
} }
out->clear();
if (!buffer_.empty()) { size_t buffer_size_;
std::swap(*out, buffer_[iteration_pos_++]); std::vector<std::vector<framework::LoDTensor>> buffer_;
}
// if buffer_ is empty, the 'out' will return as an empty vector. size_t iteration_pos_;
} size_t seed_;
};
class CreateShuffleReaderOp : public framework::OperatorBase { class CreateShuffleReaderOp : public framework::OperatorBase {
public: public:
...@@ -67,10 +80,10 @@ class CreateShuffleReaderOp : public framework::OperatorBase { ...@@ -67,10 +80,10 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out")) auto& var = detail::Ref(scope.FindVar(Output("Out")));
->template GetMutable<framework::ReaderHolder>(); var.GetMutable<framework::ReaderHolder>()->Reset(
out->Reset( new ShuffleReader(underlying_reader.Get(),
new ShuffleReader(underlying_reader.Get(), Attr<int>("buffer_size"))); static_cast<size_t>(Attr<int>("buffer_size"))));
} }
}; };
......
...@@ -21,7 +21,7 @@ from ..executor import global_scope ...@@ -21,7 +21,7 @@ from ..executor import global_scope
__all__ = [ __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'read_file' 'read_file', 'create_shuffle_reader'
] ]
...@@ -245,6 +245,8 @@ def monkey_patch_reader_methods(reader): ...@@ -245,6 +245,8 @@ def monkey_patch_reader_methods(reader):
reader.eof = eof reader.eof = eof
reader.reset = reset reader.reset = reset
reader.stop_gradient = True
reader.persistable = True
return reader return reader
...@@ -285,6 +287,25 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): ...@@ -285,6 +287,25 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var) startup_var)
def __create_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_blk.append_op(
type=op_type,
inputs={'UnderlyingReader': reader},
outputs={'Out': [startup_var]},
attrs=attrs)
startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(),
startup_var)
def create_shuffle_reader(reader, buffer_size):
return __create_decorated_reader__('create_shuffle_reader', reader,
{'buffer_size': int(buffer_size)})
def read_file(file_obj): def read_file(file_obj):
helper = LayerHelper('read_file') helper = LayerHelper('read_file')
out = [ out = [
......
...@@ -36,6 +36,7 @@ def convert_reader_to_recordio_file( ...@@ -36,6 +36,7 @@ def convert_reader_to_recordio_file(
feed_order=None): feed_order=None):
if feed_order is None: if feed_order is None:
feed_order = feeder.feed_names feed_order = feeder.feed_names
counter = 0
with create_recordio_writer(filename, compressor, with create_recordio_writer(filename, compressor,
max_num_records) as writer: max_num_records) as writer:
for batch in reader_creator(): for batch in reader_creator():
...@@ -43,3 +44,5 @@ def convert_reader_to_recordio_file( ...@@ -43,3 +44,5 @@ def convert_reader_to_recordio_file(
for each in feed_order: for each in feed_order:
writer.append_tensor(res[each]) writer.append_tensor(res[each])
writer.complete_append_tensor() writer.complete_append_tensor()
counter += 1
return counter
...@@ -31,10 +31,10 @@ class TestRecordIO(unittest.TestCase): ...@@ -31,10 +31,10 @@ class TestRecordIO(unittest.TestCase):
name='label', shape=[1], dtype='int64'), name='label', shape=[1], dtype='int64'),
], ],
place=fluid.CPUPlace()) place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file( self.num_batches = fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', reader, feeder) './mnist.recordio', reader, feeder)
def test_main(self): def test_main(self, decorator_callback=None):
# use new program # use new program
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data_file = fluid.layers.open_recordio_file( data_file = fluid.layers.open_recordio_file(
...@@ -42,6 +42,8 @@ class TestRecordIO(unittest.TestCase): ...@@ -42,6 +42,8 @@ class TestRecordIO(unittest.TestCase):
shapes=[[-1, 784], [-1, 1]], shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64']) dtypes=['float32', 'int64'])
if decorator_callback is not None:
data_file = decorator_callback(data_file)
img, label = fluid.layers.read_file(data_file) img, label = fluid.layers.read_file(data_file)
hidden = fluid.layers.fc(input=img, size=100, act='tanh') hidden = fluid.layers.fc(input=img, size=100, act='tanh')
...@@ -56,9 +58,14 @@ class TestRecordIO(unittest.TestCase): ...@@ -56,9 +58,14 @@ class TestRecordIO(unittest.TestCase):
avg_loss_np = [] avg_loss_np = []
# train a pass # train a pass
batch_id = 0
while not data_file.eof(): while not data_file.eof():
tmp, = exe.run(fetch_list=[avg_loss]) tmp, = exe.run(fetch_list=[avg_loss])
avg_loss_np.append(tmp) avg_loss_np.append(tmp)
batch_id += 1
data_file.reset() data_file.reset()
self.assertEqual(batch_id, self.num_batches)
self.assertLess(avg_loss_np[-1], avg_loss_np[0]) self.assertLess(avg_loss_np[-1], avg_loss_np[0])
def test_shuffle_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.create_shuffle_reader(reader, buffer_size=200))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册