提交 1010e39b 编写于 作者: F fengjiayi

Add ReadOp

上级 6e6f5c7e
...@@ -116,7 +116,7 @@ message LoDTensorArrayDesc { ...@@ -116,7 +116,7 @@ message LoDTensorArrayDesc {
optional int32 lod_level = 2 [ default = 0 ]; optional int32 lod_level = 2 [ default = 0 ];
} }
message Reader { repeated LoDTensorDesc lod_tensor = 1; } message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
message VarDesc { message VarDesc {
enum VarType { enum VarType {
...@@ -136,7 +136,7 @@ message VarDesc { ...@@ -136,7 +136,7 @@ message VarDesc {
optional LoDTensorDesc lod_tensor = 4; optional LoDTensorDesc lod_tensor = 4;
optional TensorDesc selected_rows = 5; optional TensorDesc selected_rows = 5;
optional LoDTensorArrayDesc tensor_array = 6; optional LoDTensorArrayDesc tensor_array = 6;
optional Reader reader = 7; optional ReaderDesc reader = 7;
} }
message BlockDesc { message BlockDesc {
......
...@@ -72,6 +72,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -72,6 +72,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void SetDim(const std::string &name, const DDim &dim) override; void SetDim(const std::string &name, const DDim &dim) override;
std::vector<DDim> GetRepeatedDim(const std::string &name) const override;
const OpDesc &op_; const OpDesc &op_;
const BlockDesc &block_; const BlockDesc &block_;
}; };
...@@ -457,22 +459,37 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs( ...@@ -457,22 +459,37 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const { DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name); auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
DDim res;
try { try {
auto shape = var->GetShape(); auto shape = var->GetShape();
if (shape.empty()) { res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape);
return framework::make_ddim({0UL});
} else {
return framework::make_ddim(var->GetShape());
}
} catch (...) { } catch (...) {
VLOG(5) << "GetDim of variable " << name << " error"; VLOG(5) << "GetDim of variable " << name << " error";
std::rethrow_exception(std::current_exception()); std::rethrow_exception(std::current_exception());
} }
return res;
}
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDim(
const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
std::vector<DDim> res;
try {
auto shapes = var->GetShapes();
for (const auto &s : shapes) {
res.push_back(s.empty() ? make_ddim({0UL}) : make_ddim(s));
}
} catch (...) {
VLOG(5) << "GetRepeatedDim of variable " << name << " error.";
std::rethrow_exception(std::current_exception());
}
return res;
} }
void CompileTimeInferShapeContext::SetDim(const std::string &name, void CompileTimeInferShapeContext::SetDim(const std::string &name,
const DDim &dim) { const DDim &dim) {
block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim)); block_.FindVarRecursive(name)->SetShape(vectorize(dim));
} }
bool CompileTimeInferShapeContext::IsRuntime() const { return false; } bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
......
...@@ -320,8 +320,8 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -320,8 +320,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
if (length == 0) { if (length == 0) {
return false; return false;
} }
PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs", PADDLE_ENFORCE_EQ(length, 1UL,
name); "Input %s should not have more than one inputs", name);
auto ipt = ins[0]; auto ipt = ins[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr; return var != nullptr;
...@@ -333,8 +333,8 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -333,8 +333,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
if (length == 0) { if (length == 0) {
return false; return false;
} }
PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs", PADDLE_ENFORCE_EQ(length, 1UL,
name); "Output %s should not have more than one inputs", name);
auto ipt = outs[0]; auto ipt = outs[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr; return var != nullptr;
...@@ -421,8 +421,22 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -421,8 +421,22 @@ class RuntimeInferShapeContext : public InferShapeContext {
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims(); return var->Get<SelectedRows>().GetCompleteDims();
} else { } else {
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", PADDLE_THROW(
name, var->Type().name()); "Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
}
std::vector<DDim> GetRepeatedDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
return var->Get<ReaderHolder>().shapes();
} else {
PADDLE_THROW(
"Only ReaderHolder support 'GetRepeatedDim', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
} }
} }
......
...@@ -25,13 +25,15 @@ DDim FileReader::shape(size_t idx) const { ...@@ -25,13 +25,15 @@ DDim FileReader::shape(size_t idx) const {
return shapes_[idx]; return shapes_[idx];
} }
std::vector<LoDTensor> ShuffleReader::ReadNext() { void ShuffleReader::ReadNext(std::vector<LoDtensor>* out) {
if (iteration_pos_ >= buffer_.size()) { if (iteration_pos_ >= buffer_.size()) {
// Reload buffer with new data // Reload buffer with new data
buffer_.clear(); buffer_.clear();
buffer_.reverse(buffer_size_);
for (int i = 0; i < buffer_size_; ++i) { for (int i = 0; i < buffer_size_; ++i) {
if (reader_->HasNext()) { if (reader_->HasNext()) {
buffer_.push_back(reader_->ReadNext()); buffer.push_back(std::vector<LoDTensor>());
reader_->ReadNext(&buffer.back());
} else { } else {
break; break;
} }
...@@ -39,29 +41,32 @@ std::vector<LoDTensor> ShuffleReader::ReadNext() { ...@@ -39,29 +41,32 @@ std::vector<LoDTensor> ShuffleReader::ReadNext() {
std::random_shuffle(buffer_.begin(), buffer_.end()); std::random_shuffle(buffer_.begin(), buffer_.end());
iteration_pos_ = 0; iteration_pos_ = 0;
} }
if (buffer_.empty()) { out->clear();
std::vector<LoDTensor> empty_res; if (!buffer_.empty()) {
return empty_res; std::swap(*out, buffer_[iteration_pos_++]);
} }
return buffer_[iteration_pos_++]; // if buffer_ is empty, the 'out' will return as an empty vector.
} }
std::vector<LoDTensor> BatchReader::ReadNext() { void BatchReader::ReadNext(std::vector<LoDtensor>* out) {
buffer_.clear(); buffer_.clear();
buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) { for (int i = 0; i < batch_size_; ++i) {
if (reader_->HasNext()) { if (reader_->HasNext()) {
buffer_.push_back(reader_->ReadNext()); buffer_.push_back(std::vector<LoDtensor>());
reader_->ReadNext(&buffer_.back());
} else { } else {
break; break;
} }
} }
// Concat instances // Concat instances
std::vector<LoDTensor> res; out.clear();
if (buffer_.empty()) { if (buffer_.empty()) {
return res; // if buffer_ is empty, the 'out' will return as an empty vector.
return;
} }
int out_num = buffer_[0].size(); int out_num = buffer_[0].size();
res.reserve(out_num); out->reserve(out_num);
for (int j = 0; j < out_num; ++j) { for (int j = 0; j < out_num; ++j) {
// Merge shape and check date type // Merge shape and check date type
std::type_index batch_type = buffer_[0][j].type(); std::type_index batch_type = buffer_[0][j].type();
...@@ -76,9 +81,9 @@ std::vector<LoDTensor> BatchReader::ReadNext() { ...@@ -76,9 +81,9 @@ std::vector<LoDTensor> BatchReader::ReadNext() {
batch_shape[0] += ins_shape[0]; batch_shape[0] += ins_shape[0];
} }
LoDTensor out; LoDTensor out_tensor;
out.Resize(batch_shape); out_tensor.Resize(batch_shape);
out.mutable_data(platform::CPUPlace(), batch_type); out_tensor.mutable_data(platform::CPUPlace(), batch_type);
int64_t dst_offset = 0; int64_t dst_offset = 0;
// Merge lod and data // Merge lod and data
...@@ -102,15 +107,14 @@ std::vector<LoDTensor> BatchReader::ReadNext() { ...@@ -102,15 +107,14 @@ std::vector<LoDTensor> BatchReader::ReadNext() {
top_level_lod.back() + top_level_lod.back() +
(ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1))); (ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1)));
Tensor dst = out.Slice(dst_offset, dst_offset + ins_shape[0]); Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]);
Copy(buffer_[i][j], platform::CPUPlace(), &dst); Copy(buffer_[i][j], platform::CPUPlace(), &dst);
dst_offset += ins_shape[0]; dst_offset += ins_shape[0];
} }
batch_lod.insert(batch_lod.begin(), top_level_lod); batch_lod.insert(batch_lod.begin(), top_level_lod);
out.set_lod(batch_lod); out_tensor.set_lod(batch_lod);
res.push_back(out); out->push_back(out_tensor);
} }
return res;
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,14 +15,14 @@ ...@@ -15,14 +15,14 @@
#pragma once #pragma once
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor_array.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class ReaderBase { class ReaderBase {
public: public:
virtual std::vector<LoDTensor> ReadNext() = 0; virtual void ReadNext(std::vector<LoDtensor>* out) = 0;
virtual bool HasNext() const = 0; virtual bool HasNext() const = 0;
virtual DDim shape(size_t idx) const = 0; virtual DDim shape(size_t idx) const = 0;
...@@ -73,24 +73,24 @@ class RandomReader : public FileReader { ...@@ -73,24 +73,24 @@ class RandomReader : public FileReader {
dist_ = std::uniform_real_distribution<float>(min_, max_); dist_ = std::uniform_real_distribution<float>(min_, max_);
} }
std::vector<LoDTensor> ReadNext() override { void ReadNext(std::vector<LoDtensor>* out) override {
std::vector<LoDTensor> res; out.clear();
res.reserve(shapes_.size()); out.reserve(shapes_.size());
for (const DDim& shape : shapes_) { for (const DDim& shape : shapes_) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
shape.size(), 2, shape.size(), 2,
"The rank of input data should be 2 at least.(Now it's %d)", "The rank of reader's output data should be 2 at least.(Now it's %d)",
shape.size()); shape.size());
LoDTensor out; LoDTensor out_tensor;
out.Resize(shape); out_tensor.Resize(shape);
T* data = out.mutable_data<T>(platform::CPUPlace()); T* data = out_tensor.mutable_data<T>(platform::CPUPlace());
int64_t numel = product(shape); int64_t numel = product(shape);
for (int64_t i = 0; i < numel; ++i) { for (int64_t i = 0; i < numel; ++i) {
data[i] = dist_(engine_); data[i] = dist_(engine_);
} }
res.push_back(out); out.push_back(out_tensor);
} }
return res; return out;
} }
bool HasNext() const override { return true; } bool HasNext() const override { return true; }
...@@ -111,11 +111,11 @@ class ShuffleReader : public DecoratedReader { ...@@ -111,11 +111,11 @@ class ShuffleReader : public DecoratedReader {
buffer_.reserve(buffer_size); buffer_.reserve(buffer_size);
} }
std::vector<LoDTensor> ReadNext() override; void ReadNext(std::vector<LoDtensor>* out) override;
private: private:
int buffer_size_; int buffer_size_;
std::vector<std::vector<LoDTensor>> buffer_; std::vector<std::vector<LoDtensor>> buffer_;
size_t iteration_pos_; size_t iteration_pos_;
}; };
...@@ -126,11 +126,11 @@ class BatchReader : public DecoratedReader { ...@@ -126,11 +126,11 @@ class BatchReader : public DecoratedReader {
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
} }
std::vector<LoDTensor> ReadNext() override; void ReadNext(std::vector<LoDtensor>* out) override;
private: private:
int batch_size_; int batch_size_;
std::vector<std::vector<LoDTensor>> buffer_; std::vector<std::vector<LoDtensor>> buffer_;
}; };
// The ReaderHolder is used as readers' unified wrapper, // The ReaderHolder is used as readers' unified wrapper,
...@@ -141,7 +141,7 @@ class ReaderHolder { ...@@ -141,7 +141,7 @@ class ReaderHolder {
ReaderBase* Get() const { return reader_.get(); } ReaderBase* Get() const { return reader_.get(); }
std::vector<LoDTensor> ReadNext() { return reader_->ReadNext(); } void ReadNext(std::vector<LoDtensor>* out) { reader_->ReadNext(out); }
bool HasNext() const { return reader_->HasNext(); } bool HasNext() const { return reader_->HasNext(); }
DDim shape(size_t idx) const { return reader_->shape(idx); } DDim shape(size_t idx) const { return reader_->shape(idx); }
......
...@@ -32,6 +32,16 @@ std::vector<DDim> InferShapeContext::GetInputsDim( ...@@ -32,6 +32,16 @@ std::vector<DDim> InferShapeContext::GetInputsDim(
return GetDims(arg_names); return GetDims(arg_names);
} }
std::vector<DDim> InferShapeContext::GetReaderDims(
const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name);
PADDLE_ENFORCE_EQ(
arg_names.size(), 1UL,
"Reader input '%s' should hold one element, but now it holds %d", name,
arg_names.size());
return this->GetRepeatedDims(arg_names[0]);
}
DDim InferShapeContext::GetInputsElementDim(const std::string &name, DDim InferShapeContext::GetInputsElementDim(const std::string &name,
int idx) const { int idx) const {
const std::vector<std::string> &names = Inputs(name); const std::vector<std::string> &names = Inputs(name);
...@@ -61,6 +71,7 @@ std::vector<DDim> InferShapeContext::GetDims( ...@@ -61,6 +71,7 @@ std::vector<DDim> InferShapeContext::GetDims(
[this](const std::string &name) { return this->GetDim(name); }); [this](const std::string &name) { return this->GetDim(name); });
return ret; return ret;
} }
void InferShapeContext::SetDims(const std::vector<std::string> &names, void InferShapeContext::SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims) { const std::vector<DDim> &dims) {
size_t length = names.size(); size_t length = names.size();
...@@ -72,14 +83,17 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names, ...@@ -72,14 +83,17 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
SetDim(names[i], dims[i]); SetDim(names[i], dims[i]);
} }
} }
std::vector<proto::VarDesc::VarType> InferShapeContext::GetInputsVarType( std::vector<proto::VarDesc::VarType> InferShapeContext::GetInputsVarType(
const std::string &name) const { const std::string &name) const {
return GetVarTypes(Inputs(name)); return GetVarTypes(Inputs(name));
} }
std::vector<proto::VarDesc::VarType> InferShapeContext::GetOutputsVarType( std::vector<proto::VarDesc::VarType> InferShapeContext::GetOutputsVarType(
const std::string &name) const { const std::string &name) const {
return GetVarTypes(Outputs(name)); return GetVarTypes(Outputs(name));
} }
std::vector<proto::VarDesc::VarType> InferShapeContext::GetVarTypes( std::vector<proto::VarDesc::VarType> InferShapeContext::GetVarTypes(
const std::vector<std::string> &names) const { const std::vector<std::string> &names) const {
std::vector<proto::VarDesc::VarType> retv; std::vector<proto::VarDesc::VarType> retv;
......
...@@ -36,8 +36,8 @@ class InferShapeContext { ...@@ -36,8 +36,8 @@ class InferShapeContext {
virtual bool HasOutputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0;
DDim GetInputDim(const std::string &name) const; DDim GetInputDim(const std::string &name) const;
std::vector<DDim> GetInputsDim(const std::string &name) const; std::vector<DDim> GetInputsDim(const std::string &name) const;
std::vector<DDim> GetReaderDims(const std::string &name) const DDim;
DDim GetInputsElementDim(const std::string &name, int idx) const; DDim GetInputsElementDim(const std::string &name, int idx) const;
void SetOutputDim(const std::string &name, const DDim &dim); void SetOutputDim(const std::string &name, const DDim &dim);
...@@ -61,6 +61,7 @@ class InferShapeContext { ...@@ -61,6 +61,7 @@ class InferShapeContext {
protected: protected:
virtual DDim GetDim(const std::string &name) const = 0; virtual DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const DDim &dim) = 0; virtual void SetDim(const std::string &name, const DDim &dim) = 0;
std::vector<DDim> GetRepeatedDim(const std::string &name) const = 0;
std::vector<DDim> GetDims(const std::vector<std::string> &names) const; std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarDesc::VarType> GetVarTypes( std::vector<proto::VarDesc::VarType> GetVarTypes(
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/framework/op_registry.h"
#include "paddle/framework/reader.h"
namespace paddle {
namespace operators {
class ReadInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Reader"),
"The ReadOp must take a reader as input.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"The ReadOp should be assigned with output.");
std::vector<DDim> reader_dims = ctx->GetReaderDims("Reader");
std::vector<std::string> out_names = ctx->Outputs("Out");
PADDLE_ENFORCE_EQ(
reader_dims.size(), out_names.size(),
"The reader's dim number doesn't match the output number.");
ctx->SetOutputsDim("Out", reader_dims);
}
};
class ReadInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
std::string reader_name = op_desc.Input("Reader")[0];
std::vector<std::string> out_names = op_desc.Output("Out");
framework::VarDesc reader = block.FindVarRecursive(reader_name);
auto dtypes = reader.GetDataTypes();
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
for (size_t i = 0; i < dtypes.size(); ++i) {
faremwork::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]);
out.SetType(framework::proto::DataType::LOD_TENSOR);
out.SetDataType(dtypes[i]);
}
}
};
class ReadOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const framework::ReaderHolder& reader =
scope.FindVar(Input("Reader"))->Get<ReaderHolder>();
if (!reader.HasNext()) {
// what shall we do???
return;
}
std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins;
reader.ReadNext(&ins);
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < ins.size(); ++i) {
auto* out =
scope.FindVar(out_arg_names[i])->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(ins[i].dims(), out->dims());
out->ShareDataWith(ins[i]);
out->set_lod(ins[i].lod());
}
}
};
class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReadOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput("Reader", "(ReaderHolder) The executed reader.");
AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable();
AddComment(R"DOC(
Read Operator
Execute a given reader once and output data.
)DOC")
}
};
} // namespace operators
} // namespace paddle
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册