提交 8d0e023e 编写于 作者: W Webbley

add dataloader

上级 970bc3f1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test_dataloader"""
import time
import unittest
import json
import os
from pgl.utils.data.dataset import Dataset, StreamDataset
from pgl.utils.data.dataloader import Dataloader
DATA_SIZE = 20
class ListDataset(Dataset):
def __init__(self):
self.dataset = list(range(0, DATA_SIZE))
def __getitem__(self, idx):
return self._transform(self.dataset[idx])
def __len__(self):
return len(self.dataset)
def _transform(self, example):
return example
class IterDataset(StreamDataset):
def __init__(self):
self.dataset = list(range(0, DATA_SIZE))
def __iter__(self):
for data in self.dataset:
yield data
class Collate_fn(object):
def __init__(self, config):
self.config = config
def __call__(self, batch_examples):
feed_dict = {}
feed_dict['data'] = batch_examples
feed_dict['labels'] = [i for i in range(len(batch_examples))]
return feed_dict
class DataloaderTest(unittest.TestCase):
def test_ListDataset(self):
config = {
'batch_size': 3,
'drop_last': True,
'shuffle': True,
'num_workers': 2,
}
collate_fn = Collate_fn(config)
ds = ListDataset()
# test batch_size
loader = Dataloader(
ds,
batch_size=config['batch_size'],
drop_last=config['drop_last'],
num_workers=config['num_workers'],
collate_fn=collate_fn)
epochs = 1
for e in range(epochs):
res = []
for batch_data in loader:
res.extend(batch_data['data'])
self.assertEqual(len(batch_data['data']), config['batch_size'])
# test shuffle
loader = Dataloader(
ds,
batch_size=3,
drop_last=False,
num_workers=1,
collate_fn=collate_fn)
for e in range(epochs):
res = []
for batch_data in loader:
res.extend(batch_data['data'])
self.assertEqual(set([i for i in range(DATA_SIZE)]), set(res))
def test_IterDataset(self):
config = {
'batch_size': 3,
'drop_last': True,
'num_workers': 2,
}
collate_fn = Collate_fn(config)
ds = IterDataset()
loader = Dataloader(
ds,
batch_size=config['batch_size'],
drop_last=config['drop_last'],
num_workers=config['num_workers'],
collate_fn=collate_fn)
epochs = 1
for e in range(epochs):
res = []
for batch_data in loader:
res.extend(batch_data['data'])
self.assertEqual(len(batch_data['data']), config['batch_size'])
# test shuffle
loader = Dataloader(
ds,
batch_size=3,
drop_last=False,
num_workers=1,
collate_fn=collate_fn)
for e in range(epochs):
res = []
for batch_data in loader:
res.extend(batch_data['data'])
self.assertEqual(set([i for i in range(DATA_SIZE)]), set(res))
if __name__ == "__main__":
unittest.main()
#-*- coding: utf-8 -*-
from .dataset import Dataset, StreamDataset
from .dataloader import Dataloader
#-*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""dataloader
"""
import numpy as np
import paddle
import paddle.fluid as F
import paddle.fluid.layers as L
from pgl.utils import mp_reader
from pgl.utils.data.dataset import Dataset, StreamDataset
from pgl.utils.data.sampler import Sampler, StreamSampler
class Dataloader(object):
"""Dataloader
"""
def __init__(
self,
dataset,
batch_size=1,
drop_last=False,
shuffle=False,
num_workers=1,
collate_fn=None,
buf_size=1000, ):
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self.collate_fn = collate_fn
self.buf_size = buf_size
self.drop_last = drop_last
def __len__(self):
if not isinstance(self.dataset, StreamDataset):
return len(self.sampler)
else:
raise "StreamDataset has no length"
def __iter__(self):
# generating a iterable sequence for produce batch data without repetition
if isinstance(self.dataset, StreamDataset): # for stream data
self.sampler = StreamSampler(
self.dataset,
batch_size=self.batch_size,
drop_last=self.drop_last)
else:
self.sampler = Sampler(
self.dataset,
batch_size=self.batch_size,
drop_last=self.drop_last,
shuffle=self.shuffle)
if self.num_workers == 1:
r = paddle.reader.buffered(_DataLoaderIter(self, 0), self.buf_size)
else:
worker_pool = [
_DataLoaderIter(self, wid) for wid in range(self.num_workers)
]
workers = mp_reader.multiprocess_reader(
worker_pool, use_pipe=True, queue_size=1000)
r = paddle.reader.buffered(workers, self.buf_size)
for batch in r():
yield batch
def __call__(self):
return self.__iter__()
class _DataLoaderIter(object):
def __init__(self, dataloader, fid=0):
self.dataset = dataloader.dataset
self.sampler = dataloader.sampler
self.collate_fn = dataloader.collate_fn
self.num_workers = dataloader.num_workers
self.drop_last = dataloader.drop_last
self.fid = fid
self.count = 0
def _data_generator(self):
for indices in self.sampler:
self.count += 1
if self.count % self.num_workers != self.fid:
continue
batch_data = [self.dataset[i] for i in indices]
if self.collate_fn is not None:
yield self.collate_fn(batch_data)
else:
yield batch_data
def _streamdata_generator(self):
dataset = iter(self.dataset)
for indices in self.sampler:
batch_data = []
for _ in indices:
try:
batch_data.append(next(dataset))
except StopIteration:
break
if len(batch_data) == 0 or (self.drop_last and
len(batch_data) < len(indices)):
break
# raise StopIteration
# make sure do not repeat in multiprocessing
self.count += 1
if self.count % self.num_workers != self.fid:
continue
if self.collate_fn is not None:
yield self.collate_fn(batch_data)
else:
yield batch_data
def __iter__(self):
if isinstance(self.dataset, StreamDataset):
data_generator = self._streamdata_generator
else:
data_generator = self._data_generator
for batch_data in data_generator():
yield batch_data
def __call__(self):
return self.__iter__()
#-*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""dataset
"""
class Dataset(object):
"""An abstract class represening Dataset.
Generally, all datasets should subclass it.
All subclasses should overwrite :code:`__getitem__` and :code:`__len__`.
"""
def __len__(self):
raise NotImplementedError
def __getitem__(self, idx):
raise NotImplementedError
class StreamDataset(object):
"""An abstract class represening StreamDataset which has unknown length.
Generally, all unknown length datasets should subclass it.
All subclasses should overwrite :code:`__iter__`.
"""
def __iter__(self):
raise NotImplementedError
#-*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""sampler
"""
import time
import numpy as np
class Sampler(object):
def __init__(self, dataset, batch_size=1, drop_last=False, shuffle=False):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle = shuffle
length = len(self.dataset)
self.perm = np.arange(0, length)
# shuffle one time whne Sampler is created
if self.shuffle:
seed = int(float(time.time()) * 1000) % 10000007
np.random.seed(seed)
np.random.shuffle(self.perm)
def __iter__(self):
batch = []
for idx in self.perm:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
length = len(self.dataset)
if self.drop_last:
length = length // self.batch_size
else:
length = (length + self.batch_size - 1) // self.batch_size
return length
class StreamSampler(object):
def __init__(self, dataset, batch_size=1, drop_last=None):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = [i for i in range(self.batch_size)]
while True:
yield batch
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册