Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
350f2de3
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
350f2de3
编写于
9月 17, 2020
作者:
W
Webbley
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
1, add hadoop dataset;
2, add some comments.
上级
8d0e023e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
158 addition
and
3 deletion
+158
-3
pgl/tests/test_dataloader.py
pgl/tests/test_dataloader.py
+7
-0
pgl/utils/data/dataloader.py
pgl/utils/data/dataloader.py
+44
-3
pgl/utils/data/dataset.py
pgl/utils/data/dataset.py
+107
-0
未找到文件。
pgl/tests/test_dataloader.py
浏览文件 @
350f2de3
...
...
@@ -35,15 +35,21 @@ class ListDataset(Dataset):
return
len
(
self
.
dataset
)
def
_transform
(
self
,
example
):
time
.
sleep
(
0.1
)
return
example
class
IterDataset
(
StreamDataset
):
def
__init__
(
self
):
self
.
dataset
=
list
(
range
(
0
,
DATA_SIZE
))
self
.
count
=
0
def
__iter__
(
self
):
for
data
in
self
.
dataset
:
self
.
count
+=
1
if
self
.
count
%
self
.
_worker_info
.
num_workers
!=
self
.
_worker_info
.
fid
:
continue
time
.
sleep
(
0.1
)
yield
data
...
...
@@ -89,6 +95,7 @@ class DataloaderTest(unittest.TestCase):
ds
,
batch_size
=
3
,
drop_last
=
False
,
shuffle
=
True
,
num_workers
=
1
,
collate_fn
=
collate_fn
)
...
...
pgl/utils/data/dataloader.py
浏览文件 @
350f2de3
...
...
@@ -16,6 +16,7 @@
"""
import
numpy
as
np
from
collections
import
namedtuple
import
paddle
import
paddle.fluid
as
F
...
...
@@ -25,9 +26,42 @@ from pgl.utils import mp_reader
from
pgl.utils.data.dataset
import
Dataset
,
StreamDataset
from
pgl.utils.data.sampler
import
Sampler
,
StreamSampler
WorkerInfo
=
namedtuple
(
"WorkerInfo"
,
[
"num_workers"
,
"fid"
])
class
Dataloader
(
object
):
"""Dataloader
"""Dataloader for loading batch data
Example:
.. code-block:: python
from pgl.utils.data import Dataset
from pgl.utils.data.dataloader import Dataloader
class MyDataset(Dataset):
def __init__(self):
self.data = list(range(0, 40))
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
def collate_fn(batch_examples):
feed_dict = {}
feed_dict['data'] = batch_examples
return feed_dict
dataset = MyDataset()
loader = Dataloader(dataset,
batch_size=3,
drop_last=False,
shuffle=True,
num_workers=4,
collate_fn=collate_fn)
for batch_data in loader:
print(batch_data)
"""
def
__init__
(
...
...
@@ -86,6 +120,9 @@ class Dataloader(object):
class
_DataLoaderIter
(
object
):
"""Iterable DataLoader Object
"""
def
__init__
(
self
,
dataloader
,
fid
=
0
):
self
.
dataset
=
dataloader
.
dataset
self
.
sampler
=
dataloader
.
sampler
...
...
@@ -110,6 +147,10 @@ class _DataLoaderIter(object):
yield
batch_data
def
_streamdata_generator
(
self
):
self
.
_worker_info
=
WorkerInfo
(
num_workers
=
self
.
num_workers
,
fid
=
self
.
fid
)
self
.
dataset
.
_set_worker_info
(
self
.
_worker_info
)
dataset
=
iter
(
self
.
dataset
)
for
indices
in
self
.
sampler
:
batch_data
=
[]
...
...
@@ -126,8 +167,8 @@ class _DataLoaderIter(object):
# make sure do not repeat in multiprocessing
self
.
count
+=
1
if
self
.
count
%
self
.
num_workers
!=
self
.
fid
:
continue
#
if self.count % self.num_workers != self.fid:
#
continue
if
self
.
collate_fn
is
not
None
:
yield
self
.
collate_fn
(
batch_data
)
...
...
pgl/utils/data/dataset.py
浏览文件 @
350f2de3
...
...
@@ -15,11 +15,59 @@
"""dataset
"""
import
os
import
sys
import
numpy
as
np
import
json
class
HadoopUtil
(
object
):
"""Implementation of some common hadoop operations.
"""
def
__init__
(
self
,
hadoop_bin
,
fs_name
,
fs_ugi
):
self
.
hadoop_bin
=
hadoop_bin
self
.
fs_name
=
fs_name
self
.
fs_ugi
=
fs_ugi
def
ls
(
self
,
path
):
""" hdfs_ls """
cmd
=
self
.
hadoop_bin
+
" fs -D fs.default.name="
+
self
.
fs_name
cmd
+=
" -D hadoop.job.ugi="
+
self
.
fs_ugi
cmd
+=
" -ls "
+
path
cmd
+=
" | grep part | awk '{print $8}'"
with
os
.
popen
(
cmd
)
as
reader
:
filelist
=
reader
.
read
().
split
()
return
filelist
def
open
(
self
,
filename
):
""" hdfs_file_open """
cmd
=
self
.
hadoop_bin
+
" fs -D fs.default.name="
+
self
.
fs_name
cmd
+=
" -D hadoop.job.ugi="
+
self
.
fs_ugi
cmd
+=
" -cat "
+
filename
p
=
os
.
popen
(
cmd
)
return
p
class
Dataset
(
object
):
"""An abstract class represening Dataset.
Generally, all datasets should subclass it.
All subclasses should overwrite :code:`__getitem__` and :code:`__len__`.
Examples:
.. code-block:: python
from pgl.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self):
self.data = list(range(0, 40))
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
"""
def
__len__
(
self
):
...
...
@@ -33,7 +81,66 @@ class StreamDataset(object):
"""An abstract class represening StreamDataset which has unknown length.
Generally, all unknown length datasets should subclass it.
All subclasses should overwrite :code:`__iter__`.
Examples:
.. code-block:: python
from pgl.utils.data import StreamDataset
class MyStreamDataset(StreamDataset):
def __init__(self):
self.data = list(range(0, 40))
self.count = 0
def __iter__(self):
for data in self.dataset:
self.count += 1
if self.count % self._worker_info.num_workers != self._worker_info.fid:
continue
# do something (like parse data) of your data
time.sleep(0.1)
yield data
"""
def
__iter__
(
self
):
raise
NotImplementedError
def
_set_worker_info
(
self
,
worker_info
):
self
.
_worker_info
=
worker_info
class
HadoopDataset
(
StreamDataset
):
"""An abstract class represening HadoopDataset which loads data from hdfs.
All subclasses should overwrite :code:`__iter__`.
Examples:
.. code-block:: python
from pgl.utils.data import HadoopDataset
class MyHadoopDataset(HadoopDataset):
def __init__(self, data_path, hadoop_bin, fs_name, fs_ugi):
super(MyHadoopDataset, self).__init__(hadoop_bin, fs_name, fs_ugi)
self.data_path = data_path
def __iter__(self):
for line in self._line_data_generator():
yield line
def _line_data_generator(self):
paths = self.hadoop_util.ls(self.data_path)
paths = sorted(paths)
for idx, filename in enumerate(paths):
if idx % self._worker_info.num_workers != self._worker_info.fid:
continue
with self.hadoop_util.open(filename) as f:
for line in f:
yield line
"""
def
__init__
(
self
,
hadoop_bin
,
fs_name
,
fs_ugi
):
self
.
hadoop_util
=
HadoopUtil
(
hadoop_bin
=
hadoop_bin
,
fs_name
=
fs_name
,
fs_ugi
=
fs_ugi
)
def
__iter__
(
self
):
raise
NotImplementedError
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录