create_shuffle_reader_op.cc 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Y
Yu Yang 已提交
15 16 17
#include <random>
#include "glog/logging.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
18 19 20 21 22 23 24 25
#include "paddle/fluid/operators/reader/reader_op_registry.h"

namespace paddle {
namespace operators {
namespace reader {

class ShuffleReader : public framework::DecoratedReader {
 public:
F
fengjiayi 已提交
26 27
  ShuffleReader(const std::shared_ptr<ReaderBase>& reader, size_t buffer_size,
                size_t seed = 0)
Y
Yu Yang 已提交
28 29 30 31 32 33
      : DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) {
    VLOG(10) << "Create shuffle reader of " << reader_;
    if (seed_ == 0) {
      std::random_device device;
      seed_ = device();
    }
F
fengjiayi 已提交
34
    ReloadBuffer();
35 36
  }

Y
Yu Yang 已提交
37
  void ReadNext(std::vector<framework::LoDTensor>* out) override {
F
fengjiayi 已提交
38
    out->clear();
Y
Yu Yang 已提交
39 40
    if (iteration_pos_ >= buffer_.size()) {
      VLOG(10) << "Resetting shuffle buffer";
F
fengjiayi 已提交
41 42 43 44
      ReloadBuffer();
      if (buffer_.empty()) {
        return;
      }
Y
Yu Yang 已提交
45 46 47
    }
    *out = buffer_[iteration_pos_++];
  }
48

Y
Yu Yang 已提交
49
 private:
F
fengjiayi 已提交
50
  void ReloadBuffer() {
51 52
    buffer_.clear();
    buffer_.reserve(buffer_size_);
Y
Yu Yang 已提交
53 54
    iteration_pos_ = 0;
    for (size_t i = 0; i < buffer_size_; ++i) {
F
fengjiayi 已提交
55 56 57
      std::vector<framework::LoDTensor> ins;
      reader_->ReadNext(&ins);
      if (ins.empty()) {
58 59
        break;
      }
F
fengjiayi 已提交
60
      buffer_.emplace_back(ins);
61
    }
Y
Yu Yang 已提交
62 63 64 65
    std::mt19937 g(seed_);
    std::shuffle(buffer_.begin(), buffer_.end(), g);
    seed_ = g();  // update seed_;
    VLOG(10) << "random buffer size = " << buffer_.size();
66
  }
Y
Yu Yang 已提交
67 68 69 70 71 72 73

  size_t buffer_size_;
  std::vector<std::vector<framework::LoDTensor>> buffer_;

  size_t iteration_pos_;
  size_t seed_;
};
74 75 76 77 78 79 80 81

class CreateShuffleReaderOp : public framework::OperatorBase {
 public:
  using framework::OperatorBase::OperatorBase;

 private:
  void RunImpl(const framework::Scope& scope,
               const platform::Place& dev_place) const override {
F
fengjiayi 已提交
82 83 84 85 86
    auto* out = detail::Ref(scope.FindVar(Output("Out")))
                    .GetMutable<framework::ReaderHolder>();
    if (out->Get() != nullptr) {
      return;
    }
87 88
    const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
                                        ->Get<framework::ReaderHolder>();
F
fengjiayi 已提交
89
    out->Reset(
Y
Yu Yang 已提交
90 91
        new ShuffleReader(underlying_reader.Get(),
                          static_cast<size_t>(Attr<int>("buffer_size"))));
92 93 94 95
  }
};

class CreateShuffleReaderOpMaker : public DecoratedReaderMakerBase {
Y
Yu Yang 已提交
96 97
 protected:
  void Apply() override {
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
    AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
    AddComment(R"DOC(
      CreateShuffleReader Operator

      A shuffle reader takes another reader as its 'underlying reader'
      and yields the underlying reader's outputs in a shuffled order.
    )DOC");
  }
};
}  // namespace reader
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_shuffle_reader,
                                   ops::CreateShuffleReaderOp,
                                   ops::CreateShuffleReaderOpMaker);