diff --git a/pgl/tests/test_dataloader.py b/pgl/tests/test_dataloader.py index 363b82082be71bf5651662395d788b9df8925ccc..e3323903c8e600ea8ce11767a35e44059c20862d 100644 --- a/pgl/tests/test_dataloader.py +++ b/pgl/tests/test_dataloader.py @@ -35,15 +35,21 @@ class ListDataset(Dataset): return len(self.dataset) def _transform(self, example): + time.sleep(0.1) return example class IterDataset(StreamDataset): def __init__(self): self.dataset = list(range(0, DATA_SIZE)) + self.count = 0 def __iter__(self): for data in self.dataset: + self.count += 1 + if self.count % self._worker_info.num_workers != self._worker_info.fid: + continue + time.sleep(0.1) yield data @@ -89,6 +95,7 @@ class DataloaderTest(unittest.TestCase): ds, batch_size=3, drop_last=False, + shuffle=True, num_workers=1, collate_fn=collate_fn) diff --git a/pgl/utils/data/dataloader.py b/pgl/utils/data/dataloader.py index 371f1993146406f81e1153413019be623ab2a8f7..71528bfb83a9119e1afd7b1dd686877d94ab4b43 100644 --- a/pgl/utils/data/dataloader.py +++ b/pgl/utils/data/dataloader.py @@ -16,6 +16,7 @@ """ import numpy as np +from collections import namedtuple import paddle import paddle.fluid as F @@ -25,9 +26,42 @@ from pgl.utils import mp_reader from pgl.utils.data.dataset import Dataset, StreamDataset from pgl.utils.data.sampler import Sampler, StreamSampler +WorkerInfo = namedtuple("WorkerInfo", ["num_workers", "fid"]) + class Dataloader(object): - """Dataloader + """Dataloader for loading batch data + + Example: + .. code-block:: python + from pgl.utils.data import Dataset + from pgl.utils.data.dataloader import Dataloader + + class MyDataset(Dataset): + def __init__(self): + self.data = list(range(0, 40)) + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return len(self.data) + + def collate_fn(batch_examples): + feed_dict = {} + feed_dict['data'] = batch_examples + return feed_dict + + dataset = MyDataset() + loader = Dataloader(dataset, + batch_size=3, + drop_last=False, + shuffle=True, + num_workers=4, + collate_fn=collate_fn) + + for batch_data in loader: + print(batch_data) """ def __init__( @@ -86,6 +120,9 @@ class Dataloader(object): class _DataLoaderIter(object): + """Iterable DataLoader Object + """ + def __init__(self, dataloader, fid=0): self.dataset = dataloader.dataset self.sampler = dataloader.sampler @@ -110,6 +147,10 @@ class _DataLoaderIter(object): yield batch_data def _streamdata_generator(self): + self._worker_info = WorkerInfo( + num_workers=self.num_workers, fid=self.fid) + self.dataset._set_worker_info(self._worker_info) + dataset = iter(self.dataset) for indices in self.sampler: batch_data = [] @@ -126,8 +167,8 @@ class _DataLoaderIter(object): # make sure do not repeat in multiprocessing self.count += 1 - if self.count % self.num_workers != self.fid: - continue + # if self.count % self.num_workers != self.fid: + # continue if self.collate_fn is not None: yield self.collate_fn(batch_data) diff --git a/pgl/utils/data/dataset.py b/pgl/utils/data/dataset.py index bffda58c1f8b29be6c298a3bc0dc64a964c7bb3a..a139cfe0c615574b3fd9fe194fc753f41f63f764 100644 --- a/pgl/utils/data/dataset.py +++ b/pgl/utils/data/dataset.py @@ -15,11 +15,59 @@ """dataset """ +import os +import sys +import numpy as np +import json + + +class HadoopUtil(object): + """Implementation of some common hadoop operations. + """ + + def __init__(self, hadoop_bin, fs_name, fs_ugi): + self.hadoop_bin = hadoop_bin + self.fs_name = fs_name + self.fs_ugi = fs_ugi + + def ls(self, path): + """ hdfs_ls """ + cmd = self.hadoop_bin + " fs -D fs.default.name=" + self.fs_name + cmd += " -D hadoop.job.ugi=" + self.fs_ugi + cmd += " -ls " + path + cmd += " | grep part | awk '{print $8}'" + with os.popen(cmd) as reader: + filelist = reader.read().split() + return filelist + + def open(self, filename): + """ hdfs_file_open """ + cmd = self.hadoop_bin + " fs -D fs.default.name=" + self.fs_name + cmd += " -D hadoop.job.ugi=" + self.fs_ugi + cmd += " -cat " + filename + p = os.popen(cmd) + return p + class Dataset(object): """An abstract class represening Dataset. Generally, all datasets should subclass it. All subclasses should overwrite :code:`__getitem__` and :code:`__len__`. + + Examples: + .. code-block:: python + + from pgl.utils.data import Dataset + + class MyDataset(Dataset): + def __init__(self): + self.data = list(range(0, 40)) + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return len(self.data) """ def __len__(self): @@ -33,7 +81,66 @@ class StreamDataset(object): """An abstract class represening StreamDataset which has unknown length. Generally, all unknown length datasets should subclass it. All subclasses should overwrite :code:`__iter__`. + + Examples: + .. code-block:: python + + from pgl.utils.data import StreamDataset + + class MyStreamDataset(StreamDataset): + def __init__(self): + self.data = list(range(0, 40)) + self.count = 0 + + def __iter__(self): + for data in self.dataset: + self.count += 1 + if self.count % self._worker_info.num_workers != self._worker_info.fid: + continue + # do something (like parse data) of your data + time.sleep(0.1) + yield data + """ + + def __iter__(self): + raise NotImplementedError + + def _set_worker_info(self, worker_info): + self._worker_info = worker_info + + +class HadoopDataset(StreamDataset): + """An abstract class represening HadoopDataset which loads data from hdfs. + All subclasses should overwrite :code:`__iter__`. + + Examples: + .. code-block:: python + + from pgl.utils.data import HadoopDataset + + class MyHadoopDataset(HadoopDataset): + def __init__(self, data_path, hadoop_bin, fs_name, fs_ugi): + super(MyHadoopDataset, self).__init__(hadoop_bin, fs_name, fs_ugi) + self.data_path = data_path + + def __iter__(self): + for line in self._line_data_generator(): + yield line + + def _line_data_generator(self): + paths = self.hadoop_util.ls(self.data_path) + paths = sorted(paths) + for idx, filename in enumerate(paths): + if idx % self._worker_info.num_workers != self._worker_info.fid: + continue + with self.hadoop_util.open(filename) as f: + for line in f: + yield line """ + def __init__(self, hadoop_bin, fs_name, fs_ugi): + self.hadoop_util = HadoopUtil( + hadoop_bin=hadoop_bin, fs_name=fs_name, fs_ugi=fs_ugi) + def __iter__(self): raise NotImplementedError