From 1acad21bbf7a7eea1dc5cb9a68057d35210f7cdb Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 30 Jan 2018 20:27:38 +0800 Subject: [PATCH] init reader.h and reader.cc files --- paddle/framework/reader.cc | 51 ++++++++++++++++++++++++++++++ paddle/framework/reader.h | 65 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 paddle/framework/reader.cc create mode 100644 paddle/framework/reader.h diff --git a/paddle/framework/reader.cc b/paddle/framework/reader.cc new file mode 100644 index 00000000000..7f80dd7fc10 --- /dev/null +++ b/paddle/framework/reader.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/framework/reader.h" + +namespace paddle { +namespace framework { + +DDim Reader::shape(int idx) const { + PADDLE_ENFORCE_LT( + idx, shapes_.size(), + "Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx, + shapes_.size()); +} + +int RandomReader::ReadNext(std::vector* outs) { + PADDLE_ENFORCE_EQ( + shapes_.size(), outs.size(), + "shapes_.size() is %d, while outs.size() is %d. They are not equal.", + shapes_.size(), outs.size()); + std::minstd_rand engine; + unsigned int seed = std::random_device()(); + engine.seed(seed); + std::uniform_real_distribution dist(min_, max_); + for (int idx = 0; idx < shapes_.size(); ++idx) { + DDim shape = shapes_[idx]; + LoDTensor* out = outs[idx]; + int64_t numel = out->numel(); + PADDLE_ENFORCE_EQ(product(shape), numel, + "The product of %d'th shape is %lld, while the " + "corresponding out's numel is %lld. They are not equal.", + idx, product(shape), numel); + for (int64_t i = 0; i < numel, ++i) { + out[i] = dist(engine); + } + } + return 0; +} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h new file mode 100644 index 00000000000..eed9c18d087 --- /dev/null +++ b/paddle/framework/reader.h @@ -0,0 +1,65 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/framework/ddim.h" +#include "paddle/framework/lod_tensor.h" + +namespace paddle { +namespace framework { + +class Reader { + public: + virtual int ReadNext(std::vector* outs) = 0; + DDim shape(int idx) const; + + private: + std::vector shapes_; +}; + +// file readers + +class RandomReader : public Reader { + public: + RandomReader(const std::vector& shapes, float min, float max) + : shapes_(shapes), min_(min), max_(max) {} + int ReadNext(std::vector* outs) override; + + private: + float min_; + float max_; +}; + +// decorators + +class BatchReader : public Reader { + public: + BatchReader(const Reader* reader) : reader_(reader) {} + int ReadNext(std::vector* outs) override; + + private: + const Reader* reader_; +}; + +class ShuffleReader : public Reader { + public: + ShuffleReader(const Reader* reader) : reader_(reader) {} + int ReadNext(std::vector* outs) override; + + private: + const Reader* reader_; +}; +} // namespace framework +} // namespace paddle -- GitLab