提交 93cab641 编写于 作者: F fengjiayi

Complete CreateRandomReaderOp

上级 d8cc21da
...@@ -33,8 +33,6 @@ class ReaderBase { ...@@ -33,8 +33,6 @@ class ReaderBase {
class FileReader : public ReaderBase { class FileReader : public ReaderBase {
public: public:
explicit FileReader(const std::vector<DDim>& shapes) : shapes_(shapes) {}
DDim shape(size_t idx) const override; DDim shape(size_t idx) const override;
std::vector<DDim> shapes() const override { return shapes_; } std::vector<DDim> shapes() const override { return shapes_; }
...@@ -44,8 +42,6 @@ class FileReader : public ReaderBase { ...@@ -44,8 +42,6 @@ class FileReader : public ReaderBase {
class ReaderDecorator : public ReaderBase { class ReaderDecorator : public ReaderBase {
public: public:
explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) {}
bool HasNext() const override { return reader_->HasNext(); } bool HasNext() const override { return reader_->HasNext(); }
DDim shape(size_t idx) const override { return reader_->shape(idx); } DDim shape(size_t idx) const override { return reader_->shape(idx); }
...@@ -60,19 +56,19 @@ class ReaderDecorator : public ReaderBase { ...@@ -60,19 +56,19 @@ class ReaderDecorator : public ReaderBase {
template <typename T> template <typename T>
class RandomReader : public FileReader { class RandomReader : public FileReader {
public: public:
RandomReader(const std::vector<DDim>& shapes, float min, float max) void Initialize(const std::vector<DDim>& shapes, float min, float max) {
: FileReader(shapes), min_(min), max_(max) {
PADDLE_ENFORCE_LE(min, max, PADDLE_ENFORCE_LE(min, max,
"'min' should be less than or equal to 'max'.(%f vs %f)", "'min' should be less than or equal to 'max'.(%f vs %f)",
min, max); min, max);
shapes_ = shapes;
min_ = min;
max_ = max;
unsigned int seed = std::random_device()();
engine_.seed(seed);
dist_ = std::uniform_real_distribution<float>(min_, max_);
} }
std::vector<LoDTensor> ReadNext() override { std::vector<LoDTensor> ReadNext() override {
std::minstd_rand engine;
unsigned int seed = std::random_device()();
engine.seed(seed);
std::uniform_real_distribution<float> dist(min_, max_);
std::vector<LoDTensor> res; std::vector<LoDTensor> res;
res.reserve(shapes_.size()); res.reserve(shapes_.size());
for (const DDim& shape : shapes_) { for (const DDim& shape : shapes_) {
...@@ -85,7 +81,7 @@ class RandomReader : public FileReader { ...@@ -85,7 +81,7 @@ class RandomReader : public FileReader {
T* data = out.mutable_data<T>(platform::CPUPlace()); T* data = out.mutable_data<T>(platform::CPUPlace());
int64_t numel = product(shape); int64_t numel = product(shape);
for (int64_t i = 0; i < numel; ++i) { for (int64_t i = 0; i < numel; ++i) {
data[i] = dist(engine); data[i] = dist_(engine_);
} }
res.push_back(out); res.push_back(out);
} }
...@@ -97,16 +93,21 @@ class RandomReader : public FileReader { ...@@ -97,16 +93,21 @@ class RandomReader : public FileReader {
private: private:
float min_; float min_;
float max_; float max_;
std::minstd_rand engine_;
std::uniform_real_distribution<float> dist_;
}; };
// decorators // decorators
class ShuffleReader : public ReaderDecorator { class ShuffleReader : public ReaderDecorator {
public: public:
ShuffleReader(ReaderBase* reader, int buffer_size) void Initialize(ReaderBase* reader, int buffer_size) {
: ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) { reader_ = reader;
buffer_size_ = buffer_size;
iteration_pos_ = 0;
buffer_.reserve(buffer_size); buffer_.reserve(buffer_size);
} }
std::vector<LoDTensor> ReadNext() override; std::vector<LoDTensor> ReadNext() override;
private: private:
...@@ -117,8 +118,12 @@ class ShuffleReader : public ReaderDecorator { ...@@ -117,8 +118,12 @@ class ShuffleReader : public ReaderDecorator {
class BatchReader : public ReaderDecorator { class BatchReader : public ReaderDecorator {
public: public:
BatchReader(ReaderBase* reader, int batch_size) void Initialize(ReaderBase* reader, int batch_size) {
: ReaderDecorator(reader), batch_size_(batch_size) {} reader_ = reader;
batch_size_ = batch_size;
buffer_.reserve(batch_size_);
}
std::vector<LoDTensor> ReadNext() override; std::vector<LoDTensor> ReadNext() override;
private: private:
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/framework/op_registry.h"
#include "paddle/framework/reader.h"
namespace paddle {
namespace operators {
// general infershape
class CreateReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of CreateReaderOp should not be null.");
}
};
template <typename T>
class CreateRandomReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::vector<framework::DDim> shapes;
int offset = 0;
for (int len : ranks) {
auto start_it = shape_concat.begin() + offset;
auto end_it = start_it + len;
shapes.push_back(
framework::make_ddim(std::vector<int>(start_it, end_it)));
offset += len;
}
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::RandomReader<T>>();
out->Initialize(shapes, Attr<float>("min"), Attr<float>("max"));
}
};
class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddOutput("Out", "(RandomReader) The created random reader.");
AddAttr<std::vector<int>>("shape_concat",
"The concat of all data's shapes.");
AddAttr<std::vector<int>>(
"ranks",
"The ranks of each data."
"e.g."
"shape_concat = [2,3,4,5,6]"
"ranks = [3,2]"
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
AddComment(R"DOC(
CreateRandomReader Operator
This Op creates a random reader.
The reader generates random data instead of really reading from files.
Generated data follow an uniform distribution between 'min' and 'max'.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>,
ops::CreateReaderInferShape, ops::CreateRandomReaderOpMaker,
paddle::framework::EmptyGradOpMaker);
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册