From 9fc11f6faa23648aab75bd89db4766a2712e3c98 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Thu, 24 Sep 2020 08:13:44 +0000 Subject: [PATCH] update example code. test=develop --- doc/paddle/api/paddle/io/DataLoader_cn.rst | 80 +++++++++++----------- 1 file changed, 41 insertions(+), 39 deletions(-) diff --git a/doc/paddle/api/paddle/io/DataLoader_cn.rst b/doc/paddle/api/paddle/io/DataLoader_cn.rst index 2012f53df..e129150f5 100644 --- a/doc/paddle/api/paddle/io/DataLoader_cn.rst +++ b/doc/paddle/api/paddle/io/DataLoader_cn.rst @@ -37,7 +37,10 @@ DataLoader当前仅支持 ``map-style`` 的数据集(可通过下标索引样本 .. code-block:: python + import numpy as np + + import paddle import paddle.fluid as fluid from paddle.io import Dataset, BatchSampler, DataLoader @@ -48,7 +51,7 @@ DataLoader当前仅支持 ``map-style`` 的数据集(可通过下标索引样本 IMAGE_SIZE = 784 CLASS_NUM = 10 - USE_GPU = True # whether use GPU to run model + USE_GPU = False # whether use GPU to run model # define a random dataset class RandomDataset(Dataset): @@ -63,11 +66,48 @@ DataLoader当前仅支持 ``map-style`` 的数据集(可通过下标索引样本 def __len__(self): return self.num_samples + dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) + # get places places = fluid.cuda_places() if USE_GPU else fluid.cpu_places() + # --------------------- dygraph mode -------------------- + + class SimpleNet(fluid.dygraph.Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.fc = fluid.dygraph.nn.Linear(IMAGE_SIZE, CLASS_NUM, act='softmax') + + def forward(self, image, label=None): + return self.fc(image) + + with fluid.dygraph.guard(places[0]): + simple_net = SimpleNet() + opt = fluid.optimizer.SGD(learning_rate=1e-3, + parameter_list=simple_net.parameters()) + + loader = DataLoader(dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) + + for e in range(EPOCH_NUM): + for i, (image, label) in enumerate(loader()): + out = simple_net(image) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = fluid.layers.reduce_mean(loss) + avg_loss.backward() + opt.minimize(avg_loss) + simple_net.clear_gradients() + print("Epoch {} batch {}: loss = {}".format(e, i, np.mean(loss.numpy()))) + + # ------------------------------------------------------- + # -------------------- static graph --------------------- + paddle.enable_static() + def simple_net(image, label): fc_tmp = fluid.layers.fc(image, size=CLASS_NUM, act='softmax') cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label) @@ -86,11 +126,8 @@ DataLoader当前仅支持 ``map-style`` 的数据集(可通过下标索引样本 prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) - dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) - loader = DataLoader(dataset, feed_list=[image, label], - places=places, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, @@ -102,41 +139,6 @@ DataLoader当前仅支持 ``map-style`` 的数据集(可通过下标索引样本 print("Epoch {} batch {}: loss = {}".format(e, i, l[0][0])) # ------------------------------------------------------- - - # -------------------- dynamic graph -------------------- - - class SimpleNet(fluid.dygraph.Layer): - def __init__(self): - super(SimpleNet, self).__init__() - self.fc = fluid.dygraph.nn.Linear(IMAGE_SIZE, CLASS_NUM, act='softmax') - - def forward(self, image, label=None): - return self.fc(image) - - with fluid.dygraph.guard(places[0]): - simple_net = SimpleNet() - opt = fluid.optimizer.SGD(learning_rate=1e-3, - parameter_list=simple_net.parameters()) - - loader = DataLoader(dataset, - places=places[0], - batch_size=BATCH_SIZE, - shuffle=True, - drop_last=True, - num_workers=2) - - for e in range(EPOCH_NUM): - for i, (image, label) in enumerate(loader()): - out = simple_net(image) - loss = fluid.layers.cross_entropy(out, label) - avg_loss = fluid.layers.reduce_mean(loss) - avg_loss.backward() - opt.minimize(avg_loss) - simple_net.clear_gradients() - print("Epoch {} batch {}: loss = {}".format(e, i, np.mean(loss.numpy()))) - - # ------------------------------------------------------- - .. py:method:: from_generator(feed_list=None, capacity=None, use_double_buffer=True, iterable=True, return_list=False, use_multiprocess=False, drop_last=True) -- GitLab