diff --git a/pgl/tests/test_dataloader.py b/pgl/tests/test_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..e3323903c8e600ea8ce11767a35e44059c20862d --- /dev/null +++ b/pgl/tests/test_dataloader.py @@ -0,0 +1,146 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""test_dataloader""" + +import time +import unittest +import json +import os + +from pgl.utils.data.dataset import Dataset, StreamDataset +from pgl.utils.data.dataloader import Dataloader + +DATA_SIZE = 20 + + +class ListDataset(Dataset): + def __init__(self): + self.dataset = list(range(0, DATA_SIZE)) + + def __getitem__(self, idx): + return self._transform(self.dataset[idx]) + + def __len__(self): + return len(self.dataset) + + def _transform(self, example): + time.sleep(0.1) + return example + + +class IterDataset(StreamDataset): + def __init__(self): + self.dataset = list(range(0, DATA_SIZE)) + self.count = 0 + + def __iter__(self): + for data in self.dataset: + self.count += 1 + if self.count % self._worker_info.num_workers != self._worker_info.fid: + continue + time.sleep(0.1) + yield data + + +class Collate_fn(object): + def __init__(self, config): + self.config = config + + def __call__(self, batch_examples): + feed_dict = {} + feed_dict['data'] = batch_examples + feed_dict['labels'] = [i for i in range(len(batch_examples))] + return feed_dict + + +class DataloaderTest(unittest.TestCase): + def test_ListDataset(self): + config = { + 'batch_size': 3, + 'drop_last': True, + 'shuffle': True, + 'num_workers': 2, + } + collate_fn = Collate_fn(config) + ds = ListDataset() + + # test batch_size + loader = Dataloader( + ds, + batch_size=config['batch_size'], + drop_last=config['drop_last'], + num_workers=config['num_workers'], + collate_fn=collate_fn) + + epochs = 1 + for e in range(epochs): + res = [] + for batch_data in loader: + res.extend(batch_data['data']) + self.assertEqual(len(batch_data['data']), config['batch_size']) + + # test shuffle + loader = Dataloader( + ds, + batch_size=3, + drop_last=False, + shuffle=True, + num_workers=1, + collate_fn=collate_fn) + + for e in range(epochs): + res = [] + for batch_data in loader: + res.extend(batch_data['data']) + self.assertEqual(set([i for i in range(DATA_SIZE)]), set(res)) + + def test_IterDataset(self): + config = { + 'batch_size': 3, + 'drop_last': True, + 'num_workers': 2, + } + collate_fn = Collate_fn(config) + ds = IterDataset() + loader = Dataloader( + ds, + batch_size=config['batch_size'], + drop_last=config['drop_last'], + num_workers=config['num_workers'], + collate_fn=collate_fn) + + epochs = 1 + for e in range(epochs): + res = [] + for batch_data in loader: + res.extend(batch_data['data']) + self.assertEqual(len(batch_data['data']), config['batch_size']) + + # test shuffle + loader = Dataloader( + ds, + batch_size=3, + drop_last=False, + num_workers=1, + collate_fn=collate_fn) + + for e in range(epochs): + res = [] + for batch_data in loader: + res.extend(batch_data['data']) + self.assertEqual(set([i for i in range(DATA_SIZE)]), set(res)) + + +if __name__ == "__main__": + unittest.main() diff --git a/pgl/utils/data/__init__.py b/pgl/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54cf127c22dc74c34b620f9eacb8a2a10f484f68 --- /dev/null +++ b/pgl/utils/data/__init__.py @@ -0,0 +1,3 @@ +#-*- coding: utf-8 -*- +from .dataset import Dataset, StreamDataset +from .dataloader import Dataloader diff --git a/pgl/utils/data/dataloader.py b/pgl/utils/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..4100bc89ef9674d155aa27cefee8e80bec6759bb --- /dev/null +++ b/pgl/utils/data/dataloader.py @@ -0,0 +1,252 @@ +#-*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""dataloader +""" +import warnings +import numpy as np +from collections import namedtuple + +import paddle +import paddle.fluid as F +import paddle.fluid.layers as L + +from pgl.utils import mp_reader +from pgl.utils.data.dataset import Dataset, StreamDataset +from pgl.utils.data.sampler import Sampler, StreamSampler + +WorkerInfo = namedtuple("WorkerInfo", ["num_workers", "fid"]) + + +class Dataloader(object): + """Dataloader for loading batch data + + Example: + .. code-block:: python + from pgl.utils.data import Dataset + from pgl.utils.data.dataloader import Dataloader + + class MyDataset(Dataset): + def __init__(self): + self.data = list(range(0, 40)) + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return len(self.data) + + def collate_fn(batch_examples): + feed_dict = {} + feed_dict['data'] = batch_examples + return feed_dict + + dataset = MyDataset() + loader = Dataloader(dataset, + batch_size=3, + drop_last=False, + shuffle=True, + num_workers=4, + collate_fn=collate_fn) + + for batch_data in loader: + print(batch_data) + """ + + def __init__(self, + dataset, + batch_size=1, + drop_last=False, + shuffle=False, + num_workers=1, + collate_fn=None, + buf_size=1000, + stream_shuffle_size=0): + + self.dataset = dataset + self.batch_size = batch_size + self.shuffle = shuffle + self.num_workers = num_workers + self.collate_fn = collate_fn + self.buf_size = buf_size + self.drop_last = drop_last + self.stream_shuffle_size = stream_shuffle_size + + if self.shuffle and isinstance(self.dataset, StreamDataset): + warn_msg = "[shuffle] should not be True with StreamDataset. " \ + "It will be ignored. " \ + "You might want to set [stream_shuffle_size] with StreamDataset." + warnings.warn(warn_msg) + + if self.stream_shuffle_size > 0 and self.batch_size >= stream_shuffle_size: + raise ValueError("stream_shuffle_size must be larger than batch_size," \ + "but got [stream_shuffle_size=%s] smaller than [batch_size=%s]" \ + % (self.stream_shuffle_size, self.batch_size)) + + if self.num_workers < 1: + raise ValueError("num_workers(default: 1) should be larger than 0, " \ + "but got [num_workers=%s] < 1." % self.num_workers) + + def __len__(self): + if not isinstance(self.dataset, StreamDataset): + return len(self.sampler) + else: + raise "StreamDataset has no length" + + def __iter__(self): + # generating a iterable sequence for produce batch data without repetition + if isinstance(self.dataset, StreamDataset): # for stream data + self.sampler = StreamSampler( + self.dataset, + batch_size=self.batch_size, + drop_last=self.drop_last) + else: + self.sampler = Sampler( + self.dataset, + batch_size=self.batch_size, + drop_last=self.drop_last, + shuffle=self.shuffle) + + if self.num_workers == 1: + r = paddle.reader.buffered(_DataLoaderIter(self, 0), self.buf_size) + else: + worker_pool = [ + _DataLoaderIter(self, wid) for wid in range(self.num_workers) + ] + workers = mp_reader.multiprocess_reader( + worker_pool, use_pipe=True, queue_size=1000) + r = paddle.reader.buffered(workers, self.buf_size) + + for batch in r(): + yield batch + + def __call__(self): + return self.__iter__() + + +class _DataLoaderIter(object): + """Iterable DataLoader Object + """ + + def __init__(self, dataloader, fid=0): + self.dataset = dataloader.dataset + self.sampler = dataloader.sampler + self.collate_fn = dataloader.collate_fn + self.num_workers = dataloader.num_workers + self.drop_last = dataloader.drop_last + self.batch_size = dataloader.batch_size + self.stream_shuffle_size = dataloader.stream_shuffle_size + self.fid = fid + self.count = 0 + + def _data_generator(self): + for indices in self.sampler: + + self.count += 1 + if self.count % self.num_workers != self.fid: + continue + + batch_data = [self.dataset[i] for i in indices] + + if self.collate_fn is not None: + yield self.collate_fn(batch_data) + else: + yield batch_data + + def _streamdata_generator(self): + self._worker_info = WorkerInfo( + num_workers=self.num_workers, fid=self.fid) + self.dataset._set_worker_info(self._worker_info) + + dataset = iter(self.dataset) + for indices in self.sampler: + batch_data = [] + for _ in indices: + try: + batch_data.append(next(dataset)) + except StopIteration: + break + + if len(batch_data) == 0 or (self.drop_last and + len(batch_data) < len(indices)): + break + # raise StopIteration + + # make sure do not repeat in multiprocessing + self.count += 1 + + if self.collate_fn is not None: + yield self.collate_fn(batch_data) + else: + yield batch_data + + def _stream_shuffle_data_generator(self): + def _stream_shuffle_index_generator(): + shuffle_size = [i for i in range(self.stream_shuffle_size)] + while True: + yield shuffle_size + + def _data_generator(): + dataset = iter(self.dataset) + for shuffle_size in _stream_shuffle_index_generator(): + shuffle_size_data = [] + for idx in shuffle_size: + try: + shuffle_size_data.append(next(dataset)) + except StopIteration: + break + + if len(shuffle_size_data) == 0: + break + + yield shuffle_size_data + + def _batch_data_generator(): + batch_data = [] + for shuffle_size_data in _data_generator(): + np.random.shuffle(shuffle_size_data) + + for d in shuffle_size_data: + batch_data.append(d) + if len(batch_data) == self.batch_size: + yield batch_data + batch_data = [] + + if not self.drop_last and len(batch_data) > 0: + yield batch_data + + self._worker_info = WorkerInfo( + num_workers=self.num_workers, fid=self.fid) + self.dataset._set_worker_info(self._worker_info) + + for batch_data in _batch_data_generator(): + if self.collate_fn is not None: + yield self.collate_fn(batch_data) + else: + yield batch_data + + def __iter__(self): + if isinstance(self.dataset, StreamDataset): + if self.stream_shuffle_size > 0: + data_generator = self._stream_shuffle_data_generator + else: + data_generator = self._streamdata_generator + else: + data_generator = self._data_generator + + for batch_data in data_generator(): + yield batch_data + + def __call__(self): + return self.__iter__() diff --git a/pgl/utils/data/dataset.py b/pgl/utils/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..184bb555388e58283435ab6e00058f699ddb12c8 --- /dev/null +++ b/pgl/utils/data/dataset.py @@ -0,0 +1,151 @@ +#-*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""dataset +""" + +import os +import sys +import numpy as np +import json +import io +from subprocess import Popen, PIPE + + +class HadoopUtil(object): + """Implementation of some common hadoop operations. + + """ + + def __init__(self, hadoop_bin, fs_name, fs_ugi): + self.hadoop_bin = hadoop_bin + self.fs_name = fs_name + self.fs_ugi = fs_ugi + + def ls(self, path): + """ hdfs_ls """ + cmd = self.hadoop_bin + " fs -D fs.default.name=" + self.fs_name + cmd += " -D hadoop.job.ugi=" + self.fs_ugi + cmd += " -ls " + path + cmd += " | grep part | awk '{print $8}'" + with os.popen(cmd) as reader: + filelist = reader.read().split() + return filelist + + def open(self, filename, encoding='utf-8'): + """ hdfs_file_open """ + cmd = self.hadoop_bin + " fs -D fs.default.name=" + self.fs_name + cmd += " -D hadoop.job.ugi=" + self.fs_ugi + cmd += " -cat " + filename + + p = Popen(cmd, shell=True, stdout=PIPE) + p = io.TextIOWrapper(p.stdout, encoding=encoding, errors='ignore') + return p + + +class Dataset(object): + """An abstract class represening Dataset. + Generally, all datasets should subclass it. + All subclasses should overwrite :code:`__getitem__` and :code:`__len__`. + + Examples: + .. code-block:: python + + from pgl.utils.data import Dataset + + class MyDataset(Dataset): + def __init__(self): + self.data = list(range(0, 40)) + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return len(self.data) + """ + + def __len__(self): + raise NotImplementedError + + def __getitem__(self, idx): + raise NotImplementedError + + +class StreamDataset(object): + """An abstract class represening StreamDataset which has unknown length. + Generally, all unknown length datasets should subclass it. + All subclasses should overwrite :code:`__iter__`. + + Examples: + .. code-block:: python + + from pgl.utils.data import StreamDataset + + class MyStreamDataset(StreamDataset): + def __init__(self): + self.data = list(range(0, 40)) + self.count = 0 + + def __iter__(self): + for data in self.dataset: + self.count += 1 + if self.count % self._worker_info.num_workers != self._worker_info.fid: + continue + # do something (like parse data) of your data + time.sleep(0.1) + yield data + """ + + def __iter__(self): + raise NotImplementedError + + def _set_worker_info(self, worker_info): + self._worker_info = worker_info + + +class HadoopDataset(StreamDataset): + """An abstract class represening HadoopDataset which loads data from hdfs. + All subclasses should overwrite :code:`__iter__`. + + Examples: + .. code-block:: python + + from pgl.utils.data import HadoopDataset + + class MyHadoopDataset(HadoopDataset): + def __init__(self, data_path, hadoop_bin, fs_name, fs_ugi): + super(MyHadoopDataset, self).__init__(hadoop_bin, fs_name, fs_ugi) + self.data_path = data_path + + def __iter__(self): + for line in self._line_data_generator(): + yield line + + def _line_data_generator(self): + paths = self.hadoop_util.ls(self.data_path) + paths = sorted(paths) + for idx, filename in enumerate(paths): + if idx % self._worker_info.num_workers != self._worker_info.fid: + continue + with self.hadoop_util.open(filename) as f: + for line in f: + yield line + """ + + def __init__(self, hadoop_bin, fs_name, fs_ugi): + self.hadoop_util = HadoopUtil( + hadoop_bin=hadoop_bin, fs_name=fs_name, fs_ugi=fs_ugi) + + def __iter__(self): + raise NotImplementedError diff --git a/pgl/utils/data/sampler.py b/pgl/utils/data/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..3cfd2bf430e340d5d093733812db2cd492a72ab0 --- /dev/null +++ b/pgl/utils/data/sampler.py @@ -0,0 +1,67 @@ +#-*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""sampler +""" + +import time +import numpy as np + + +class Sampler(object): + def __init__(self, dataset, batch_size=1, drop_last=False, shuffle=False): + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + self.shuffle = shuffle + + length = len(self.dataset) + self.perm = np.arange(0, length) + + # shuffle one time whne Sampler is created + if self.shuffle: + seed = int(float(time.time()) * 1000) % 10000007 + np.random.seed(seed) + np.random.shuffle(self.perm) + + def __iter__(self): + batch = [] + for idx in self.perm: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + + if len(batch) > 0 and not self.drop_last: + yield batch + + def __len__(self): + length = len(self.dataset) + if self.drop_last: + length = length // self.batch_size + else: + length = (length + self.batch_size - 1) // self.batch_size + return length + + +class StreamSampler(object): + def __init__(self, dataset, batch_size=1, drop_last=None): + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self): + batch = [i for i in range(self.batch_size)] + while True: + yield batch