未验证 提交 95ba4bd2 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add shape and type check at read_op (#20754)

上级 bb8d7783
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -28,6 +29,20 @@ namespace framework { ...@@ -28,6 +29,20 @@ namespace framework {
class ReaderBase { class ReaderBase {
public: public:
explicit ReaderBase(const std::vector<DDim>& shapes,
const std::vector<proto::VarType::Type>& var_types,
const std::vector<bool>& need_check_feed)
: shapes_(shapes),
var_types_(var_types),
need_check_feed_(need_check_feed) {
PADDLE_ENFORCE_EQ(shapes_.size(), need_check_feed_.size(),
"Construct ReaderBase with mismatched sizes of shapes "
"and need_check_feed");
PADDLE_ENFORCE_EQ(var_types_.size(), need_check_feed_.size(),
"Construct ReaderBase with mismatched sizes of var_types "
"and need_check_feed");
}
virtual void ReadNext(std::vector<LoDTensor>* out); virtual void ReadNext(std::vector<LoDTensor>* out);
virtual void Shutdown(); virtual void Shutdown();
...@@ -38,6 +53,18 @@ class ReaderBase { ...@@ -38,6 +53,18 @@ class ReaderBase {
// they are readers just before read op. // they are readers just before read op.
std::unordered_set<ReaderBase*> GetEndPoints(); std::unordered_set<ReaderBase*> GetEndPoints();
// Returns the shapes of the feeded variables
const std::vector<DDim>& Shapes() const { return shapes_; }
// Returns the dtypes of the feeded variables
const std::vector<proto::VarType::Type>& VarTypes() const {
return var_types_;
}
// For Backward compatibility, old fluid.layers.data doesn't check shape.
// This function returns whether you have the check shape for this Reader.
const std::vector<bool>& NeedCheckFeed() const { return need_check_feed_; }
virtual ~ReaderBase(); virtual ~ReaderBase();
protected: protected:
...@@ -53,6 +80,17 @@ class ReaderBase { ...@@ -53,6 +80,17 @@ class ReaderBase {
mutable std::mutex mu_; mutable std::mutex mu_;
// The shapes of the feeded variables.
std::vector<DDim> shapes_;
// The dtypes of the feeded variables.
std::vector<proto::VarType::Type> var_types_;
// Whether to check the shape and dtype of feeded variables.
// For Backward compatibility, variables created by old API fluid.layers.data
// doesn't check shape but fluid.data checks.
std::vector<bool> need_check_feed_;
private: private:
friend class DecoratedReader; friend class DecoratedReader;
// These methods can be only invoked inside DecoratedReader to record the // These methods can be only invoked inside DecoratedReader to record the
...@@ -67,7 +105,9 @@ class DecoratedReader : public ReaderBase, ...@@ -67,7 +105,9 @@ class DecoratedReader : public ReaderBase,
public std::enable_shared_from_this<DecoratedReader> { public std::enable_shared_from_this<DecoratedReader> {
public: public:
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader) explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) { : ReaderBase(reader->Shapes(), reader->VarTypes(),
reader->NeedCheckFeed()),
reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
} }
...@@ -89,7 +129,13 @@ class DecoratedReader : public ReaderBase, ...@@ -89,7 +129,13 @@ class DecoratedReader : public ReaderBase,
}; };
// FileReader is just a conceptual class. // FileReader is just a conceptual class.
class FileReader : public ReaderBase {}; class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& shapes,
const std::vector<proto::VarType::Type>& var_types,
const std::vector<bool>& need_check_feed)
: ReaderBase(shapes, var_types, need_check_feed) {}
};
// The ReaderHolder is used as reader' unified wrapper, // The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables. // making it easier to access different type reader in Variables.
...@@ -134,6 +180,16 @@ class ReaderHolder { ...@@ -134,6 +180,16 @@ class ReaderHolder {
reader_->Start(); reader_->Start();
} }
const std::vector<DDim>& Shapes() const { return reader_->Shapes(); }
const std::vector<proto::VarType::Type>& VarTypes() const {
return reader_->VarTypes();
}
const std::vector<bool>& NeedCheckFeed() const {
return reader_->NeedCheckFeed();
}
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; } operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
private: private:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include <memory> #include <memory>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/ddim.h"
class StubDecoratedReader : public paddle::framework::DecoratedReader { class StubDecoratedReader : public paddle::framework::DecoratedReader {
public: public:
...@@ -26,11 +27,23 @@ class StubDecoratedReader : public paddle::framework::DecoratedReader { ...@@ -26,11 +27,23 @@ class StubDecoratedReader : public paddle::framework::DecoratedReader {
class StubRootReader : public paddle::framework::ReaderBase { class StubRootReader : public paddle::framework::ReaderBase {
public: public:
explicit StubRootReader(
const std::vector<paddle::framework::DDim> &dims,
const std::vector<paddle::framework::proto::VarType::Type> &var_types,
const std::vector<bool> &need_check_feed)
: paddle::framework::ReaderBase(dims, var_types, need_check_feed) {}
void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {} void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
}; };
TEST(READER, decorate_chain) { TEST(READER, decorate_chain) {
auto root = std::make_shared<StubRootReader>(); paddle::framework::proto::VarType::Type dtype =
paddle::framework::proto::VarType::FP32;
paddle::framework::DDim dim = paddle::framework::make_ddim({5, 7});
std::vector<paddle::framework::DDim> init_dims(4, dim);
std::vector<paddle::framework::proto::VarType::Type> init_types(4, dtype);
std::vector<bool> init_need_check(4, true);
auto root =
std::make_shared<StubRootReader>(init_dims, init_types, init_need_check);
auto end_point1 = auto end_point1 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root); paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
auto end_point2 = auto end_point2 =
...@@ -49,4 +62,10 @@ TEST(READER, decorate_chain) { ...@@ -49,4 +62,10 @@ TEST(READER, decorate_chain) {
ASSERT_EQ(root->GetEndPoints().size(), 3U); ASSERT_EQ(root->GetEndPoints().size(), 3U);
} }
{ ASSERT_EQ(root->GetEndPoints().size(), 2U); } { ASSERT_EQ(root->GetEndPoints().size(), 2U); }
{
ASSERT_EQ(end_point1->Shapes(), init_dims);
ASSERT_EQ(end_point1->VarTypes(), init_types);
ASSERT_EQ(end_point1->NeedCheckFeed(), init_need_check);
}
} }
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/operators/reader/py_reader.h" #include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
...@@ -39,7 +41,38 @@ class CreatePyReaderOp : public framework::OperatorBase { ...@@ -39,7 +41,38 @@ class CreatePyReaderOp : public framework::OperatorBase {
auto* queue_holder = auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>(); queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue())); /* Coverting shape_concat and ranks into DDim of each data.
shape_concat and ranks are shapes and shape ranks of each data.E.g.
shape_concat = [2,3,4,5,6], ranks = [3,2] means two data whose shapes are
[2,3,4] and [5,6] respectively. */
auto& shape_concat = Attr<std::vector<int>>("shape_concat");
auto& ranks = Attr<std::vector<int>>("ranks");
int shape_start_index = 0;
std::vector<framework::DDim> dims;
for (size_t i = 0; i < ranks.size(); ++i) {
int shape_end_index = shape_start_index + ranks[i];
auto shape = std::vector<int>(shape_concat.begin() + shape_start_index,
shape_concat.begin() + shape_end_index);
dims.push_back(framework::make_ddim(shape));
shape_start_index = shape_end_index;
}
// Converts VarType from int to enum
auto& dtype_int = Attr<std::vector<int>>("dtypes");
std::vector<framework::proto::VarType::Type> var_types;
for (size_t i = 0; i < dtype_int.size(); ++i) {
var_types.push_back(
static_cast<framework::proto::VarType::Type>(dtype_int[i]));
}
// Converts need_check_feed from int to bool
auto& need_check_feed_int = Attr<std::vector<int>>("need_check_feed");
std::vector<bool> need_check_feed;
for (size_t i = 0; i < need_check_feed_int.size(); ++i) {
need_check_feed.push_back(static_cast<bool>(need_check_feed_int[i]));
}
out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue(), dims,
var_types, need_check_feed));
} }
}; };
......
...@@ -19,8 +19,12 @@ namespace paddle { ...@@ -19,8 +19,12 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
PyReader::PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue) PyReader::PyReader(
: framework::FileReader() { const std::shared_ptr<LoDTensorBlockingQueue>& queue,
const std::vector<framework::DDim>& dims,
const std::vector<framework::proto::VarType::Type>& var_types,
const std::vector<bool>& need_check_feed)
: framework::FileReader(dims, var_types, need_check_feed) {
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
queue_ = queue; queue_ = queue;
} }
......
...@@ -26,7 +26,11 @@ namespace reader { ...@@ -26,7 +26,11 @@ namespace reader {
class PyReader : public framework::FileReader { class PyReader : public framework::FileReader {
public: public:
explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue); explicit PyReader(
const std::shared_ptr<LoDTensorBlockingQueue>& queue,
const std::vector<framework::DDim>& dims,
const std::vector<framework::proto::VarType::Type>& var_types,
const std::vector<bool>& need_check_feed);
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNext(std::vector<framework::LoDTensor>* out) override;
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
...@@ -20,6 +21,26 @@ ...@@ -20,6 +21,26 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// Returns true if the two dimensions are compatible.
// A dimension is compatible with the other if:
// 1. The length of the dimensions are same.
// 2. Each non-negative number of the two dimentions are same.
// 3. For negative number in a dimention, it means unknown so it is compatible
// with any number.
bool DimensionIsCompatibleWith(const framework::DDim& first,
const framework::DDim& second) {
int dim_size = first.size();
if (dim_size != second.size()) {
return false;
}
for (int i = 0; i < dim_size; ++i) {
if (first[i] >= 0 && second[i] >= 0 && first[i] != second[i]) {
return false;
}
}
return true;
}
class ReadInferShape : public framework::InferShapeBase { class ReadInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override { void operator()(framework::InferShapeContext* ctx) const override {
...@@ -89,10 +110,32 @@ class ReadOp : public framework::OperatorBase { ...@@ -89,10 +110,32 @@ class ReadOp : public framework::OperatorBase {
VLOG(3) << "throw_eof_exp"; VLOG(3) << "throw_eof_exp";
PADDLE_THROW_EOF(); PADDLE_THROW_EOF();
} }
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size(),
"input size and output size of read_op do not match");
const std::vector<framework::DDim>& shapes = reader->Shapes();
const std::vector<framework::proto::VarType::Type>& var_types =
reader->VarTypes();
const std::vector<bool>& need_check_feed = reader->NeedCheckFeed();
PADDLE_ENFORCE_EQ(out_arg_names.size(), need_check_feed.size(),
"output size of read_op and the number of feeded "
"variables of reader do not match");
for (size_t i = 0; i < out_arg_names.size(); ++i) { for (size_t i = 0; i < out_arg_names.size(); ++i) {
auto* out = auto* out =
scope.FindVar(out_arg_names[i])->GetMutable<framework::LoDTensor>(); scope.FindVar(out_arg_names[i])->GetMutable<framework::LoDTensor>();
if (need_check_feed[i]) {
auto in_dims = ins[i].dims();
PADDLE_ENFORCE_EQ(DimensionIsCompatibleWith(shapes[i], in_dims), true,
"The feeded Variable %s should have dimensions = %d, "
"shape = [%s], but received feeded shape [%s]",
out_arg_names[i], shapes[i].size(), shapes[i],
in_dims);
PADDLE_ENFORCE_EQ(
ins[i].type(), var_types[i],
"The data type of feeded Variable %s must be %s, but received %s",
out_arg_names[i], var_types[i], ins[i].type());
}
out->ShareDataWith(ins[i]); out->ShareDataWith(ins[i]);
out->set_lod(ins[i].lod()); out->set_lod(ins[i].lod());
} }
......
...@@ -50,6 +50,10 @@ void FileReaderMakerBase::Make() { ...@@ -50,6 +50,10 @@ void FileReaderMakerBase::Make() {
"It means the reader will generate two data each time," "It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively."); "whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data."); AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
AddAttr<std::vector<int>>("dtypes",
"The int value of enum dtypes of each data.");
AddAttr<std::vector<int>>("need_check_feed",
"Whether to check shape and dtypes of input");
AddAttr<bool>( AddAttr<bool>(
"use_data_config", "use_data_config",
"Use the config of all datas like shape_concat/ranks/lod_levels") "Use the config of all datas like shape_concat/ranks/lod_levels")
...@@ -77,6 +81,17 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { ...@@ -77,6 +81,17 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
"The number of 'lod_levels'(%d) doesn't match the number " "The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).", "of 'shapes'(%d).",
lod_levels.size(), shapes.size()); lod_levels.size(), shapes.size());
const auto dtypes = ctx->Attrs().Get<std::vector<int>>("dtypes");
PADDLE_ENFORCE_EQ(
dtypes.size(), shapes.size(),
"The number of 'dtypes'(%d) doesn't match the number of 'shapes'(%d).",
dtypes.size(), shapes.size());
const auto need_check_feed =
ctx->Attrs().Get<std::vector<int>>("need_check_feed");
PADDLE_ENFORCE_EQ(need_check_feed.size(), shapes.size(),
"The number of 'need_check_feed'(%d) doesn't match the "
"number of 'shapes'(%d).",
need_check_feed.size(), shapes.size());
framework::VarDesc* reader = framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]); boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels); reader->SetLoDLevels(lod_levels);
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "Python.h" #include "Python.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/reader/buffered_reader.h" #include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/py_reader.h" #include "paddle/fluid/operators/reader/py_reader.h"
...@@ -40,12 +41,19 @@ class MultiDeviceFeedReader { ...@@ -40,12 +41,19 @@ class MultiDeviceFeedReader {
MultiDeviceFeedReader( MultiDeviceFeedReader(
const std::shared_ptr<operators::reader::LoDTensorBlockingQueue> &queue, const std::shared_ptr<operators::reader::LoDTensorBlockingQueue> &queue,
const std::vector<std::string> &names, const std::vector<std::string> &names,
const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, bool use_double_buffer) const std::vector<platform::Place> &dst_places, bool use_double_buffer)
: queue_(queue), : queue_(queue),
names_(names), names_(names),
pool_(new ::ThreadPool(dst_places.size())) { pool_(new ::ThreadPool(dst_places.size())) {
std::vector<framework::DDim> dims;
for (auto &shape : shapes) {
dims.push_back(framework::make_ddim(shape));
}
std::shared_ptr<framework::ReaderBase> reader( std::shared_ptr<framework::ReaderBase> reader(
new operators::reader::PyReader(queue)); new operators::reader::PyReader(queue, dims, dtypes, need_check_feed));
readers_.reserve(dst_places.size()); readers_.reserve(dst_places.size());
for (auto &p : dst_places) { for (auto &p : dst_places) {
...@@ -206,9 +214,13 @@ void BindReader(py::module *module) { ...@@ -206,9 +214,13 @@ void BindReader(py::module *module) {
[](const std::shared_ptr<operators::reader::LoDTensorBlockingQueue> [](const std::shared_ptr<operators::reader::LoDTensorBlockingQueue>
&queue, &queue,
const std::vector<std::string> &names, const std::vector<std::string> &names,
const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, const std::vector<platform::Place> &dst_places,
bool use_double_buffer) { bool use_double_buffer) {
return new MultiDeviceFeedReader(queue, names, dst_places, return new MultiDeviceFeedReader(queue, names, shapes, dtypes,
need_check_feed, dst_places,
use_double_buffer); use_double_buffer);
}, },
py::return_value_policy::take_ownership); py::return_value_policy::take_ownership);
......
...@@ -385,6 +385,7 @@ def _py_reader(capacity, ...@@ -385,6 +385,7 @@ def _py_reader(capacity,
shape_concat = [] shape_concat = []
ranks = [] ranks = []
shapes = [] shapes = []
need_check_feed = []
for feed_data in feed_list: for feed_data in feed_list:
dtypes.append(feed_data.dtype) dtypes.append(feed_data.dtype)
...@@ -392,8 +393,10 @@ def _py_reader(capacity, ...@@ -392,8 +393,10 @@ def _py_reader(capacity,
ranks.append(len(feed_data.shape)) ranks.append(len(feed_data.shape))
shapes.append(feed_data.shape) shapes.append(feed_data.shape)
lod_levels.append(feed_data.lod_level) lod_levels.append(feed_data.lod_level)
need_check_feed.append(int(feed_data.desc.need_check_feed()))
else: else:
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
need_check_feed = [0 for dt in dtypes]
shape_concat = [] shape_concat = []
ranks = [] ranks = []
...@@ -403,7 +406,7 @@ def _py_reader(capacity, ...@@ -403,7 +406,7 @@ def _py_reader(capacity,
if lod_levels is None: if lod_levels is None:
lod_levels = [0] * len(shapes) lod_levels = [0] * len(shapes)
dtype_int = [int(t) for t in dtypes]
if name is None: if name is None:
queue_name = unique_name('lod_tensor_blocking_queue') queue_name = unique_name('lod_tensor_blocking_queue')
reader_name = unique_name('create_py_reader') reader_name = unique_name('create_py_reader')
...@@ -425,6 +428,8 @@ def _py_reader(capacity, ...@@ -425,6 +428,8 @@ def _py_reader(capacity,
attrs={ attrs={
'shape_concat': shape_concat, 'shape_concat': shape_concat,
'lod_levels': lod_levels, 'lod_levels': lod_levels,
'dtypes': dtype_int,
'need_check_feed': need_check_feed,
'ranks': ranks 'ranks': ranks
}) })
......
...@@ -342,12 +342,21 @@ class GeneratorLoader(DataLoaderBase): ...@@ -342,12 +342,21 @@ class GeneratorLoader(DataLoaderBase):
self._wait_thread_ends() self._wait_thread_ends()
if in_dygraph_mode(): if in_dygraph_mode():
self._var_names = [] self._var_names = []
self._shapes = []
self._dtypes = []
self._need_check_feed = []
else: else:
self._var_names = [v.name for v in self._feed_list] self._var_names = [v.name for v in self._feed_list]
self._shapes = [v.shape for v in self._feed_list]
self._dtypes = [v.dtype for v in self._feed_list]
self._need_check_feed = [
v.desc.need_check_feed() for v in self._feed_list
]
self._queue = core.init_lod_tensor_blocking_queue(core.Variable(), self._queue = core.init_lod_tensor_blocking_queue(core.Variable(),
self._capacity) self._capacity)
self._reader = core.create_py_reader( self._reader = core.create_py_reader(
self.queue, self._var_names, self._places, self._use_double_buffer) self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer)
def _init_non_iterable(self): def _init_non_iterable(self):
lod_levels = [] lod_levels = []
...@@ -355,6 +364,7 @@ class GeneratorLoader(DataLoaderBase): ...@@ -355,6 +364,7 @@ class GeneratorLoader(DataLoaderBase):
shape_concat = [] shape_concat = []
ranks = [] ranks = []
shapes = [] shapes = []
need_check_feed = []
for feed_data in self._feed_list: for feed_data in self._feed_list:
dtypes.append(feed_data.dtype) dtypes.append(feed_data.dtype)
...@@ -362,6 +372,7 @@ class GeneratorLoader(DataLoaderBase): ...@@ -362,6 +372,7 @@ class GeneratorLoader(DataLoaderBase):
ranks.append(len(feed_data.shape)) ranks.append(len(feed_data.shape))
shapes.append(feed_data.shape) shapes.append(feed_data.shape)
lod_levels.append(feed_data.lod_level) lod_levels.append(feed_data.lod_level)
need_check_feed.append(int(feed_data.desc.need_check_feed()))
queue_name = data_loader_unique_name_generator( queue_name = data_loader_unique_name_generator(
'lod_tensor_blocking_queue') 'lod_tensor_blocking_queue')
...@@ -374,6 +385,7 @@ class GeneratorLoader(DataLoaderBase): ...@@ -374,6 +385,7 @@ class GeneratorLoader(DataLoaderBase):
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=reader_name) startup_var = startup_blk.create_var(name=reader_name)
dtype_int = [int(t) for t in dtypes]
startup_blk.append_op( startup_blk.append_op(
type='create_py_reader', type='create_py_reader',
inputs={'blocking_queue': [queue_name]}, inputs={'blocking_queue': [queue_name]},
...@@ -381,6 +393,8 @@ class GeneratorLoader(DataLoaderBase): ...@@ -381,6 +393,8 @@ class GeneratorLoader(DataLoaderBase):
attrs={ attrs={
'shape_concat': shape_concat, 'shape_concat': shape_concat,
'lod_levels': lod_levels, 'lod_levels': lod_levels,
'dtypes': dtype_int,
'need_check_feed': need_check_feed,
'ranks': ranks 'ranks': ranks
}) })
......
...@@ -73,8 +73,7 @@ class TestPyReader(unittest.TestCase): ...@@ -73,8 +73,7 @@ class TestPyReader(unittest.TestCase):
for _ in range(10): for _ in range(10):
sample = np.random.uniform( sample = np.random.uniform(
low=0, high=1, size=[3, 2, 1]).astype("float32") low=0, high=1, size=[3, 2, 1]).astype("float32")
label = np.random.uniform( label = np.random.randint(low=0, high=10, dtype="int64")
low=0, high=10, size=[1]).astype("int64")
self.inputs.append((sample, label)) self.inputs.append((sample, label))
self.input_tensors = [] self.input_tensors = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册