create_shuffle_reader_op.cc 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Y
Yu Yang 已提交
15 16 17
#include <random>
#include "glog/logging.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
18 19 20 21 22 23 24 25
#include "paddle/fluid/operators/reader/reader_op_registry.h"

namespace paddle {
namespace operators {
namespace reader {

class ShuffleReader : public framework::DecoratedReader {
 public:
Y
Yu Yang 已提交
26 27 28 29 30 31 32 33
  ShuffleReader(ReaderBase* reader, size_t buffer_size, size_t seed = 0)
      : DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) {
    VLOG(10) << "Create shuffle reader of " << reader_;
    if (seed_ == 0) {
      std::random_device device;
      seed_ = device();
    }
    ReadIntoBuffers();
34 35
  }

Y
Yu Yang 已提交
36
  void ReadNext(std::vector<framework::LoDTensor>* out) override {
F
fengjiayi 已提交
37 38 39
    if (!HasNext()) {
      PADDLE_THROW("There is no next data!");
    }
Y
Yu Yang 已提交
40 41 42 43 44 45
    if (iteration_pos_ >= buffer_.size()) {
      VLOG(10) << "Resetting shuffle buffer";
      ReadIntoBuffers();
    }
    *out = buffer_[iteration_pos_++];
  }
46

Y
Yu Yang 已提交
47 48 49
  bool HasNext() const override {
    return iteration_pos_ < buffer_.size() || reader_->HasNext();
  }
50

Y
Yu Yang 已提交
51 52
 private:
  void ReadIntoBuffers() {
53 54
    buffer_.clear();
    buffer_.reserve(buffer_size_);
Y
Yu Yang 已提交
55 56 57
    iteration_pos_ = 0;
    for (size_t i = 0; i < buffer_size_; ++i) {
      if (!reader_->HasNext()) {
58 59
        break;
      }
Y
Yu Yang 已提交
60 61
      buffer_.emplace_back();
      reader_->ReadNext(&buffer_.back());
62
    }
Y
Yu Yang 已提交
63 64 65 66
    std::mt19937 g(seed_);
    std::shuffle(buffer_.begin(), buffer_.end(), g);
    seed_ = g();  // update seed_;
    VLOG(10) << "random buffer size = " << buffer_.size();
67
  }
Y
Yu Yang 已提交
68 69 70 71 72 73 74

  size_t buffer_size_;
  std::vector<std::vector<framework::LoDTensor>> buffer_;

  size_t iteration_pos_;
  size_t seed_;
};
75 76 77 78 79 80 81 82

class CreateShuffleReaderOp : public framework::OperatorBase {
 public:
  using framework::OperatorBase::OperatorBase;

 private:
  void RunImpl(const framework::Scope& scope,
               const platform::Place& dev_place) const override {
F
fengjiayi 已提交
83 84 85 86 87
    auto* out = detail::Ref(scope.FindVar(Output("Out")))
                    .GetMutable<framework::ReaderHolder>();
    if (out->Get() != nullptr) {
      return;
    }
88 89
    const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
                                        ->Get<framework::ReaderHolder>();
F
fengjiayi 已提交
90
    out->Reset(
Y
Yu Yang 已提交
91 92
        new ShuffleReader(underlying_reader.Get(),
                          static_cast<size_t>(Attr<int>("buffer_size"))));
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
  }
};

class CreateShuffleReaderOpMaker : public DecoratedReaderMakerBase {
 public:
  CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
      : DecoratedReaderMakerBase(op_proto, op_checker) {
    AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
    AddComment(R"DOC(
      CreateShuffleReader Operator

      A shuffle reader takes another reader as its 'underlying reader'
      and yields the underlying reader's outputs in a shuffled order.
    )DOC");
  }
};
}  // namespace reader
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_shuffle_reader,
                                   ops::CreateShuffleReaderOp,
                                   ops::CreateShuffleReaderOpMaker);