未验证 提交 8473ee9d 编写于 作者: K Kaipeng Deng 提交者: GitHub

make places of DataLoader optional. (#27432)

* make places of DataLoader optional. test=develop
上级 3d552214
...@@ -167,10 +167,10 @@ class DataLoader(object): ...@@ -167,10 +167,10 @@ class DataLoader(object):
The variables should be created by :code:`fluid.data()`. The variables should be created by :code:`fluid.data()`.
:attr:`feed_list` must be set if :attr:`return_list` is :attr:`feed_list` must be set if :attr:`return_list` is
False. Default None. False. Default None.
places(list(Place)|tuple(Place)): a list of Place, to put data places(list(Place)|tuple(Place)|optional): a list of Place,
onto, :attr:`places` must be set in both static graph and to put data onto, :attr:`places` can be None, if
dynamic graph mode, in dynamic graph mode, place number must :attr:`places` is None, default place(CPUPlace or CUDAPlace(0))
be 1. Default None. will be used. Default None.
return_list (bool): whether the return value on each device is return_list (bool): whether the return value on each device is
presented as a list. If :attr:`return_list=False`, the return presented as a list. If :attr:`return_list=False`, the return
value on each device would be a dict of str -> LoDTensor, where value on each device would be a dict of str -> LoDTensor, where
...@@ -222,6 +222,8 @@ class DataLoader(object): ...@@ -222,6 +222,8 @@ class DataLoader(object):
.. code-block:: python .. code-block:: python
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.io import Dataset, BatchSampler, DataLoader from paddle.io import Dataset, BatchSampler, DataLoader
...@@ -247,11 +249,48 @@ class DataLoader(object): ...@@ -247,11 +249,48 @@ class DataLoader(object):
def __len__(self): def __len__(self):
return self.num_samples return self.num_samples
dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
# get places # get places
places = fluid.cuda_places() if USE_GPU else fluid.cpu_places() places = fluid.cuda_places() if USE_GPU else fluid.cpu_places()
# --------------------- dygraph mode --------------------
class SimpleNet(fluid.dygraph.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = fluid.dygraph.nn.Linear(IMAGE_SIZE, CLASS_NUM, act='softmax')
def forward(self, image, label=None):
return self.fc(image)
with fluid.dygraph.guard(places[0]):
simple_net = SimpleNet()
opt = fluid.optimizer.SGD(learning_rate=1e-3,
parameter_list=simple_net.parameters())
loader = DataLoader(dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=2)
for e in range(EPOCH_NUM):
for i, (image, label) in enumerate(loader()):
out = simple_net(image)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.reduce_mean(loss)
avg_loss.backward()
opt.minimize(avg_loss)
simple_net.clear_gradients()
print("Epoch {} batch {}: loss = {}".format(e, i, np.mean(loss.numpy())))
# -------------------------------------------------------
# -------------------- static graph --------------------- # -------------------- static graph ---------------------
paddle.enable_static()
def simple_net(image, label): def simple_net(image, label):
fc_tmp = fluid.layers.fc(image, size=CLASS_NUM, act='softmax') fc_tmp = fluid.layers.fc(image, size=CLASS_NUM, act='softmax')
cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label) cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label)
...@@ -270,11 +309,8 @@ class DataLoader(object): ...@@ -270,11 +309,8 @@ class DataLoader(object):
prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)
dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
loader = DataLoader(dataset, loader = DataLoader(dataset,
feed_list=[image, label], feed_list=[image, label],
places=places,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
...@@ -287,39 +323,6 @@ class DataLoader(object): ...@@ -287,39 +323,6 @@ class DataLoader(object):
# ------------------------------------------------------- # -------------------------------------------------------
# --------------------- dygraph mode --------------------
class SimpleNet(fluid.dygraph.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = fluid.dygraph.nn.Linear(IMAGE_SIZE, CLASS_NUM, act='softmax')
def forward(self, image, label=None):
return self.fc(image)
with fluid.dygraph.guard(places[0]):
simple_net = SimpleNet()
opt = fluid.optimizer.SGD(learning_rate=1e-3,
parameter_list=simple_net.parameters())
loader = DataLoader(dataset,
places=places[0],
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=2)
for e in range(EPOCH_NUM):
for i, (image, label) in enumerate(loader()):
out = simple_net(image)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.reduce_mean(loss)
avg_loss.backward()
opt.minimize(avg_loss)
simple_net.clear_gradients()
print("Epoch {} batch {}: loss = {}".format(e, i, np.mean(loss.numpy())))
# -------------------------------------------------------
.. note:: .. note::
For reading iterable dataset with multiprocess Dataloader, For reading iterable dataset with multiprocess Dataloader,
...@@ -356,11 +359,9 @@ class DataLoader(object): ...@@ -356,11 +359,9 @@ class DataLoader(object):
"feed_list should be set when return_list=False" "feed_list should be set when return_list=False"
self.feed_list = feed_list self.feed_list = feed_list
assert places is not None, "places cannot be None" if places is None:
places = _current_expected_place()
self.places = _convert_places(places) self.places = _convert_places(places)
if in_dygraph_mode():
assert len(self.places) == 1, \
"Number of places must be 1 in dygraph mode"
assert num_workers >= 0, "num_workers should be a non-negative value" assert num_workers >= 0, "num_workers should be a non-negative value"
if num_workers > 0 and (sys.platform == 'darwin' or if num_workers > 0 and (sys.platform == 'darwin' or
......
...@@ -76,7 +76,6 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -76,7 +76,6 @@ class TestDygraphDataLoader(unittest.TestCase):
dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM) dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
places=places,
num_workers=num_workers, num_workers=num_workers,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
drop_last=True) drop_last=True)
......
...@@ -76,7 +76,6 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -76,7 +76,6 @@ class TestDygraphDataLoader(unittest.TestCase):
dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM) dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
places=places,
num_workers=num_workers, num_workers=num_workers,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
drop_last=True) drop_last=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册