未验证 提交 95ba4bd2 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add shape and type check at read_op (#20754)

上级 bb8d7783
......@@ -20,6 +20,7 @@
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
......@@ -28,6 +29,20 @@ namespace framework {
class ReaderBase {
public:
explicit ReaderBase(const std::vector<DDim>& shapes,
const std::vector<proto::VarType::Type>& var_types,
const std::vector<bool>& need_check_feed)
: shapes_(shapes),
var_types_(var_types),
need_check_feed_(need_check_feed) {
PADDLE_ENFORCE_EQ(shapes_.size(), need_check_feed_.size(),
"Construct ReaderBase with mismatched sizes of shapes "
"and need_check_feed");
PADDLE_ENFORCE_EQ(var_types_.size(), need_check_feed_.size(),
"Construct ReaderBase with mismatched sizes of var_types "
"and need_check_feed");
}
virtual void ReadNext(std::vector<LoDTensor>* out);
virtual void Shutdown();
......@@ -38,6 +53,18 @@ class ReaderBase {
// they are readers just before read op.
std::unordered_set<ReaderBase*> GetEndPoints();
// Returns the shapes of the feeded variables
const std::vector<DDim>& Shapes() const { return shapes_; }
// Returns the dtypes of the feeded variables
const std::vector<proto::VarType::Type>& VarTypes() const {
return var_types_;
}
// For Backward compatibility, old fluid.layers.data doesn't check shape.
// This function returns whether you have the check shape for this Reader.
const std::vector<bool>& NeedCheckFeed() const { return need_check_feed_; }
virtual ~ReaderBase();
protected:
......@@ -53,6 +80,17 @@ class ReaderBase {
mutable std::mutex mu_;
// The shapes of the feeded variables.
std::vector<DDim> shapes_;
// The dtypes of the feeded variables.
std::vector<proto::VarType::Type> var_types_;
// Whether to check the shape and dtype of feeded variables.
// For Backward compatibility, variables created by old API fluid.layers.data
// doesn't check shape but fluid.data checks.
std::vector<bool> need_check_feed_;
private:
friend class DecoratedReader;
// These methods can be only invoked inside DecoratedReader to record the
......@@ -67,7 +105,9 @@ class DecoratedReader : public ReaderBase,
public std::enable_shared_from_this<DecoratedReader> {
public:
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) {
: ReaderBase(reader->Shapes(), reader->VarTypes(),
reader->NeedCheckFeed()),
reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
}
......@@ -89,7 +129,13 @@ class DecoratedReader : public ReaderBase,
};
// FileReader is just a conceptual class.
class FileReader : public ReaderBase {};
class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& shapes,
const std::vector<proto::VarType::Type>& var_types,
const std::vector<bool>& need_check_feed)
: ReaderBase(shapes, var_types, need_check_feed) {}
};
// The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
......@@ -134,6 +180,16 @@ class ReaderHolder {
reader_->Start();
}
const std::vector<DDim>& Shapes() const { return reader_->Shapes(); }
const std::vector<proto::VarType::Type>& VarTypes() const {
return reader_->VarTypes();
}
const std::vector<bool>& NeedCheckFeed() const {
return reader_->NeedCheckFeed();
}
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
private:
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/reader.h"
#include <memory>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ddim.h"
class StubDecoratedReader : public paddle::framework::DecoratedReader {
public:
......@@ -26,11 +27,23 @@ class StubDecoratedReader : public paddle::framework::DecoratedReader {
class StubRootReader : public paddle::framework::ReaderBase {
public:
explicit StubRootReader(
const std::vector<paddle::framework::DDim> &dims,
const std::vector<paddle::framework::proto::VarType::Type> &var_types,
const std::vector<bool> &need_check_feed)
: paddle::framework::ReaderBase(dims, var_types, need_check_feed) {}
void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
};
TEST(READER, decorate_chain) {
auto root = std::make_shared<StubRootReader>();
paddle::framework::proto::VarType::Type dtype =
paddle::framework::proto::VarType::FP32;
paddle::framework::DDim dim = paddle::framework::make_ddim({5, 7});
std::vector<paddle::framework::DDim> init_dims(4, dim);
std::vector<paddle::framework::proto::VarType::Type> init_types(4, dtype);
std::vector<bool> init_need_check(4, true);
auto root =
std::make_shared<StubRootReader>(init_dims, init_types, init_need_check);
auto end_point1 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
auto end_point2 =
......@@ -49,4 +62,10 @@ TEST(READER, decorate_chain) {
ASSERT_EQ(root->GetEndPoints().size(), 3U);
}
{ ASSERT_EQ(root->GetEndPoints().size(), 2U); }
{
ASSERT_EQ(end_point1->Shapes(), init_dims);
ASSERT_EQ(end_point1->VarTypes(), init_types);
ASSERT_EQ(end_point1->NeedCheckFeed(), init_need_check);
}
}
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
......@@ -39,7 +41,38 @@ class CreatePyReaderOp : public framework::OperatorBase {
auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue()));
/* Coverting shape_concat and ranks into DDim of each data.
shape_concat and ranks are shapes and shape ranks of each data.E.g.
shape_concat = [2,3,4,5,6], ranks = [3,2] means two data whose shapes are
[2,3,4] and [5,6] respectively. */
auto& shape_concat = Attr<std::vector<int>>("shape_concat");
auto& ranks = Attr<std::vector<int>>("ranks");
int shape_start_index = 0;
std::vector<framework::DDim> dims;
for (size_t i = 0; i < ranks.size(); ++i) {
int shape_end_index = shape_start_index + ranks[i];
auto shape = std::vector<int>(shape_concat.begin() + shape_start_index,
shape_concat.begin() + shape_end_index);
dims.push_back(framework::make_ddim(shape));
shape_start_index = shape_end_index;
}
// Converts VarType from int to enum
auto& dtype_int = Attr<std::vector<int>>("dtypes");
std::vector<framework::proto::VarType::Type> var_types;
for (size_t i = 0; i < dtype_int.size(); ++i) {
var_types.push_back(
static_cast<framework::proto::VarType::Type>(dtype_int[i]));
}
// Converts need_check_feed from int to bool
auto& need_check_feed_int = Attr<std::vector<int>>("need_check_feed");
std::vector<bool> need_check_feed;
for (size_t i = 0; i < need_check_feed_int.size(); ++i) {
need_check_feed.push_back(static_cast<bool>(need_check_feed_int[i]));
}
out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue(), dims,
var_types, need_check_feed));
}
};
......
......@@ -19,8 +19,12 @@ namespace paddle {
namespace operators {
namespace reader {
PyReader::PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue)
: framework::FileReader() {
PyReader::PyReader(
const std::shared_ptr<LoDTensorBlockingQueue>& queue,
const std::vector<framework::DDim>& dims,
const std::vector<framework::proto::VarType::Type>& var_types,
const std::vector<bool>& need_check_feed)
: framework::FileReader(dims, var_types, need_check_feed) {
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
queue_ = queue;
}
......
......@@ -26,7 +26,11 @@ namespace reader {
class PyReader : public framework::FileReader {
public:
explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue);
explicit PyReader(
const std::shared_ptr<LoDTensorBlockingQueue>& queue,
const std::vector<framework::DDim>& dims,
const std::vector<framework::proto::VarType::Type>& var_types,
const std::vector<bool>& need_check_feed);
void ReadNext(std::vector<framework::LoDTensor>* out) override;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
......@@ -20,6 +21,26 @@
namespace paddle {
namespace operators {
// Returns true if the two dimensions are compatible.
// A dimension is compatible with the other if:
// 1. The length of the dimensions are same.
// 2. Each non-negative number of the two dimentions are same.
// 3. For negative number in a dimention, it means unknown so it is compatible
// with any number.
bool DimensionIsCompatibleWith(const framework::DDim& first,
const framework::DDim& second) {
int dim_size = first.size();
if (dim_size != second.size()) {
return false;
}
for (int i = 0; i < dim_size; ++i) {
if (first[i] >= 0 && second[i] >= 0 && first[i] != second[i]) {
return false;
}
}
return true;
}
class ReadInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
......@@ -89,10 +110,32 @@ class ReadOp : public framework::OperatorBase {
VLOG(3) << "throw_eof_exp";
PADDLE_THROW_EOF();
}
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size(),
"input size and output size of read_op do not match");
const std::vector<framework::DDim>& shapes = reader->Shapes();
const std::vector<framework::proto::VarType::Type>& var_types =
reader->VarTypes();
const std::vector<bool>& need_check_feed = reader->NeedCheckFeed();
PADDLE_ENFORCE_EQ(out_arg_names.size(), need_check_feed.size(),
"output size of read_op and the number of feeded "
"variables of reader do not match");
for (size_t i = 0; i < out_arg_names.size(); ++i) {
auto* out =
scope.FindVar(out_arg_names[i])->GetMutable<framework::LoDTensor>();
if (need_check_feed[i]) {
auto in_dims = ins[i].dims();
PADDLE_ENFORCE_EQ(DimensionIsCompatibleWith(shapes[i], in_dims), true,
"The feeded Variable %s should have dimensions = %d, "
"shape = [%s], but received feeded shape [%s]",
out_arg_names[i], shapes[i].size(), shapes[i],
in_dims);
PADDLE_ENFORCE_EQ(
ins[i].type(), var_types[i],
"The data type of feeded Variable %s must be %s, but received %s",
out_arg_names[i], var_types[i], ins[i].type());
}
out->ShareDataWith(ins[i]);
out->set_lod(ins[i].lod());
}
......
......@@ -50,6 +50,10 @@ void FileReaderMakerBase::Make() {
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
AddAttr<std::vector<int>>("dtypes",
"The int value of enum dtypes of each data.");
AddAttr<std::vector<int>>("need_check_feed",
"Whether to check shape and dtypes of input");
AddAttr<bool>(
"use_data_config",
"Use the config of all datas like shape_concat/ranks/lod_levels")
......@@ -77,6 +81,17 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).",
lod_levels.size(), shapes.size());
const auto dtypes = ctx->Attrs().Get<std::vector<int>>("dtypes");
PADDLE_ENFORCE_EQ(
dtypes.size(), shapes.size(),
"The number of 'dtypes'(%d) doesn't match the number of 'shapes'(%d).",
dtypes.size(), shapes.size());
const auto need_check_feed =
ctx->Attrs().Get<std::vector<int>>("need_check_feed");
PADDLE_ENFORCE_EQ(need_check_feed.size(), shapes.size(),
"The number of 'need_check_feed'(%d) doesn't match the "
"number of 'shapes'(%d).",
need_check_feed.size(), shapes.size());
framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
......
......@@ -20,6 +20,7 @@
#include <utility>
#include <vector>
#include "Python.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/py_reader.h"
......@@ -40,12 +41,19 @@ class MultiDeviceFeedReader {
MultiDeviceFeedReader(
const std::shared_ptr<operators::reader::LoDTensorBlockingQueue> &queue,
const std::vector<std::string> &names,
const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, bool use_double_buffer)
: queue_(queue),
names_(names),
pool_(new ::ThreadPool(dst_places.size())) {
std::vector<framework::DDim> dims;
for (auto &shape : shapes) {
dims.push_back(framework::make_ddim(shape));
}
std::shared_ptr<framework::ReaderBase> reader(
new operators::reader::PyReader(queue));
new operators::reader::PyReader(queue, dims, dtypes, need_check_feed));
readers_.reserve(dst_places.size());
for (auto &p : dst_places) {
......@@ -206,9 +214,13 @@ void BindReader(py::module *module) {
[](const std::shared_ptr<operators::reader::LoDTensorBlockingQueue>
&queue,
const std::vector<std::string> &names,
const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places,
bool use_double_buffer) {
return new MultiDeviceFeedReader(queue, names, dst_places,
return new MultiDeviceFeedReader(queue, names, shapes, dtypes,
need_check_feed, dst_places,
use_double_buffer);
},
py::return_value_policy::take_ownership);
......
......@@ -385,6 +385,7 @@ def _py_reader(capacity,
shape_concat = []
ranks = []
shapes = []
need_check_feed = []
for feed_data in feed_list:
dtypes.append(feed_data.dtype)
......@@ -392,8 +393,10 @@ def _py_reader(capacity,
ranks.append(len(feed_data.shape))
shapes.append(feed_data.shape)
lod_levels.append(feed_data.lod_level)
need_check_feed.append(int(feed_data.desc.need_check_feed()))
else:
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
need_check_feed = [0 for dt in dtypes]
shape_concat = []
ranks = []
......@@ -403,7 +406,7 @@ def _py_reader(capacity,
if lod_levels is None:
lod_levels = [0] * len(shapes)
dtype_int = [int(t) for t in dtypes]
if name is None:
queue_name = unique_name('lod_tensor_blocking_queue')
reader_name = unique_name('create_py_reader')
......@@ -425,6 +428,8 @@ def _py_reader(capacity,
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'dtypes': dtype_int,
'need_check_feed': need_check_feed,
'ranks': ranks
})
......
......@@ -331,7 +331,7 @@ class GeneratorLoader(DataLoaderBase):
self._init_non_iterable()
def _wait_thread_ends(self):
# Get self._thread first to prevent data race, because __thread_main__
# Get self._thread first to prevent data race, because __thread_main__
# would set self._thread be None at the end
thread = self._thread
if thread is not None and self._iterable:
......@@ -342,12 +342,21 @@ class GeneratorLoader(DataLoaderBase):
self._wait_thread_ends()
if in_dygraph_mode():
self._var_names = []
self._shapes = []
self._dtypes = []
self._need_check_feed = []
else:
self._var_names = [v.name for v in self._feed_list]
self._shapes = [v.shape for v in self._feed_list]
self._dtypes = [v.dtype for v in self._feed_list]
self._need_check_feed = [
v.desc.need_check_feed() for v in self._feed_list
]
self._queue = core.init_lod_tensor_blocking_queue(core.Variable(),
self._capacity)
self._reader = core.create_py_reader(
self.queue, self._var_names, self._places, self._use_double_buffer)
self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer)
def _init_non_iterable(self):
lod_levels = []
......@@ -355,6 +364,7 @@ class GeneratorLoader(DataLoaderBase):
shape_concat = []
ranks = []
shapes = []
need_check_feed = []
for feed_data in self._feed_list:
dtypes.append(feed_data.dtype)
......@@ -362,6 +372,7 @@ class GeneratorLoader(DataLoaderBase):
ranks.append(len(feed_data.shape))
shapes.append(feed_data.shape)
lod_levels.append(feed_data.lod_level)
need_check_feed.append(int(feed_data.desc.need_check_feed()))
queue_name = data_loader_unique_name_generator(
'lod_tensor_blocking_queue')
......@@ -374,6 +385,7 @@ class GeneratorLoader(DataLoaderBase):
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=reader_name)
dtype_int = [int(t) for t in dtypes]
startup_blk.append_op(
type='create_py_reader',
inputs={'blocking_queue': [queue_name]},
......@@ -381,6 +393,8 @@ class GeneratorLoader(DataLoaderBase):
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'dtypes': dtype_int,
'need_check_feed': need_check_feed,
'ranks': ranks
})
......
......@@ -73,8 +73,7 @@ class TestPyReader(unittest.TestCase):
for _ in range(10):
sample = np.random.uniform(
low=0, high=1, size=[3, 2, 1]).astype("float32")
label = np.random.uniform(
low=0, high=10, size=[1]).astype("int64")
label = np.random.randint(low=0, high=10, dtype="int64")
self.inputs.append((sample, label))
self.input_tensors = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册