create_shuffle_reader_op.cc 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Y
Yu Yang 已提交
15 16 17
#include <random>
#include "glog/logging.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
18 19 20 21 22 23 24 25
#include "paddle/fluid/operators/reader/reader_op_registry.h"

namespace paddle {
namespace operators {
namespace reader {

class ShuffleReader : public framework::DecoratedReader {
 public:
F
fengjiayi 已提交
26 27
  ShuffleReader(const std::shared_ptr<ReaderBase>& reader, size_t buffer_size,
                size_t seed = 0)
Y
Yu Yang 已提交
28
      : DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) {
M
minqiyang 已提交
29
    VLOG(10) << "Create shuffle reader of " << reader_;
Y
Yu Yang 已提交
30 31 32 33
    if (seed_ == 0) {
      std::random_device device;
      seed_ = device();
    }
F
fengjiayi 已提交
34
    ReloadBuffer();
35 36
  }

37
  void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
F
fengjiayi 已提交
38
    out->clear();
Y
Yu Yang 已提交
39
    if (iteration_pos_ >= buffer_.size()) {
M
minqiyang 已提交
40
      VLOG(10) << "Resetting shuffle buffer";
F
fengjiayi 已提交
41 42 43 44
      ReloadBuffer();
      if (buffer_.empty()) {
        return;
      }
Y
Yu Yang 已提交
45 46 47
    }
    *out = buffer_[iteration_pos_++];
  }
48

Y
Yu Yang 已提交
49
 private:
50
  void ShutdownImpl() override {
Y
yuyang18 已提交
51
    reader_->Shutdown();
52 53 54 55 56 57 58 59 60
    buffer_.clear();
    iteration_pos_ = 0;
  }

  void StartImpl() override {
    reader_->Start();
    ReloadBuffer();
  }

F
fengjiayi 已提交
61
  void ReloadBuffer() {
62 63
    buffer_.clear();
    buffer_.reserve(buffer_size_);
Y
Yu Yang 已提交
64 65
    iteration_pos_ = 0;
    for (size_t i = 0; i < buffer_size_; ++i) {
F
fengjiayi 已提交
66 67 68
      std::vector<framework::LoDTensor> ins;
      reader_->ReadNext(&ins);
      if (ins.empty()) {
69 70
        break;
      }
F
fengjiayi 已提交
71
      buffer_.emplace_back(ins);
72
    }
Y
Yu Yang 已提交
73 74 75
    std::mt19937 g(seed_);
    std::shuffle(buffer_.begin(), buffer_.end(), g);
    seed_ = g();  // update seed_;
M
minqiyang 已提交
76
    VLOG(10) << "random buffer size = " << buffer_.size();
77
  }
Y
Yu Yang 已提交
78 79 80 81 82 83 84

  size_t buffer_size_;
  std::vector<std::vector<framework::LoDTensor>> buffer_;

  size_t iteration_pos_;
  size_t seed_;
};
85 86 87 88 89 90 91 92

class CreateShuffleReaderOp : public framework::OperatorBase {
 public:
  using framework::OperatorBase::OperatorBase;

 private:
  void RunImpl(const framework::Scope& scope,
               const platform::Place& dev_place) const override {
F
fengjiayi 已提交
93 94 95 96 97
    auto* out = detail::Ref(scope.FindVar(Output("Out")))
                    .GetMutable<framework::ReaderHolder>();
    if (out->Get() != nullptr) {
      return;
    }
98 99
    const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
                                        ->Get<framework::ReaderHolder>();
100 101
    out->Reset(framework::MakeDecoratedReader<ShuffleReader>(
        underlying_reader, static_cast<size_t>(Attr<int>("buffer_size"))));
102 103 104 105
  }
};

class CreateShuffleReaderOpMaker : public DecoratedReaderMakerBase {
Y
Yu Yang 已提交
106 107
 protected:
  void Apply() override {
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
    AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
    AddComment(R"DOC(
      CreateShuffleReader Operator

      A shuffle reader takes another reader as its 'underlying reader'
      and yields the underlying reader's outputs in a shuffled order.
    )DOC");
  }
};
}  // namespace reader
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_shuffle_reader,
                                   ops::CreateShuffleReaderOp,
                                   ops::CreateShuffleReaderOpMaker);