未验证 提交 7e5da5f5 编写于 作者: H Huang Zhengjie 提交者: GitHub

Merge pull request #125 from Liwb5/main

add dataloader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test_dataloader"""
import time
import unittest
import json
import os
from pgl.utils.data.dataset import Dataset, StreamDataset
from pgl.utils.data.dataloader import Dataloader
DATA_SIZE = 20
class ListDataset(Dataset):
def __init__(self):
self.dataset = list(range(0, DATA_SIZE))
def __getitem__(self, idx):
return self._transform(self.dataset[idx])
def __len__(self):
return len(self.dataset)
def _transform(self, example):
time.sleep(0.1)
return example
class IterDataset(StreamDataset):
def __init__(self):
self.dataset = list(range(0, DATA_SIZE))
self.count = 0
def __iter__(self):
for data in self.dataset:
self.count += 1
if self.count % self._worker_info.num_workers != self._worker_info.fid:
continue
time.sleep(0.1)
yield data
class Collate_fn(object):
def __init__(self, config):
self.config = config
def __call__(self, batch_examples):
feed_dict = {}
feed_dict['data'] = batch_examples
feed_dict['labels'] = [i for i in range(len(batch_examples))]
return feed_dict
class DataloaderTest(unittest.TestCase):
def test_ListDataset(self):
config = {
'batch_size': 3,
'drop_last': True,
'shuffle': True,
'num_workers': 2,
}
collate_fn = Collate_fn(config)
ds = ListDataset()
# test batch_size
loader = Dataloader(
ds,
batch_size=config['batch_size'],
drop_last=config['drop_last'],
num_workers=config['num_workers'],
collate_fn=collate_fn)
epochs = 1
for e in range(epochs):
res = []
for batch_data in loader:
res.extend(batch_data['data'])
self.assertEqual(len(batch_data['data']), config['batch_size'])
# test shuffle
loader = Dataloader(
ds,
batch_size=3,
drop_last=False,
shuffle=True,
num_workers=1,
collate_fn=collate_fn)
for e in range(epochs):
res = []
for batch_data in loader:
res.extend(batch_data['data'])
self.assertEqual(set([i for i in range(DATA_SIZE)]), set(res))
def test_IterDataset(self):
config = {
'batch_size': 3,
'drop_last': True,
'num_workers': 2,
}
collate_fn = Collate_fn(config)
ds = IterDataset()
loader = Dataloader(
ds,
batch_size=config['batch_size'],
drop_last=config['drop_last'],
num_workers=config['num_workers'],
collate_fn=collate_fn)
epochs = 1
for e in range(epochs):
res = []
for batch_data in loader:
res.extend(batch_data['data'])
self.assertEqual(len(batch_data['data']), config['batch_size'])
# test shuffle
loader = Dataloader(
ds,
batch_size=3,
drop_last=False,
num_workers=1,
collate_fn=collate_fn)
for e in range(epochs):
res = []
for batch_data in loader:
res.extend(batch_data['data'])
self.assertEqual(set([i for i in range(DATA_SIZE)]), set(res))
if __name__ == "__main__":
unittest.main()
#-*- coding: utf-8 -*-
from .dataset import Dataset, StreamDataset
from .dataloader import Dataloader
#-*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""dataloader
"""
import warnings
import numpy as np
from collections import namedtuple
import paddle
import paddle.fluid as F
import paddle.fluid.layers as L
from pgl.utils import mp_reader
from pgl.utils.data.dataset import Dataset, StreamDataset
from pgl.utils.data.sampler import Sampler, StreamSampler
WorkerInfo = namedtuple("WorkerInfo", ["num_workers", "fid"])
class Dataloader(object):
"""Dataloader for loading batch data
Example:
.. code-block:: python
from pgl.utils.data import Dataset
from pgl.utils.data.dataloader import Dataloader
class MyDataset(Dataset):
def __init__(self):
self.data = list(range(0, 40))
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
def collate_fn(batch_examples):
feed_dict = {}
feed_dict['data'] = batch_examples
return feed_dict
dataset = MyDataset()
loader = Dataloader(dataset,
batch_size=3,
drop_last=False,
shuffle=True,
num_workers=4,
collate_fn=collate_fn)
for batch_data in loader:
print(batch_data)
"""
def __init__(self,
dataset,
batch_size=1,
drop_last=False,
shuffle=False,
num_workers=1,
collate_fn=None,
buf_size=1000,
stream_shuffle_size=0):
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self.collate_fn = collate_fn
self.buf_size = buf_size
self.drop_last = drop_last
self.stream_shuffle_size = stream_shuffle_size
if self.shuffle and isinstance(self.dataset, StreamDataset):
warn_msg = "[shuffle] should not be True with StreamDataset. " \
"It will be ignored. " \
"You might want to set [stream_shuffle_size] with StreamDataset."
warnings.warn(warn_msg)
if self.stream_shuffle_size > 0 and self.batch_size >= stream_shuffle_size:
raise ValueError("stream_shuffle_size must be larger than batch_size," \
"but got [stream_shuffle_size=%s] smaller than [batch_size=%s]" \
% (self.stream_shuffle_size, self.batch_size))
if self.num_workers < 1:
raise ValueError("num_workers(default: 1) should be larger than 0, " \
"but got [num_workers=%s] < 1." % self.num_workers)
def __len__(self):
if not isinstance(self.dataset, StreamDataset):
return len(self.sampler)
else:
raise "StreamDataset has no length"
def __iter__(self):
# generating a iterable sequence for produce batch data without repetition
if isinstance(self.dataset, StreamDataset): # for stream data
self.sampler = StreamSampler(
self.dataset,
batch_size=self.batch_size,
drop_last=self.drop_last)
else:
self.sampler = Sampler(
self.dataset,
batch_size=self.batch_size,
drop_last=self.drop_last,
shuffle=self.shuffle)
if self.num_workers == 1:
r = paddle.reader.buffered(_DataLoaderIter(self, 0), self.buf_size)
else:
worker_pool = [
_DataLoaderIter(self, wid) for wid in range(self.num_workers)
]
workers = mp_reader.multiprocess_reader(
worker_pool, use_pipe=True, queue_size=1000)
r = paddle.reader.buffered(workers, self.buf_size)
for batch in r():
yield batch
def __call__(self):
return self.__iter__()
class _DataLoaderIter(object):
"""Iterable DataLoader Object
"""
def __init__(self, dataloader, fid=0):
self.dataset = dataloader.dataset
self.sampler = dataloader.sampler
self.collate_fn = dataloader.collate_fn
self.num_workers = dataloader.num_workers
self.drop_last = dataloader.drop_last
self.batch_size = dataloader.batch_size
self.stream_shuffle_size = dataloader.stream_shuffle_size
self.fid = fid
self.count = 0
def _data_generator(self):
for indices in self.sampler:
self.count += 1
if self.count % self.num_workers != self.fid:
continue
batch_data = [self.dataset[i] for i in indices]
if self.collate_fn is not None:
yield self.collate_fn(batch_data)
else:
yield batch_data
def _streamdata_generator(self):
self._worker_info = WorkerInfo(
num_workers=self.num_workers, fid=self.fid)
self.dataset._set_worker_info(self._worker_info)
dataset = iter(self.dataset)
for indices in self.sampler:
batch_data = []
for _ in indices:
try:
batch_data.append(next(dataset))
except StopIteration:
break
if len(batch_data) == 0 or (self.drop_last and
len(batch_data) < len(indices)):
break
# raise StopIteration
# make sure do not repeat in multiprocessing
self.count += 1
if self.collate_fn is not None:
yield self.collate_fn(batch_data)
else:
yield batch_data
def _stream_shuffle_data_generator(self):
def _stream_shuffle_index_generator():
shuffle_size = [i for i in range(self.stream_shuffle_size)]
while True:
yield shuffle_size
def _data_generator():
dataset = iter(self.dataset)
for shuffle_size in _stream_shuffle_index_generator():
shuffle_size_data = []
for idx in shuffle_size:
try:
shuffle_size_data.append(next(dataset))
except StopIteration:
break
if len(shuffle_size_data) == 0:
break
yield shuffle_size_data
def _batch_data_generator():
batch_data = []
for shuffle_size_data in _data_generator():
np.random.shuffle(shuffle_size_data)
for d in shuffle_size_data:
batch_data.append(d)
if len(batch_data) == self.batch_size:
yield batch_data
batch_data = []
if not self.drop_last and len(batch_data) > 0:
yield batch_data
self._worker_info = WorkerInfo(
num_workers=self.num_workers, fid=self.fid)
self.dataset._set_worker_info(self._worker_info)
for batch_data in _batch_data_generator():
if self.collate_fn is not None:
yield self.collate_fn(batch_data)
else:
yield batch_data
def __iter__(self):
if isinstance(self.dataset, StreamDataset):
if self.stream_shuffle_size > 0:
data_generator = self._stream_shuffle_data_generator
else:
data_generator = self._streamdata_generator
else:
data_generator = self._data_generator
for batch_data in data_generator():
yield batch_data
def __call__(self):
return self.__iter__()
#-*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""dataset
"""
import os
import sys
import numpy as np
import json
import io
from subprocess import Popen, PIPE
class HadoopUtil(object):
"""Implementation of some common hadoop operations.
"""
def __init__(self, hadoop_bin, fs_name, fs_ugi):
self.hadoop_bin = hadoop_bin
self.fs_name = fs_name
self.fs_ugi = fs_ugi
def ls(self, path):
""" hdfs_ls """
cmd = self.hadoop_bin + " fs -D fs.default.name=" + self.fs_name
cmd += " -D hadoop.job.ugi=" + self.fs_ugi
cmd += " -ls " + path
cmd += " | grep part | awk '{print $8}'"
with os.popen(cmd) as reader:
filelist = reader.read().split()
return filelist
def open(self, filename, encoding='utf-8'):
""" hdfs_file_open """
cmd = self.hadoop_bin + " fs -D fs.default.name=" + self.fs_name
cmd += " -D hadoop.job.ugi=" + self.fs_ugi
cmd += " -cat " + filename
p = Popen(cmd, shell=True, stdout=PIPE)
p = io.TextIOWrapper(p.stdout, encoding=encoding, errors='ignore')
return p
class Dataset(object):
"""An abstract class represening Dataset.
Generally, all datasets should subclass it.
All subclasses should overwrite :code:`__getitem__` and :code:`__len__`.
Examples:
.. code-block:: python
from pgl.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self):
self.data = list(range(0, 40))
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
"""
def __len__(self):
raise NotImplementedError
def __getitem__(self, idx):
raise NotImplementedError
class StreamDataset(object):
"""An abstract class represening StreamDataset which has unknown length.
Generally, all unknown length datasets should subclass it.
All subclasses should overwrite :code:`__iter__`.
Examples:
.. code-block:: python
from pgl.utils.data import StreamDataset
class MyStreamDataset(StreamDataset):
def __init__(self):
self.data = list(range(0, 40))
self.count = 0
def __iter__(self):
for data in self.dataset:
self.count += 1
if self.count % self._worker_info.num_workers != self._worker_info.fid:
continue
# do something (like parse data) of your data
time.sleep(0.1)
yield data
"""
def __iter__(self):
raise NotImplementedError
def _set_worker_info(self, worker_info):
self._worker_info = worker_info
class HadoopDataset(StreamDataset):
"""An abstract class represening HadoopDataset which loads data from hdfs.
All subclasses should overwrite :code:`__iter__`.
Examples:
.. code-block:: python
from pgl.utils.data import HadoopDataset
class MyHadoopDataset(HadoopDataset):
def __init__(self, data_path, hadoop_bin, fs_name, fs_ugi):
super(MyHadoopDataset, self).__init__(hadoop_bin, fs_name, fs_ugi)
self.data_path = data_path
def __iter__(self):
for line in self._line_data_generator():
yield line
def _line_data_generator(self):
paths = self.hadoop_util.ls(self.data_path)
paths = sorted(paths)
for idx, filename in enumerate(paths):
if idx % self._worker_info.num_workers != self._worker_info.fid:
continue
with self.hadoop_util.open(filename) as f:
for line in f:
yield line
"""
def __init__(self, hadoop_bin, fs_name, fs_ugi):
self.hadoop_util = HadoopUtil(
hadoop_bin=hadoop_bin, fs_name=fs_name, fs_ugi=fs_ugi)
def __iter__(self):
raise NotImplementedError
#-*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""sampler
"""
import time
import numpy as np
class Sampler(object):
def __init__(self, dataset, batch_size=1, drop_last=False, shuffle=False):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle = shuffle
length = len(self.dataset)
self.perm = np.arange(0, length)
# shuffle one time whne Sampler is created
if self.shuffle:
seed = int(float(time.time()) * 1000) % 10000007
np.random.seed(seed)
np.random.shuffle(self.perm)
def __iter__(self):
batch = []
for idx in self.perm:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
length = len(self.dataset)
if self.drop_last:
length = length // self.batch_size
else:
length = (length + self.batch_size - 1) // self.batch_size
return length
class StreamSampler(object):
def __init__(self, dataset, batch_size=1, drop_last=None):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = [i for i in range(self.batch_size)]
while True:
yield batch
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册