create_shuffle_reader_op.cc 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Y
Yu Yang 已提交
15 16 17
#include <random>
#include "glog/logging.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
18 19 20 21 22 23 24 25
#include "paddle/fluid/operators/reader/reader_op_registry.h"

namespace paddle {
namespace operators {
namespace reader {

class ShuffleReader : public framework::DecoratedReader {
 public:
Y
Yu Yang 已提交
26 27 28 29 30 31 32
  ShuffleReader(ReaderBase* reader, size_t buffer_size, size_t seed = 0)
      : DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) {
    VLOG(10) << "Create shuffle reader of " << reader_;
    if (seed_ == 0) {
      std::random_device device;
      seed_ = device();
    }
F
fengjiayi 已提交
33
    ReloadBuffer();
34 35
  }

Y
Yu Yang 已提交
36
  void ReadNext(std::vector<framework::LoDTensor>* out) override {
F
fengjiayi 已提交
37
    out->clear();
Y
Yu Yang 已提交
38 39
    if (iteration_pos_ >= buffer_.size()) {
      VLOG(10) << "Resetting shuffle buffer";
F
fengjiayi 已提交
40 41 42 43
      ReloadBuffer();
      if (buffer_.empty()) {
        return;
      }
Y
Yu Yang 已提交
44 45 46
    }
    *out = buffer_[iteration_pos_++];
  }
47

Y
Yu Yang 已提交
48
 private:
F
fengjiayi 已提交
49
  void ReloadBuffer() {
50 51
    buffer_.clear();
    buffer_.reserve(buffer_size_);
Y
Yu Yang 已提交
52 53
    iteration_pos_ = 0;
    for (size_t i = 0; i < buffer_size_; ++i) {
F
fengjiayi 已提交
54 55 56
      std::vector<framework::LoDTensor> ins;
      reader_->ReadNext(&ins);
      if (ins.empty()) {
57 58
        break;
      }
F
fengjiayi 已提交
59
      buffer_.emplace_back(ins);
60
    }
Y
Yu Yang 已提交
61 62 63 64
    std::mt19937 g(seed_);
    std::shuffle(buffer_.begin(), buffer_.end(), g);
    seed_ = g();  // update seed_;
    VLOG(10) << "random buffer size = " << buffer_.size();
65
  }
Y
Yu Yang 已提交
66 67 68 69 70 71 72

  size_t buffer_size_;
  std::vector<std::vector<framework::LoDTensor>> buffer_;

  size_t iteration_pos_;
  size_t seed_;
};
73 74 75 76 77 78 79 80

class CreateShuffleReaderOp : public framework::OperatorBase {
 public:
  using framework::OperatorBase::OperatorBase;

 private:
  void RunImpl(const framework::Scope& scope,
               const platform::Place& dev_place) const override {
F
fengjiayi 已提交
81 82 83 84 85
    auto* out = detail::Ref(scope.FindVar(Output("Out")))
                    .GetMutable<framework::ReaderHolder>();
    if (out->Get() != nullptr) {
      return;
    }
86 87
    const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
                                        ->Get<framework::ReaderHolder>();
F
fengjiayi 已提交
88
    out->Reset(
Y
Yu Yang 已提交
89 90
        new ShuffleReader(underlying_reader.Get(),
                          static_cast<size_t>(Attr<int>("buffer_size"))));
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
  }
};

class CreateShuffleReaderOpMaker : public DecoratedReaderMakerBase {
 public:
  CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
      : DecoratedReaderMakerBase(op_proto, op_checker) {
    AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
    AddComment(R"DOC(
      CreateShuffleReader Operator

      A shuffle reader takes another reader as its 'underlying reader'
      and yields the underlying reader's outputs in a shuffled order.
    )DOC");
  }
};
}  // namespace reader
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_shuffle_reader,
                                   ops::CreateShuffleReaderOp,
                                   ops::CreateShuffleReaderOpMaker);