提交 0bb9c80e 编写于 作者: F fengjiayi

refine code and add unit tests

上级 1010e39b
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/reader.h"
#include "paddle/platform/place.h"
#include "paddle/platform/profiler.h"
......@@ -52,11 +53,13 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarDesc::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarDesc::READER) {
var->GetMutable<ReaderHolder>();
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LoDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST, LOD_RANK_TABLE,"
" PLACE_LIST]",
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER]",
var_type);
}
}
......
......@@ -72,7 +72,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void SetDim(const std::string &name, const DDim &dim) override;
std::vector<DDim> GetRepeatedDim(const std::string &name) const override;
std::vector<DDim> GetRepeatedDims(const std::string &name) const override;
void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) override;
const OpDesc &op_;
const BlockDesc &block_;
......@@ -470,7 +473,7 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
return res;
}
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDim(
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
......@@ -491,6 +494,16 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name,
const DDim &dim) {
block_.FindVarRecursive(name)->SetShape(vectorize(dim));
}
void CompileTimeInferShapeContext::SetRepeatedDims(
const std::string &name, const std::vector<DDim> &dims) {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
std::vector<std::vector<int64_t>> dim_vec(dims.size());
std::transform(dims.begin(), dims.end(), dim_vec.begin(), vectorize);
var->SetShapes(dim_vec);
}
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
......
......@@ -428,13 +428,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
}
std::vector<DDim> GetRepeatedDim(const std::string& name) const override {
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
return var->Get<ReaderHolder>().shapes();
} else {
PADDLE_THROW(
"Only ReaderHolder support 'GetRepeatedDim', but Variable %s's "
"Only ReaderHolder support 'GetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
......@@ -452,6 +452,19 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
}
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
var->GetMutable<ReaderHolder>()->set_shapes(dims);
} else {
PADDLE_THROW(
"Only ReaderHolder support 'SetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
}
proto::VarDesc::VarType GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name);
return ToVarType(var->Type());
......
......@@ -17,7 +17,7 @@
namespace paddle {
namespace framework {
DDim FileReader::shape(size_t idx) const {
DDim ReaderBase::shape(size_t idx) const {
PADDLE_ENFORCE_LT(
idx, shapes_.size(),
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx,
......@@ -25,15 +25,15 @@ DDim FileReader::shape(size_t idx) const {
return shapes_[idx];
}
void ShuffleReader::ReadNext(std::vector<LoDtensor>* out) {
void ShuffleReader::ReadNext(std::vector<LoDTensor>* out) {
if (iteration_pos_ >= buffer_.size()) {
// Reload buffer with new data
buffer_.clear();
buffer_.reverse(buffer_size_);
buffer_.reserve(buffer_size_);
for (int i = 0; i < buffer_size_; ++i) {
if (reader_->HasNext()) {
buffer.push_back(std::vector<LoDTensor>());
reader_->ReadNext(&buffer.back());
buffer_.push_back(std::vector<LoDTensor>());
reader_->ReadNext(&buffer_.back());
} else {
break;
}
......@@ -48,19 +48,19 @@ void ShuffleReader::ReadNext(std::vector<LoDtensor>* out) {
// if buffer_ is empty, the 'out' will return as an empty vector.
}
void BatchReader::ReadNext(std::vector<LoDtensor>* out) {
void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
buffer_.clear();
buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) {
if (reader_->HasNext()) {
buffer_.push_back(std::vector<LoDtensor>());
buffer_.push_back(std::vector<LoDTensor>());
reader_->ReadNext(&buffer_.back());
} else {
break;
}
}
// Concat instances
out.clear();
out->clear();
if (buffer_.empty()) {
// if buffer_ is empty, the 'out' will return as an empty vector.
return;
......
......@@ -22,39 +22,36 @@ namespace framework {
class ReaderBase {
public:
virtual void ReadNext(std::vector<LoDtensor>* out) = 0;
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
PADDLE_ENFORCE(!shapes_.empty());
}
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
virtual bool HasNext() const = 0;
virtual DDim shape(size_t idx) const = 0;
virtual std::vector<DDim> shapes() const = 0;
DDim shape(size_t idx) const;
std::vector<DDim> shapes() const { return shapes_; }
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
virtual ~ReaderBase() {}
protected:
std::vector<DDim> shapes_;
};
class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& shapes) : shapes_(shapes) {
PADDLE_ENFORCE(!shapes_.empty());
}
DDim shape(size_t idx) const override;
std::vector<DDim> shapes() const override { return shapes_; }
protected:
std::vector<DDim> shapes_;
explicit FileReader(const std::vector<DDim>& shapes) : ReaderBase(shapes) {}
};
class DecoratedReader : public ReaderBase {
public:
explicit DecoratedReader(ReaderBase* reader) : reader_(reader) {
explicit DecoratedReader(ReaderBase* reader)
: ReaderBase(reader->shapes()), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
}
bool HasNext() const override { return reader_->HasNext(); }
DDim shape(size_t idx) const override { return reader_->shape(idx); }
std::vector<DDim> shapes() const override { return reader_->shapes(); }
protected:
ReaderBase* reader_;
};
......@@ -73,9 +70,9 @@ class RandomReader : public FileReader {
dist_ = std::uniform_real_distribution<float>(min_, max_);
}
void ReadNext(std::vector<LoDtensor>* out) override {
out.clear();
out.reserve(shapes_.size());
void ReadNext(std::vector<LoDTensor>* out) override {
out->clear();
out->reserve(shapes_.size());
for (const DDim& shape : shapes_) {
PADDLE_ENFORCE_GE(
shape.size(), 2,
......@@ -88,9 +85,8 @@ class RandomReader : public FileReader {
for (int64_t i = 0; i < numel; ++i) {
data[i] = dist_(engine_);
}
out.push_back(out_tensor);
out->push_back(out_tensor);
}
return out;
}
bool HasNext() const override { return true; }
......@@ -111,11 +107,11 @@ class ShuffleReader : public DecoratedReader {
buffer_.reserve(buffer_size);
}
void ReadNext(std::vector<LoDtensor>* out) override;
void ReadNext(std::vector<LoDTensor>* out) override;
private:
int buffer_size_;
std::vector<std::vector<LoDtensor>> buffer_;
std::vector<std::vector<LoDTensor>> buffer_;
size_t iteration_pos_;
};
......@@ -126,11 +122,11 @@ class BatchReader : public DecoratedReader {
buffer_.reserve(batch_size_);
}
void ReadNext(std::vector<LoDtensor>* out) override;
void ReadNext(std::vector<LoDTensor>* out) override;
private:
int batch_size_;
std::vector<std::vector<LoDtensor>> buffer_;
std::vector<std::vector<LoDTensor>> buffer_;
};
// The ReaderHolder is used as readers' unified wrapper,
......@@ -141,11 +137,14 @@ class ReaderHolder {
ReaderBase* Get() const { return reader_.get(); }
void ReadNext(std::vector<LoDtensor>* out) { reader_->ReadNext(out); }
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
bool HasNext() const { return reader_->HasNext(); }
DDim shape(size_t idx) const { return reader_->shape(idx); }
std::vector<DDim> shapes() const { return reader_->shapes(); }
void set_shapes(const std::vector<DDim>& shapes) {
reader_->set_shapes(shapes);
}
private:
std::unique_ptr<ReaderBase> reader_;
......
......@@ -62,6 +62,16 @@ void InferShapeContext::SetOutputsDim(const std::string &name,
SetDims(names, dims);
}
void InferShapeContext::SetReaderDims(const std::string &name,
const std::vector<DDim> &dims) {
const std::vector<std::string> &arg_names = Outputs(name);
PADDLE_ENFORCE_EQ(
arg_names.size(), 1UL,
"Reader output '%s' should hold one element, but now it holds %d", name,
arg_names.size());
return this->SetRepeatedDims(arg_names[0], dims);
}
std::vector<DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const {
std::vector<DDim> ret;
......
......@@ -37,11 +37,12 @@ class InferShapeContext {
DDim GetInputDim(const std::string &name) const;
std::vector<DDim> GetInputsDim(const std::string &name) const;
std::vector<DDim> GetReaderDims(const std::string &name) const DDim;
std::vector<DDim> GetReaderDims(const std::string &name) const;
DDim GetInputsElementDim(const std::string &name, int idx) const;
void SetOutputDim(const std::string &name, const DDim &dim);
void SetOutputsDim(const std::string &name, const std::vector<DDim> &dims);
void SetReaderDims(const std::string &name, const std::vector<DDim> &dims);
virtual AttrReader Attrs() const = 0;
virtual const std::vector<std::string> &Inputs(
......@@ -61,7 +62,9 @@ class InferShapeContext {
protected:
virtual DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const DDim &dim) = 0;
std::vector<DDim> GetRepeatedDim(const std::string &name) const = 0;
virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0;
virtual void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) = 0;
std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarDesc::VarType> GetVarTypes(
......
......@@ -57,10 +57,13 @@ size_t VarDesc::GetTensorDescNum() const {
void VarDesc::SetShapes(
const std::vector<std::vector<int64_t>> &multiple_dims) {
PADDLE_ENFORCE_EQ(multiple_dims.size(), GetTensorDescNum(),
"The number of given shapes(%d) doesn't equal to the "
"number of sub tensor.",
multiple_dims.size(), GetTensorDescNum());
if (multiple_dims.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given shapes(" << multiple_dims.size()
<< ") doesn't match the existing tensor number("
<< GetTensorDescNum()
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_dims.size());
}
std::vector<proto::TensorDesc *> tensors = mutable_tensor_descs();
for (size_t i = 0; i < multiple_dims.size(); ++i) {
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
......@@ -87,10 +90,14 @@ void VarDesc::SetDataType(proto::DataType data_type) {
void VarDesc::SetDataTypes(
const std::vector<proto::DataType> &multiple_data_type) {
PADDLE_ENFORCE_EQ(multiple_data_type.size(), GetTensorDescNum(),
"The number of given data types(%d) doesn't equal to the "
"number of sub tensor.",
multiple_data_type.size(), GetTensorDescNum());
if (multiple_data_type.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given data types("
<< multiple_data_type.size()
<< ") doesn't match the existing tensor number("
<< GetTensorDescNum()
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_data_type.size());
}
std::vector<proto::TensorDesc *> tensor_descs = mutable_tensor_descs();
for (size_t i = 0; i < multiple_data_type.size(); ++i) {
tensor_descs[i]->set_data_type(multiple_data_type[i]);
......@@ -127,10 +134,14 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
}
void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
PADDLE_ENFORCE_EQ(multiple_lod_level.size(), GetTensorDescNum(),
"The number of given data types(%d) doesn't equal to the "
"number of sub tensor.",
multiple_lod_level.size(), GetTensorDescNum());
if (multiple_lod_level.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given lod_levels("
<< multiple_lod_level.size()
<< ") doesn't match the existing tensor number("
<< GetTensorDescNum()
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_lod_level.size());
}
switch (desc_.type()) {
case proto::VarDesc::READER: {
size_t i = 0;
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/reader.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/variable.h"
......@@ -31,6 +32,8 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) {
return proto::VarDesc_VarType_LOD_TENSOR_ARRAY;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
return proto::VarDesc_VarType_SELECTED_ROWS;
} else if (type.hash_code() == typeid(ReaderHolder).hash_code()) {
return proto::VarDesc_VarType_READER;
} else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
}
......@@ -40,7 +43,7 @@ template <typename Visitor>
inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) {
case proto::VarDesc_VarType_LOD_TENSOR:
visitor(var.Get<framework::LoDTensor>());
visitor(var.Get<LoDTensor>());
return;
case proto::VarDesc_VarType_LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>());
......@@ -51,6 +54,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
case proto::VarDesc_VarType_SELECTED_ROWS:
visitor(var.Get<SelectedRows>());
return;
case proto::VarDesc_VarType_READER:
visitor(var.Get<ReaderHolder>());
return;
default:
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
}
......
......@@ -18,12 +18,30 @@
namespace paddle {
namespace operators {
std::vector<framework::DDim> RestoreShapes(const std::vector<int>& shape_concat,
const std::vector<int>& ranks) {
std::vector<framework::DDim> res;
int offset = 0;
for (int len : ranks) {
auto start_it = shape_concat.begin() + offset;
auto end_it = start_it + len;
res.push_back(framework::make_ddim(std::vector<int>(start_it, end_it)));
offset += len;
}
return res;
}
// general infershape for file readers
class CreateFileReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output file reader should not be null.");
const auto shape_concat =
ctx->Attrs().Get<std::vector<int>>("shape_concat");
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
ctx->SetReaderDims("Out", shapes);
}
};
......@@ -31,10 +49,22 @@ class CreateFileReaderInferShape : public framework::InferShapeBase {
class CreateDecoratedReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"),
"Input(Underlying_reader) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
"Input(UnderlyingReader) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null.");
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
}
};
// general var type inference for all readers
class CreateReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
std::string reader_name = op_desc.Output("Out")[0];
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
reader->SetType(framework::proto::VarDesc::READER);
}
};
......@@ -51,15 +81,7 @@ class CreateRandomReaderOp : public framework::OperatorBase {
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::vector<framework::DDim> shapes;
int offset = 0;
for (int len : ranks) {
auto start_it = shape_concat.begin() + offset;
auto end_it = start_it + len;
shapes.push_back(
framework::make_ddim(std::vector<int>(start_it, end_it)));
offset += len;
}
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new framework::RandomReader<T>(shapes, Attr<float>("min"),
......@@ -99,7 +121,7 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("Underlying_reader"))
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
......@@ -113,7 +135,7 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput(
"Underlying_reader",
"UnderlyingReader",
"(ReaderHolder) The underlying reader for creating a shuffle reader.");
AddOutput("Out", "(ReaderHolder) The created shuffle reader.");
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
......@@ -131,7 +153,7 @@ class CreateBatchReaderOp : public framework::OperatorBase {
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("Underlying_reader"))
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
......@@ -145,7 +167,7 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput(
"Underlying_reader",
"UnderlyingReader",
"(ReaderHolder) The underlying reader for creating a batch reader.");
AddOutput("Out", "(ReaderHolder) The created batch reader.");
AddAttr<int>("batch_size",
......@@ -167,12 +189,15 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>,
ops::CreateFileReaderInferShape,
ops::CreateRandomReaderOpMaker,
paddle::framework::EmptyGradOpMaker);
paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType);
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
ops::CreateDecoratedReaderInferShape,
ops::CreateShuffleReaderOpMaker,
paddle::framework::EmptyGradOpMaker);
paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType);
REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp,
ops::CreateDecoratedReaderInferShape,
ops::CreateBatchReaderOpMaker,
paddle::framework::EmptyGradOpMaker);
paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType);
......@@ -25,7 +25,7 @@ class ReadInferShape : public framework::InferShapeBase {
"The ReadOp must take a reader as input.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"The ReadOp should be assigned with output.");
std::vector<DDim> reader_dims = ctx->GetReaderDims("Reader");
std::vector<framework::DDim> reader_dims = ctx->GetReaderDims("Reader");
std::vector<std::string> out_names = ctx->Outputs("Out");
PADDLE_ENFORCE_EQ(
reader_dims.size(), out_names.size(),
......@@ -40,12 +40,12 @@ class ReadInferVarType : public framework::VarTypeInference {
framework::BlockDesc* block) const override {
std::string reader_name = op_desc.Input("Reader")[0];
std::vector<std::string> out_names = op_desc.Output("Out");
framework::VarDesc reader = block.FindVarRecursive(reader_name);
auto dtypes = reader.GetDataTypes();
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
auto dtypes = reader->GetDataTypes();
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
for (size_t i = 0; i < dtypes.size(); ++i) {
faremwork::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]);
out.SetType(framework::proto::DataType::LOD_TENSOR);
framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]);
out.SetType(framework::proto::VarDesc::LOD_TENSOR);
out.SetDataType(dtypes[i]);
}
}
......@@ -56,20 +56,18 @@ class ReadOp : public framework::OperatorBase {
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const framework::ReaderHolder& reader =
scope.FindVar(Input("Reader"))->Get<ReaderHolder>();
if (!reader.HasNext()) {
// what shall we do???
framework::ReaderHolder* reader =
scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>();
if (!reader->HasNext()) {
return;
}
std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins;
reader.ReadNext(&ins);
reader->ReadNext(&ins);
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < ins.size(); ++i) {
auto* out =
scope.FindVar(out_arg_names[i])->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(ins[i].dims(), out->dims());
out->ShareDataWith(ins[i]);
out->set_lod(ins[i].lod());
}
......@@ -86,9 +84,13 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
Read Operator
Execute a given reader once and output data.
)DOC")
)DOC");
}
};
} // namespace operators
} // namespace paddle
\ No newline at end of file
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(read, ops::ReadOp, ops::ReadInferShape, ops::ReadOpMaker,
paddle::framework::EmptyGradOpMaker, ops::ReadInferVarType);
......@@ -217,8 +217,6 @@ void BindVarDsec(py::module &m) {
.def("set_shapes", &VarDesc::SetShapes)
.def("set_dtype", &VarDesc::SetDataType)
.def("set_dtypes", &VarDesc::SetDataTypes)
.def("set_tensor_num", &VarDesc::SetTensorDescNum)
.def("tensor_num", &VarDesc::GetTensorDescNum)
.def("shape", &VarDesc::GetShape, py::return_value_policy::reference)
.def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference)
.def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference)
......
......@@ -51,7 +51,8 @@ def as_numpy(tensor):
if len(lod) == 0:
ans = tensor_data
else:
raise RuntimeError("LoD Calculate lacks unit tests and buggy")
#raise RuntimeError("LoD Calculate lacks unit tests and buggy")
ans = tensor_data
# elif len(lod) == 1:
# ans = []
# idx = 0
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import numpy as np
prog = fluid.framework.Program()
block = prog.current_block()
random_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="RandomReader")
random_reader.desc.set_lod_levels([0, 0])
create_random_reader_op = block.append_op(
type="create_random_reader",
outputs={"Out": random_reader},
attrs={
"shape_concat": [1, 2, 1, 1],
"ranks": [2, 2],
"min": 0.0,
"max": 1.0
})
batch_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name=("BatchReader"))
batch_reader.desc.set_lod_levels([0, 0])
create_batch_reader_op = block.append_op(
type="create_batch_reader",
inputs={"UnderlyingReader": random_reader},
outputs={"Out": batch_reader},
attrs={"batch_size": 10})
out1 = block.create_var(
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
name="Out1",
shape=[10, 2],
dtype="float32",
lod_level=1)
out2 = block.create_var(
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
name="Out2",
shape=[10, 1],
dtype="float32",
lod_level=1)
read_op = block.append_op(
type="read", inputs={"Reader": batch_reader},
outputs={"Out": [out1, out2]})
place = fluid.CPUPlace()
exe = fluid.Executor(place)
[res1, res2] = exe.run(prog, fetch_list=[out1, out2])
if len(res1) == 0 or len(res2) == 0:
exit(1)
exit(0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册