提交 c1349d98 编写于 作者: F fengjiayi

fix compile errors

上级 b00cae60
...@@ -38,6 +38,8 @@ void ShuffleReader::ReadNext(std::vector<LoDTensor>* out) { ...@@ -38,6 +38,8 @@ void ShuffleReader::ReadNext(std::vector<LoDTensor>* out) {
break; break;
} }
} }
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be
// optimize.
std::random_shuffle(buffer_.begin(), buffer_.end()); std::random_shuffle(buffer_.begin(), buffer_.end());
iteration_pos_ = 0; iteration_pos_ = 0;
} }
......
...@@ -28,6 +28,8 @@ class ReaderBase { ...@@ -28,6 +28,8 @@ class ReaderBase {
virtual void ReadNext(std::vector<LoDTensor>* out) = 0; virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
virtual bool HasNext() const = 0; virtual bool HasNext() const = 0;
virtual void ReInit() = 0;
DDim shape(size_t idx) const; DDim shape(size_t idx) const;
std::vector<DDim> shapes() const { return shapes_; } std::vector<DDim> shapes() const { return shapes_; }
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; } void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
...@@ -52,6 +54,8 @@ class DecoratedReader : public ReaderBase { ...@@ -52,6 +54,8 @@ class DecoratedReader : public ReaderBase {
bool HasNext() const override { return reader_->HasNext(); } bool HasNext() const override { return reader_->HasNext(); }
void ReInit() override { reader_->ReInit(); }
protected: protected:
ReaderBase* reader_; ReaderBase* reader_;
}; };
...@@ -59,9 +63,9 @@ class DecoratedReader : public ReaderBase { ...@@ -59,9 +63,9 @@ class DecoratedReader : public ReaderBase {
// file readers // file readers
template <typename T> template <typename T>
class RandomReader : public FileReader { class RandomDataGenerator : public FileReader {
public: public:
RandomReader(const std::vector<DDim>& shapes, float min, float max) RandomDataGenerator(const std::vector<DDim>& shapes, float min, float max)
: FileReader(shapes), min_(min), max_(max) { : FileReader(shapes), min_(min), max_(max) {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max); min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
...@@ -91,6 +95,8 @@ class RandomReader : public FileReader { ...@@ -91,6 +95,8 @@ class RandomReader : public FileReader {
bool HasNext() const override { return true; } bool HasNext() const override { return true; }
void ReInit() override { return; }
private: private:
float min_; float min_;
float max_; float max_;
...@@ -139,6 +145,7 @@ class ReaderHolder { ...@@ -139,6 +145,7 @@ class ReaderHolder {
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); } void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
bool HasNext() const { return reader_->HasNext(); } bool HasNext() const { return reader_->HasNext(); }
void ReInit() { reader_->ReInit(); }
DDim shape(size_t idx) const { return reader_->shape(idx); } DDim shape(size_t idx) const { return reader_->shape(idx); }
std::vector<DDim> shapes() const { return reader_->shapes(); } std::vector<DDim> shapes() const { return reader_->shapes(); }
......
...@@ -186,7 +186,7 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) ...@@ -186,7 +186,7 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS}) foreach(src ${GENERAL_OPS})
op_library(${src}) op_library(${src})
endforeach() endforeach()
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_reader);\n") file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_data_generator);\n")
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
std::vector<framework::DDim> RestoreShapes(const std::vector<int>& shape_concat, static std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& ranks) { const std::vector<int>& shape_concat, const std::vector<int>& ranks) {
std::vector<framework::DDim> res; std::vector<framework::DDim> res;
int offset = 0; int offset = 0;
for (int len : ranks) { for (int len : ranks) {
...@@ -69,7 +69,7 @@ class CreateReaderInferVarType : public framework::VarTypeInference { ...@@ -69,7 +69,7 @@ class CreateReaderInferVarType : public framework::VarTypeInference {
}; };
template <typename T> template <typename T>
class CreateRandomReaderOp : public framework::OperatorBase { class CreateRandomDataGeneratorOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
...@@ -84,14 +84,15 @@ class CreateRandomReaderOp : public framework::OperatorBase { ...@@ -84,14 +84,15 @@ class CreateRandomReaderOp : public framework::OperatorBase {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks); std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new framework::RandomReader<T>(shapes, Attr<float>("min"), out->Reset(new framework::RandomDataGenerator<T>(shapes, Attr<float>("min"),
Attr<float>("max"))); Attr<float>("max")));
} }
}; };
class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { class CreateRandomDataGeneratorOpMaker
: public framework::OpProtoAndCheckerMaker {
public: public:
CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) { : OpProtoAndCheckerMaker(op_proto, op_checker) {
AddOutput("Out", "(ReaderHolder) The created random reader."); AddOutput("Out", "(ReaderHolder) The created random reader.");
AddAttr<std::vector<int>>("shape_concat", AddAttr<std::vector<int>>("shape_concat",
...@@ -107,7 +108,7 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -107,7 +108,7 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("min", "The lower bound of reader's uniform distribution."); AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
AddAttr<float>("max", "The upper bound of reader's uniform distribution."); AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
AddComment(R"DOC( AddComment(R"DOC(
CreateRandomReader Operator CreateRandomDataGenerator Operator
This Op creates a random reader. This Op creates a random reader.
The reader generates random data instead of really reading from files. The reader generates random data instead of really reading from files.
...@@ -186,9 +187,10 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -186,9 +187,10 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>, REGISTER_OPERATOR(create_random_data_generator,
ops::CreateRandomDataGeneratorOp<float>,
ops::CreateFileReaderInferShape, ops::CreateFileReaderInferShape,
ops::CreateRandomReaderOpMaker, ops::CreateRandomDataGeneratorOpMaker,
paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType); ops::CreateReaderInferVarType);
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp, REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
......
...@@ -59,7 +59,10 @@ class ReadOp : public framework::OperatorBase { ...@@ -59,7 +59,10 @@ class ReadOp : public framework::OperatorBase {
framework::ReaderHolder* reader = framework::ReaderHolder* reader =
scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>(); scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>();
if (!reader->HasNext()) { if (!reader->HasNext()) {
return; reader->ReInit();
PADDLE_ENFORCE(
reader->HasNext(),
"Reader can not read the next data even it has been re-initialized.");
} }
std::vector<std::string> out_arg_names = Outputs("Out"); std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins; std::vector<framework::LoDTensor> ins;
......
...@@ -20,11 +20,11 @@ prog = fluid.framework.Program() ...@@ -20,11 +20,11 @@ prog = fluid.framework.Program()
block = prog.current_block() block = prog.current_block()
random_reader = block.create_var( random_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="RandomReader") type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator")
random_reader.desc.set_lod_levels([0, 0]) random_reader.desc.set_lod_levels([0, 0])
create_random_reader_op = block.append_op( create_random_data_generator_op = block.append_op(
type="create_random_reader", type="create_random_data_generator",
outputs={"Out": random_reader}, outputs={"Out": random_reader},
attrs={ attrs={
"shape_concat": [1, 2, 1, 1], "shape_concat": [1, 2, 1, 1],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册