提交 1010e39b 编写于 作者: F fengjiayi

Add ReadOp

上级 6e6f5c7e
......@@ -116,7 +116,7 @@ message LoDTensorArrayDesc {
optional int32 lod_level = 2 [ default = 0 ];
}
message Reader { repeated LoDTensorDesc lod_tensor = 1; }
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
message VarDesc {
enum VarType {
......@@ -136,7 +136,7 @@ message VarDesc {
optional LoDTensorDesc lod_tensor = 4;
optional TensorDesc selected_rows = 5;
optional LoDTensorArrayDesc tensor_array = 6;
optional Reader reader = 7;
optional ReaderDesc reader = 7;
}
message BlockDesc {
......
......@@ -72,6 +72,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void SetDim(const std::string &name, const DDim &dim) override;
std::vector<DDim> GetRepeatedDim(const std::string &name) const override;
const OpDesc &op_;
const BlockDesc &block_;
};
......@@ -457,22 +459,37 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
DDim res;
try {
auto shape = var->GetShape();
if (shape.empty()) {
return framework::make_ddim({0UL});
} else {
return framework::make_ddim(var->GetShape());
}
res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape);
} catch (...) {
VLOG(5) << "GetDim of variable " << name << " error";
std::rethrow_exception(std::current_exception());
}
return res;
}
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDim(
const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
std::vector<DDim> res;
try {
auto shapes = var->GetShapes();
for (const auto &s : shapes) {
res.push_back(s.empty() ? make_ddim({0UL}) : make_ddim(s));
}
} catch (...) {
VLOG(5) << "GetRepeatedDim of variable " << name << " error.";
std::rethrow_exception(std::current_exception());
}
return res;
}
void CompileTimeInferShapeContext::SetDim(const std::string &name,
const DDim &dim) {
block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
block_.FindVarRecursive(name)->SetShape(vectorize(dim));
}
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
......
......@@ -320,8 +320,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs",
name);
PADDLE_ENFORCE_EQ(length, 1UL,
"Input %s should not have more than one inputs", name);
auto ipt = ins[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
......@@ -333,8 +333,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs",
name);
PADDLE_ENFORCE_EQ(length, 1UL,
"Output %s should not have more than one inputs", name);
auto ipt = outs[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
......@@ -421,8 +421,22 @@ class RuntimeInferShapeContext : public InferShapeContext {
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.",
name, var->Type().name());
PADDLE_THROW(
"Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
}
std::vector<DDim> GetRepeatedDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
return var->Get<ReaderHolder>().shapes();
} else {
PADDLE_THROW(
"Only ReaderHolder support 'GetRepeatedDim', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
}
......
......@@ -25,13 +25,15 @@ DDim FileReader::shape(size_t idx) const {
return shapes_[idx];
}
std::vector<LoDTensor> ShuffleReader::ReadNext() {
void ShuffleReader::ReadNext(std::vector<LoDtensor>* out) {
if (iteration_pos_ >= buffer_.size()) {
// Reload buffer with new data
buffer_.clear();
buffer_.reverse(buffer_size_);
for (int i = 0; i < buffer_size_; ++i) {
if (reader_->HasNext()) {
buffer_.push_back(reader_->ReadNext());
buffer.push_back(std::vector<LoDTensor>());
reader_->ReadNext(&buffer.back());
} else {
break;
}
......@@ -39,29 +41,32 @@ std::vector<LoDTensor> ShuffleReader::ReadNext() {
std::random_shuffle(buffer_.begin(), buffer_.end());
iteration_pos_ = 0;
}
if (buffer_.empty()) {
std::vector<LoDTensor> empty_res;
return empty_res;
out->clear();
if (!buffer_.empty()) {
std::swap(*out, buffer_[iteration_pos_++]);
}
return buffer_[iteration_pos_++];
// if buffer_ is empty, the 'out' will return as an empty vector.
}
std::vector<LoDTensor> BatchReader::ReadNext() {
void BatchReader::ReadNext(std::vector<LoDtensor>* out) {
buffer_.clear();
buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) {
if (reader_->HasNext()) {
buffer_.push_back(reader_->ReadNext());
buffer_.push_back(std::vector<LoDtensor>());
reader_->ReadNext(&buffer_.back());
} else {
break;
}
}
// Concat instances
std::vector<LoDTensor> res;
out.clear();
if (buffer_.empty()) {
return res;
// if buffer_ is empty, the 'out' will return as an empty vector.
return;
}
int out_num = buffer_[0].size();
res.reserve(out_num);
out->reserve(out_num);
for (int j = 0; j < out_num; ++j) {
// Merge shape and check date type
std::type_index batch_type = buffer_[0][j].type();
......@@ -76,9 +81,9 @@ std::vector<LoDTensor> BatchReader::ReadNext() {
batch_shape[0] += ins_shape[0];
}
LoDTensor out;
out.Resize(batch_shape);
out.mutable_data(platform::CPUPlace(), batch_type);
LoDTensor out_tensor;
out_tensor.Resize(batch_shape);
out_tensor.mutable_data(platform::CPUPlace(), batch_type);
int64_t dst_offset = 0;
// Merge lod and data
......@@ -102,15 +107,14 @@ std::vector<LoDTensor> BatchReader::ReadNext() {
top_level_lod.back() +
(ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1)));
Tensor dst = out.Slice(dst_offset, dst_offset + ins_shape[0]);
Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]);
Copy(buffer_[i][j], platform::CPUPlace(), &dst);
dst_offset += ins_shape[0];
}
batch_lod.insert(batch_lod.begin(), top_level_lod);
out.set_lod(batch_lod);
res.push_back(out);
out_tensor.set_lod(batch_lod);
out->push_back(out_tensor);
}
return res;
}
} // namespace framework
} // namespace paddle
......@@ -15,14 +15,14 @@
#pragma once
#include "paddle/framework/ddim.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor_array.h"
namespace paddle {
namespace framework {
class ReaderBase {
public:
virtual std::vector<LoDTensor> ReadNext() = 0;
virtual void ReadNext(std::vector<LoDtensor>* out) = 0;
virtual bool HasNext() const = 0;
virtual DDim shape(size_t idx) const = 0;
......@@ -73,24 +73,24 @@ class RandomReader : public FileReader {
dist_ = std::uniform_real_distribution<float>(min_, max_);
}
std::vector<LoDTensor> ReadNext() override {
std::vector<LoDTensor> res;
res.reserve(shapes_.size());
void ReadNext(std::vector<LoDtensor>* out) override {
out.clear();
out.reserve(shapes_.size());
for (const DDim& shape : shapes_) {
PADDLE_ENFORCE_GE(
shape.size(), 2,
"The rank of input data should be 2 at least.(Now it's %d)",
"The rank of reader's output data should be 2 at least.(Now it's %d)",
shape.size());
LoDTensor out;
out.Resize(shape);
T* data = out.mutable_data<T>(platform::CPUPlace());
LoDTensor out_tensor;
out_tensor.Resize(shape);
T* data = out_tensor.mutable_data<T>(platform::CPUPlace());
int64_t numel = product(shape);
for (int64_t i = 0; i < numel; ++i) {
data[i] = dist_(engine_);
}
res.push_back(out);
out.push_back(out_tensor);
}
return res;
return out;
}
bool HasNext() const override { return true; }
......@@ -111,11 +111,11 @@ class ShuffleReader : public DecoratedReader {
buffer_.reserve(buffer_size);
}
std::vector<LoDTensor> ReadNext() override;
void ReadNext(std::vector<LoDtensor>* out) override;
private:
int buffer_size_;
std::vector<std::vector<LoDTensor>> buffer_;
std::vector<std::vector<LoDtensor>> buffer_;
size_t iteration_pos_;
};
......@@ -126,11 +126,11 @@ class BatchReader : public DecoratedReader {
buffer_.reserve(batch_size_);
}
std::vector<LoDTensor> ReadNext() override;
void ReadNext(std::vector<LoDtensor>* out) override;
private:
int batch_size_;
std::vector<std::vector<LoDTensor>> buffer_;
std::vector<std::vector<LoDtensor>> buffer_;
};
// The ReaderHolder is used as readers' unified wrapper,
......@@ -141,7 +141,7 @@ class ReaderHolder {
ReaderBase* Get() const { return reader_.get(); }
std::vector<LoDTensor> ReadNext() { return reader_->ReadNext(); }
void ReadNext(std::vector<LoDtensor>* out) { reader_->ReadNext(out); }
bool HasNext() const { return reader_->HasNext(); }
DDim shape(size_t idx) const { return reader_->shape(idx); }
......
......@@ -32,6 +32,16 @@ std::vector<DDim> InferShapeContext::GetInputsDim(
return GetDims(arg_names);
}
std::vector<DDim> InferShapeContext::GetReaderDims(
const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name);
PADDLE_ENFORCE_EQ(
arg_names.size(), 1UL,
"Reader input '%s' should hold one element, but now it holds %d", name,
arg_names.size());
return this->GetRepeatedDims(arg_names[0]);
}
DDim InferShapeContext::GetInputsElementDim(const std::string &name,
int idx) const {
const std::vector<std::string> &names = Inputs(name);
......@@ -61,6 +71,7 @@ std::vector<DDim> InferShapeContext::GetDims(
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}
void InferShapeContext::SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims) {
size_t length = names.size();
......@@ -72,14 +83,17 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
SetDim(names[i], dims[i]);
}
}
std::vector<proto::VarDesc::VarType> InferShapeContext::GetInputsVarType(
const std::string &name) const {
return GetVarTypes(Inputs(name));
}
std::vector<proto::VarDesc::VarType> InferShapeContext::GetOutputsVarType(
const std::string &name) const {
return GetVarTypes(Outputs(name));
}
std::vector<proto::VarDesc::VarType> InferShapeContext::GetVarTypes(
const std::vector<std::string> &names) const {
std::vector<proto::VarDesc::VarType> retv;
......
......@@ -36,8 +36,8 @@ class InferShapeContext {
virtual bool HasOutputs(const std::string &name) const = 0;
DDim GetInputDim(const std::string &name) const;
std::vector<DDim> GetInputsDim(const std::string &name) const;
std::vector<DDim> GetReaderDims(const std::string &name) const DDim;
DDim GetInputsElementDim(const std::string &name, int idx) const;
void SetOutputDim(const std::string &name, const DDim &dim);
......@@ -61,6 +61,7 @@ class InferShapeContext {
protected:
virtual DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const DDim &dim) = 0;
std::vector<DDim> GetRepeatedDim(const std::string &name) const = 0;
std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarDesc::VarType> GetVarTypes(
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/framework/op_registry.h"
#include "paddle/framework/reader.h"
namespace paddle {
namespace operators {
class ReadInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Reader"),
"The ReadOp must take a reader as input.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"The ReadOp should be assigned with output.");
std::vector<DDim> reader_dims = ctx->GetReaderDims("Reader");
std::vector<std::string> out_names = ctx->Outputs("Out");
PADDLE_ENFORCE_EQ(
reader_dims.size(), out_names.size(),
"The reader's dim number doesn't match the output number.");
ctx->SetOutputsDim("Out", reader_dims);
}
};
class ReadInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
std::string reader_name = op_desc.Input("Reader")[0];
std::vector<std::string> out_names = op_desc.Output("Out");
framework::VarDesc reader = block.FindVarRecursive(reader_name);
auto dtypes = reader.GetDataTypes();
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
for (size_t i = 0; i < dtypes.size(); ++i) {
faremwork::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]);
out.SetType(framework::proto::DataType::LOD_TENSOR);
out.SetDataType(dtypes[i]);
}
}
};
class ReadOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const framework::ReaderHolder& reader =
scope.FindVar(Input("Reader"))->Get<ReaderHolder>();
if (!reader.HasNext()) {
// what shall we do???
return;
}
std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins;
reader.ReadNext(&ins);
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < ins.size(); ++i) {
auto* out =
scope.FindVar(out_arg_names[i])->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(ins[i].dims(), out->dims());
out->ShareDataWith(ins[i]);
out->set_lod(ins[i].lod());
}
}
};
class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReadOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput("Reader", "(ReaderHolder) The executed reader.");
AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable();
AddComment(R"DOC(
Read Operator
Execute a given reader once and output data.
)DOC")
}
};
} // namespace operators
} // namespace paddle
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册