diff --git a/pgl/tests/test_dataloader.py b/pgl/tests/test_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..363b82082be71bf5651662395d788b9df8925ccc --- /dev/null +++ b/pgl/tests/test_dataloader.py @@ -0,0 +1,139 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""test_dataloader""" + +import time +import unittest +import json +import os + +from pgl.utils.data.dataset import Dataset, StreamDataset +from pgl.utils.data.dataloader import Dataloader + +DATA_SIZE = 20 + + +class ListDataset(Dataset): + def __init__(self): + self.dataset = list(range(0, DATA_SIZE)) + + def __getitem__(self, idx): + return self._transform(self.dataset[idx]) + + def __len__(self): + return len(self.dataset) + + def _transform(self, example): + return example + + +class IterDataset(StreamDataset): + def __init__(self): + self.dataset = list(range(0, DATA_SIZE)) + + def __iter__(self): + for data in self.dataset: + yield data + + +class Collate_fn(object): + def __init__(self, config): + self.config = config + + def __call__(self, batch_examples): + feed_dict = {} + feed_dict['data'] = batch_examples + feed_dict['labels'] = [i for i in range(len(batch_examples))] + return feed_dict + + +class DataloaderTest(unittest.TestCase): + def test_ListDataset(self): + config = { + 'batch_size': 3, + 'drop_last': True, + 'shuffle': True, + 'num_workers': 2, + } + collate_fn = Collate_fn(config) + ds = ListDataset() + + # test batch_size + loader = Dataloader( + ds, + batch_size=config['batch_size'], + drop_last=config['drop_last'], + num_workers=config['num_workers'], + collate_fn=collate_fn) + + epochs = 1 + for e in range(epochs): + res = [] + for batch_data in loader: + res.extend(batch_data['data']) + self.assertEqual(len(batch_data['data']), config['batch_size']) + + # test shuffle + loader = Dataloader( + ds, + batch_size=3, + drop_last=False, + num_workers=1, + collate_fn=collate_fn) + + for e in range(epochs): + res = [] + for batch_data in loader: + res.extend(batch_data['data']) + self.assertEqual(set([i for i in range(DATA_SIZE)]), set(res)) + + def test_IterDataset(self): + config = { + 'batch_size': 3, + 'drop_last': True, + 'num_workers': 2, + } + collate_fn = Collate_fn(config) + ds = IterDataset() + loader = Dataloader( + ds, + batch_size=config['batch_size'], + drop_last=config['drop_last'], + num_workers=config['num_workers'], + collate_fn=collate_fn) + + epochs = 1 + for e in range(epochs): + res = [] + for batch_data in loader: + res.extend(batch_data['data']) + self.assertEqual(len(batch_data['data']), config['batch_size']) + + # test shuffle + loader = Dataloader( + ds, + batch_size=3, + drop_last=False, + num_workers=1, + collate_fn=collate_fn) + + for e in range(epochs): + res = [] + for batch_data in loader: + res.extend(batch_data['data']) + self.assertEqual(set([i for i in range(DATA_SIZE)]), set(res)) + + +if __name__ == "__main__": + unittest.main() diff --git a/pgl/utils/data/__init__.py b/pgl/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54cf127c22dc74c34b620f9eacb8a2a10f484f68 --- /dev/null +++ b/pgl/utils/data/__init__.py @@ -0,0 +1,3 @@ +#-*- coding: utf-8 -*- +from .dataset import Dataset, StreamDataset +from .dataloader import Dataloader diff --git a/pgl/utils/data/dataloader.py b/pgl/utils/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..371f1993146406f81e1153413019be623ab2a8f7 --- /dev/null +++ b/pgl/utils/data/dataloader.py @@ -0,0 +1,147 @@ +#-*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""dataloader +""" + +import numpy as np + +import paddle +import paddle.fluid as F +import paddle.fluid.layers as L + +from pgl.utils import mp_reader +from pgl.utils.data.dataset import Dataset, StreamDataset +from pgl.utils.data.sampler import Sampler, StreamSampler + + +class Dataloader(object): + """Dataloader + """ + + def __init__( + self, + dataset, + batch_size=1, + drop_last=False, + shuffle=False, + num_workers=1, + collate_fn=None, + buf_size=1000, ): + + self.dataset = dataset + self.batch_size = batch_size + self.shuffle = shuffle + self.num_workers = num_workers + self.collate_fn = collate_fn + self.buf_size = buf_size + self.drop_last = drop_last + + def __len__(self): + if not isinstance(self.dataset, StreamDataset): + return len(self.sampler) + else: + raise "StreamDataset has no length" + + def __iter__(self): + # generating a iterable sequence for produce batch data without repetition + if isinstance(self.dataset, StreamDataset): # for stream data + self.sampler = StreamSampler( + self.dataset, + batch_size=self.batch_size, + drop_last=self.drop_last) + else: + self.sampler = Sampler( + self.dataset, + batch_size=self.batch_size, + drop_last=self.drop_last, + shuffle=self.shuffle) + + if self.num_workers == 1: + r = paddle.reader.buffered(_DataLoaderIter(self, 0), self.buf_size) + else: + worker_pool = [ + _DataLoaderIter(self, wid) for wid in range(self.num_workers) + ] + workers = mp_reader.multiprocess_reader( + worker_pool, use_pipe=True, queue_size=1000) + r = paddle.reader.buffered(workers, self.buf_size) + + for batch in r(): + yield batch + + def __call__(self): + return self.__iter__() + + +class _DataLoaderIter(object): + def __init__(self, dataloader, fid=0): + self.dataset = dataloader.dataset + self.sampler = dataloader.sampler + self.collate_fn = dataloader.collate_fn + self.num_workers = dataloader.num_workers + self.drop_last = dataloader.drop_last + self.fid = fid + self.count = 0 + + def _data_generator(self): + for indices in self.sampler: + + self.count += 1 + if self.count % self.num_workers != self.fid: + continue + + batch_data = [self.dataset[i] for i in indices] + + if self.collate_fn is not None: + yield self.collate_fn(batch_data) + else: + yield batch_data + + def _streamdata_generator(self): + dataset = iter(self.dataset) + for indices in self.sampler: + batch_data = [] + for _ in indices: + try: + batch_data.append(next(dataset)) + except StopIteration: + break + + if len(batch_data) == 0 or (self.drop_last and + len(batch_data) < len(indices)): + break + # raise StopIteration + + # make sure do not repeat in multiprocessing + self.count += 1 + if self.count % self.num_workers != self.fid: + continue + + if self.collate_fn is not None: + yield self.collate_fn(batch_data) + else: + yield batch_data + + def __iter__(self): + if isinstance(self.dataset, StreamDataset): + data_generator = self._streamdata_generator + else: + data_generator = self._data_generator + + for batch_data in data_generator(): + yield batch_data + + def __call__(self): + return self.__iter__() diff --git a/pgl/utils/data/dataset.py b/pgl/utils/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bffda58c1f8b29be6c298a3bc0dc64a964c7bb3a --- /dev/null +++ b/pgl/utils/data/dataset.py @@ -0,0 +1,39 @@ +#-*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""dataset +""" + + +class Dataset(object): + """An abstract class represening Dataset. + Generally, all datasets should subclass it. + All subclasses should overwrite :code:`__getitem__` and :code:`__len__`. + """ + + def __len__(self): + raise NotImplementedError + + def __getitem__(self, idx): + raise NotImplementedError + + +class StreamDataset(object): + """An abstract class represening StreamDataset which has unknown length. + Generally, all unknown length datasets should subclass it. + All subclasses should overwrite :code:`__iter__`. + """ + + def __iter__(self): + raise NotImplementedError diff --git a/pgl/utils/data/sampler.py b/pgl/utils/data/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..3cfd2bf430e340d5d093733812db2cd492a72ab0 --- /dev/null +++ b/pgl/utils/data/sampler.py @@ -0,0 +1,67 @@ +#-*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""sampler +""" + +import time +import numpy as np + + +class Sampler(object): + def __init__(self, dataset, batch_size=1, drop_last=False, shuffle=False): + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + self.shuffle = shuffle + + length = len(self.dataset) + self.perm = np.arange(0, length) + + # shuffle one time whne Sampler is created + if self.shuffle: + seed = int(float(time.time()) * 1000) % 10000007 + np.random.seed(seed) + np.random.shuffle(self.perm) + + def __iter__(self): + batch = [] + for idx in self.perm: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + + if len(batch) > 0 and not self.drop_last: + yield batch + + def __len__(self): + length = len(self.dataset) + if self.drop_last: + length = length // self.batch_size + else: + length = (length + self.batch_size - 1) // self.batch_size + return length + + +class StreamSampler(object): + def __init__(self, dataset, batch_size=1, drop_last=None): + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self): + batch = [i for i in range(self.batch_size)] + while True: + yield batch