Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
f2dde176
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f2dde176
编写于
12月 16, 2021
作者:
S
sibo2rr
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
multi scale sampler and dataset
上级
9e9a77f3
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
362 addition
and
0 deletion
+362
-0
ppcls/configs/multi_scale/MobileNetV1_multi_scale.yaml
ppcls/configs/multi_scale/MobileNetV1_multi_scale.yaml
+134
-0
ppcls/data/__init__.py
ppcls/data/__init__.py
+2
-0
ppcls/data/dataloader/__init__.py
ppcls/data/dataloader/__init__.py
+2
-0
ppcls/data/dataloader/multi_scale_dataset.py
ppcls/data/dataloader/multi_scale_dataset.py
+119
-0
ppcls/data/dataloader/multi_scale_sampler.py
ppcls/data/dataloader/multi_scale_sampler.py
+105
-0
未找到文件。
ppcls/configs/multi_scale/MobileNetV1_multi_scale.yaml
0 → 100644
浏览文件 @
f2dde176
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
120
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
use_dali
:
True
# model architecture
Arch
:
name
:
MobileNetV1
class_num
:
100
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Piecewise
learning_rate
:
0.1
decay_epochs
:
[
30
,
60
,
90
]
values
:
[
0.1
,
0.01
,
0.001
,
0.0001
]
regularizer
:
name
:
'
L2'
coeff
:
0.00003
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
MultiScaleDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
MultiScaleSamplerDDP
scales
:
[
224
,
256
]
first_bs
:
4
is_training
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/whl/demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/data/__init__.py
浏览文件 @
f2dde176
...
...
@@ -28,11 +28,13 @@ from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
from
ppcls.data.dataloader.logo_dataset
import
LogoDataset
from
ppcls.data.dataloader.icartoon_dataset
import
ICartoonDataset
from
ppcls.data.dataloader.mix_dataset
import
MixDataset
from
ppcls.data.dataloader.multi_scale_dataset
import
MultiScaleDataset
# sampler
from
ppcls.data.dataloader.DistributedRandomIdentitySampler
import
DistributedRandomIdentitySampler
from
ppcls.data.dataloader.pk_sampler
import
PKSampler
from
ppcls.data.dataloader.mix_sampler
import
MixSampler
from
ppcls.data.dataloader.multi_scale_sampler
import
MultiScaleSamplerDDP
from
ppcls.data
import
preprocess
from
ppcls.data.preprocess
import
transform
...
...
ppcls/data/dataloader/__init__.py
浏览文件 @
f2dde176
...
...
@@ -5,5 +5,7 @@ from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
from
ppcls.data.dataloader.logo_dataset
import
LogoDataset
from
ppcls.data.dataloader.icartoon_dataset
import
ICartoonDataset
from
ppcls.data.dataloader.mix_dataset
import
MixDataset
from
ppcls.data.dataloader.multi_scale_dataset
import
MultiScaleDataset
from
ppcls.data.dataloader.mix_sampler
import
MixSampler
from
ppcls.data.dataloader.multi_scale_sampler
import
MultiScaleSamplerDDP
from
ppcls.data.dataloader.pk_sampler
import
PKSampler
ppcls/data/dataloader/multi_scale_dataset.py
0 → 100644
浏览文件 @
f2dde176
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
numpy
as
np
import
os
from
paddle.io
import
Dataset
from
paddle.vision
import
transforms
import
cv2
import
warnings
from
ppcls.data
import
preprocess
from
ppcls.data.preprocess
import
transform
from
ppcls.data.preprocess.ops.operators
import
DecodeImage
from
ppcls.utils
import
logger
def
create_operators
(
params
):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert
isinstance
(
params
,
list
),
(
'operator config should be a list'
)
ops
=
[]
for
operator
in
params
:
assert
isinstance
(
operator
,
dict
)
and
len
(
operator
)
==
1
,
"yaml format error"
op_name
=
list
(
operator
)[
0
]
param
=
{}
if
operator
[
op_name
]
is
None
else
operator
[
op_name
]
op
=
getattr
(
preprocess
,
op_name
)(
**
param
)
ops
.
append
(
op
)
return
ops
class
MultiScaleDataset
(
Dataset
):
def
__init__
(
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
,
):
self
.
_img_root
=
image_root
self
.
_cls_path
=
cls_label_path
self
.
transform_ops
=
transform_ops
# if transform_ops:
# self._transform_ops = create_operators(transform_ops)
self
.
images
=
[]
self
.
labels
=
[]
self
.
_load_anno
()
def
_load_anno
(
self
,
seed
=
None
):
assert
os
.
path
.
exists
(
self
.
_cls_path
)
assert
os
.
path
.
exists
(
self
.
_img_root
)
self
.
images
=
[]
self
.
labels
=
[]
with
open
(
self
.
_cls_path
)
as
fd
:
lines
=
fd
.
readlines
()
if
seed
is
not
None
:
np
.
random
.
RandomState
(
seed
).
shuffle
(
lines
)
for
l
in
lines
:
l
=
l
.
strip
().
split
(
" "
)
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
l
[
0
]))
self
.
labels
.
append
(
np
.
int64
(
l
[
1
]))
assert
os
.
path
.
exists
(
self
.
images
[
-
1
])
def
__getitem__
(
self
,
properties
):
# properites is a tuple, contains (width, height, index)
img_width
=
properties
[
0
]
img_height
=
properties
[
1
]
index
=
properties
[
2
]
has_crop
=
False
if
self
.
transform_ops
:
for
i
in
range
(
len
(
self
.
transform_ops
)):
op
=
self
.
transform_ops
[
i
]
if
'RandCropImage'
in
op
:
warnings
.
warn
(
"Multi scale dataset will crop image according to the multi scale resolution"
)
self
.
transform_ops
[
i
][
'RandCropImage'
]
=
{
'size'
:
img_width
}
has_crop
=
True
if
has_crop
==
False
:
raise
RuntimeError
(
"Multi scale dateset requests RandCropImage"
)
self
.
_transform_ops
=
create_operators
(
self
.
transform_ops
)
try
:
with
open
(
self
.
images
[
index
],
'rb'
)
as
f
:
img
=
f
.
read
()
if
self
.
_transform_ops
:
img
=
transform
(
img
,
self
.
_transform_ops
)
img
=
img
.
transpose
((
2
,
0
,
1
))
return
(
img
,
self
.
labels
[
index
])
except
Exception
as
ex
:
logger
.
error
(
"Exception occured when parse line: {} with msg: {}"
.
format
(
self
.
images
[
index
],
ex
))
rnd_idx
=
np
.
random
.
randint
(
self
.
__len__
())
return
self
.
__getitem__
(
rnd_idx
)
def
__len__
(
self
):
return
len
(
self
.
images
)
@
property
def
class_num
(
self
):
return
len
(
set
(
self
.
labels
))
ppcls/data/dataloader/multi_scale_sampler.py
0 → 100644
浏览文件 @
f2dde176
from
paddle.io
import
Sampler
import
paddle.distributed
as
dist
import
math
import
random
import
numpy
as
np
from
ppcls
import
data
class
MultiScaleSamplerDDP
(
Sampler
):
def
__init__
(
self
,
data_source
,
scales
,
first_bs
,
g
):
print
(
scales
)
# min. and max. spatial dimensions
self
.
data_source
=
data_source
self
.
n_data_samples
=
len
(
self
.
data_source
)
if
isinstance
(
scales
[
0
],
tuple
):
width_dims
=
[
i
[
0
]
for
i
in
scales
]
height_dims
=
[
i
[
1
]
for
i
in
scales
]
elif
isinstance
(
scales
[
0
],
int
):
width_dims
=
scales
height_dims
=
scales
base_im_w
=
width_dims
[
0
]
base_im_h
=
height_dims
[
0
]
base_batch_size
=
first_bs
# Get the GPU and node related information
num_replicas
=
dist
.
get_world_size
()
rank
=
dist
.
get_rank
()
# adjust the total samples to avoid batch dropping
num_samples_per_replica
=
int
(
math
.
ceil
(
self
.
n_data_samples
*
1.0
/
num_replicas
))
img_indices
=
[
idx
for
idx
in
range
(
self
.
n_data_samples
)]
self
.
shuffle
=
False
if
is_training
:
# compute the spatial dimensions and corresponding batch size
# ImageNet models down-sample images by a factor of 32.
# Ensure that width and height dimensions are multiples are multiple of 32.
width_dims
=
[
int
((
w
//
32
)
*
32
)
for
w
in
width_dims
]
height_dims
=
[
int
((
h
//
32
)
*
32
)
for
h
in
height_dims
]
img_batch_pairs
=
list
()
base_elements
=
base_im_w
*
base_im_h
*
base_batch_size
for
(
h
,
w
)
in
zip
(
height_dims
,
width_dims
):
batch_size
=
int
(
max
(
1
,
(
base_elements
/
(
h
*
w
))))
img_batch_pairs
.
append
((
h
,
w
,
batch_size
))
self
.
img_batch_pairs
=
img_batch_pairs
self
.
shuffle
=
True
else
:
self
.
img_batch_pairs
=
[(
base_im_h
,
base_im_w
,
base_batch_size
)]
self
.
img_indices
=
img_indices
self
.
n_samples_per_replica
=
num_samples_per_replica
self
.
epoch
=
0
self
.
rank
=
rank
self
.
num_replicas
=
num_replicas
self
.
batch_list
=
[]
self
.
current
=
0
indices_rank_i
=
self
.
img_indices
[
self
.
rank
:
len
(
self
.
img_indices
)
:
self
.
num_replicas
]
while
self
.
current
<
self
.
n_samples_per_replica
:
curr_h
,
curr_w
,
curr_bsz
=
random
.
choice
(
self
.
img_batch_pairs
)
end_index
=
min
(
self
.
current
+
curr_bsz
,
self
.
n_samples_per_replica
)
batch_ids
=
indices_rank_i
[
self
.
current
:
end_index
]
n_batch_samples
=
len
(
batch_ids
)
if
n_batch_samples
!=
curr_bsz
:
batch_ids
+=
indices_rank_i
[:(
curr_bsz
-
n_batch_samples
)]
self
.
current
+=
curr_bsz
if
len
(
batch_ids
)
>
0
:
batch
=
[
curr_h
,
curr_w
,
len
(
batch_ids
)]
self
.
batch_list
.
append
(
batch
)
self
.
length
=
len
(
self
.
batch_list
)
def
__iter__
(
self
):
if
self
.
shuffle
:
random
.
seed
(
self
.
epoch
)
random
.
shuffle
(
self
.
img_indices
)
random
.
shuffle
(
self
.
img_batch_pairs
)
indices_rank_i
=
self
.
img_indices
[
self
.
rank
:
len
(
self
.
img_indices
)
:
self
.
num_replicas
]
else
:
indices_rank_i
=
self
.
img_indices
[
self
.
rank
:
len
(
self
.
img_indices
)
:
self
.
num_replicas
]
start_index
=
0
for
batch_tuple
in
self
.
batch_list
:
curr_h
,
curr_w
,
curr_bsz
=
batch_tuple
end_index
=
min
(
start_index
+
curr_bsz
,
self
.
n_samples_per_replica
)
batch_ids
=
indices_rank_i
[
start_index
:
end_index
]
n_batch_samples
=
len
(
batch_ids
)
if
n_batch_samples
!=
curr_bsz
:
batch_ids
+=
indices_rank_i
[:(
curr_bsz
-
n_batch_samples
)]
start_index
+=
curr_bsz
if
len
(
batch_ids
)
>
0
:
batch
=
[(
curr_h
,
curr_w
,
b_id
)
for
b_id
in
batch_ids
]
yield
batch
def
set_epoch
(
self
,
epoch
:
int
):
self
.
epoch
=
epoch
def
__len__
(
self
):
return
self
.
length
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录