reader.cc 4.8 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/reader.h"
F
fengjiayi 已提交
16 17 18 19

namespace paddle {
namespace framework {

F
fengjiayi 已提交
20
DDim ReaderBase::shape(size_t idx) const {
F
fengjiayi 已提交
21 22 23 24
  PADDLE_ENFORCE_LT(
      idx, shapes_.size(),
      "Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx,
      shapes_.size());
F
fengjiayi 已提交
25
  return shapes_[idx];
F
fengjiayi 已提交
26 27
}

F
fengjiayi 已提交
28
void ShuffleReader::ReadNext(std::vector<LoDTensor>* out) {
F
fengjiayi 已提交
29 30 31
  if (iteration_pos_ >= buffer_.size()) {
    // Reload buffer with new data
    buffer_.clear();
F
fengjiayi 已提交
32
    buffer_.reserve(buffer_size_);
F
fengjiayi 已提交
33 34
    for (int i = 0; i < buffer_size_; ++i) {
      if (reader_->HasNext()) {
F
fengjiayi 已提交
35 36
        buffer_.push_back(std::vector<LoDTensor>());
        reader_->ReadNext(&buffer_.back());
F
fengjiayi 已提交
37 38 39
      } else {
        break;
      }
F
fengjiayi 已提交
40
    }
F
fengjiayi 已提交
41 42
    // TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be
    // optimize.
F
fengjiayi 已提交
43 44
    std::random_shuffle(buffer_.begin(), buffer_.end());
    iteration_pos_ = 0;
F
fengjiayi 已提交
45
  }
F
fengjiayi 已提交
46 47 48
  out->clear();
  if (!buffer_.empty()) {
    std::swap(*out, buffer_[iteration_pos_++]);
F
fengjiayi 已提交
49
  }
F
fengjiayi 已提交
50
  // if buffer_ is empty, the 'out' will return as an empty vector.
F
fengjiayi 已提交
51 52
}

F
fengjiayi 已提交
53
void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
F
fengjiayi 已提交
54
  buffer_.clear();
F
fengjiayi 已提交
55
  buffer_.reserve(batch_size_);
F
fengjiayi 已提交
56 57
  for (int i = 0; i < batch_size_; ++i) {
    if (reader_->HasNext()) {
F
fengjiayi 已提交
58
      buffer_.push_back(std::vector<LoDTensor>());
F
fengjiayi 已提交
59
      reader_->ReadNext(&buffer_.back());
F
fengjiayi 已提交
60 61 62 63 64
    } else {
      break;
    }
  }
  // Concat instances
F
fengjiayi 已提交
65
  out->clear();
F
fengjiayi 已提交
66
  if (buffer_.empty()) {
F
fengjiayi 已提交
67 68
    // if buffer_ is empty, the 'out' will return as an empty vector.
    return;
F
fengjiayi 已提交
69 70
  }
  int out_num = buffer_[0].size();
F
fengjiayi 已提交
71
  out->reserve(out_num);
F
fengjiayi 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85
  for (int j = 0; j < out_num; ++j) {
    // Merge shape and check date type
    std::type_index batch_type = buffer_[0][j].type();
    DDim batch_shape = buffer_[0][j].dims();
    for (size_t i = 1; i < buffer_.size(); ++i) {
      std::type_index ins_type = buffer_[i][j].type();
      DDim ins_shape = buffer_[i][j].dims();
      PADDLE_ENFORCE_EQ(batch_type, ins_type);
      PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()),
                        slice_ddim(ins_shape, 1, ins_shape.size()));
      PADDLE_ENFORCE_GT(ins_shape[0], 0);
      batch_shape[0] += ins_shape[0];
    }

F
fengjiayi 已提交
86 87 88
    LoDTensor out_tensor;
    out_tensor.Resize(batch_shape);
    out_tensor.mutable_data(platform::CPUPlace(), batch_type);
F
fengjiayi 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    int64_t dst_offset = 0;

    // Merge lod and data
    LoD batch_lod;
    for (size_t i = 0; i < buffer_.size(); ++i) {
      DDim ins_shape = buffer_[i][j].dims();
      LoD ins_lod = buffer_[i][j].lod();
      if (i == 0) {
        batch_lod = ins_lod;
      } else {
        PADDLE_ENFORCE_EQ(batch_lod.size(), ins_lod.size());
        for (size_t level_idx = 0; level_idx < batch_lod.size(); ++level_idx) {
          auto& lod_level = batch_lod[level_idx];
          for (size_t k = 1; k < ins_lod[level_idx].size(); ++k) {
            lod_level.push_back(ins_lod[level_idx][k] + lod_level.back());
          }
        }
      }
F
fengjiayi 已提交
107
      Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]);
Y
Yi Wang 已提交
108
      TensorCopy(buffer_[i][j], platform::CPUPlace(), &dst);
F
fengjiayi 已提交
109 110
      dst_offset += ins_shape[0];
    }
F
fengjiayi 已提交
111 112
    out_tensor.set_lod(batch_lod);
    out->push_back(out_tensor);
F
fengjiayi 已提交
113
  }
F
fengjiayi 已提交
114
}
F
fengjiayi 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155

void DoubleBufferReader::ReadNext(std::vector<LoDTensor>* out) {
  std::unique_lock<std::mutex> lck(mtx_);
  while (write_pos_ == read_pos_) {
    buffer_not_empty_.wait(lck);
  }

  out->clear();
  out->resize(buffer_[read_pos_].size());
  // TODO(fengjiayi): This copy shall be reduced.
  for (size_t i = 0; i < buffer_[read_pos_].size(); ++i) {
    TensorCopy(buffer_[read_pos_][i], platform::CPUPlace(), &out[i]);
    out[i].set_lod(buffer_[read_pos_][i].lod());
  }

  ++read_pos_;
  if (read_pos_ >= kDoubleBufferSize) {
    read_pos_ = 0;
  }
  buffer_not_full_.notify_all();
}

bool DoubleBufferReader::HasNext() {
  return reader_->HasNext() || !buffer_.empty();
}

void DoubleBufferReader::ProducerThreadFunc() {
  while (reader_->HasNext()) {
    std::unique_lock<std::mutex> lck(mtx);
    while (((write_pos_ + 1) % kDoubleBufferSize) == read_pos_) {
      buffer_not_full_.wait(lck);
    }
    reader_->ReadNext(&buffer_[write_pos_]);
    ++write_pos_;
    if (write_pos_ >= kDoubleBufferSize) {
      write_pos_ = 0;
    }
    buffer_not_empty_.notify_all();
  }
}

F
fengjiayi 已提交
156 157
}  // namespace framework
}  // namespace paddle