提交 350f2de3 编写于 作者: W Webbley

1, add hadoop dataset;

2, add some comments.
上级 8d0e023e
......@@ -35,15 +35,21 @@ class ListDataset(Dataset):
return len(self.dataset)
def _transform(self, example):
time.sleep(0.1)
return example
class IterDataset(StreamDataset):
def __init__(self):
self.dataset = list(range(0, DATA_SIZE))
self.count = 0
def __iter__(self):
for data in self.dataset:
self.count += 1
if self.count % self._worker_info.num_workers != self._worker_info.fid:
continue
time.sleep(0.1)
yield data
......@@ -89,6 +95,7 @@ class DataloaderTest(unittest.TestCase):
ds,
batch_size=3,
drop_last=False,
shuffle=True,
num_workers=1,
collate_fn=collate_fn)
......
......@@ -16,6 +16,7 @@
"""
import numpy as np
from collections import namedtuple
import paddle
import paddle.fluid as F
......@@ -25,9 +26,42 @@ from pgl.utils import mp_reader
from pgl.utils.data.dataset import Dataset, StreamDataset
from pgl.utils.data.sampler import Sampler, StreamSampler
WorkerInfo = namedtuple("WorkerInfo", ["num_workers", "fid"])
class Dataloader(object):
"""Dataloader
"""Dataloader for loading batch data
Example:
.. code-block:: python
from pgl.utils.data import Dataset
from pgl.utils.data.dataloader import Dataloader
class MyDataset(Dataset):
def __init__(self):
self.data = list(range(0, 40))
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
def collate_fn(batch_examples):
feed_dict = {}
feed_dict['data'] = batch_examples
return feed_dict
dataset = MyDataset()
loader = Dataloader(dataset,
batch_size=3,
drop_last=False,
shuffle=True,
num_workers=4,
collate_fn=collate_fn)
for batch_data in loader:
print(batch_data)
"""
def __init__(
......@@ -86,6 +120,9 @@ class Dataloader(object):
class _DataLoaderIter(object):
"""Iterable DataLoader Object
"""
def __init__(self, dataloader, fid=0):
self.dataset = dataloader.dataset
self.sampler = dataloader.sampler
......@@ -110,6 +147,10 @@ class _DataLoaderIter(object):
yield batch_data
def _streamdata_generator(self):
self._worker_info = WorkerInfo(
num_workers=self.num_workers, fid=self.fid)
self.dataset._set_worker_info(self._worker_info)
dataset = iter(self.dataset)
for indices in self.sampler:
batch_data = []
......@@ -126,8 +167,8 @@ class _DataLoaderIter(object):
# make sure do not repeat in multiprocessing
self.count += 1
if self.count % self.num_workers != self.fid:
continue
# if self.count % self.num_workers != self.fid:
# continue
if self.collate_fn is not None:
yield self.collate_fn(batch_data)
......
......@@ -15,11 +15,59 @@
"""dataset
"""
import os
import sys
import numpy as np
import json
class HadoopUtil(object):
"""Implementation of some common hadoop operations.
"""
def __init__(self, hadoop_bin, fs_name, fs_ugi):
self.hadoop_bin = hadoop_bin
self.fs_name = fs_name
self.fs_ugi = fs_ugi
def ls(self, path):
""" hdfs_ls """
cmd = self.hadoop_bin + " fs -D fs.default.name=" + self.fs_name
cmd += " -D hadoop.job.ugi=" + self.fs_ugi
cmd += " -ls " + path
cmd += " | grep part | awk '{print $8}'"
with os.popen(cmd) as reader:
filelist = reader.read().split()
return filelist
def open(self, filename):
""" hdfs_file_open """
cmd = self.hadoop_bin + " fs -D fs.default.name=" + self.fs_name
cmd += " -D hadoop.job.ugi=" + self.fs_ugi
cmd += " -cat " + filename
p = os.popen(cmd)
return p
class Dataset(object):
"""An abstract class represening Dataset.
Generally, all datasets should subclass it.
All subclasses should overwrite :code:`__getitem__` and :code:`__len__`.
Examples:
.. code-block:: python
from pgl.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self):
self.data = list(range(0, 40))
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
"""
def __len__(self):
......@@ -33,7 +81,66 @@ class StreamDataset(object):
"""An abstract class represening StreamDataset which has unknown length.
Generally, all unknown length datasets should subclass it.
All subclasses should overwrite :code:`__iter__`.
Examples:
.. code-block:: python
from pgl.utils.data import StreamDataset
class MyStreamDataset(StreamDataset):
def __init__(self):
self.data = list(range(0, 40))
self.count = 0
def __iter__(self):
for data in self.dataset:
self.count += 1
if self.count % self._worker_info.num_workers != self._worker_info.fid:
continue
# do something (like parse data) of your data
time.sleep(0.1)
yield data
"""
def __iter__(self):
raise NotImplementedError
def _set_worker_info(self, worker_info):
self._worker_info = worker_info
class HadoopDataset(StreamDataset):
"""An abstract class represening HadoopDataset which loads data from hdfs.
All subclasses should overwrite :code:`__iter__`.
Examples:
.. code-block:: python
from pgl.utils.data import HadoopDataset
class MyHadoopDataset(HadoopDataset):
def __init__(self, data_path, hadoop_bin, fs_name, fs_ugi):
super(MyHadoopDataset, self).__init__(hadoop_bin, fs_name, fs_ugi)
self.data_path = data_path
def __iter__(self):
for line in self._line_data_generator():
yield line
def _line_data_generator(self):
paths = self.hadoop_util.ls(self.data_path)
paths = sorted(paths)
for idx, filename in enumerate(paths):
if idx % self._worker_info.num_workers != self._worker_info.fid:
continue
with self.hadoop_util.open(filename) as f:
for line in f:
yield line
"""
def __init__(self, hadoop_bin, fs_name, fs_ugi):
self.hadoop_util = HadoopUtil(
hadoop_bin=hadoop_bin, fs_name=fs_name, fs_ugi=fs_ugi)
def __iter__(self):
raise NotImplementedError
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册