提交 0bb9c80e 编写于 作者: F fengjiayi

refine code and add unit tests

上级 1010e39b
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/framework/lod_rank_table.h" #include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/reader.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "paddle/platform/profiler.h" #include "paddle/platform/profiler.h"
...@@ -52,11 +53,13 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { ...@@ -52,11 +53,13 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
var->GetMutable<LoDTensorArray>(); var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarDesc::PLACE_LIST) { } else if (var_type == proto::VarDesc::PLACE_LIST) {
var->GetMutable<platform::PlaceList>(); var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarDesc::READER) {
var->GetMutable<ReaderHolder>();
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Variable type %d is not in " "Variable type %d is not in "
"[LoDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST, LOD_RANK_TABLE," "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
" PLACE_LIST]", "LOD_RANK_TABLE, PLACE_LIST, READER]",
var_type); var_type);
} }
} }
......
...@@ -72,7 +72,10 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -72,7 +72,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void SetDim(const std::string &name, const DDim &dim) override; void SetDim(const std::string &name, const DDim &dim) override;
std::vector<DDim> GetRepeatedDim(const std::string &name) const override; std::vector<DDim> GetRepeatedDims(const std::string &name) const override;
void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) override;
const OpDesc &op_; const OpDesc &op_;
const BlockDesc &block_; const BlockDesc &block_;
...@@ -470,7 +473,7 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const { ...@@ -470,7 +473,7 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
return res; return res;
} }
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDim( std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
const std::string &name) const { const std::string &name) const {
auto var = block_.FindVarRecursive(name); auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
...@@ -491,6 +494,16 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name, ...@@ -491,6 +494,16 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name,
const DDim &dim) { const DDim &dim) {
block_.FindVarRecursive(name)->SetShape(vectorize(dim)); block_.FindVarRecursive(name)->SetShape(vectorize(dim));
} }
void CompileTimeInferShapeContext::SetRepeatedDims(
const std::string &name, const std::vector<DDim> &dims) {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
std::vector<std::vector<int64_t>> dim_vec(dims.size());
std::transform(dims.begin(), dims.end(), dim_vec.begin(), vectorize);
var->SetShapes(dim_vec);
}
bool CompileTimeInferShapeContext::IsRuntime() const { return false; } bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType( proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
......
...@@ -428,13 +428,13 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -428,13 +428,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
} }
std::vector<DDim> GetRepeatedDim(const std::string& name) const override { std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
Variable* var = scope_.FindVar(name); Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) { if (var->IsType<ReaderHolder>()) {
return var->Get<ReaderHolder>().shapes(); return var->Get<ReaderHolder>().shapes();
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Only ReaderHolder support 'GetRepeatedDim', but Variable %s's " "Only ReaderHolder support 'GetRepeatedDims', but Variable %s's "
"type_id is %s.", "type_id is %s.",
name, var->Type().name()); name, var->Type().name());
} }
...@@ -452,6 +452,19 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -452,6 +452,19 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
} }
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
var->GetMutable<ReaderHolder>()->set_shapes(dims);
} else {
PADDLE_THROW(
"Only ReaderHolder support 'SetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
}
proto::VarDesc::VarType GetVarType(const std::string& name) const override { proto::VarDesc::VarType GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name); auto* var = scope_.FindVar(name);
return ToVarType(var->Type()); return ToVarType(var->Type());
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
DDim FileReader::shape(size_t idx) const { DDim ReaderBase::shape(size_t idx) const {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
idx, shapes_.size(), idx, shapes_.size(),
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx, "Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx,
...@@ -25,15 +25,15 @@ DDim FileReader::shape(size_t idx) const { ...@@ -25,15 +25,15 @@ DDim FileReader::shape(size_t idx) const {
return shapes_[idx]; return shapes_[idx];
} }
void ShuffleReader::ReadNext(std::vector<LoDtensor>* out) { void ShuffleReader::ReadNext(std::vector<LoDTensor>* out) {
if (iteration_pos_ >= buffer_.size()) { if (iteration_pos_ >= buffer_.size()) {
// Reload buffer with new data // Reload buffer with new data
buffer_.clear(); buffer_.clear();
buffer_.reverse(buffer_size_); buffer_.reserve(buffer_size_);
for (int i = 0; i < buffer_size_; ++i) { for (int i = 0; i < buffer_size_; ++i) {
if (reader_->HasNext()) { if (reader_->HasNext()) {
buffer.push_back(std::vector<LoDTensor>()); buffer_.push_back(std::vector<LoDTensor>());
reader_->ReadNext(&buffer.back()); reader_->ReadNext(&buffer_.back());
} else { } else {
break; break;
} }
...@@ -48,19 +48,19 @@ void ShuffleReader::ReadNext(std::vector<LoDtensor>* out) { ...@@ -48,19 +48,19 @@ void ShuffleReader::ReadNext(std::vector<LoDtensor>* out) {
// if buffer_ is empty, the 'out' will return as an empty vector. // if buffer_ is empty, the 'out' will return as an empty vector.
} }
void BatchReader::ReadNext(std::vector<LoDtensor>* out) { void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
buffer_.clear(); buffer_.clear();
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) { for (int i = 0; i < batch_size_; ++i) {
if (reader_->HasNext()) { if (reader_->HasNext()) {
buffer_.push_back(std::vector<LoDtensor>()); buffer_.push_back(std::vector<LoDTensor>());
reader_->ReadNext(&buffer_.back()); reader_->ReadNext(&buffer_.back());
} else { } else {
break; break;
} }
} }
// Concat instances // Concat instances
out.clear(); out->clear();
if (buffer_.empty()) { if (buffer_.empty()) {
// if buffer_ is empty, the 'out' will return as an empty vector. // if buffer_ is empty, the 'out' will return as an empty vector.
return; return;
......
...@@ -22,39 +22,36 @@ namespace framework { ...@@ -22,39 +22,36 @@ namespace framework {
class ReaderBase { class ReaderBase {
public: public:
virtual void ReadNext(std::vector<LoDtensor>* out) = 0; explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
PADDLE_ENFORCE(!shapes_.empty());
}
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
virtual bool HasNext() const = 0; virtual bool HasNext() const = 0;
virtual DDim shape(size_t idx) const = 0; DDim shape(size_t idx) const;
virtual std::vector<DDim> shapes() const = 0; std::vector<DDim> shapes() const { return shapes_; }
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
virtual ~ReaderBase() {} virtual ~ReaderBase() {}
protected:
std::vector<DDim> shapes_;
}; };
class FileReader : public ReaderBase { class FileReader : public ReaderBase {
public: public:
explicit FileReader(const std::vector<DDim>& shapes) : shapes_(shapes) { explicit FileReader(const std::vector<DDim>& shapes) : ReaderBase(shapes) {}
PADDLE_ENFORCE(!shapes_.empty());
}
DDim shape(size_t idx) const override;
std::vector<DDim> shapes() const override { return shapes_; }
protected:
std::vector<DDim> shapes_;
}; };
class DecoratedReader : public ReaderBase { class DecoratedReader : public ReaderBase {
public: public:
explicit DecoratedReader(ReaderBase* reader) : reader_(reader) { explicit DecoratedReader(ReaderBase* reader)
: ReaderBase(reader->shapes()), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
} }
bool HasNext() const override { return reader_->HasNext(); } bool HasNext() const override { return reader_->HasNext(); }
DDim shape(size_t idx) const override { return reader_->shape(idx); }
std::vector<DDim> shapes() const override { return reader_->shapes(); }
protected: protected:
ReaderBase* reader_; ReaderBase* reader_;
}; };
...@@ -73,9 +70,9 @@ class RandomReader : public FileReader { ...@@ -73,9 +70,9 @@ class RandomReader : public FileReader {
dist_ = std::uniform_real_distribution<float>(min_, max_); dist_ = std::uniform_real_distribution<float>(min_, max_);
} }
void ReadNext(std::vector<LoDtensor>* out) override { void ReadNext(std::vector<LoDTensor>* out) override {
out.clear(); out->clear();
out.reserve(shapes_.size()); out->reserve(shapes_.size());
for (const DDim& shape : shapes_) { for (const DDim& shape : shapes_) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
shape.size(), 2, shape.size(), 2,
...@@ -88,9 +85,8 @@ class RandomReader : public FileReader { ...@@ -88,9 +85,8 @@ class RandomReader : public FileReader {
for (int64_t i = 0; i < numel; ++i) { for (int64_t i = 0; i < numel; ++i) {
data[i] = dist_(engine_); data[i] = dist_(engine_);
} }
out.push_back(out_tensor); out->push_back(out_tensor);
} }
return out;
} }
bool HasNext() const override { return true; } bool HasNext() const override { return true; }
...@@ -111,11 +107,11 @@ class ShuffleReader : public DecoratedReader { ...@@ -111,11 +107,11 @@ class ShuffleReader : public DecoratedReader {
buffer_.reserve(buffer_size); buffer_.reserve(buffer_size);
} }
void ReadNext(std::vector<LoDtensor>* out) override; void ReadNext(std::vector<LoDTensor>* out) override;
private: private:
int buffer_size_; int buffer_size_;
std::vector<std::vector<LoDtensor>> buffer_; std::vector<std::vector<LoDTensor>> buffer_;
size_t iteration_pos_; size_t iteration_pos_;
}; };
...@@ -126,11 +122,11 @@ class BatchReader : public DecoratedReader { ...@@ -126,11 +122,11 @@ class BatchReader : public DecoratedReader {
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
} }
void ReadNext(std::vector<LoDtensor>* out) override; void ReadNext(std::vector<LoDTensor>* out) override;
private: private:
int batch_size_; int batch_size_;
std::vector<std::vector<LoDtensor>> buffer_; std::vector<std::vector<LoDTensor>> buffer_;
}; };
// The ReaderHolder is used as readers' unified wrapper, // The ReaderHolder is used as readers' unified wrapper,
...@@ -141,11 +137,14 @@ class ReaderHolder { ...@@ -141,11 +137,14 @@ class ReaderHolder {
ReaderBase* Get() const { return reader_.get(); } ReaderBase* Get() const { return reader_.get(); }
void ReadNext(std::vector<LoDtensor>* out) { reader_->ReadNext(out); } void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
bool HasNext() const { return reader_->HasNext(); } bool HasNext() const { return reader_->HasNext(); }
DDim shape(size_t idx) const { return reader_->shape(idx); } DDim shape(size_t idx) const { return reader_->shape(idx); }
std::vector<DDim> shapes() const { return reader_->shapes(); } std::vector<DDim> shapes() const { return reader_->shapes(); }
void set_shapes(const std::vector<DDim>& shapes) {
reader_->set_shapes(shapes);
}
private: private:
std::unique_ptr<ReaderBase> reader_; std::unique_ptr<ReaderBase> reader_;
......
...@@ -62,6 +62,16 @@ void InferShapeContext::SetOutputsDim(const std::string &name, ...@@ -62,6 +62,16 @@ void InferShapeContext::SetOutputsDim(const std::string &name,
SetDims(names, dims); SetDims(names, dims);
} }
void InferShapeContext::SetReaderDims(const std::string &name,
const std::vector<DDim> &dims) {
const std::vector<std::string> &arg_names = Outputs(name);
PADDLE_ENFORCE_EQ(
arg_names.size(), 1UL,
"Reader output '%s' should hold one element, but now it holds %d", name,
arg_names.size());
return this->SetRepeatedDims(arg_names[0], dims);
}
std::vector<DDim> InferShapeContext::GetDims( std::vector<DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const { const std::vector<std::string> &names) const {
std::vector<DDim> ret; std::vector<DDim> ret;
......
...@@ -37,11 +37,12 @@ class InferShapeContext { ...@@ -37,11 +37,12 @@ class InferShapeContext {
DDim GetInputDim(const std::string &name) const; DDim GetInputDim(const std::string &name) const;
std::vector<DDim> GetInputsDim(const std::string &name) const; std::vector<DDim> GetInputsDim(const std::string &name) const;
std::vector<DDim> GetReaderDims(const std::string &name) const DDim; std::vector<DDim> GetReaderDims(const std::string &name) const;
DDim GetInputsElementDim(const std::string &name, int idx) const; DDim GetInputsElementDim(const std::string &name, int idx) const;
void SetOutputDim(const std::string &name, const DDim &dim); void SetOutputDim(const std::string &name, const DDim &dim);
void SetOutputsDim(const std::string &name, const std::vector<DDim> &dims); void SetOutputsDim(const std::string &name, const std::vector<DDim> &dims);
void SetReaderDims(const std::string &name, const std::vector<DDim> &dims);
virtual AttrReader Attrs() const = 0; virtual AttrReader Attrs() const = 0;
virtual const std::vector<std::string> &Inputs( virtual const std::vector<std::string> &Inputs(
...@@ -61,7 +62,9 @@ class InferShapeContext { ...@@ -61,7 +62,9 @@ class InferShapeContext {
protected: protected:
virtual DDim GetDim(const std::string &name) const = 0; virtual DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const DDim &dim) = 0; virtual void SetDim(const std::string &name, const DDim &dim) = 0;
std::vector<DDim> GetRepeatedDim(const std::string &name) const = 0; virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0;
virtual void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) = 0;
std::vector<DDim> GetDims(const std::vector<std::string> &names) const; std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarDesc::VarType> GetVarTypes( std::vector<proto::VarDesc::VarType> GetVarTypes(
......
...@@ -57,10 +57,13 @@ size_t VarDesc::GetTensorDescNum() const { ...@@ -57,10 +57,13 @@ size_t VarDesc::GetTensorDescNum() const {
void VarDesc::SetShapes( void VarDesc::SetShapes(
const std::vector<std::vector<int64_t>> &multiple_dims) { const std::vector<std::vector<int64_t>> &multiple_dims) {
PADDLE_ENFORCE_EQ(multiple_dims.size(), GetTensorDescNum(), if (multiple_dims.size() != GetTensorDescNum()) {
"The number of given shapes(%d) doesn't equal to the " VLOG(3) << "WARNING: The number of given shapes(" << multiple_dims.size()
"number of sub tensor.", << ") doesn't match the existing tensor number("
multiple_dims.size(), GetTensorDescNum()); << GetTensorDescNum()
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_dims.size());
}
std::vector<proto::TensorDesc *> tensors = mutable_tensor_descs(); std::vector<proto::TensorDesc *> tensors = mutable_tensor_descs();
for (size_t i = 0; i < multiple_dims.size(); ++i) { for (size_t i = 0; i < multiple_dims.size(); ++i) {
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims()); VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
...@@ -87,10 +90,14 @@ void VarDesc::SetDataType(proto::DataType data_type) { ...@@ -87,10 +90,14 @@ void VarDesc::SetDataType(proto::DataType data_type) {
void VarDesc::SetDataTypes( void VarDesc::SetDataTypes(
const std::vector<proto::DataType> &multiple_data_type) { const std::vector<proto::DataType> &multiple_data_type) {
PADDLE_ENFORCE_EQ(multiple_data_type.size(), GetTensorDescNum(), if (multiple_data_type.size() != GetTensorDescNum()) {
"The number of given data types(%d) doesn't equal to the " VLOG(3) << "WARNING: The number of given data types("
"number of sub tensor.", << multiple_data_type.size()
multiple_data_type.size(), GetTensorDescNum()); << ") doesn't match the existing tensor number("
<< GetTensorDescNum()
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_data_type.size());
}
std::vector<proto::TensorDesc *> tensor_descs = mutable_tensor_descs(); std::vector<proto::TensorDesc *> tensor_descs = mutable_tensor_descs();
for (size_t i = 0; i < multiple_data_type.size(); ++i) { for (size_t i = 0; i < multiple_data_type.size(); ++i) {
tensor_descs[i]->set_data_type(multiple_data_type[i]); tensor_descs[i]->set_data_type(multiple_data_type[i]);
...@@ -127,10 +134,14 @@ void VarDesc::SetLoDLevel(int32_t lod_level) { ...@@ -127,10 +134,14 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
} }
void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) { void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
PADDLE_ENFORCE_EQ(multiple_lod_level.size(), GetTensorDescNum(), if (multiple_lod_level.size() != GetTensorDescNum()) {
"The number of given data types(%d) doesn't equal to the " VLOG(3) << "WARNING: The number of given lod_levels("
"number of sub tensor.", << multiple_lod_level.size()
multiple_lod_level.size(), GetTensorDescNum()); << ") doesn't match the existing tensor number("
<< GetTensorDescNum()
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_lod_level.size());
}
switch (desc_.type()) { switch (desc_.type()) {
case proto::VarDesc::READER: { case proto::VarDesc::READER: {
size_t i = 0; size_t i = 0;
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/framework/lod_rank_table.h" #include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/reader.h"
#include "paddle/framework/selected_rows.h" #include "paddle/framework/selected_rows.h"
#include "paddle/framework/variable.h" #include "paddle/framework/variable.h"
...@@ -31,6 +32,8 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) { ...@@ -31,6 +32,8 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) {
return proto::VarDesc_VarType_LOD_TENSOR_ARRAY; return proto::VarDesc_VarType_LOD_TENSOR_ARRAY;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) { } else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
return proto::VarDesc_VarType_SELECTED_ROWS; return proto::VarDesc_VarType_SELECTED_ROWS;
} else if (type.hash_code() == typeid(ReaderHolder).hash_code()) {
return proto::VarDesc_VarType_READER;
} else { } else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
} }
...@@ -40,7 +43,7 @@ template <typename Visitor> ...@@ -40,7 +43,7 @@ template <typename Visitor>
inline void VisitVarType(const framework::Variable& var, Visitor visitor) { inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) { switch (ToVarType(var.Type())) {
case proto::VarDesc_VarType_LOD_TENSOR: case proto::VarDesc_VarType_LOD_TENSOR:
visitor(var.Get<framework::LoDTensor>()); visitor(var.Get<LoDTensor>());
return; return;
case proto::VarDesc_VarType_LOD_RANK_TABLE: case proto::VarDesc_VarType_LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>()); visitor(var.Get<LoDRankTable>());
...@@ -51,6 +54,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) { ...@@ -51,6 +54,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
case proto::VarDesc_VarType_SELECTED_ROWS: case proto::VarDesc_VarType_SELECTED_ROWS:
visitor(var.Get<SelectedRows>()); visitor(var.Get<SelectedRows>());
return; return;
case proto::VarDesc_VarType_READER:
visitor(var.Get<ReaderHolder>());
return;
default: default:
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type())); PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
} }
......
...@@ -18,12 +18,30 @@ ...@@ -18,12 +18,30 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
std::vector<framework::DDim> RestoreShapes(const std::vector<int>& shape_concat,
const std::vector<int>& ranks) {
std::vector<framework::DDim> res;
int offset = 0;
for (int len : ranks) {
auto start_it = shape_concat.begin() + offset;
auto end_it = start_it + len;
res.push_back(framework::make_ddim(std::vector<int>(start_it, end_it)));
offset += len;
}
return res;
}
// general infershape for file readers // general infershape for file readers
class CreateFileReaderInferShape : public framework::InferShapeBase { class CreateFileReaderInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override { void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output file reader should not be null."); "The output file reader should not be null.");
const auto shape_concat =
ctx->Attrs().Get<std::vector<int>>("shape_concat");
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
ctx->SetReaderDims("Out", shapes);
} }
}; };
...@@ -31,10 +49,22 @@ class CreateFileReaderInferShape : public framework::InferShapeBase { ...@@ -31,10 +49,22 @@ class CreateFileReaderInferShape : public framework::InferShapeBase {
class CreateDecoratedReaderInferShape : public framework::InferShapeBase { class CreateDecoratedReaderInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override { void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"), PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
"Input(Underlying_reader) should not be null."); "Input(UnderlyingReader) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null."); "The output decorated reader should not be null.");
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
}
};
// general var type inference for all readers
class CreateReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
std::string reader_name = op_desc.Output("Out")[0];
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
reader->SetType(framework::proto::VarDesc::READER);
} }
}; };
...@@ -51,15 +81,7 @@ class CreateRandomReaderOp : public framework::OperatorBase { ...@@ -51,15 +81,7 @@ class CreateRandomReaderOp : public framework::OperatorBase {
int(shape_concat.size()), int(shape_concat.size()),
"The accumulate of all ranks should be equal to the " "The accumulate of all ranks should be equal to the "
"shape concat's length."); "shape concat's length.");
std::vector<framework::DDim> shapes; std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
int offset = 0;
for (int len : ranks) {
auto start_it = shape_concat.begin() + offset;
auto end_it = start_it + len;
shapes.push_back(
framework::make_ddim(std::vector<int>(start_it, end_it)));
offset += len;
}
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new framework::RandomReader<T>(shapes, Attr<float>("min"), out->Reset(new framework::RandomReader<T>(shapes, Attr<float>("min"),
...@@ -99,7 +121,7 @@ class CreateShuffleReaderOp : public framework::OperatorBase { ...@@ -99,7 +121,7 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("Underlying_reader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
...@@ -113,7 +135,7 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -113,7 +135,7 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) { : OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput( AddInput(
"Underlying_reader", "UnderlyingReader",
"(ReaderHolder) The underlying reader for creating a shuffle reader."); "(ReaderHolder) The underlying reader for creating a shuffle reader.");
AddOutput("Out", "(ReaderHolder) The created shuffle reader."); AddOutput("Out", "(ReaderHolder) The created shuffle reader.");
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0); AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
...@@ -131,7 +153,7 @@ class CreateBatchReaderOp : public framework::OperatorBase { ...@@ -131,7 +153,7 @@ class CreateBatchReaderOp : public framework::OperatorBase {
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("Underlying_reader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
...@@ -145,7 +167,7 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -145,7 +167,7 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) { : OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput( AddInput(
"Underlying_reader", "UnderlyingReader",
"(ReaderHolder) The underlying reader for creating a batch reader."); "(ReaderHolder) The underlying reader for creating a batch reader.");
AddOutput("Out", "(ReaderHolder) The created batch reader."); AddOutput("Out", "(ReaderHolder) The created batch reader.");
AddAttr<int>("batch_size", AddAttr<int>("batch_size",
...@@ -167,12 +189,15 @@ namespace ops = paddle::operators; ...@@ -167,12 +189,15 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>, REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>,
ops::CreateFileReaderInferShape, ops::CreateFileReaderInferShape,
ops::CreateRandomReaderOpMaker, ops::CreateRandomReaderOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType);
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp, REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
ops::CreateDecoratedReaderInferShape, ops::CreateDecoratedReaderInferShape,
ops::CreateShuffleReaderOpMaker, ops::CreateShuffleReaderOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType);
REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp, REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp,
ops::CreateDecoratedReaderInferShape, ops::CreateDecoratedReaderInferShape,
ops::CreateBatchReaderOpMaker, ops::CreateBatchReaderOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType);
...@@ -25,7 +25,7 @@ class ReadInferShape : public framework::InferShapeBase { ...@@ -25,7 +25,7 @@ class ReadInferShape : public framework::InferShapeBase {
"The ReadOp must take a reader as input."); "The ReadOp must take a reader as input.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"), PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"The ReadOp should be assigned with output."); "The ReadOp should be assigned with output.");
std::vector<DDim> reader_dims = ctx->GetReaderDims("Reader"); std::vector<framework::DDim> reader_dims = ctx->GetReaderDims("Reader");
std::vector<std::string> out_names = ctx->Outputs("Out"); std::vector<std::string> out_names = ctx->Outputs("Out");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
reader_dims.size(), out_names.size(), reader_dims.size(), out_names.size(),
...@@ -40,12 +40,12 @@ class ReadInferVarType : public framework::VarTypeInference { ...@@ -40,12 +40,12 @@ class ReadInferVarType : public framework::VarTypeInference {
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
std::string reader_name = op_desc.Input("Reader")[0]; std::string reader_name = op_desc.Input("Reader")[0];
std::vector<std::string> out_names = op_desc.Output("Out"); std::vector<std::string> out_names = op_desc.Output("Out");
framework::VarDesc reader = block.FindVarRecursive(reader_name); framework::VarDesc* reader = block->FindVarRecursive(reader_name);
auto dtypes = reader.GetDataTypes(); auto dtypes = reader->GetDataTypes();
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
for (size_t i = 0; i < dtypes.size(); ++i) { for (size_t i = 0; i < dtypes.size(); ++i) {
faremwork::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]);
out.SetType(framework::proto::DataType::LOD_TENSOR); out.SetType(framework::proto::VarDesc::LOD_TENSOR);
out.SetDataType(dtypes[i]); out.SetDataType(dtypes[i]);
} }
} }
...@@ -56,20 +56,18 @@ class ReadOp : public framework::OperatorBase { ...@@ -56,20 +56,18 @@ class ReadOp : public framework::OperatorBase {
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const framework::ReaderHolder& reader = framework::ReaderHolder* reader =
scope.FindVar(Input("Reader"))->Get<ReaderHolder>(); scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>();
if (!reader.HasNext()) { if (!reader->HasNext()) {
// what shall we do???
return; return;
} }
std::vector<std::string> out_arg_names = Outputs("Out"); std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins; std::vector<framework::LoDTensor> ins;
reader.ReadNext(&ins); reader->ReadNext(&ins);
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
auto* out = auto* out =
scope.FindVar(out_arg_names[i])->GetMutable<framework::LoDTensor>(); scope.FindVar(out_arg_names[i])->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(ins[i].dims(), out->dims());
out->ShareDataWith(ins[i]); out->ShareDataWith(ins[i]);
out->set_lod(ins[i].lod()); out->set_lod(ins[i].lod());
} }
...@@ -86,9 +84,13 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -86,9 +84,13 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
Read Operator Read Operator
Execute a given reader once and output data. Execute a given reader once and output data.
)DOC") )DOC");
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
\ No newline at end of file
namespace ops = paddle::operators;
REGISTER_OPERATOR(read, ops::ReadOp, ops::ReadInferShape, ops::ReadOpMaker,
paddle::framework::EmptyGradOpMaker, ops::ReadInferVarType);
...@@ -217,8 +217,6 @@ void BindVarDsec(py::module &m) { ...@@ -217,8 +217,6 @@ void BindVarDsec(py::module &m) {
.def("set_shapes", &VarDesc::SetShapes) .def("set_shapes", &VarDesc::SetShapes)
.def("set_dtype", &VarDesc::SetDataType) .def("set_dtype", &VarDesc::SetDataType)
.def("set_dtypes", &VarDesc::SetDataTypes) .def("set_dtypes", &VarDesc::SetDataTypes)
.def("set_tensor_num", &VarDesc::SetTensorDescNum)
.def("tensor_num", &VarDesc::GetTensorDescNum)
.def("shape", &VarDesc::GetShape, py::return_value_policy::reference) .def("shape", &VarDesc::GetShape, py::return_value_policy::reference)
.def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference) .def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference)
.def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference) .def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference)
......
...@@ -51,7 +51,8 @@ def as_numpy(tensor): ...@@ -51,7 +51,8 @@ def as_numpy(tensor):
if len(lod) == 0: if len(lod) == 0:
ans = tensor_data ans = tensor_data
else: else:
raise RuntimeError("LoD Calculate lacks unit tests and buggy") #raise RuntimeError("LoD Calculate lacks unit tests and buggy")
ans = tensor_data
# elif len(lod) == 1: # elif len(lod) == 1:
# ans = [] # ans = []
# idx = 0 # idx = 0
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import numpy as np
prog = fluid.framework.Program()
block = prog.current_block()
random_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="RandomReader")
random_reader.desc.set_lod_levels([0, 0])
create_random_reader_op = block.append_op(
type="create_random_reader",
outputs={"Out": random_reader},
attrs={
"shape_concat": [1, 2, 1, 1],
"ranks": [2, 2],
"min": 0.0,
"max": 1.0
})
batch_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name=("BatchReader"))
batch_reader.desc.set_lod_levels([0, 0])
create_batch_reader_op = block.append_op(
type="create_batch_reader",
inputs={"UnderlyingReader": random_reader},
outputs={"Out": batch_reader},
attrs={"batch_size": 10})
out1 = block.create_var(
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
name="Out1",
shape=[10, 2],
dtype="float32",
lod_level=1)
out2 = block.create_var(
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
name="Out2",
shape=[10, 1],
dtype="float32",
lod_level=1)
read_op = block.append_op(
type="read", inputs={"Reader": batch_reader},
outputs={"Out": [out1, out2]})
place = fluid.CPUPlace()
exe = fluid.Executor(place)
[res1, res2] = exe.run(prog, fetch_list=[out1, out2])
if len(res1) == 0 or len(res2) == 0:
exit(1)
exit(0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册