未验证 提交 0aedd463 编写于 作者: M mls1999725 提交者: GitHub

Update get_worker_info API (#29190)

* Update get_worker_info API

* Update dataloader_iter.py

* Update dataloader_iter.py

* Update dataloader_iter.py
上级 c59b4f28
......@@ -153,8 +153,8 @@ def get_worker_info():
.. code-block:: python
import math
import paddle
import numpy as np
import paddle.fluid as fluid
from paddle.io import IterableDataset, DataLoader, get_worker_info
class SplitedIterableDataset(IterableDataset):
......@@ -178,8 +178,7 @@ def get_worker_info():
for i in range(iter_start, iter_end):
yield np.array([i])
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
place = paddle.CPUPlace()
dataset = SplitedIterableDataset(start=2, end=9)
dataloader = DataLoader(
dataset,
......@@ -188,7 +187,8 @@ def get_worker_info():
batch_size=1,
drop_last=True)
print(list(dataloader))
for data in dataloader:
print(data)
# outputs: [2, 5, 3, 6, 4, 7]
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册