buffered_reader.cc 6.5 KB
Newer Older
Y
yuyang18 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/reader/buffered_reader.h"
C
chengduo 已提交
16
#include <memory>
Z
Zeng Jinle 已提交
17
#include <utility>
Y
yuyang18 已提交
18
#include <vector>
D
Dun Liang 已提交
19
#include "paddle/fluid/framework/data_type.h"
Y
yuyang18 已提交
20

21
#include "paddle/fluid/platform/profiler.h"
Y
yuyang18 已提交
22 23 24
namespace paddle {
namespace operators {
namespace reader {
F
fengjiayi 已提交
25
BufferedReader::~BufferedReader() {
Q
Qiao Longfei 已提交
26
  VLOG(1) << "~BufferedReader";
F
fengjiayi 已提交
27 28 29 30 31
  reader_->Shutdown();
  while (!position_.empty()) {
    position_.front().wait();
    position_.pop();
  }
D
Dun Liang 已提交
32 33 34
#ifdef PADDLE_WITH_CUDA
  if (platform::is_gpu_place(place_)) {
    platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
S
sneaxiy 已提交
35 36 37 38
    PADDLE_ENFORCE(cudaStreamDestroy(stream_));
    for (auto &event : events_) {
      PADDLE_ENFORCE(cudaEventDestroy(event));
    }
D
Dun Liang 已提交
39 40
  }
#endif
F
fengjiayi 已提交
41 42
}

Y
yuyang18 已提交
43 44 45 46 47 48 49
BufferedReader::BufferedReader(
    const std::shared_ptr<framework::ReaderBase> &reader,
    const platform::Place &place, size_t buffer_size)
    : framework::DecoratedReader(reader),
      thread_pool_(1),
      place_(place),
      buffer_size_(buffer_size) {
Q
Qiao Longfei 已提交
50
  VLOG(1) << "BufferedReader";
D
Dun Liang 已提交
51 52 53
#ifdef PADDLE_WITH_CUDA
  if (platform::is_gpu_place(place_)) {
    platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
S
sneaxiy 已提交
54
    compute_stream_ =
D
Dun Liang 已提交
55 56 57
        ((platform::CUDADeviceContext *)(platform::DeviceContextPool::Instance()
                                             .Get(place_)))
            ->stream();
S
sneaxiy 已提交
58 59
    events_.resize(buffer_size);
    for (auto &event : events_) {
D
Dun Liang 已提交
60
      PADDLE_ENFORCE(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
S
sneaxiy 已提交
61 62
    }
    PADDLE_ENFORCE(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
D
Dun Liang 已提交
63 64
  }
#endif
Y
yuyang18 已提交
65 66
  cpu_buffer_.resize(buffer_size);
  gpu_buffer_.resize(buffer_size);
Y
yuyang18 已提交
67
  ReadTillBufferFullAsync();
Y
yuyang18 已提交
68
}
F
fengjiayi 已提交
69

Y
yuyang18 已提交
70
void BufferedReader::ReadTillBufferFullAsync() {
Y
yuyang18 已提交
71 72
  PADDLE_ENFORCE_EQ(position_.size(), 0U);
  for (size_t i = 0; i < buffer_size_; ++i) {
Y
yuyang18 已提交
73
    ReadAsync(i);
Y
yuyang18 已提交
74 75
  }
}
F
fengjiayi 已提交
76

Y
yuyang18 已提交
77
void BufferedReader::ReadAsync(size_t i) {
Y
yuyang18 已提交
78 79 80
  position_.emplace(thread_pool_.enqueue([this, i]() -> size_t {
    TensorVec &cpu = cpu_buffer_[i];
    reader_->ReadNext(&cpu);
Y
yuyang18 已提交
81

Y
yuyang18 已提交
82 83 84
    if (cpu.empty()) {
      return -1UL;
    }
Y
yuyang18 已提交
85

D
Dun Liang 已提交
86 87
#ifdef PADDLE_WITH_CUDA
    // NOTE(liangdun): using async copy instead of TensorCopySync
88 89 90
    // TensorCopySync would block other stream, because TensorCopySync
    // issues the copying command to the default stream, it will make two
    // commands from different streams cannot run concurrently.
Y
yuyang18 已提交
91 92
    if (platform::is_gpu_place(place_)) {
      TensorVec &gpu = gpu_buffer_[i];
93 94 95 96 97 98 99 100 101
      if (gpu.empty()) {
        gpu.resize(cpu.size());
      } else {
        PADDLE_ENFORCE_EQ(gpu.size(), cpu.size(),
                          "Input tensor number not matched");
      }

      std::vector<void *> gpu_ptrs;
      gpu_ptrs.reserve(cpu.size());
Y
yuyang18 已提交
102
      for (size_t i = 0; i < cpu.size(); ++i) {
D
Dun Liang 已提交
103 104
        gpu[i].Resize(cpu[i].dims());
        gpu[i].set_layout(cpu[i].layout());
105 106 107 108 109 110 111 112 113 114 115 116
        gpu_ptrs.emplace_back(gpu[i].mutable_data(place_, cpu[i].type()));
      }

      // NOTE(zjl): cudaStreamWaitEvent() must be called after all
      // gpu[i].mutable_data() is called, since some ops release
      // gpu memory immediately without waiting gpu kernel ends
      platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
      PADDLE_ENFORCE(cudaEventRecord(events_[i], compute_stream_));
      PADDLE_ENFORCE(cudaStreamWaitEvent(stream_, events_[i], 0));

      platform::RecordEvent record_event("BufferedReader:MemoryCopy");
      for (size_t i = 0; i < cpu.size(); ++i) {
D
Dun Liang 已提交
117 118
        auto cpu_place = cpu[i].place();
        auto cpu_ptr = cpu[i].data<void>();
119
        auto gpu_ptr = gpu_ptrs[i];
D
Dun Liang 已提交
120 121
        auto size =
            cpu[i].numel() * paddle::framework::SizeOfType(cpu[i].type());
S
sneaxiy 已提交
122
        if (platform::is_cuda_pinned_place(cpu_place)) {
D
Dun Liang 已提交
123 124
          memory::Copy(boost::get<platform::CUDAPlace>(place_), gpu_ptr,
                       boost::get<platform::CUDAPinnedPlace>(cpu_place),
S
sneaxiy 已提交
125 126
                       cpu_ptr, size, stream_);
        } else if ((platform::is_gpu_place(cpu_place))) {
D
Dun Liang 已提交
127 128
          memory::Copy(boost::get<platform::CUDAPlace>(place_), gpu_ptr,
                       boost::get<platform::CUDAPlace>(cpu_place), cpu_ptr,
S
sneaxiy 已提交
129 130
                       size, stream_);
        } else {
131 132 133 134 135 136 137 138
          platform::CUDAPinnedPlace cuda_pinned_place;
          framework::LoDTensor cuda_pinned_tensor;
          cuda_pinned_tensor.Resize(cpu[i].dims());
          auto cuda_pinned_ptr =
              cuda_pinned_tensor.mutable_data(cuda_pinned_place, cpu[i].type());
          memory::Copy(cuda_pinned_place, cuda_pinned_ptr,
                       boost::get<platform::CPUPlace>(cpu_place), cpu_ptr,
                       size);
D
Dun Liang 已提交
139
          memory::Copy(boost::get<platform::CUDAPlace>(place_), gpu_ptr,
140 141 142
                       cuda_pinned_place, cuda_pinned_ptr, size, stream_);
          PADDLE_ENFORCE(cudaStreamSynchronize(stream_),
                         "cuda stream sync error.");
S
sneaxiy 已提交
143
        }
Y
yuyang18 已提交
144
        gpu[i].set_lod(cpu[i].lod());
Y
yuyang18 已提交
145
      }
S
sneaxiy 已提交
146
      PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
Y
yuyang18 已提交
147
    }
D
Dun Liang 已提交
148
#endif
Y
yuyang18 已提交
149
    return i;
Y
yuyang18 已提交
150 151
  }));
}
F
fengjiayi 已提交
152

Y
yuyang18 已提交
153
void BufferedReader::ShutdownImpl() {
Q
Qiao Longfei 已提交
154
  VLOG(1) << "ShutdownImpl";
Y
yuyang18 已提交
155
  reader_->Shutdown();
Y
yuyang18 已提交
156 157 158
  while (!position_.empty()) {
    position_.pop();
  }
Y
yuyang18 已提交
159
  prev_pos_ = -1UL;
Y
yuyang18 已提交
160
}
F
fengjiayi 已提交
161

Y
yuyang18 已提交
162 163
void BufferedReader::StartImpl() {
  reader_->Start();
Y
yuyang18 已提交
164
  ReadTillBufferFullAsync();
Y
yuyang18 已提交
165
}
F
fengjiayi 已提交
166

Y
yuyang18 已提交
167
void BufferedReader::ReadNextImpl(std::vector<framework::LoDTensor> *out) {
Y
yuyang18 已提交
168 169 170 171 172 173 174 175 176 177 178 179
  if (position_.empty()) {
    out->clear();
    return;
  }
  size_t i = position_.front().get();
  position_.pop();

  if (i == -1UL) {
    ReadNextImpl(out);
    return;
  }

Z
Zeng Jinle 已提交
180 181
  *out = std::move(platform::is_gpu_place(place_) ? gpu_buffer_[i]
                                                  : cpu_buffer_[i]);
Y
yuyang18 已提交
182 183 184 185 186 187 188 189

  // Do not push current position into ReadAsync. Push the previous position
  // Since all computation in fluid are async, change the data of
  // current position may cause data error.
  if (prev_pos_ != -1Ul) {
    ReadAsync(prev_pos_);
  }
  prev_pos_ = i;
Y
yuyang18 已提交
190 191 192 193 194
}

}  // namespace reader
}  // namespace operators
}  // namespace paddle