提交 9889e82e 编写于 作者: M Megvii Engine Team

docs(imperative): add example of dataset

GitOrigin-RevId: 70f2a513cfb5c6e48b30e0e56bf469f3ac99a86a
上级 7d83a9ad
......@@ -68,6 +68,10 @@ class DataLoader:
batch from workers. Default: 0
preload: whether to enable the preloading strategy of the dataloader.
When enabling, the dataloader will preload one batch to the device memory to speed up the whole training process.
parallel_stream: whether to splitting workload across all workers when dataset is streamdataset and num_workers > 0.
When enabling, each worker will collect data from different dataset in order to speed up the whole loading process.
See ref:`streamdataset-example` for more details
.. admonition:: The effect of enabling preload
:class: warning
......
......@@ -58,6 +58,33 @@ class StreamDataset(Dataset):
r"""An abstract class for stream data.
__iter__ method is aditionally needed.
Examples:
.. code-block:: python
from megengine.data.dataset import StreamDataset
from megengine.data.dataloader import DataLoader, get_worker_info
from megengine.data.sampler import StreamSampler
class MyStream(StreamDataset):
def __init__(self):
self.data = [iter([1, 2, 3]), iter([4, 5, 6]), iter([7, 8, 9])]
def __iter__(self):
worker_info = get_worker_info()
data_iter = self.data[worker_info.idx]
while True:
yield next(data_iter)
dataloader = DataLoader(
dataset = MyStream(),
sampler = StreamSampler(batch_size=2),
num_workers=3,
parallel_stream = True,
)
for step, data in enumerate(dataloader):
print(data)
"""
@abstractmethod
......@@ -80,6 +107,29 @@ class ArrayDataset(Dataset):
One or more numpy arrays are needed to initiate the dataset.
And the dimensions represented sample number are expected to be the same.
Examples:
.. code-block:: python
from megengine.data.dataset import ArrayDataset
from megengine.data.dataloader import DataLoader
from megengine.data.sampler import SequentialSampler
rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8)
label = np.random.randint(0, 10, size=(sample_num,), dtype=int)
dataset = ArrayDataset(rand_data, label)
seque_sampler = SequentialSampler(dataset, batch_size=2)
dataloader = DataLoader(
dataset,
sampler = seque_sampler,
num_workers=3,
)
for step, data in enumerate(dataloader):
print(data)
"""
def __init__(self, *arrays):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册