提交 f32ca636 编写于 作者: F fengjiayi

draft of Reader classes

上级 1acad21b
......@@ -24,6 +24,8 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(variable_test SRCS variable_test.cc)
cc_library(threadpool SRCS threadpool.cc DEPS enforce)
......
......@@ -17,35 +17,100 @@
namespace paddle {
namespace framework {
DDim Reader::shape(int idx) const {
DDim Reader::shape(size_t idx) const {
PADDLE_ENFORCE_LT(
idx, shapes_.size(),
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx,
shapes_.size());
return shapes_[idx];
}
int RandomReader::ReadNext(std::vector<LoDTensor>* outs) {
PADDLE_ENFORCE_EQ(
shapes_.size(), outs.size(),
"shapes_.size() is %d, while outs.size() is %d. They are not equal.",
shapes_.size(), outs.size());
std::minstd_rand engine;
unsigned int seed = std::random_device()();
engine.seed(seed);
std::uniform_real_distribution<float> dist(min_, max_);
for (int idx = 0; idx < shapes_.size(); ++idx) {
DDim shape = shapes_[idx];
LoDTensor* out = outs[idx];
int64_t numel = out->numel();
PADDLE_ENFORCE_EQ(product(shape), numel,
"The product of %d'th shape is %lld, while the "
"corresponding out's numel is %lld. They are not equal.",
idx, product(shape), numel);
for (int64_t i = 0; i < numel, ++i) {
out[i] = dist(engine);
std::vector<LoDTensor> ShuffleReader::ReadNext() {
if (iteration_pos_ >= buffer_.size()) {
// Reload buffer with new data
buffer_.clear();
for (int i = 0; i < buffer_size_; ++i) {
if (reader_->HasNext()) {
buffer_.push_back(reader_->ReadNext());
} else {
break;
}
}
std::random_shuffle(buffer_.begin(), buffer_.end());
iteration_pos_ = 0;
}
return 0;
if (buffer_.empty()) {
std::vector<LoDTensor> empty_res;
return empty_res;
}
return buffer_[iteration_pos_++];
}
std::vector<LoDTensor> BatchReader::ReadNext() {
buffer_.clear();
for (int i = 0; i < batch_size_; ++i) {
if (reader_->HasNext()) {
buffer_.push_back(reader_->ReadNext());
} else {
break;
}
}
// Concat instances
std::vector<LoDTensor> res;
if (buffer_.empty()) {
return res;
}
int out_num = buffer_[0].size();
res.reserve(out_num);
for (int j = 0; j < out_num; ++j) {
// Merge shape and check date type
std::type_index batch_type = buffer_[0][j].type();
DDim batch_shape = buffer_[0][j].dims();
for (size_t i = 1; i < buffer_.size(); ++i) {
std::type_index ins_type = buffer_[i][j].type();
DDim ins_shape = buffer_[i][j].dims();
PADDLE_ENFORCE_EQ(batch_type, ins_type);
PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()),
slice_ddim(ins_shape, 1, ins_shape.size()));
PADDLE_ENFORCE_GT(ins_shape[0], 0);
batch_shape[0] += ins_shape[0];
}
LoDTensor out;
out.Resize(batch_shape);
out.mutable_data(platform::CPUPlace(), batch_type);
int64_t dst_offset = 0;
// Merge lod and data
LoD batch_lod;
std::vector<size_t> top_level_lod({0});
for (size_t i = 0; i < buffer_.size(); ++i) {
DDim ins_shape = buffer_[i][j].dims();
LoD ins_lod = buffer_[i][j].lod();
if (i == 0) {
batch_lod = ins_lod;
} else {
PADDLE_ENFORCE_EQ(batch_lod.size(), ins_lod.size());
for (size_t level_idx = 0; level_idx < batch_lod.size(); ++level_idx) {
auto& lod_level = batch_lod[level_idx];
for (size_t k = 1; k < ins_lod[level_idx].size(); ++k) {
lod_level.push_back(ins_lod[level_idx][k] + lod_level.back());
}
}
}
top_level_lod.push_back(
top_level_lod.back() +
(ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1)));
Tensor dst = out.Slice(dst_offset, dst_offset + ins_shape[0]);
Copy(buffer_[i][j], platform::CPUPlace(), &dst);
dst_offset += ins_shape[0];
}
batch_lod.insert(batch_lod.begin(), top_level_lod);
out.set_lod(batch_lod);
res.push_back(out);
}
return res;
}
} // namespace framework
} // namespace paddle
......@@ -22,20 +22,61 @@ namespace framework {
class Reader {
public:
virtual int ReadNext(std::vector<LoDTensor>* outs) = 0;
DDim shape(int idx) const;
Reader() {}
explicit Reader(const std::vector<DDim>& shapes) : shapes_(shapes) {}
virtual std::vector<LoDTensor> ReadNext() = 0;
virtual bool HasNext() const = 0;
virtual DDim shape(size_t idx) const;
virtual std::vector<DDim> shapes() const { return shapes_; }
virtual ~Reader() {}
private:
// set private to prevent directly access in decorators
// a decorator should access its underlying reader_'s shape, not its own.
std::vector<DDim> shapes_;
};
// file readers
template <typename T>
class RandomReader : public Reader {
public:
RandomReader(const std::vector<DDim>& shapes, float min, float max)
: shapes_(shapes), min_(min), max_(max) {}
int ReadNext(std::vector<LoDTensor>* outs) override;
: Reader(shapes), min_(min), max_(max) {
PADDLE_ENFORCE_LE(min, max,
"'min' should be less than or equal to 'max'.(%f vs %f)",
min, max);
}
std::vector<LoDTensor> ReadNext() override {
std::minstd_rand engine;
unsigned int seed = std::random_device()();
engine.seed(seed);
std::uniform_real_distribution<float> dist(min_, max_);
std::vector<LoDTensor> res;
res.reserve(shapes().size());
for (const DDim& shape : shapes()) {
PADDLE_ENFORCE_GE(
shape.size(), 2,
"The rank of input data should be 2 at least.(Now it's %d)",
shape.size());
LoDTensor out;
out.Resize(shape);
T* data = out.mutable_data<T>(platform::CPUPlace());
int64_t numel = product(shape);
for (int64_t i = 0; i < numel; ++i) {
data[i] = dist(engine);
}
res.push_back(out);
}
return res;
}
bool HasNext() const override { return true; }
private:
float min_;
......@@ -44,22 +85,40 @@ class RandomReader : public Reader {
// decorators
class BatchReader : public Reader {
class ShuffleReader : public Reader {
public:
BatchReader(const Reader* reader) : reader_(reader) {}
int ReadNext(std::vector<LoDTensor>* outs) override;
ShuffleReader(Reader* reader, int buffer_size)
: reader_(reader), buffer_size_(buffer_size), iteration_pos_(0) {
buffer_.reserve(buffer_size);
}
std::vector<LoDTensor> ReadNext() override;
bool HasNext() const override { return reader_->HasNext(); }
DDim shape(size_t idx) const override { return reader_->shape(idx); }
std::vector<DDim> shapes() const override { return reader_->shapes(); }
private:
const Reader* reader_;
Reader* reader_;
int buffer_size_;
std::vector<std::vector<LoDTensor>> buffer_;
size_t iteration_pos_;
};
class ShuffleReader : public Reader {
class BatchReader : public Reader {
public:
ShuffleReader(const Reader* reader) : reader_(reader) {}
int ReadNext(std::vector<LoDTensor>* outs) override;
BatchReader(Reader* reader, int batch_size)
: reader_(reader), batch_size_(batch_size) {}
std::vector<LoDTensor> ReadNext() override;
bool HasNext() const override { return reader_->HasNext(); };
DDim shape(size_t idx) const override { return reader_->shape(idx); }
std::vector<DDim> shapes() const override { return reader_->shapes(); }
private:
const Reader* reader_;
Reader* reader_;
int batch_size_;
std::vector<std::vector<LoDTensor>> buffer_;
};
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册