Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
d32846e8
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d32846e8
编写于
5月 31, 2021
作者:
F
Felix
提交者:
GitHub
5月 31, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update __init__.py
上级
806ec9a1
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
122 addition
and
2 deletion
+122
-2
ppcls/data/__init__.py
ppcls/data/__init__.py
+122
-2
未找到文件。
ppcls/data/__init__.py
浏览文件 @
d32846e8
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,5 +11,125 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
paddle
import
numpy
as
np
from
paddle.io
import
DistributedBatchSampler
,
BatchSampler
,
DataLoader
from
.reader
import
Reader
from
ppcls.utils
import
logger
from
.
import
datasets
from
.
import
imaug
from
.
import
samplers
# dataset
from
.datasets.imagenet_dataset
import
ImageNetDataset
from
.dataset.multilabel_dataset
import
MultiLabelDataset
# sampler
from
.samplers
import
DistributedRandomIdentitySampler
from
.preprocess
import
transform
def
create_operators
(
params
):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert
isinstance
(
params
,
list
),
(
'operator config should be a list'
)
ops
=
[]
for
operator
in
params
:
print
(
operator
)
assert
isinstance
(
operator
,
dict
)
and
len
(
operator
)
==
1
,
"yaml format error"
op_name
=
list
(
operator
)[
0
]
param
=
{}
if
operator
[
op_name
]
is
None
else
operator
[
op_name
]
op
=
getattr
(
preprocess
,
op_name
)(
**
param
)
ops
.
append
(
op
)
return
ops
def
build_dataloader
(
config
,
mode
,
device
,
seed
=
None
):
assert
mode
in
[
'Train'
,
'Eval'
,
'Test'
],
"Mode should be Train, Eval or Test."
# build dataset
config_dataset
=
config
[
mode
][
'dataset'
]
dataset_name
=
config_dataset
.
pop
(
'name'
)
if
'batch_transform_ops'
in
config_dataset
:
batch_transform
=
config_dataset
.
pop
(
'batch_transform_ops'
)
else
:
batch_transform
=
None
dataset
=
eval
(
dataset_name
)(
**
config_dataset
)
logger
.
info
(
"build dataset({}) success..."
.
format
(
dataset
))
# build sampler
config_sampler
=
config
[
mode
][
'sampler'
]
if
"name"
not
in
config_sampler
:
batch_sampler
=
None
batch_size
=
config_sampler
[
"batch_size"
]
drop_last
=
config_sampler
[
"drop_last"
]
shuffle
=
config_sampler
[
"shuffle"
]
else
:
sampler_name
=
config_sampler
.
pop
(
"name"
)
batch_sampler
=
eval
(
sampler_name
)(
dataset
,
**
config_sampler
)
logger
.
info
(
"build batch_sampler({}) success..."
.
format
(
batch_sampler
))
# build batch operator
def
mix_collate_fn
(
batch
):
batch
=
transform
(
batch
,
batch_ops
)
# batch each field
slots
=
[]
for
items
in
batch
:
for
i
,
item
in
enumerate
(
items
):
if
len
(
slots
)
<
len
(
items
):
slots
.
append
([
item
])
else
:
slots
[
i
].
append
(
item
)
return
[
np
.
stack
(
slot
,
axis
=
0
)
for
slot
in
slots
]
if
isinstance
(
batch_transform
,
list
):
batch_ops
=
create_operators
(
batch_transform
)
batch_collate_fn
=
mix_collate_fn
else
:
batch_collate_fn
=
None
# build dataloader
config_loader
=
config
[
mode
][
'loader'
]
num_workers
=
config_loader
[
"num_workers"
]
use_shared_memory
=
config_loader
[
"use_shared_memory"
]
if
batch_sampler
is
None
:
data_loader
=
DataLoader
(
dataset
=
dataset
,
places
=
device
,
num_workers
=
num_workers
,
return_list
=
True
,
use_shared_memory
=
use_shared_memory
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
,
collate_fn
=
batch_collate_fn
)
else
:
data_loader
=
DataLoader
(
dataset
=
dataset
,
places
=
device
,
num_workers
=
num_workers
,
return_list
=
True
,
use_shared_memory
=
use_shared_memory
,
batch_sampler
=
batch_sampler
,
collate_fn
=
batch_collate_fn
)
logger
.
info
(
"build data_loader({}) success..."
.
format
(
data_loader
))
'''
# TODO: fix the format
def build_dataloader(config, mode, device, seed=None):
from . import reader
from .reader import Reader
dataloader = Reader(config, mode=mode, places=device)()
return dataloader
return data_loader
'''
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录