Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
3b4f5f4d
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3b4f5f4d
编写于
6月 10, 2021
作者:
L
littletomatodonkey
提交者:
GitHub
6月 10, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add distillation and fix some apis (#810)
* fix save load and imagenet dataset * refine trainer
上级
b9786424
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
585 addition
and
247 deletion
+585
-247
ppcls/arch/__init__.py
ppcls/arch/__init__.py
+47
-1
ppcls/arch/loss_metrics/__init__.py
ppcls/arch/loss_metrics/__init__.py
+0
-91
ppcls/configs/ImageNet/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml
...t/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml
+145
-0
ppcls/data/dataloader/imagenet_dataset.py
ppcls/data/dataloader/imagenet_dataset.py
+0
-2
ppcls/engine/trainer.py
ppcls/engine/trainer.py
+11
-2
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+5
-0
ppcls/loss/celoss.py
ppcls/loss/celoss.py
+28
-102
ppcls/metric/__init__.py
ppcls/metric/__init__.py
+4
-2
ppcls/metric/metrics.py
ppcls/metric/metrics.py
+17
-4
ppcls/utils/download.py
ppcls/utils/download.py
+319
-0
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+9
-43
未找到文件。
ppcls/arch/__init__.py
浏览文件 @
3b4f5f4d
...
...
@@ -21,8 +21,9 @@ from . import backbone, gears
from
.backbone
import
*
from
.gears
import
build_gear
from
.utils
import
*
from
ppcls.utils.save_load
import
load_dygraph_pretrain
__all__
=
[
"build_model"
,
"RecModel"
]
__all__
=
[
"build_model"
,
"RecModel"
,
"DistillationModel"
]
def
build_model
(
config
):
...
...
@@ -62,3 +63,48 @@ class RecModel(nn.Layer):
else
:
y
=
None
return
{
"features"
:
x
,
"logits"
:
y
}
class
DistillationModel
(
nn
.
Layer
):
def
__init__
(
self
,
models
=
None
,
pretrained_list
=
None
,
freeze_params_list
=
None
):
super
().
__init__
()
assert
isinstance
(
models
,
list
)
self
.
model_list
=
[]
self
.
model_name_list
=
[]
if
pretrained_list
is
not
None
:
assert
len
(
pretrained_list
)
==
len
(
models
)
if
freeze_params_list
is
None
:
freeze_params_list
=
[
False
]
*
len
(
models
)
assert
len
(
freeze_params_list
)
==
len
(
models
)
for
idx
,
model_config
in
enumerate
(
models
):
assert
len
(
model_config
)
==
1
key
=
list
(
model_config
.
keys
())[
0
]
model_config
=
model_config
[
key
]
print
(
model_config
)
model_name
=
model_config
.
pop
(
"name"
)
model
=
eval
(
model_name
)(
**
model_config
)
if
freeze_params_list
[
idx
]:
for
param
in
model
.
parameters
():
param
.
trainable
=
False
self
.
model_list
.
append
(
self
.
add_sublayer
(
key
,
model
))
self
.
model_name_list
.
append
(
key
)
if
pretrained_list
is
not
None
:
for
idx
,
pretrained
in
enumerate
(
pretrained_list
):
if
pretrained
is
not
None
:
load_dygraph_pretrain
(
self
.
model_name_list
[
idx
],
path
=
pretrained
)
def
forward
(
self
,
x
,
label
=
None
):
result_dict
=
dict
()
for
idx
,
model_name
in
enumerate
(
self
.
model_name_list
):
if
label
is
None
:
result_dict
[
model_name
]
=
self
.
model_list
[
idx
](
x
)
else
:
result_dict
[
model_name
]
=
self
.
model_list
[
idx
](
x
)
return
result_dict
ppcls/arch/loss_metrics/__init__.py
已删除
100644 → 0
浏览文件 @
b9786424
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
copy
import
sys
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
# TODO: fix the format
class
CELoss
(
nn
.
Layer
):
"""
"""
def
__init__
(
self
,
name
=
"loss"
,
epsilon
=
None
):
super
().
__init__
()
self
.
name
=
name
if
epsilon
is
not
None
and
(
epsilon
<=
0
or
epsilon
>=
1
):
epsilon
=
None
self
.
epsilon
=
epsilon
def
_labelsmoothing
(
self
,
target
,
class_num
):
if
target
.
shape
[
-
1
]
!=
class_num
:
one_hot_target
=
F
.
one_hot
(
target
,
class_num
)
else
:
one_hot_target
=
target
soft_target
=
F
.
label_smooth
(
one_hot_target
,
epsilon
=
self
.
epsilon
)
soft_target
=
paddle
.
reshape
(
soft_target
,
shape
=
[
-
1
,
class_num
])
return
soft_target
def
forward
(
self
,
logits
,
label
,
mode
=
"train"
):
loss_dict
=
{}
if
self
.
epsilon
is
not
None
:
class_num
=
logits
.
shape
[
-
1
]
label
=
self
.
_labelsmoothing
(
label
,
class_num
)
x
=
-
F
.
log_softmax
(
logits
,
axis
=-
1
)
loss
=
paddle
.
sum
(
logits
*
label
,
axis
=-
1
)
else
:
if
label
.
shape
[
-
1
]
==
logits
.
shape
[
-
1
]:
label
=
F
.
softmax
(
label
,
axis
=-
1
)
soft_label
=
True
else
:
soft_label
=
False
loss
=
F
.
cross_entropy
(
logits
,
label
=
label
,
soft_label
=
soft_label
)
loss_dict
[
self
.
name
]
=
paddle
.
mean
(
loss
)
return
loss_dict
# TODO: fix the format
class
Topk
(
nn
.
Layer
):
def
__init__
(
self
,
topk
=
[
1
,
5
]):
super
().
__init__
()
assert
isinstance
(
topk
,
(
int
,
list
))
if
isinstance
(
topk
,
int
):
topk
=
[
topk
]
self
.
topk
=
topk
def
forward
(
self
,
x
,
label
):
if
isinstance
(
x
,
dict
):
x
=
x
[
"logits"
]
metric_dict
=
dict
()
for
k
in
self
.
topk
:
metric_dict
[
"top{}"
.
format
(
k
)]
=
paddle
.
metric
.
accuracy
(
x
,
label
,
k
=
k
)
return
metric_dict
# TODO: fix the format
def
build_loss
(
config
):
loss_func
=
CELoss
()
return
loss_func
# TODO: fix the format
def
build_metrics
(
config
):
metrics_func
=
Topk
()
return
metrics_func
ppcls/configs/ImageNet/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml
0 → 100644
浏览文件 @
3b4f5f4d
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
"
./output/"
device
:
"
gpu"
class_num
:
1000
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
120
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
"
./inference"
# model architecture
Arch
:
name
:
"
DistillationModel"
# if not null, its lengths should be same as models
pretrained_list
:
# if not null, its lengths should be same as models
freeze_params_list
:
-
True
-
False
models
:
-
Teacher
:
name
:
MobileNetV3_large_x1_0
pretrained
:
True
use_ssld
:
True
-
Student
:
name
:
MobileNetV3_small_x1_0
pretrained
:
False
# loss function config for traing/eval process
Loss
:
Train
:
-
DistillationCELoss
:
weight
:
1.0
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
Eval
:
-
DistillationGTCELoss
:
weight
:
1.0
model_names
:
[
"
Student"
]
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
1.3
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00001
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
"
./dataset/ILSVRC2012/"
cls_label_path
:
"
./dataset/ILSVRC2012/train_list.txt"
transform_ops
:
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
AutoAugment
:
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
512
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
6
use_shared_memory
:
True
Eval
:
# TOTO: modify to the latest trainer
dataset
:
name
:
ImageNetDataset
image_root
:
"
./dataset/ILSVRC2012/"
cls_label_path
:
"
./dataset/ILSVRC2012/val_list.txt"
transform_ops
:
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
6
use_shared_memory
:
True
Infer
:
infer_imgs
:
"
docs/images/whl/demo.jpg"
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
"
ppcls/utils/imagenet1k_label_list.txt"
Metric
:
Train
:
-
DistillationTopkAcc
:
model_key
:
"
Student"
topk
:
[
1
,
5
]
Eval
:
-
DistillationTopkAcc
:
model_key
:
"
Student"
topk
:
[
1
,
5
]
ppcls/data/dataloader/imagenet_dataset.py
浏览文件 @
3b4f5f4d
...
...
@@ -31,8 +31,6 @@ class ImageNetDataset(CommonDataset):
lines
=
fd
.
readlines
()
if
seed
is
not
None
:
np
.
random
.
RandomState
(
seed
).
shuffle
(
lines
)
else
:
np
.
random
.
shuffle
(
lines
)
for
l
in
lines
:
l
=
l
.
strip
().
split
(
" "
)
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
l
[
0
]))
...
...
ppcls/engine/trainer.py
浏览文件 @
3b4f5f4d
...
...
@@ -235,6 +235,8 @@ class Trainer(object):
self
.
output_dir
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"best_model"
)
logger
.
info
(
"[Eval][Epoch {}][best metric: {}]"
.
format
(
epoch_id
,
acc
))
self
.
model
.
train
()
# save model
...
...
@@ -245,14 +247,21 @@ class Trainer(object):
"epoch"
:
epoch_id
},
self
.
output_dir
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"ppcls_epoch_{}"
.
format
(
epoch_id
))
prefix
=
"epoch_{}"
.
format
(
epoch_id
))
# save the latest model
save_load
.
save_model
(
self
.
model
,
optimizer
,
{
"metric"
:
acc
,
"epoch"
:
epoch_id
},
self
.
output_dir
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"latest"
)
def
build_avg_metrics
(
self
,
info_dict
):
return
{
key
:
AverageMeter
(
key
,
'7.5f'
)
for
key
in
info_dict
}
@
paddle
.
no_grad
()
def
eval
(
self
,
epoch_id
=
0
):
self
.
model
.
eval
()
if
self
.
eval_loss_func
is
None
:
loss_config
=
self
.
config
.
get
(
"Loss"
,
None
)
...
...
ppcls/loss/__init__.py
浏览文件 @
3b4f5f4d
...
...
@@ -13,7 +13,12 @@ from .trihardloss import TriHardLoss
from
.triplet
import
TripletLoss
,
TripletLossV2
from
.supconloss
import
SupConLoss
from
.pairwisecosface
import
PairwiseCosface
from
.dmlloss
import
DMLLoss
from
.distanceloss
import
DistanceLoss
from
.distillationloss
import
DistillationCELoss
from
.distillationloss
import
DistillationGTCELoss
from
.distillationloss
import
DistillationDMLLoss
class
CombinedLoss
(
nn
.
Layer
):
...
...
ppcls/loss/celoss.py
浏览文件 @
3b4f5f4d
# copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,113 +13,39 @@
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
__all__
=
[
'CELoss'
,
'JSDivLoss'
,
'KLDivLoss'
]
class
CELoss
(
nn
.
Layer
):
def
__init__
(
self
,
epsilon
=
None
):
super
().
__init__
()
if
epsilon
is
not
None
and
(
epsilon
<=
0
or
epsilon
>=
1
):
epsilon
=
None
self
.
epsilon
=
epsilon
class
Loss
(
object
):
"""
Loss
"""
def
__init__
(
self
,
class_dim
=
1000
,
epsilon
=
None
):
assert
class_dim
>
1
,
"class_dim=%d is not larger than 1"
%
(
class_dim
)
self
.
_class_dim
=
class_dim
if
epsilon
is
not
None
and
epsilon
>=
0.0
and
epsilon
<=
1.0
:
self
.
_epsilon
=
epsilon
self
.
_label_smoothing
=
True
#use label smoothing.(Actually, it is softmax label)
else
:
self
.
_epsilon
=
None
self
.
_label_smoothing
=
False
#do label_smoothing
def
_labelsmoothing
(
self
,
target
):
if
target
.
shape
[
-
1
]
!=
self
.
_class_dim
:
one_hot_target
=
F
.
one_hot
(
target
,
self
.
_class_dim
)
#do ont hot(23,34,46)-> 3 * _class_dim
def
_labelsmoothing
(
self
,
target
,
class_num
):
if
target
.
shape
[
-
1
]
!=
class_num
:
one_hot_target
=
F
.
one_hot
(
target
,
class_num
)
else
:
one_hot_target
=
target
#do label_smooth
soft_target
=
F
.
label_smooth
(
one_hot_target
,
epsilon
=
self
.
_epsilon
)
#(1 - epsilon) * input + eposilon / K.
soft_target
=
paddle
.
reshape
(
soft_target
,
shape
=
[
-
1
,
self
.
_class_dim
])
soft_target
=
F
.
label_smooth
(
one_hot_target
,
epsilon
=
self
.
epsilon
)
soft_target
=
paddle
.
reshape
(
soft_target
,
shape
=
[
-
1
,
class_num
])
return
soft_target
def
_crossentropy
(
self
,
input
,
target
,
use_pure_fp16
=
False
):
if
self
.
_label_smoothing
:
target
=
self
.
_labelsmoothing
(
target
)
input
=
-
F
.
log_softmax
(
input
,
axis
=-
1
)
#softmax and do log
cost
=
paddle
.
sum
(
target
*
input
,
axis
=-
1
)
#sum
else
:
cost
=
F
.
cross_entropy
(
input
=
input
,
label
=
target
)
if
use_pure_fp16
:
avg_cost
=
paddle
.
sum
(
cost
)
else
:
avg_cost
=
paddle
.
mean
(
cost
)
return
avg_cost
def
_kldiv
(
self
,
input
,
target
,
name
=
None
):
eps
=
1.0e-10
cost
=
target
*
paddle
.
log
(
(
target
+
eps
)
/
(
input
+
eps
))
*
self
.
_class_dim
return
cost
def
_jsdiv
(
self
,
input
,
target
):
#so the input and target is the fc output; no softmax
input
=
F
.
softmax
(
input
)
target
=
F
.
softmax
(
target
)
#two distribution
cost
=
self
.
_kldiv
(
input
,
target
)
+
self
.
_kldiv
(
target
,
input
)
cost
=
cost
/
2
avg_cost
=
paddle
.
mean
(
cost
)
return
avg_cost
def
__call__
(
self
,
input
,
target
):
pass
class
CELoss
(
Loss
):
"""
Cross entropy loss
"""
def
__init__
(
self
,
class_dim
=
1000
,
epsilon
=
None
):
super
(
CELoss
,
self
).
__init__
(
class_dim
,
epsilon
)
def
__call__
(
self
,
input
,
target
,
use_pure_fp16
=
False
):
if
type
(
input
)
is
dict
:
logits
=
input
[
"logits"
]
def
forward
(
self
,
x
,
label
):
if
isinstance
(
x
,
dict
):
x
=
x
[
"logits"
]
if
self
.
epsilon
is
not
None
:
class_num
=
x
.
shape
[
-
1
]
label
=
self
.
_labelsmoothing
(
label
,
class_num
)
x
=
-
F
.
log_softmax
(
x
,
axis
=-
1
)
loss
=
paddle
.
sum
(
x
*
label
,
axis
=-
1
)
else
:
logits
=
input
cost
=
self
.
_crossentropy
(
logits
,
target
,
use_pure_fp16
)
return
{
"CELoss"
:
cost
}
class
JSDivLoss
(
Loss
):
"""
JSDiv loss
"""
def
__init__
(
self
,
class_dim
=
1000
,
epsilon
=
None
):
super
(
JSDivLoss
,
self
).
__init__
(
class_dim
,
epsilon
)
def
__call__
(
self
,
input
,
target
):
cost
=
self
.
_jsdiv
(
input
,
target
)
return
cost
class
KLDivLoss
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
):
super
(
KLDivLoss
,
self
).
__init__
()
def
__call__
(
self
,
p
,
q
,
is_logit
=
True
):
if
is_logit
:
p
=
paddle
.
nn
.
functional
.
softmax
(
p
)
q
=
paddle
.
nn
.
functional
.
softmax
(
q
)
return
-
(
p
*
paddle
.
log
(
q
+
1e-8
)).
sum
(
1
).
mean
()
if
label
.
shape
[
-
1
]
==
x
.
shape
[
-
1
]:
label
=
F
.
softmax
(
label
,
axis
=-
1
)
soft_label
=
True
else
:
soft_label
=
False
loss
=
F
.
cross_entropy
(
x
,
label
=
label
,
soft_label
=
soft_label
)
return
{
"CELoss"
:
loss
}
ppcls/metric/__init__.py
浏览文件 @
3b4f5f4d
...
...
@@ -17,6 +17,8 @@ import copy
from
collections
import
OrderedDict
from
.metrics
import
TopkAcc
,
mAP
,
mINP
,
Recallk
,
RetriMetric
from
.metrics
import
DistillationTopkAcc
class
CombinedMetrics
(
nn
.
Layer
):
def
__init__
(
self
,
config_list
):
...
...
@@ -24,7 +26,7 @@ class CombinedMetrics(nn.Layer):
self
.
metric_func_list
=
[]
assert
isinstance
(
config_list
,
list
),
(
'operator config should be a list'
)
self
.
retri_config
=
dict
()
# retrieval metrics config
for
config
in
config_list
:
assert
isinstance
(
config
,
...
...
@@ -35,7 +37,7 @@ class CombinedMetrics(nn.Layer):
continue
metric_params
=
config
[
metric_name
]
self
.
metric_func_list
.
append
(
eval
(
metric_name
)(
**
metric_params
))
if
self
.
retri_config
:
self
.
metric_func_list
.
append
(
RetriMetric
(
self
.
retri_config
))
...
...
ppcls/metric/metrics.py
浏览文件 @
3b4f5f4d
...
...
@@ -18,7 +18,6 @@ import paddle.nn as nn
from
functools
import
lru_cache
# TODO: fix the format
class
TopkAcc
(
nn
.
Layer
):
def
__init__
(
self
,
topk
=
(
1
,
5
)):
super
().
__init__
()
...
...
@@ -84,6 +83,7 @@ class Recallk(nn.Layer):
metric_dict
[
"recall{}"
.
format
(
k
)]
=
all_cmc
[
k
-
1
]
return
metric_dict
# retrieval metrics
class
RetriMetric
(
nn
.
Layer
):
def
__init__
(
self
,
config
):
...
...
@@ -93,8 +93,8 @@ class RetriMetric(nn.Layer):
def
forward
(
self
,
similarities_matrix
,
query_img_id
,
gallery_img_id
):
metric_dict
=
dict
()
all_cmc
,
all_AP
,
all_INP
=
get_metrics
(
similarities_matrix
,
query_img_id
,
gallery_img_id
,
self
.
max_rank
)
all_cmc
,
all_AP
,
all_INP
=
get_metrics
(
similarities_matrix
,
query_img_id
,
gallery_img_id
,
self
.
max_rank
)
if
"Recallk"
in
self
.
config
.
keys
():
topk
=
self
.
config
[
'Recallk'
][
'topk'
]
assert
isinstance
(
topk
,
(
int
,
list
,
tuple
))
...
...
@@ -109,7 +109,7 @@ class RetriMetric(nn.Layer):
mINP
=
np
.
mean
(
all_INP
)
metric_dict
[
"mINP"
]
=
mINP
return
metric_dict
@
lru_cache
()
def
get_metrics
(
similarities_matrix
,
query_img_id
,
gallery_img_id
,
...
...
@@ -155,3 +155,16 @@ def get_metrics(similarities_matrix, query_img_id, gallery_img_id,
all_cmc
=
all_cmc
.
sum
(
0
)
/
num_valid_q
return
all_cmc
,
all_AP
,
all_INP
class
DistillationTopkAcc
(
TopkAcc
):
def
__init__
(
self
,
model_key
,
feature_key
=
None
,
topk
=
(
1
,
5
)):
super
().
__init__
(
topk
=
topk
)
self
.
model_key
=
model_key
self
.
feature_key
=
feature_key
def
forward
(
self
,
x
,
label
):
x
=
x
[
self
.
model_key
]
if
self
.
feature_key
is
not
None
:
x
=
x
[
self
.
feature_key
]
return
super
().
forward
(
x
,
label
)
ppcls/utils/download.py
0 → 100644
浏览文件 @
3b4f5f4d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
import
os.path
as
osp
import
shutil
import
requests
import
hashlib
import
tarfile
import
zipfile
import
time
from
collections
import
OrderedDict
from
tqdm
import
tqdm
from
ppcls.utils
import
logger
__all__
=
[
'get_weights_path_from_url'
]
WEIGHTS_HOME
=
osp
.
expanduser
(
"~/.paddleclas/weights"
)
DOWNLOAD_RETRY_LIMIT
=
3
def
is_url
(
path
):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return
path
.
startswith
(
'http://'
)
or
path
.
startswith
(
'https://'
)
def
get_weights_path_from_url
(
url
,
md5sum
=
None
):
"""Get weights path from WEIGHT_HOME, if not exists,
download it from url.
Args:
url (str): download url
md5sum (str): md5 sum of download package
Returns:
str: a local path to save downloaded weights.
Examples:
.. code-block:: python
from paddle.utils.download import get_weights_path_from_url
resnet18_pretrained_weight_url = 'https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams'
local_weight_path = get_weights_path_from_url(resnet18_pretrained_weight_url)
"""
path
=
get_path_from_url
(
url
,
WEIGHTS_HOME
,
md5sum
)
return
path
def
_map_path
(
url
,
root_dir
):
# parse path after download under root_dir
fname
=
osp
.
split
(
url
)[
-
1
]
fpath
=
fname
return
osp
.
join
(
root_dir
,
fpath
)
def
_get_unique_endpoints
(
trainer_endpoints
):
# Sorting is to avoid different environmental variables for each card
trainer_endpoints
.
sort
()
ips
=
set
()
unique_endpoints
=
set
()
for
endpoint
in
trainer_endpoints
:
ip
=
endpoint
.
split
(
":"
)[
0
]
if
ip
in
ips
:
continue
ips
.
add
(
ip
)
unique_endpoints
.
add
(
endpoint
)
logger
.
info
(
"unique_endpoints {}"
.
format
(
unique_endpoints
))
return
unique_endpoints
def
get_path_from_url
(
url
,
root_dir
,
md5sum
=
None
,
check_exist
=
True
,
decompress
=
True
):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
assert
is_url
(
url
),
"downloading from {} not a url"
.
format
(
url
)
# parse path after download to decompress under root_dir
fullpath
=
_map_path
(
url
,
root_dir
)
# Mainly used to solve the problem of downloading data from different
# machines in the case of multiple machines. Different ips will download
# data, and the same ip will only download data once.
unique_endpoints
=
_get_unique_endpoints
(
ParallelEnv
()
.
trainer_endpoints
[:])
if
osp
.
exists
(
fullpath
)
and
check_exist
and
_md5check
(
fullpath
,
md5sum
):
logger
.
info
(
"Found {}"
.
format
(
fullpath
))
else
:
if
ParallelEnv
().
current_endpoint
in
unique_endpoints
:
fullpath
=
_download
(
url
,
root_dir
,
md5sum
)
else
:
while
not
os
.
path
.
exists
(
fullpath
):
time
.
sleep
(
1
)
if
ParallelEnv
().
current_endpoint
in
unique_endpoints
:
if
decompress
and
(
tarfile
.
is_tarfile
(
fullpath
)
or
zipfile
.
is_zipfile
(
fullpath
)):
fullpath
=
_decompress
(
fullpath
)
return
fullpath
def
_download
(
url
,
path
,
md5sum
=
None
):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if
not
osp
.
exists
(
path
):
os
.
makedirs
(
path
)
fname
=
osp
.
split
(
url
)[
-
1
]
fullname
=
osp
.
join
(
path
,
fname
)
retry_cnt
=
0
while
not
(
osp
.
exists
(
fullname
)
and
_md5check
(
fullname
,
md5sum
)):
if
retry_cnt
<
DOWNLOAD_RETRY_LIMIT
:
retry_cnt
+=
1
else
:
raise
RuntimeError
(
"Download from {} failed. "
"Retry limit reached"
.
format
(
url
))
logger
.
info
(
"Downloading {} from {}"
.
format
(
fname
,
url
))
try
:
req
=
requests
.
get
(
url
,
stream
=
True
)
except
Exception
as
e
:
# requests.exceptions.ConnectionError
logger
.
info
(
"Downloading {} from {} failed {} times with exception {}"
.
format
(
fname
,
url
,
retry_cnt
+
1
,
str
(
e
)))
time
.
sleep
(
1
)
continue
if
req
.
status_code
!=
200
:
raise
RuntimeError
(
"Downloading from {} failed with code "
"{}!"
.
format
(
url
,
req
.
status_code
))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname
=
fullname
+
"_tmp"
total_size
=
req
.
headers
.
get
(
'content-length'
)
with
open
(
tmp_fullname
,
'wb'
)
as
f
:
if
total_size
:
with
tqdm
(
total
=
(
int
(
total_size
)
+
1023
)
//
1024
)
as
pbar
:
for
chunk
in
req
.
iter_content
(
chunk_size
=
1024
):
f
.
write
(
chunk
)
pbar
.
update
(
1
)
else
:
for
chunk
in
req
.
iter_content
(
chunk_size
=
1024
):
if
chunk
:
f
.
write
(
chunk
)
shutil
.
move
(
tmp_fullname
,
fullname
)
return
fullname
def
_md5check
(
fullname
,
md5sum
=
None
):
if
md5sum
is
None
:
return
True
logger
.
info
(
"File {} md5 checking..."
.
format
(
fullname
))
md5
=
hashlib
.
md5
()
with
open
(
fullname
,
'rb'
)
as
f
:
for
chunk
in
iter
(
lambda
:
f
.
read
(
4096
),
b
""
):
md5
.
update
(
chunk
)
calc_md5sum
=
md5
.
hexdigest
()
if
calc_md5sum
!=
md5sum
:
logger
.
info
(
"File {} md5 check failed, {}(calc) != "
"{}(base)"
.
format
(
fullname
,
calc_md5sum
,
md5sum
))
return
False
return
True
def
_decompress
(
fname
):
"""
Decompress for zip and tar file
"""
logger
.
info
(
"Decompressing {}..."
.
format
(
fname
))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if
tarfile
.
is_tarfile
(
fname
):
uncompressed_path
=
_uncompress_file_tar
(
fname
)
elif
zipfile
.
is_zipfile
(
fname
):
uncompressed_path
=
_uncompress_file_zip
(
fname
)
else
:
raise
TypeError
(
"Unsupport compress file type {}"
.
format
(
fname
))
return
uncompressed_path
def
_uncompress_file_zip
(
filepath
):
files
=
zipfile
.
ZipFile
(
filepath
,
'r'
)
file_list
=
files
.
namelist
()
file_dir
=
os
.
path
.
dirname
(
filepath
)
if
_is_a_single_file
(
file_list
):
rootpath
=
file_list
[
0
]
uncompressed_path
=
os
.
path
.
join
(
file_dir
,
rootpath
)
for
item
in
file_list
:
files
.
extract
(
item
,
file_dir
)
elif
_is_a_single_dir
(
file_list
):
rootpath
=
os
.
path
.
splitext
(
file_list
[
0
])[
0
].
split
(
os
.
sep
)[
-
1
]
uncompressed_path
=
os
.
path
.
join
(
file_dir
,
rootpath
)
for
item
in
file_list
:
files
.
extract
(
item
,
file_dir
)
else
:
rootpath
=
os
.
path
.
splitext
(
filepath
)[
0
].
split
(
os
.
sep
)[
-
1
]
uncompressed_path
=
os
.
path
.
join
(
file_dir
,
rootpath
)
if
not
os
.
path
.
exists
(
uncompressed_path
):
os
.
makedirs
(
uncompressed_path
)
for
item
in
file_list
:
files
.
extract
(
item
,
os
.
path
.
join
(
file_dir
,
rootpath
))
files
.
close
()
return
uncompressed_path
def
_uncompress_file_tar
(
filepath
,
mode
=
"r:*"
):
files
=
tarfile
.
open
(
filepath
,
mode
)
file_list
=
files
.
getnames
()
file_dir
=
os
.
path
.
dirname
(
filepath
)
if
_is_a_single_file
(
file_list
):
rootpath
=
file_list
[
0
]
uncompressed_path
=
os
.
path
.
join
(
file_dir
,
rootpath
)
for
item
in
file_list
:
files
.
extract
(
item
,
file_dir
)
elif
_is_a_single_dir
(
file_list
):
rootpath
=
os
.
path
.
splitext
(
file_list
[
0
])[
0
].
split
(
os
.
sep
)[
-
1
]
uncompressed_path
=
os
.
path
.
join
(
file_dir
,
rootpath
)
for
item
in
file_list
:
files
.
extract
(
item
,
file_dir
)
else
:
rootpath
=
os
.
path
.
splitext
(
filepath
)[
0
].
split
(
os
.
sep
)[
-
1
]
uncompressed_path
=
os
.
path
.
join
(
file_dir
,
rootpath
)
if
not
os
.
path
.
exists
(
uncompressed_path
):
os
.
makedirs
(
uncompressed_path
)
for
item
in
file_list
:
files
.
extract
(
item
,
os
.
path
.
join
(
file_dir
,
rootpath
))
files
.
close
()
return
uncompressed_path
def
_is_a_single_file
(
file_list
):
if
len
(
file_list
)
==
1
and
file_list
[
0
].
find
(
os
.
sep
)
<
-
1
:
return
True
return
False
def
_is_a_single_dir
(
file_list
):
new_file_list
=
[]
for
file_path
in
file_list
:
if
'/'
in
file_path
:
file_path
=
file_path
.
replace
(
'/'
,
os
.
sep
)
elif
'
\\
'
in
file_path
:
file_path
=
file_path
.
replace
(
'
\\
'
,
os
.
sep
)
new_file_list
.
append
(
file_path
)
file_name
=
new_file_list
[
0
].
split
(
os
.
sep
)[
0
]
for
i
in
range
(
1
,
len
(
new_file_list
)):
if
file_name
!=
new_file_list
[
i
].
split
(
os
.
sep
)[
0
]:
return
False
return
True
ppcls/utils/save_load.py
浏览文件 @
3b4f5f4d
...
...
@@ -23,10 +23,8 @@ import shutil
import
tempfile
import
paddle
from
paddle.static
import
load_program_state
from
paddle.utils.download
import
get_weights_path_from_url
from
ppcls.utils
import
logger
from
.download
import
get_weights_path_from_url
__all__
=
[
'init_model'
,
'save_model'
,
'load_dygraph_pretrain'
]
...
...
@@ -47,70 +45,42 @@ def _mkdir_if_not_exist(path):
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
def
load_dygraph_pretrain
(
model
,
path
=
None
,
load_static_weights
=
False
):
def
load_dygraph_pretrain
(
model
,
path
=
None
):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
path
))
if
load_static_weights
:
pre_state_dict
=
load_program_state
(
path
)
param_state_dict
=
{}
model_dict
=
model
.
state_dict
()
for
key
in
model_dict
.
keys
():
weight_name
=
model_dict
[
key
].
name
if
weight_name
in
pre_state_dict
.
keys
():
logger
.
info
(
'Load weight: {}, shape: {}'
.
format
(
weight_name
,
pre_state_dict
[
weight_name
].
shape
))
param_state_dict
[
key
]
=
pre_state_dict
[
weight_name
]
else
:
param_state_dict
[
key
]
=
model_dict
[
key
]
model
.
set_dict
(
param_state_dict
)
return
param_state_dict
=
paddle
.
load
(
path
+
".pdparams"
)
model
.
set_dict
(
param_state_dict
)
return
def
load_dygraph_pretrain_from_url
(
model
,
pretrained_url
,
use_ssld
,
load_static_weights
=
False
):
def
load_dygraph_pretrain_from_url
(
model
,
pretrained_url
,
use_ssld
):
if
use_ssld
:
pretrained_url
=
pretrained_url
.
replace
(
"_pretrained"
,
"_ssld_pretrained"
)
local_weight_path
=
get_weights_path_from_url
(
pretrained_url
).
replace
(
".pdparams"
,
""
)
load_dygraph_pretrain
(
model
,
path
=
local_weight_path
,
load_static_weights
=
load_static_weights
)
load_dygraph_pretrain
(
model
,
path
=
local_weight_path
)
return
def
load_distillation_model
(
model
,
pretrained_model
,
load_static_weights
):
def
load_distillation_model
(
model
,
pretrained_model
):
logger
.
info
(
"In distillation mode, teacher model will be "
"loaded firstly before student model."
)
if
not
isinstance
(
pretrained_model
,
list
):
pretrained_model
=
[
pretrained_model
]
if
not
isinstance
(
load_static_weights
,
list
):
load_static_weights
=
[
load_static_weights
]
*
len
(
pretrained_model
)
teacher
=
model
.
teacher
if
hasattr
(
model
,
"teacher"
)
else
model
.
_layers
.
teacher
student
=
model
.
student
if
hasattr
(
model
,
"student"
)
else
model
.
_layers
.
student
load_dygraph_pretrain
(
teacher
,
path
=
pretrained_model
[
0
],
load_static_weights
=
load_static_weights
[
0
])
load_dygraph_pretrain
(
teacher
,
path
=
pretrained_model
[
0
])
logger
.
info
(
"Finish initing teacher model from {}"
.
format
(
pretrained_model
))
# load student model
if
len
(
pretrained_model
)
>=
2
:
load_dygraph_pretrain
(
student
,
path
=
pretrained_model
[
1
],
load_static_weights
=
load_static_weights
[
1
])
load_dygraph_pretrain
(
student
,
path
=
pretrained_model
[
1
])
logger
.
info
(
"Finish initing student model from {}"
.
format
(
pretrained_model
))
...
...
@@ -134,16 +104,12 @@ def init_model(config, net, optimizer=None):
return
metric_dict
pretrained_model
=
config
.
get
(
'pretrained_model'
)
load_static_weights
=
config
.
get
(
'load_static_weights'
,
False
)
use_distillation
=
config
.
get
(
'use_distillation'
,
False
)
if
pretrained_model
:
if
use_distillation
:
load_distillation_model
(
net
,
pretrained_model
,
load_static_weights
)
load_distillation_model
(
net
,
pretrained_model
)
else
:
# common load
load_dygraph_pretrain
(
net
,
path
=
pretrained_model
,
load_static_weights
=
load_static_weights
)
load_dygraph_pretrain
(
net
,
path
=
pretrained_model
)
logger
.
info
(
logger
.
coloring
(
"Finish load pretrained model from {}"
.
format
(
pretrained_model
),
"HEADER"
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录