create_double_buffer_reader_op.cc 6.8 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

F
fengjiayi 已提交
15
#include <thread>  // NOLINT
F
fengjiayi 已提交
16

17
#include "paddle/fluid/framework/channel.h"
F
fengjiayi 已提交
18 19 20 21 22 23
#include "paddle/fluid/operators/reader/reader_op_registry.h"

namespace paddle {
namespace operators {
namespace reader {

J
JiayiFeng 已提交
24
// 'Double buffer' means we shall maintain two batches of input data at the same
J
JiayiFeng 已提交
25
// time. So the kCacheSize shoul be at least 2.
J
JiayiFeng 已提交
26
static constexpr size_t kCacheSize = 2;
J
JiayiFeng 已提交
27 28 29 30 31
// There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by
// subsequent operators.
// So the channel size should be kChacheSize - 2
J
JiayiFeng 已提交
32
static constexpr size_t kChannelSize = 0;  // kCacheSize - 2
F
fengjiayi 已提交
33 34 35

class DoubleBufferReader : public framework::DecoratedReader {
 public:
Y
Yu Yang 已提交
36 37
  struct Item {
    Item() : ctx_(nullptr) {}
J
JiayiFeng 已提交
38 39 40 41 42 43 44 45 46
    Item(Item&& b) {
      payloads_ = std::move(b.payloads_);
      ctx_ = std::move(b.ctx_);
    }
    Item& operator=(Item&& b) {
      payloads_ = std::move(b.payloads_);
      ctx_ = std::move(b.ctx_);
      return *this;
    }
Y
Yu Yang 已提交
47 48 49 50 51

    std::vector<framework::LoDTensor> payloads_;
    platform::DeviceContext* ctx_;
  };

Y
Yu Yang 已提交
52 53 54
  explicit DoubleBufferReader(
      ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
      : DecoratedReader(reader), place_(target_place) {
Y
Yu Yang 已提交
55
#ifdef PADDLE_WITH_CUDA
J
JiayiFeng 已提交
56
    for (size_t i = 0; i < kCacheSize; ++i) {
F
fengjiayi 已提交
57
      if (platform::is_gpu_place(place_)) {
Y
Yu Yang 已提交
58 59 60 61
        ctxs_.emplace_back(new platform::CUDADeviceContext(
            boost::get<platform::CUDAPlace>(place_)));
      }
    }
F
fengjiayi 已提交
62 63
#endif
    StartPrefetcher();
F
fengjiayi 已提交
64 65
  }

F
fengjiayi 已提交
66
  bool HasNext() const override;
F
fengjiayi 已提交
67
  void ReadNext(std::vector<framework::LoDTensor>* out) override;
68 69
  void ReInit() override;

70 71 72
  ~DoubleBufferReader() { EndPrefetcher(); }

 private:
F
fengjiayi 已提交
73
  void StartPrefetcher() {
J
JiayiFeng 已提交
74
    channel_ = framework::MakeChannel<Item>(kChannelSize);
F
fengjiayi 已提交
75 76 77 78
    prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
  }

  void EndPrefetcher() {
J
JiayiFeng 已提交
79 80
    channel_->Close();
    if (prefetcher_.joinable()) {
F
fengjiayi 已提交
81 82
      prefetcher_.join();
    }
J
JiayiFeng 已提交
83 84
    delete channel_;
    channel_ = nullptr;
F
fengjiayi 已提交
85
  }
F
fengjiayi 已提交
86 87 88

  void PrefetchThreadFunc();

F
fengjiayi 已提交
89
  std::thread prefetcher_;
J
JiayiFeng 已提交
90
  framework::Channel<Item>* channel_;
Y
Yu Yang 已提交
91
  platform::Place place_;
Y
Yu Yang 已提交
92
  std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
F
fengjiayi 已提交
93 94 95 96 97 98 99 100 101 102 103
};

class CreateDoubleBufferReaderOp : public framework::OperatorBase {
 public:
  using framework::OperatorBase::OperatorBase;

 private:
  void RunImpl(const framework::Scope& scope,
               const platform::Place& dev_place) const override {
    auto* out = scope.FindVar(Output("Out"))
                    ->template GetMutable<framework::ReaderHolder>();
F
fengjiayi 已提交
104 105 106 107 108
    if (out->Get() != nullptr) {
      return;
    }
    const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
                                        ->Get<framework::ReaderHolder>();
Y
Yu Yang 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122

    auto place_str = Attr<std::string>("place");
    platform::Place place;
    if (place_str == "CPU") {
      place = platform::CPUPlace();
    } else {
      std::istringstream sin(place_str);
      sin.seekg(std::string("CUDA:").size(), std::ios::beg);
      size_t num;
      sin >> num;
      place = platform::CUDAPlace(static_cast<int>(num));
    }

    out->Reset(new DoubleBufferReader(underlying_reader.Get(), place));
F
fengjiayi 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136
  }
};

class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
 public:
  CreateDoubleBufferReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
      : DecoratedReaderMakerBase(op_proto, op_checker) {
    AddComment(R"DOC(
      CreateDoubleBufferReader Operator

      A double buffer reader takes another reader as its 'underlying reader'.
      It launches another thread to execute the 'underlying reader' asynchronously, 
      which prevents reading process from blocking subsequent training.
    )DOC");
Y
Yu Yang 已提交
137 138 139 140 141 142 143 144 145
    std::unordered_set<std::string> enum_range;
    constexpr size_t kMaxCUDADevs = 128;
    for (size_t i = 0; i < kMaxCUDADevs; ++i) {
      enum_range.insert(string::Sprintf("CUDA:%d", i));
    }
    enum_range.insert("CPU");
    AddAttr<std::string>("place", "The double buffer place, default is CPU")
        .SetDefault("CPU")
        .InEnum({enum_range});
F
fengjiayi 已提交
146 147 148
  }
};

F
fengjiayi 已提交
149
bool DoubleBufferReader::HasNext() const {
J
JiayiFeng 已提交
150
  while (!channel_->IsClosed() && !channel_->CanReceive()) {
F
fengjiayi 已提交
151
  }
J
JiayiFeng 已提交
152
  return channel_->CanReceive();
F
fengjiayi 已提交
153 154
}

F
fengjiayi 已提交
155
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
F
fengjiayi 已提交
156 157 158 159
  if (!HasNext()) {
    PADDLE_THROW("There is no next data!");
  }

F
fengjiayi 已提交
160
  Item batch;
J
JiayiFeng 已提交
161 162
  channel_->Receive(&batch);
  *out = batch.payloads_;
F
fengjiayi 已提交
163 164
  if (batch.ctx_) {
    batch.ctx_->Wait();
Y
Yu Yang 已提交
165
  }
F
fengjiayi 已提交
166 167
}

168 169
void DoubleBufferReader::ReInit() {
  reader_->ReInit();
F
fengjiayi 已提交
170 171
  EndPrefetcher();
  StartPrefetcher();
F
fengjiayi 已提交
172 173 174
}

void DoubleBufferReader::PrefetchThreadFunc() {
175
  VLOG(5) << "A new prefetch thread starts.";
F
fengjiayi 已提交
176 177 178
  std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
  std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
  size_t cached_tensor_id = 0;
179

Y
Yu Yang 已提交
180
  while (reader_->HasNext()) {
Y
Yu Yang 已提交
181
    Item batch;
F
fengjiayi 已提交
182 183
    auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
    reader_->ReadNext(&cpu_batch);
Y
Yu Yang 已提交
184
    if (platform::is_gpu_place(place_)) {
F
fengjiayi 已提交
185 186 187
      auto& gpu_batch = gpu_tensor_cache[cached_tensor_id];
      auto* gpu_ctx = ctxs_[cached_tensor_id].get();
      gpu_batch.resize(cpu_batch.size());
188 189
      for (size_t i = 0; i < cpu_batch.size(); ++i) {
        framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]);
J
JiayiFeng 已提交
190
        gpu_batch[i].set_lod(cpu_batch[i].lod());
Y
Yu Yang 已提交
191
      }
J
JiayiFeng 已提交
192
      batch.payloads_ = gpu_batch;
F
fengjiayi 已提交
193 194 195
      batch.ctx_ = gpu_ctx;
    } else {
      // CPUPlace
J
JiayiFeng 已提交
196
      batch.payloads_ = cpu_batch;
F
fengjiayi 已提交
197
    }
F
fengjiayi 已提交
198 199
    ++cached_tensor_id;
    cached_tensor_id %= kCacheSize;
Y
Yu Yang 已提交
200

201
    try {
J
JiayiFeng 已提交
202
      channel_->Send(&batch);
203
    } catch (paddle::platform::EnforceNotMet e) {
204
      VLOG(5) << "WARNING: The double buffer channel has been closed. The "
F
fengjiayi 已提交
205
                 "prefetch thread will terminate.";
206
      break;
F
fengjiayi 已提交
207 208
    }
  }
J
JiayiFeng 已提交
209
  channel_->Close();
F
fengjiayi 已提交
210
  VLOG(5) << "Prefetch thread terminates.";
F
fengjiayi 已提交
211 212 213 214 215 216 217 218 219 220
}

}  // namespace reader
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_double_buffer_reader,
                                   ops::CreateDoubleBufferReaderOp,
                                   ops::CreateDoubleBufferReaderOpMaker);