提交 8cc249ef 编写于 作者: F fengjiayi

make data_feeder support dynamic shape

上级 a29cb4be
...@@ -29,6 +29,14 @@ class DataToLoDTensorConverter(object): ...@@ -29,6 +29,14 @@ class DataToLoDTensorConverter(object):
self.place = place self.place = place
self.lod_level = lod_level self.lod_level = lod_level
self.shape = shape self.shape = shape
self.dynamic_shape = False
negtive_count = 0
for s in self.shape:
if s < 0:
negtive_count += 1
if negtive_count > 1:
self.shape = None
break
if dtype == core.VarDesc.VarType.FP32: if dtype == core.VarDesc.VarType.FP32:
self.dtype = 'float32' self.dtype = 'float32'
elif dtype == core.VarDesc.VarType.INT64: elif dtype == core.VarDesc.VarType.INT64:
...@@ -61,7 +69,9 @@ class DataToLoDTensorConverter(object): ...@@ -61,7 +69,9 @@ class DataToLoDTensorConverter(object):
self._feed_impl_(each_data, lod[1:], lod_level - 1) self._feed_impl_(each_data, lod[1:], lod_level - 1)
def done(self): def done(self):
arr = numpy.array(self.data, dtype=self.dtype).reshape(self.shape) arr = numpy.array(self.data, dtype=self.dtype)
if self.shape:
arr = arr.reshape(self.shape)
t = core.LoDTensor() t = core.LoDTensor()
t.set(arr, self.place) t.set(arr, self.place)
if self.lod_level > 0: if self.lod_level > 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册