提交 350f2de3 编写于 作者: W Webbley

1, add hadoop dataset;

2, add some comments.
上级 8d0e023e
...@@ -35,15 +35,21 @@ class ListDataset(Dataset): ...@@ -35,15 +35,21 @@ class ListDataset(Dataset):
return len(self.dataset) return len(self.dataset)
def _transform(self, example): def _transform(self, example):
time.sleep(0.1)
return example return example
class IterDataset(StreamDataset): class IterDataset(StreamDataset):
def __init__(self): def __init__(self):
self.dataset = list(range(0, DATA_SIZE)) self.dataset = list(range(0, DATA_SIZE))
self.count = 0
def __iter__(self): def __iter__(self):
for data in self.dataset: for data in self.dataset:
self.count += 1
if self.count % self._worker_info.num_workers != self._worker_info.fid:
continue
time.sleep(0.1)
yield data yield data
...@@ -89,6 +95,7 @@ class DataloaderTest(unittest.TestCase): ...@@ -89,6 +95,7 @@ class DataloaderTest(unittest.TestCase):
ds, ds,
batch_size=3, batch_size=3,
drop_last=False, drop_last=False,
shuffle=True,
num_workers=1, num_workers=1,
collate_fn=collate_fn) collate_fn=collate_fn)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
""" """
import numpy as np import numpy as np
from collections import namedtuple
import paddle import paddle
import paddle.fluid as F import paddle.fluid as F
...@@ -25,9 +26,42 @@ from pgl.utils import mp_reader ...@@ -25,9 +26,42 @@ from pgl.utils import mp_reader
from pgl.utils.data.dataset import Dataset, StreamDataset from pgl.utils.data.dataset import Dataset, StreamDataset
from pgl.utils.data.sampler import Sampler, StreamSampler from pgl.utils.data.sampler import Sampler, StreamSampler
WorkerInfo = namedtuple("WorkerInfo", ["num_workers", "fid"])
class Dataloader(object): class Dataloader(object):
"""Dataloader """Dataloader for loading batch data
Example:
.. code-block:: python
from pgl.utils.data import Dataset
from pgl.utils.data.dataloader import Dataloader
class MyDataset(Dataset):
def __init__(self):
self.data = list(range(0, 40))
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
def collate_fn(batch_examples):
feed_dict = {}
feed_dict['data'] = batch_examples
return feed_dict
dataset = MyDataset()
loader = Dataloader(dataset,
batch_size=3,
drop_last=False,
shuffle=True,
num_workers=4,
collate_fn=collate_fn)
for batch_data in loader:
print(batch_data)
""" """
def __init__( def __init__(
...@@ -86,6 +120,9 @@ class Dataloader(object): ...@@ -86,6 +120,9 @@ class Dataloader(object):
class _DataLoaderIter(object): class _DataLoaderIter(object):
"""Iterable DataLoader Object
"""
def __init__(self, dataloader, fid=0): def __init__(self, dataloader, fid=0):
self.dataset = dataloader.dataset self.dataset = dataloader.dataset
self.sampler = dataloader.sampler self.sampler = dataloader.sampler
...@@ -110,6 +147,10 @@ class _DataLoaderIter(object): ...@@ -110,6 +147,10 @@ class _DataLoaderIter(object):
yield batch_data yield batch_data
def _streamdata_generator(self): def _streamdata_generator(self):
self._worker_info = WorkerInfo(
num_workers=self.num_workers, fid=self.fid)
self.dataset._set_worker_info(self._worker_info)
dataset = iter(self.dataset) dataset = iter(self.dataset)
for indices in self.sampler: for indices in self.sampler:
batch_data = [] batch_data = []
...@@ -126,8 +167,8 @@ class _DataLoaderIter(object): ...@@ -126,8 +167,8 @@ class _DataLoaderIter(object):
# make sure do not repeat in multiprocessing # make sure do not repeat in multiprocessing
self.count += 1 self.count += 1
if self.count % self.num_workers != self.fid: # if self.count % self.num_workers != self.fid:
continue # continue
if self.collate_fn is not None: if self.collate_fn is not None:
yield self.collate_fn(batch_data) yield self.collate_fn(batch_data)
......
...@@ -15,11 +15,59 @@ ...@@ -15,11 +15,59 @@
"""dataset """dataset
""" """
import os
import sys
import numpy as np
import json
class HadoopUtil(object):
"""Implementation of some common hadoop operations.
"""
def __init__(self, hadoop_bin, fs_name, fs_ugi):
self.hadoop_bin = hadoop_bin
self.fs_name = fs_name
self.fs_ugi = fs_ugi
def ls(self, path):
""" hdfs_ls """
cmd = self.hadoop_bin + " fs -D fs.default.name=" + self.fs_name
cmd += " -D hadoop.job.ugi=" + self.fs_ugi
cmd += " -ls " + path
cmd += " | grep part | awk '{print $8}'"
with os.popen(cmd) as reader:
filelist = reader.read().split()
return filelist
def open(self, filename):
""" hdfs_file_open """
cmd = self.hadoop_bin + " fs -D fs.default.name=" + self.fs_name
cmd += " -D hadoop.job.ugi=" + self.fs_ugi
cmd += " -cat " + filename
p = os.popen(cmd)
return p
class Dataset(object): class Dataset(object):
"""An abstract class represening Dataset. """An abstract class represening Dataset.
Generally, all datasets should subclass it. Generally, all datasets should subclass it.
All subclasses should overwrite :code:`__getitem__` and :code:`__len__`. All subclasses should overwrite :code:`__getitem__` and :code:`__len__`.
Examples:
.. code-block:: python
from pgl.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self):
self.data = list(range(0, 40))
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
""" """
def __len__(self): def __len__(self):
...@@ -33,7 +81,66 @@ class StreamDataset(object): ...@@ -33,7 +81,66 @@ class StreamDataset(object):
"""An abstract class represening StreamDataset which has unknown length. """An abstract class represening StreamDataset which has unknown length.
Generally, all unknown length datasets should subclass it. Generally, all unknown length datasets should subclass it.
All subclasses should overwrite :code:`__iter__`. All subclasses should overwrite :code:`__iter__`.
Examples:
.. code-block:: python
from pgl.utils.data import StreamDataset
class MyStreamDataset(StreamDataset):
def __init__(self):
self.data = list(range(0, 40))
self.count = 0
def __iter__(self):
for data in self.dataset:
self.count += 1
if self.count % self._worker_info.num_workers != self._worker_info.fid:
continue
# do something (like parse data) of your data
time.sleep(0.1)
yield data
"""
def __iter__(self):
raise NotImplementedError
def _set_worker_info(self, worker_info):
self._worker_info = worker_info
class HadoopDataset(StreamDataset):
"""An abstract class represening HadoopDataset which loads data from hdfs.
All subclasses should overwrite :code:`__iter__`.
Examples:
.. code-block:: python
from pgl.utils.data import HadoopDataset
class MyHadoopDataset(HadoopDataset):
def __init__(self, data_path, hadoop_bin, fs_name, fs_ugi):
super(MyHadoopDataset, self).__init__(hadoop_bin, fs_name, fs_ugi)
self.data_path = data_path
def __iter__(self):
for line in self._line_data_generator():
yield line
def _line_data_generator(self):
paths = self.hadoop_util.ls(self.data_path)
paths = sorted(paths)
for idx, filename in enumerate(paths):
if idx % self._worker_info.num_workers != self._worker_info.fid:
continue
with self.hadoop_util.open(filename) as f:
for line in f:
yield line
""" """
def __init__(self, hadoop_bin, fs_name, fs_ugi):
self.hadoop_util = HadoopUtil(
hadoop_bin=hadoop_bin, fs_name=fs_name, fs_ugi=fs_ugi)
def __iter__(self): def __iter__(self):
raise NotImplementedError raise NotImplementedError
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册