提交 9889e82e 编写于 作者: M Megvii Engine Team

docs(imperative): add example of dataset

GitOrigin-RevId: 70f2a513cfb5c6e48b30e0e56bf469f3ac99a86a
上级 7d83a9ad
...@@ -68,6 +68,10 @@ class DataLoader: ...@@ -68,6 +68,10 @@ class DataLoader:
batch from workers. Default: 0 batch from workers. Default: 0
preload: whether to enable the preloading strategy of the dataloader. preload: whether to enable the preloading strategy of the dataloader.
When enabling, the dataloader will preload one batch to the device memory to speed up the whole training process. When enabling, the dataloader will preload one batch to the device memory to speed up the whole training process.
parallel_stream: whether to splitting workload across all workers when dataset is streamdataset and num_workers > 0.
When enabling, each worker will collect data from different dataset in order to speed up the whole loading process.
See ref:`streamdataset-example` for more details
.. admonition:: The effect of enabling preload .. admonition:: The effect of enabling preload
:class: warning :class: warning
......
...@@ -58,6 +58,33 @@ class StreamDataset(Dataset): ...@@ -58,6 +58,33 @@ class StreamDataset(Dataset):
r"""An abstract class for stream data. r"""An abstract class for stream data.
__iter__ method is aditionally needed. __iter__ method is aditionally needed.
Examples:
.. code-block:: python
from megengine.data.dataset import StreamDataset
from megengine.data.dataloader import DataLoader, get_worker_info
from megengine.data.sampler import StreamSampler
class MyStream(StreamDataset):
def __init__(self):
self.data = [iter([1, 2, 3]), iter([4, 5, 6]), iter([7, 8, 9])]
def __iter__(self):
worker_info = get_worker_info()
data_iter = self.data[worker_info.idx]
while True:
yield next(data_iter)
dataloader = DataLoader(
dataset = MyStream(),
sampler = StreamSampler(batch_size=2),
num_workers=3,
parallel_stream = True,
)
for step, data in enumerate(dataloader):
print(data)
""" """
@abstractmethod @abstractmethod
...@@ -80,6 +107,29 @@ class ArrayDataset(Dataset): ...@@ -80,6 +107,29 @@ class ArrayDataset(Dataset):
One or more numpy arrays are needed to initiate the dataset. One or more numpy arrays are needed to initiate the dataset.
And the dimensions represented sample number are expected to be the same. And the dimensions represented sample number are expected to be the same.
Examples:
.. code-block:: python
from megengine.data.dataset import ArrayDataset
from megengine.data.dataloader import DataLoader
from megengine.data.sampler import SequentialSampler
rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8)
label = np.random.randint(0, 10, size=(sample_num,), dtype=int)
dataset = ArrayDataset(rand_data, label)
seque_sampler = SequentialSampler(dataset, batch_size=2)
dataloader = DataLoader(
dataset,
sampler = seque_sampler,
num_workers=3,
)
for step, data in enumerate(dataloader):
print(data)
""" """
def __init__(self, *arrays): def __init__(self, *arrays):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册