提交 c1349d98 编写于 作者: F fengjiayi

fix compile errors

上级 b00cae60
......@@ -38,6 +38,8 @@ void ShuffleReader::ReadNext(std::vector<LoDTensor>* out) {
break;
}
}
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be
// optimize.
std::random_shuffle(buffer_.begin(), buffer_.end());
iteration_pos_ = 0;
}
......
......@@ -28,6 +28,8 @@ class ReaderBase {
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
virtual bool HasNext() const = 0;
virtual void ReInit() = 0;
DDim shape(size_t idx) const;
std::vector<DDim> shapes() const { return shapes_; }
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
......@@ -52,6 +54,8 @@ class DecoratedReader : public ReaderBase {
bool HasNext() const override { return reader_->HasNext(); }
void ReInit() override { reader_->ReInit(); }
protected:
ReaderBase* reader_;
};
......@@ -59,9 +63,9 @@ class DecoratedReader : public ReaderBase {
// file readers
template <typename T>
class RandomReader : public FileReader {
class RandomDataGenerator : public FileReader {
public:
RandomReader(const std::vector<DDim>& shapes, float min, float max)
RandomDataGenerator(const std::vector<DDim>& shapes, float min, float max)
: FileReader(shapes), min_(min), max_(max) {
PADDLE_ENFORCE_LE(
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
......@@ -91,6 +95,8 @@ class RandomReader : public FileReader {
bool HasNext() const override { return true; }
void ReInit() override { return; }
private:
float min_;
float max_;
......@@ -139,6 +145,7 @@ class ReaderHolder {
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
bool HasNext() const { return reader_->HasNext(); }
void ReInit() { reader_->ReInit(); }
DDim shape(size_t idx) const { return reader_->shape(idx); }
std::vector<DDim> shapes() const { return reader_->shapes(); }
......
......@@ -186,7 +186,7 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
op_library(${src})
endforeach()
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_reader);\n")
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_data_generator);\n")
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
......
......@@ -18,8 +18,8 @@
namespace paddle {
namespace operators {
std::vector<framework::DDim> RestoreShapes(const std::vector<int>& shape_concat,
const std::vector<int>& ranks) {
static std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks) {
std::vector<framework::DDim> res;
int offset = 0;
for (int len : ranks) {
......@@ -69,7 +69,7 @@ class CreateReaderInferVarType : public framework::VarTypeInference {
};
template <typename T>
class CreateRandomReaderOp : public framework::OperatorBase {
class CreateRandomDataGeneratorOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
......@@ -84,14 +84,15 @@ class CreateRandomReaderOp : public framework::OperatorBase {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new framework::RandomReader<T>(shapes, Attr<float>("min"),
out->Reset(new framework::RandomDataGenerator<T>(shapes, Attr<float>("min"),
Attr<float>("max")));
}
};
class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
class CreateRandomDataGeneratorOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddOutput("Out", "(ReaderHolder) The created random reader.");
AddAttr<std::vector<int>>("shape_concat",
......@@ -107,7 +108,7 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
AddComment(R"DOC(
CreateRandomReader Operator
CreateRandomDataGenerator Operator
This Op creates a random reader.
The reader generates random data instead of really reading from files.
......@@ -186,9 +187,10 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>,
REGISTER_OPERATOR(create_random_data_generator,
ops::CreateRandomDataGeneratorOp<float>,
ops::CreateFileReaderInferShape,
ops::CreateRandomReaderOpMaker,
ops::CreateRandomDataGeneratorOpMaker,
paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType);
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
......
......@@ -59,7 +59,10 @@ class ReadOp : public framework::OperatorBase {
framework::ReaderHolder* reader =
scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>();
if (!reader->HasNext()) {
return;
reader->ReInit();
PADDLE_ENFORCE(
reader->HasNext(),
"Reader can not read the next data even it has been re-initialized.");
}
std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins;
......
......@@ -20,11 +20,11 @@ prog = fluid.framework.Program()
block = prog.current_block()
random_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="RandomReader")
type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator")
random_reader.desc.set_lod_levels([0, 0])
create_random_reader_op = block.append_op(
type="create_random_reader",
create_random_data_generator_op = block.append_op(
type="create_random_data_generator",
outputs={"Out": random_reader},
attrs={
"shape_concat": [1, 2, 1, 1],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册