Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b85792ac
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b85792ac
编写于
1月 10, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/data): add concat dataset
GitOrigin-RevId: a82b720998c797c45de8a396a0d80a5db68925ef
上级
55cbab7a
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
91 addition
and
2 deletion
+91
-2
imperative/python/megengine/data/dataset/__init__.py
imperative/python/megengine/data/dataset/__init__.py
+1
-1
imperative/python/megengine/data/dataset/meta_dataset.py
imperative/python/megengine/data/dataset/meta_dataset.py
+71
-0
imperative/python/test/unit/data/test_dataset.py
imperative/python/test/unit/data/test_dataset.py
+19
-1
未找到文件。
imperative/python/megengine/data/dataset/__init__.py
浏览文件 @
b85792ac
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from
.meta_dataset
import
ArrayDataset
,
Dataset
,
StreamDataset
from
.meta_dataset
import
ArrayDataset
,
ConcatDataset
,
Dataset
,
StreamDataset
from
.vision
import
*
from
.vision
import
*
imperative/python/megengine/data/dataset/meta_dataset.py
浏览文件 @
b85792ac
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
bisect
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Tuple
from
typing
import
Tuple
...
@@ -143,3 +144,73 @@ class ArrayDataset(Dataset):
...
@@ -143,3 +144,73 @@ class ArrayDataset(Dataset):
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
return
len
(
self
.
arrays
[
0
])
return
len
(
self
.
arrays
[
0
])
class
ConcatDataset
(
Dataset
):
r
"""ConcatDataset is a concatenation of multiple datasets.
This dataset is used for assembleing multiple map-style
datasets.
Examples:
.. code-block:: python
from megengine.data.dataset import ArrayDataset, ConcatDataset
data1 = np.random.randint(0, 255, size=(2, 1, 32, 32), dtype=np.uint8)
data2 = np.random.randint(0, 255, size=(2, 1, 32, 32), dtype=np.uint8)
label = np.random.randint(0, 10, size=(2,), dtype=int)
labe2 = np.random.randint(0, 10, size=(2,), dtype=int)
dataset1 = ArrayDataset(data1, label1)
dataset2 = ArrayDataset(data2, label2)
dataset = ConcatDataset([dataset1, dataset2])
seque_sampler = SequentialSampler(dataset, batch_size=2)
dataloader = DataLoader(
dataset,
sampler = seque_sampler,
num_workers=3,
)
for step, data in enumerate(dataloader):
print(data)
"""
def
__init__
(
self
,
datasets
):
super
(
ConcatDataset
,
self
).
__init__
()
self
.
datasets
=
datasets
def
cumsum
(
datasets
):
r
,
s
=
[],
0
for
e
in
datasets
:
l
=
len
(
e
)
r
.
append
(
l
+
s
)
s
+=
l
return
r
assert
len
(
self
.
datasets
)
>
0
,
"datasets should not be an empty iterable"
for
d
in
self
.
datasets
:
assert
not
isinstance
(
d
,
StreamDataset
),
"ConcatDataset does not support StreamDataset"
self
.
datasets
=
list
(
datasets
)
self
.
cumulative_sizes
=
cumsum
(
self
.
datasets
)
def
__getitem__
(
self
,
idx
):
if
idx
<
0
:
if
-
idx
>
len
(
self
):
raise
ValueError
(
"absolute value of index should not exceed dataset length"
)
idx
=
len
(
self
)
+
idx
dataset_idx
=
bisect
.
bisect_right
(
self
.
cumulative_sizes
,
idx
)
if
dataset_idx
==
0
:
sample_idx
=
idx
else
:
sample_idx
=
idx
-
self
.
cumulative_sizes
[
dataset_idx
-
1
]
return
self
.
datasets
[
dataset_idx
][
sample_idx
]
def
__len__
(
self
):
return
self
.
cumulative_sizes
[
-
1
]
imperative/python/test/unit/data/test_dataset.py
浏览文件 @
b85792ac
...
@@ -5,7 +5,7 @@ import sys
...
@@ -5,7 +5,7 @@ import sys
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
megengine.data.dataset
import
ArrayDataset
,
Dataset
,
StreamDataset
from
megengine.data.dataset
import
ArrayDataset
,
ConcatDataset
,
Dataset
,
StreamDataset
def
test_abstract_cls
():
def
test_abstract_cls
():
...
@@ -32,3 +32,21 @@ def test_array_dataset_dim_error():
...
@@ -32,3 +32,21 @@ def test_array_dataset_dim_error():
label
=
np
.
random
.
randint
(
0
,
9
,
(
1
,))
label
=
np
.
random
.
randint
(
0
,
9
,
(
1
,))
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
ArrayDataset
(
data
,
label
)
ArrayDataset
(
data
,
label
)
def
test_concat_dataset
():
size1
=
(
10
,)
size2
=
(
20
,)
data_shape1
=
(
3
,
256
,
256
)
data_shape2
=
(
2
,
128
,
128
)
label_shape1
=
(
1
,)
label_shape2
=
(
2
,)
data1
=
np
.
random
.
randint
(
0
,
255
,
size1
+
data_shape1
)
data2
=
np
.
random
.
randint
(
0
,
255
,
size2
+
data_shape2
)
label1
=
np
.
random
.
randint
(
0
,
9
,
size1
+
label_shape1
)
label2
=
np
.
random
.
randint
(
0
,
9
,
size2
+
label_shape2
)
dataset1
=
ArrayDataset
(
data1
,
label1
)
dataset2
=
ArrayDataset
(
data2
,
label2
)
dataset
=
ConcatDataset
([
dataset1
,
dataset2
])
assert
dataset
[
15
][
0
].
shape
==
data_shape2
assert
dataset
[
15
][
1
].
shape
==
label_shape2
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录