From d752177b8f1fafe3588fe7f77a4960813f1bab4f Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 7 Jan 2019 08:57:21 +0000 Subject: [PATCH] enforce_dim_check_in_data_feeder test=develop --- python/paddle/fluid/data_feeder.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index c280ff21e..130152591 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -71,10 +71,20 @@ class DataToLoDTensorConverter(object): for each_data in data: self._feed_impl_(each_data, lod[1:], lod_level - 1) + def _check_shape_(self, shape): + for s1, s2 in zip(self.shape, shape): + if s1 != s2 and s1 >= 0 and s2 >= 0: + raise ValueError( + "Shape not match. What is defined in data layer is {}, but receive {}". + format(self.shape, shape)) + def done(self): arr = numpy.array(self.data, dtype=self.dtype) - if self.shape and len(arr.shape) != len(self.shape): - arr = arr.reshape(self.shape) + if self.shape: + if len(arr.shape) != len(self.shape): + arr = arr.reshape(self.shape) + else: + self._check_shape_(arr.shape) t = core.LoDTensor() t.set(arr, self.place) if self.lod_level > 0: @@ -152,17 +162,8 @@ class DataFeeder(object): raise TypeError("Feed list should contain a list of variable") self.feed_dtypes.append(each_var.dtype) self.feed_names.append(each_var.name) - shape = each_var.shape - batch_size_dim = -1 - for i, s in enumerate(shape): - if s < 0: - batch_size_dim = i - break - if batch_size_dim == -1: - raise ValueError("Variable {0} must has a batch size dimension", - each_var.name) self.feed_lod_level.append(each_var.lod_level) - self.feed_shapes.append(shape) + self.feed_shapes.append(each_var.shape) self.place = place -- GitLab