提交 9de95540 编写于 作者: L Leo Chen 提交者: Zeng Jinle

use fluid.data() (#1495)

上级 c4f50b30
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
import paddle.fluid as fluid import paddle.fluid as fluid
image = fluid.layers.data(name='image', dtype='float32', shape=[784]) image = fluid.data(name='image', dtype='float32', shape=[None, 784])
label = fluid.layers.data(name='label', dtype='int64', shape=[1]) label = fluid.data(name='label', dtype='int64', shape=[None, 1])
ITERABLE = True ITERABLE = True
...@@ -43,8 +43,8 @@ ...@@ -43,8 +43,8 @@
import paddle.dataset.mnist as mnist import paddle.dataset.mnist as mnist
def network(): def network():
image = fluid.layers.data(name='image', dtype='float32', shape=[784]) image = fluid.data(name='image', dtype='float32', shape=[None, 784])
label = fluid.layers.data(name='label', dtype='int64', shape=[1]) label = fluid.data(name='label', dtype='int64', shape=[None, 1])
loader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=64) loader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=64)
# Definition of models # Definition of models
...@@ -112,14 +112,14 @@ DataLoader对象通过 :code:`set_sample_generator()` , :code:`set_sample_list ...@@ -112,14 +112,14 @@ DataLoader对象通过 :code:`set_sample_generator()` , :code:`set_sample_list
batch_label = np.random.random_integers(size=(batch_size, 1), low=0, high=9).astype('int64') batch_label = np.random.random_integers(size=(batch_size, 1), low=0, high=9).astype('int64')
yield batch_image, batch_label yield batch_image, batch_label
image1 = fluid.layers.data(name='image1', dtype='float32', shape=[784]) image1 = fluid.data(name='image1', dtype='float32', shape=[None, 784])
label1 = fluid.layers.data(name='label1', dtype='int64', shape=[1]) label1 = fluid.data(name='label1', dtype='int64', shape=[None, 1])
image2 = fluid.layers.data(name='image2', dtype='float32', shape=[784]) image2 = fluid.data(name='image2', dtype='float32', shape=[None, 784])
label2 = fluid.layers.data(name='label2', dtype='int64', shape=[1]) label2 = fluid.data(name='label2', dtype='int64', shape=[None, 1])
image3 = fluid.layers.data(name='image3', dtype='float32', shape=[784]) image3 = fluid.data(name='image3', dtype='float32', shape=[None, 784])
label3 = fluid.layers.data(name='label3', dtype='int64', shape=[1]) label3 = fluid.data(name='label3', dtype='int64', shape=[None, 1])
对应的DataLoader设置如下: 对应的DataLoader设置如下:
...@@ -178,8 +178,8 @@ DataLoader对象通过 :code:`set_sample_generator()` , :code:`set_sample_list ...@@ -178,8 +178,8 @@ DataLoader对象通过 :code:`set_sample_generator()` , :code:`set_sample_list
def network(): def network():
# 创建数据层对象 # 创建数据层对象
image = fluid.layers.data(name='image', dtype='float32', shape=[784]) image = fluid.data(name='image', dtype='float32', shape=[None, 784])
label = fluid.layers.data(name='label', dtype='int64', shape=[1]) label = fluid.data(name='label', dtype='int64', shape=[None, 1])
# 创建DataLoader对象 # 创建DataLoader对象
reader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=64, iterable=ITERABLE) reader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=64, iterable=ITERABLE)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册