Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
b6144fb7
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b6144fb7
编写于
9月 04, 2021
作者:
W
weishengyu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add mix dataloader and mix sampler
上级
ce39aea9
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
129 addition
and
0 deletion
+129
-0
ppcls/data/dataloader/__init__.py
ppcls/data/dataloader/__init__.py
+6
-0
ppcls/data/dataloader/mix_dataset.py
ppcls/data/dataloader/mix_dataset.py
+49
-0
ppcls/data/dataloader/mix_sampler.py
ppcls/data/dataloader/mix_sampler.py
+74
-0
未找到文件。
ppcls/data/dataloader/__init__.py
浏览文件 @
b6144fb7
from
ppcls.data.dataloader.imagenet_dataset
import
ImageNetDataset
from
ppcls.data.dataloader.multilabel_dataset
import
MultiLabelDataset
from
ppcls.data.dataloader.common_dataset
import
create_operators
from
ppcls.data.dataloader.vehicle_dataset
import
CompCars
,
VeriWild
from
ppcls.data.dataloader.logo_dataset
import
LogoDataset
from
ppcls.data.dataloader.icartoon_dataset
import
ICartoonDataset
ppcls/data/dataloader/mix_dataset.py
0 → 100644
浏览文件 @
b6144fb7
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
numpy
as
np
import
os
from
paddle.io
import
Dataset
from
..
import
dataloader
class
MixDataset
(
Dataset
):
def
__init__
(
self
,
datasets_config
):
super
(
MixDataset
,
self
).
__init__
()
self
.
dataset_list
=
[]
start_idx
=
0
end_idx
=
0
for
config_i
in
datasets_config
:
dataset_name
=
config_i
.
pop
(
'name'
)
dataset
=
getattr
(
dataloader
,
dataset_name
)(
**
config_i
)
end_idx
+=
len
(
dataset
)
self
.
dataset_list
.
append
([
end_idx
,
start_idx
,
dataset
])
start_idx
=
end_idx
self
.
length
=
end_idx
def
__getitem__
(
self
,
idx
):
for
dataset_i
in
self
.
dataset_list
:
if
dataset_i
[
0
]
>
idx
:
dataset_i_idx
=
idx
-
dataset_i
[
1
]
return
dataset_i
[
2
][
dataset_i_idx
]
def
__len__
(
self
):
return
self
.
length
def
get_dataset_list
(
self
):
return
self
.
dataset_list
ppcls/data/dataloader/mix_sampler.py
0 → 100644
浏览文件 @
b6144fb7
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
paddle.io
import
DistributedBatchSampler
,
Sampler
from
ppcls.utils
import
logger
from
ppcls.data.dataloader.mix_dataset
import
MixDataset
from
ppcls.data
import
dataloader
class
MixSampler
(
DistributedBatchSampler
):
def
__init__
(
self
,
dataset
,
batch_size
,
sample_configs
,
iter_per_epoch
):
super
(
MixSampler
,
self
).
__init__
(
dataset
,
batch_size
)
assert
isinstance
(
dataset
,
MixDataset
),
"MixSampler only support MixDataset"
self
.
sampler_list
=
[]
self
.
batch_size
=
batch_size
self
.
start_list
=
[]
self
.
length
=
iter_per_epoch
dataset_list
=
dataset
.
get_dataset_list
()
batch_size_left
=
self
.
batch_size
self
.
iter_list
=
[]
for
i
,
config_i
in
enumerate
(
sample_configs
):
sample_method
=
config_i
.
pop
(
"name"
)
ratio_i
=
config_i
.
pop
(
"ratio"
)
if
i
<
len
(
sample_configs
)
-
1
:
batch_size_i
=
self
.
batch_size
*
ratio_i
batch_size_left
-=
batch_size_i
else
:
batch_size_i
=
batch_size_left
assert
batch_size_i
<=
len
(
dataset_list
[
i
][
2
])
config_i
[
"batch_size"
]
=
batch_size_i
if
sample_method
==
"DistributedBatchSampler"
:
sampler_i
=
DistributedBatchSampler
(
dataset_list
[
i
][
2
],
**
config_i
)
else
:
sampler_i
=
getattr
(
dataloader
,
sample_method
)(
dataset_list
[
i
][
2
],
**
config_i
)
self
.
sampler_list
.
append
(
sampler_i
)
self
.
iter_list
.
append
(
iter
(
sampler_i
))
self
.
length
+=
len
(
dataset_list
[
i
][
2
])
*
ratio_i
self
.
iter_counter
=
0
def
__iter__
(
self
):
while
self
.
iter_counter
<
self
.
length
:
batch
=
[]
for
i
,
iter_i
in
enumerate
(
self
.
iter_list
):
batch_i
=
next
(
iter_i
,
None
)
if
batch_i
is
None
:
iter_i
=
iter
(
self
.
sampler_list
[
i
])
self
.
iter_list
[
i
]
=
iter_i
batch_i
=
next
(
iter_i
,
None
)
assert
batch_i
is
not
None
,
"dataset {} return None"
.
format
(
i
)
batch
+=
[
idx
+
self
.
start_list
[
i
]
for
idx
in
batch_i
]
if
len
(
batch
)
==
self
.
batch_size
:
self
.
iter_counter
+=
1
yield
batch
else
:
logger
.
info
(
"Some dataset reaches end"
)
self
.
iter_counter
=
0
def
__len__
(
self
):
return
self
.
length
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录