create_double_buffer_reader_op.cc 6.3 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

F
fengjiayi 已提交
15
#include <thread>
16
#include "paddle/fluid/framework/channel.h"
F
fengjiayi 已提交
17 18 19 20 21 22
#include "paddle/fluid/operators/reader/reader_op_registry.h"

namespace paddle {
namespace operators {
namespace reader {

J
JiayiFeng 已提交
23 24
static constexpr size_t kCacheSize = 2;
static constexpr size_t kChannelSize = 0;  // kCacheSize - 2
F
fengjiayi 已提交
25 26 27

class DoubleBufferReader : public framework::DecoratedReader {
 public:
Y
Yu Yang 已提交
28 29
  struct Item {
    Item() : ctx_(nullptr) {}
J
JiayiFeng 已提交
30 31 32 33 34 35 36 37 38
    Item(Item&& b) {
      payloads_ = std::move(b.payloads_);
      ctx_ = std::move(b.ctx_);
    }
    Item& operator=(Item&& b) {
      payloads_ = std::move(b.payloads_);
      ctx_ = std::move(b.ctx_);
      return *this;
    }
Y
Yu Yang 已提交
39 40 41 42 43

    std::vector<framework::LoDTensor> payloads_;
    platform::DeviceContext* ctx_;
  };

Y
Yu Yang 已提交
44 45 46
  explicit DoubleBufferReader(
      ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
      : DecoratedReader(reader), place_(target_place) {
Y
Yu Yang 已提交
47
#ifdef PADDLE_WITH_CUDA
J
JiayiFeng 已提交
48
    for (size_t i = 0; i < kCacheSize; ++i) {
F
fengjiayi 已提交
49
      if (platform::is_gpu_place(place_)) {
Y
Yu Yang 已提交
50 51 52 53
        ctxs_.emplace_back(new platform::CUDADeviceContext(
            boost::get<platform::CUDAPlace>(place_)));
      }
    }
F
fengjiayi 已提交
54 55
#endif
    StartPrefetcher();
F
fengjiayi 已提交
56 57
  }

F
fengjiayi 已提交
58
  bool HasNext() const override;
F
fengjiayi 已提交
59
  void ReadNext(std::vector<framework::LoDTensor>* out) override;
60 61
  void ReInit() override;

F
fengjiayi 已提交
62
  void StartPrefetcher() {
J
JiayiFeng 已提交
63
    channel_ = framework::MakeChannel<Item>(kChannelSize);
F
fengjiayi 已提交
64 65 66 67
    prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
  }

  void EndPrefetcher() {
J
JiayiFeng 已提交
68 69
    channel_->Close();
    if (prefetcher_.joinable()) {
F
fengjiayi 已提交
70 71
      prefetcher_.join();
    }
J
JiayiFeng 已提交
72 73
    delete channel_;
    channel_ = nullptr;
F
fengjiayi 已提交
74
  }
F
fengjiayi 已提交
75

F
fengjiayi 已提交
76
  ~DoubleBufferReader() { EndPrefetcher(); }
Y
Yu Yang 已提交
77

F
fengjiayi 已提交
78 79 80
 private:
  void PrefetchThreadFunc();

F
fengjiayi 已提交
81
  std::thread prefetcher_;
J
JiayiFeng 已提交
82
  framework::Channel<Item>* channel_;
Y
Yu Yang 已提交
83
  platform::Place place_;
Y
Yu Yang 已提交
84
  std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
F
fengjiayi 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97
};

class CreateDoubleBufferReaderOp : public framework::OperatorBase {
 public:
  using framework::OperatorBase::OperatorBase;

 private:
  void RunImpl(const framework::Scope& scope,
               const platform::Place& dev_place) const override {
    const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
                                        ->Get<framework::ReaderHolder>();
    auto* out = scope.FindVar(Output("Out"))
                    ->template GetMutable<framework::ReaderHolder>();
Y
Yu Yang 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111

    auto place_str = Attr<std::string>("place");
    platform::Place place;
    if (place_str == "CPU") {
      place = platform::CPUPlace();
    } else {
      std::istringstream sin(place_str);
      sin.seekg(std::string("CUDA:").size(), std::ios::beg);
      size_t num;
      sin >> num;
      place = platform::CUDAPlace(static_cast<int>(num));
    }

    out->Reset(new DoubleBufferReader(underlying_reader.Get(), place));
F
fengjiayi 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125
  }
};

class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
 public:
  CreateDoubleBufferReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
      : DecoratedReaderMakerBase(op_proto, op_checker) {
    AddComment(R"DOC(
      CreateDoubleBufferReader Operator

      A double buffer reader takes another reader as its 'underlying reader'.
      It launches another thread to execute the 'underlying reader' asynchronously, 
      which prevents reading process from blocking subsequent training.
    )DOC");
Y
Yu Yang 已提交
126 127 128 129 130 131 132 133 134
    std::unordered_set<std::string> enum_range;
    constexpr size_t kMaxCUDADevs = 128;
    for (size_t i = 0; i < kMaxCUDADevs; ++i) {
      enum_range.insert(string::Sprintf("CUDA:%d", i));
    }
    enum_range.insert("CPU");
    AddAttr<std::string>("place", "The double buffer place, default is CPU")
        .SetDefault("CPU")
        .InEnum({enum_range});
F
fengjiayi 已提交
135 136 137
  }
};

F
fengjiayi 已提交
138
bool DoubleBufferReader::HasNext() const {
J
JiayiFeng 已提交
139
  while (!channel_->IsClosed() && !channel_->CanReceive()) {
F
fengjiayi 已提交
140
  }
J
JiayiFeng 已提交
141
  return channel_->CanReceive();
F
fengjiayi 已提交
142 143
}

F
fengjiayi 已提交
144
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
F
fengjiayi 已提交
145 146 147 148
  if (!HasNext()) {
    PADDLE_THROW("There is no next data!");
  }

F
fengjiayi 已提交
149
  Item batch;
J
JiayiFeng 已提交
150 151
  channel_->Receive(&batch);
  *out = batch.payloads_;
F
fengjiayi 已提交
152 153
  if (batch.ctx_) {
    batch.ctx_->Wait();
Y
Yu Yang 已提交
154
  }
F
fengjiayi 已提交
155 156
}

157 158
void DoubleBufferReader::ReInit() {
  reader_->ReInit();
F
fengjiayi 已提交
159 160
  EndPrefetcher();
  StartPrefetcher();
F
fengjiayi 已提交
161 162 163
}

void DoubleBufferReader::PrefetchThreadFunc() {
164
  VLOG(5) << "A new prefetch thread starts.";
F
fengjiayi 已提交
165 166 167
  std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
  std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
  size_t cached_tensor_id = 0;
168

Y
Yu Yang 已提交
169
  while (reader_->HasNext()) {
Y
Yu Yang 已提交
170
    Item batch;
F
fengjiayi 已提交
171 172
    auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
    reader_->ReadNext(&cpu_batch);
Y
Yu Yang 已提交
173
    if (platform::is_gpu_place(place_)) {
F
fengjiayi 已提交
174 175 176
      auto& gpu_batch = gpu_tensor_cache[cached_tensor_id];
      auto* gpu_ctx = ctxs_[cached_tensor_id].get();
      gpu_batch.resize(cpu_batch.size());
177 178
      for (size_t i = 0; i < cpu_batch.size(); ++i) {
        framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]);
J
JiayiFeng 已提交
179
        gpu_batch[i].set_lod(cpu_batch[i].lod());
Y
Yu Yang 已提交
180
      }
J
JiayiFeng 已提交
181
      batch.payloads_ = gpu_batch;
F
fengjiayi 已提交
182 183 184
      batch.ctx_ = gpu_ctx;
    } else {
      // CPUPlace
J
JiayiFeng 已提交
185
      batch.payloads_ = cpu_batch;
F
fengjiayi 已提交
186
    }
F
fengjiayi 已提交
187 188
    ++cached_tensor_id;
    cached_tensor_id %= kCacheSize;
Y
Yu Yang 已提交
189

190
    try {
J
JiayiFeng 已提交
191
      channel_->Send(&batch);
192
    } catch (paddle::platform::EnforceNotMet e) {
193
      VLOG(5) << "WARNING: The double buffer channel has been closed. The "
F
fengjiayi 已提交
194
                 "prefetch thread will terminate.";
195
      break;
F
fengjiayi 已提交
196 197
    }
  }
J
JiayiFeng 已提交
198
  channel_->Close();
F
fengjiayi 已提交
199
  VLOG(5) << "Prefetch thread terminates.";
F
fengjiayi 已提交
200 201 202 203 204 205 206 207 208 209
}

}  // namespace reader
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_double_buffer_reader,
                                   ops::CreateDoubleBufferReaderOp,
                                   ops::CreateDoubleBufferReaderOpMaker);