From c1349d98aa48060b449c4eea4dfc95a2989ad203 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 7 Feb 2018 14:43:14 +0800 Subject: [PATCH] fix compile errors --- paddle/framework/reader.cc | 2 ++ paddle/framework/reader.h | 11 ++++++++-- paddle/operators/CMakeLists.txt | 2 +- paddle/operators/create_reader_op.cc | 22 ++++++++++--------- paddle/operators/read_op.cc | 5 ++++- .../paddle/v2/fluid/tests/test_cpp_reader.py | 6 ++--- 6 files changed, 31 insertions(+), 17 deletions(-) diff --git a/paddle/framework/reader.cc b/paddle/framework/reader.cc index 86220cd0bba..928b661aaad 100644 --- a/paddle/framework/reader.cc +++ b/paddle/framework/reader.cc @@ -38,6 +38,8 @@ void ShuffleReader::ReadNext(std::vector* out) { break; } } + // TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be + // optimize. std::random_shuffle(buffer_.begin(), buffer_.end()); iteration_pos_ = 0; } diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index ff7153bc7bf..534894cfbd6 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -28,6 +28,8 @@ class ReaderBase { virtual void ReadNext(std::vector* out) = 0; virtual bool HasNext() const = 0; + virtual void ReInit() = 0; + DDim shape(size_t idx) const; std::vector shapes() const { return shapes_; } void set_shapes(const std::vector& shapes) { shapes_ = shapes; } @@ -52,6 +54,8 @@ class DecoratedReader : public ReaderBase { bool HasNext() const override { return reader_->HasNext(); } + void ReInit() override { reader_->ReInit(); } + protected: ReaderBase* reader_; }; @@ -59,9 +63,9 @@ class DecoratedReader : public ReaderBase { // file readers template -class RandomReader : public FileReader { +class RandomDataGenerator : public FileReader { public: - RandomReader(const std::vector& shapes, float min, float max) + RandomDataGenerator(const std::vector& shapes, float min, float max) : FileReader(shapes), min_(min), max_(max) { PADDLE_ENFORCE_LE( min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max); @@ -91,6 +95,8 @@ class RandomReader : public FileReader { bool HasNext() const override { return true; } + void ReInit() override { return; } + private: float min_; float max_; @@ -139,6 +145,7 @@ class ReaderHolder { void ReadNext(std::vector* out) { reader_->ReadNext(out); } bool HasNext() const { return reader_->HasNext(); } + void ReInit() { reader_->ReInit(); } DDim shape(size_t idx) const { return reader_->shape(idx); } std::vector shapes() const { return reader_->shapes(); } diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index e1dba8bb3f9..25bb7187d36 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -186,7 +186,7 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) op_library(${src}) endforeach() -file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_reader);\n") +file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_data_generator);\n") set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") diff --git a/paddle/operators/create_reader_op.cc b/paddle/operators/create_reader_op.cc index 11c77a06032..5ba2a25ab4c 100644 --- a/paddle/operators/create_reader_op.cc +++ b/paddle/operators/create_reader_op.cc @@ -18,8 +18,8 @@ namespace paddle { namespace operators { -std::vector RestoreShapes(const std::vector& shape_concat, - const std::vector& ranks) { +static std::vector RestoreShapes( + const std::vector& shape_concat, const std::vector& ranks) { std::vector res; int offset = 0; for (int len : ranks) { @@ -69,7 +69,7 @@ class CreateReaderInferVarType : public framework::VarTypeInference { }; template -class CreateRandomReaderOp : public framework::OperatorBase { +class CreateRandomDataGeneratorOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; void Run(const framework::Scope& scope, @@ -84,14 +84,15 @@ class CreateRandomReaderOp : public framework::OperatorBase { std::vector shapes = RestoreShapes(shape_concat, ranks); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new framework::RandomReader(shapes, Attr("min"), - Attr("max"))); + out->Reset(new framework::RandomDataGenerator(shapes, Attr("min"), + Attr("max"))); } }; -class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { +class CreateRandomDataGeneratorOpMaker + : public framework::OpProtoAndCheckerMaker { public: - CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(op_proto, op_checker) { AddOutput("Out", "(ReaderHolder) The created random reader."); AddAttr>("shape_concat", @@ -107,7 +108,7 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("min", "The lower bound of reader's uniform distribution."); AddAttr("max", "The upper bound of reader's uniform distribution."); AddComment(R"DOC( - CreateRandomReader Operator + CreateRandomDataGenerator Operator This Op creates a random reader. The reader generates random data instead of really reading from files. @@ -186,9 +187,10 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp, +REGISTER_OPERATOR(create_random_data_generator, + ops::CreateRandomDataGeneratorOp, ops::CreateFileReaderInferShape, - ops::CreateRandomReaderOpMaker, + ops::CreateRandomDataGeneratorOpMaker, paddle::framework::EmptyGradOpMaker, ops::CreateReaderInferVarType); REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp, diff --git a/paddle/operators/read_op.cc b/paddle/operators/read_op.cc index 3d17b26c998..3ae454101f5 100644 --- a/paddle/operators/read_op.cc +++ b/paddle/operators/read_op.cc @@ -59,7 +59,10 @@ class ReadOp : public framework::OperatorBase { framework::ReaderHolder* reader = scope.FindVar(Input("Reader"))->GetMutable(); if (!reader->HasNext()) { - return; + reader->ReInit(); + PADDLE_ENFORCE( + reader->HasNext(), + "Reader can not read the next data even it has been re-initialized."); } std::vector out_arg_names = Outputs("Out"); std::vector ins; diff --git a/python/paddle/v2/fluid/tests/test_cpp_reader.py b/python/paddle/v2/fluid/tests/test_cpp_reader.py index 7efcb0c46d2..e71c3a290c9 100644 --- a/python/paddle/v2/fluid/tests/test_cpp_reader.py +++ b/python/paddle/v2/fluid/tests/test_cpp_reader.py @@ -20,11 +20,11 @@ prog = fluid.framework.Program() block = prog.current_block() random_reader = block.create_var( - type=fluid.core.VarDesc.VarType.READER, name="RandomReader") + type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator") random_reader.desc.set_lod_levels([0, 0]) -create_random_reader_op = block.append_op( - type="create_random_reader", +create_random_data_generator_op = block.append_op( + type="create_random_data_generator", outputs={"Out": random_reader}, attrs={ "shape_concat": [1, 2, 1, 1], -- GitLab