Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
35d2be15
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
35d2be15
编写于
8月 25, 2021
作者:
B
Bin Lu
提交者:
GitHub
8月 25, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'PaddlePaddle:develop' into develop
上级
860e4aa4
17a06daf
变更
13
展开全部
显示空白变更内容
内联
并排
Showing
13 changed file
with
860 addition
and
86 deletion
+860
-86
ppcls/arch/backbone/model_zoo/swin_transformer.py
ppcls/arch/backbone/model_zoo/swin_transformer.py
+2
-2
ppcls/engine/engine.py
ppcls/engine/engine.py
+391
-0
ppcls/engine/evaluation/__init__.py
ppcls/engine/evaluation/__init__.py
+16
-0
ppcls/engine/evaluation/classification.py
ppcls/engine/evaluation/classification.py
+114
-0
ppcls/engine/evaluation/retrieval.py
ppcls/engine/evaluation/retrieval.py
+154
-0
ppcls/engine/slim/__init__.py
ppcls/engine/slim/__init__.py
+0
-0
ppcls/engine/train/__init__.py
ppcls/engine/train/__init__.py
+14
-0
ppcls/engine/train/train.py
ppcls/engine/train/train.py
+85
-0
ppcls/engine/train/utils.py
ppcls/engine/train/utils.py
+72
-0
tools/eval.py
tools/eval.py
+3
-3
tools/export_model.py
tools/export_model.py
+3
-74
tools/infer.py
tools/infer.py
+3
-4
tools/train.py
tools/train.py
+3
-3
未找到文件。
ppcls/arch/backbone/model_zoo/swin_transformer.py
浏览文件 @
35d2be15
...
@@ -33,9 +33,9 @@ MODEL_URLS = {
...
@@ -33,9 +33,9 @@ MODEL_URLS = {
"SwinTransformer_base_patch4_window12_384"
:
"SwinTransformer_base_patch4_window12_384"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_pretrained.pdparams"
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_base_patch4_window12_384_pretrained.pdparams"
,
"SwinTransformer_large_patch4_window7_224"
:
"SwinTransformer_large_patch4_window7_224"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window7_224_pretrained.pdparams"
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window7_224_
22kto1k_
pretrained.pdparams"
,
"SwinTransformer_large_patch4_window12_384"
:
"SwinTransformer_large_patch4_window12_384"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window12_384_pretrained.pdparams"
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/SwinTransformer_large_patch4_window12_384_
22kto1k_
pretrained.pdparams"
,
}
}
__all__
=
list
(
MODEL_URLS
.
keys
())
__all__
=
list
(
MODEL_URLS
.
keys
())
...
...
ppcls/engine/
trainer
.py
→
ppcls/engine/
engine
.py
浏览文件 @
35d2be15
此差异已折叠。
点击以展开。
ppcls/engine/evaluation/__init__.py
0 → 100644
浏览文件 @
35d2be15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
ppcls.engine.evaluation.classification
import
classification_eval
from
ppcls.engine.evaluation.retrieval
import
retrieval_eval
ppcls/engine/evaluation/classification.py
0 → 100644
浏览文件 @
35d2be15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
time
import
platform
import
paddle
from
ppcls.utils.misc
import
AverageMeter
from
ppcls.utils
import
logger
def
classification_eval
(
evaler
,
epoch_id
=
0
):
output_info
=
dict
()
time_info
=
{
"batch_cost"
:
AverageMeter
(
"batch_cost"
,
'.5f'
,
postfix
=
" s,"
),
"reader_cost"
:
AverageMeter
(
"reader_cost"
,
".5f"
,
postfix
=
" s,"
),
}
print_batch_step
=
evaler
.
config
[
"Global"
][
"print_batch_step"
]
metric_key
=
None
tic
=
time
.
time
()
eval_dataloader
=
evaler
.
eval_dataloader
if
evaler
.
use_dali
else
evaler
.
eval_dataloader
(
)
max_iter
=
len
(
evaler
.
eval_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
evaler
.
eval_dataloader
)
for
iter_id
,
batch
in
enumerate
(
eval_dataloader
):
if
iter_id
>=
max_iter
:
break
if
iter_id
==
5
:
for
key
in
time_info
:
time_info
[
key
].
reset
()
if
evaler
.
use_dali
:
batch
=
[
paddle
.
to_tensor
(
batch
[
0
][
'data'
]),
paddle
.
to_tensor
(
batch
[
0
][
'label'
])
]
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
batch_size
=
batch
[
0
].
shape
[
0
]
batch
[
0
]
=
paddle
.
to_tensor
(
batch
[
0
]).
astype
(
"float32"
)
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
# image input
out
=
evaler
.
model
(
batch
[
0
])
# calc loss
if
evaler
.
eval_loss_func
is
not
None
:
loss_dict
=
evaler
.
eval_loss_func
(
out
,
batch
[
1
])
for
key
in
loss_dict
:
if
key
not
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
output_info
[
key
].
update
(
loss_dict
[
key
].
numpy
()[
0
],
batch_size
)
# calc metric
if
evaler
.
eval_metric_func
is
not
None
:
metric_dict
=
evaler
.
eval_metric_func
(
out
,
batch
[
1
])
if
paddle
.
distributed
.
get_world_size
()
>
1
:
for
key
in
metric_dict
:
paddle
.
distributed
.
all_reduce
(
metric_dict
[
key
],
op
=
paddle
.
distributed
.
ReduceOp
.
SUM
)
metric_dict
[
key
]
=
metric_dict
[
key
]
/
paddle
.
distributed
.
get_world_size
()
for
key
in
metric_dict
:
if
metric_key
is
None
:
metric_key
=
key
if
key
not
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
output_info
[
key
].
update
(
metric_dict
[
key
].
numpy
()[
0
],
batch_size
)
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
if
iter_id
%
print_batch_step
==
0
:
time_msg
=
"s, "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
time_info
[
key
].
avg
)
for
key
in
time_info
])
ips_msg
=
"ips: {:.5f} images/sec"
.
format
(
batch_size
/
time_info
[
"batch_cost"
].
avg
)
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
val
)
for
key
in
output_info
])
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}"
.
format
(
epoch_id
,
iter_id
,
len
(
evaler
.
eval_dataloader
),
metric_msg
,
time_msg
,
ips_msg
))
tic
=
time
.
time
()
if
evaler
.
use_dali
:
evaler
.
eval_dataloader
.
reset
()
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
])
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
# do not try to save best eval.model
if
evaler
.
eval_metric_func
is
None
:
return
-
1
# return 1st metric in the dict
return
output_info
[
metric_key
].
avg
ppcls/engine/evaluation/retrieval.py
0 → 100644
浏览文件 @
35d2be15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
platform
import
paddle
from
ppcls.utils
import
logger
def
retrieval_eval
(
evaler
,
epoch_id
=
0
):
evaler
.
model
.
eval
()
# step1. build gallery
gallery_feas
,
gallery_img_id
,
gallery_unique_id
=
cal_feature
(
evaler
,
name
=
'gallery'
)
query_feas
,
query_img_id
,
query_query_id
=
cal_feature
(
evaler
,
name
=
'query'
)
# step2. do evaluation
sim_block_size
=
evaler
.
config
[
"Global"
].
get
(
"sim_block_size"
,
64
)
sections
=
[
sim_block_size
]
*
(
len
(
query_feas
)
//
sim_block_size
)
if
len
(
query_feas
)
%
sim_block_size
:
sections
.
append
(
len
(
query_feas
)
%
sim_block_size
)
fea_blocks
=
paddle
.
split
(
query_feas
,
num_or_sections
=
sections
)
if
query_query_id
is
not
None
:
query_id_blocks
=
paddle
.
split
(
query_query_id
,
num_or_sections
=
sections
)
image_id_blocks
=
paddle
.
split
(
query_img_id
,
num_or_sections
=
sections
)
metric_key
=
None
if
evaler
.
eval_loss_func
is
None
:
metric_dict
=
{
metric_key
:
0.
}
else
:
metric_dict
=
dict
()
for
block_idx
,
block_fea
in
enumerate
(
fea_blocks
):
similarity_matrix
=
paddle
.
matmul
(
block_fea
,
gallery_feas
,
transpose_y
=
True
)
if
query_query_id
is
not
None
:
query_id_block
=
query_id_blocks
[
block_idx
]
query_id_mask
=
(
query_id_block
!=
gallery_unique_id
.
t
())
image_id_block
=
image_id_blocks
[
block_idx
]
image_id_mask
=
(
image_id_block
!=
gallery_img_id
.
t
())
keep_mask
=
paddle
.
logical_or
(
query_id_mask
,
image_id_mask
)
similarity_matrix
=
similarity_matrix
*
keep_mask
.
astype
(
"float32"
)
else
:
keep_mask
=
None
metric_tmp
=
evaler
.
eval_metric_func
(
similarity_matrix
,
image_id_blocks
[
block_idx
],
gallery_img_id
,
keep_mask
)
for
key
in
metric_tmp
:
if
key
not
in
metric_dict
:
metric_dict
[
key
]
=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
else
:
metric_dict
[
key
]
+=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
metric_info_list
=
[]
for
key
in
metric_dict
:
if
metric_key
is
None
:
metric_key
=
key
metric_info_list
.
append
(
"{}: {:.5f}"
.
format
(
key
,
metric_dict
[
key
]))
metric_msg
=
", "
.
join
(
metric_info_list
)
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
return
metric_dict
[
metric_key
]
def
cal_feature
(
evaler
,
name
=
'gallery'
):
all_feas
=
None
all_image_id
=
None
all_unique_id
=
None
has_unique_id
=
False
if
name
==
'gallery'
:
dataloader
=
evaler
.
gallery_dataloader
elif
name
==
'query'
:
dataloader
=
evaler
.
query_dataloader
else
:
raise
RuntimeError
(
"Only support gallery or query dataset"
)
max_iter
=
len
(
dataloader
)
-
1
if
platform
.
system
()
==
"Windows"
else
len
(
dataloader
)
dataloader_tmp
=
dataloader
if
evaler
.
use_dali
else
dataloader
()
for
idx
,
batch
in
enumerate
(
dataloader_tmp
):
# load is very time-consuming
if
idx
>=
max_iter
:
break
if
idx
%
evaler
.
config
[
"Global"
][
"print_batch_step"
]
==
0
:
logger
.
info
(
f
"
{
name
}
feature calculation process: [
{
idx
}
/
{
len
(
dataloader
)
}
]"
)
if
evaler
.
use_dali
:
batch
=
[
paddle
.
to_tensor
(
batch
[
0
][
'data'
]),
paddle
.
to_tensor
(
batch
[
0
][
'label'
])
]
batch
=
[
paddle
.
to_tensor
(
x
)
for
x
in
batch
]
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
if
len
(
batch
)
==
3
:
has_unique_id
=
True
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
out
=
evaler
.
model
(
batch
[
0
],
batch
[
1
])
batch_feas
=
out
[
"features"
]
# do norm
if
evaler
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
feas_norm
=
paddle
.
sqrt
(
paddle
.
sum
(
paddle
.
square
(
batch_feas
),
axis
=
1
,
keepdim
=
True
))
batch_feas
=
paddle
.
divide
(
batch_feas
,
feas_norm
)
if
all_feas
is
None
:
all_feas
=
batch_feas
if
has_unique_id
:
all_unique_id
=
batch
[
2
]
all_image_id
=
batch
[
1
]
else
:
all_feas
=
paddle
.
concat
([
all_feas
,
batch_feas
])
all_image_id
=
paddle
.
concat
([
all_image_id
,
batch
[
1
]])
if
has_unique_id
:
all_unique_id
=
paddle
.
concat
([
all_unique_id
,
batch
[
2
]])
if
evaler
.
use_dali
:
dataloader_tmp
.
reset
()
if
paddle
.
distributed
.
get_world_size
()
>
1
:
feat_list
=
[]
img_id_list
=
[]
unique_id_list
=
[]
paddle
.
distributed
.
all_gather
(
feat_list
,
all_feas
)
paddle
.
distributed
.
all_gather
(
img_id_list
,
all_image_id
)
all_feas
=
paddle
.
concat
(
feat_list
,
axis
=
0
)
all_image_id
=
paddle
.
concat
(
img_id_list
,
axis
=
0
)
if
has_unique_id
:
paddle
.
distributed
.
all_gather
(
unique_id_list
,
all_unique_id
)
all_unique_id
=
paddle
.
concat
(
unique_id_list
,
axis
=
0
)
logger
.
info
(
"Build {} done, all feat shape: {}, begin to eval.."
.
format
(
name
,
all_feas
.
shape
))
return
all_feas
,
all_image_id
,
all_unique_id
deploy/auto_log.log
→
ppcls/engine/slim/__init__.py
浏览文件 @
35d2be15
文件已移动
ppcls/engine/train/__init__.py
0 → 100644
浏览文件 @
35d2be15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
ppcls.engine.train.train
import
train_epoch
ppcls/engine/train/train.py
0 → 100644
浏览文件 @
35d2be15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
import
time
import
paddle
from
ppcls.engine.train.utils
import
update_loss
,
update_metric
,
log_info
def
train_epoch
(
trainer
,
epoch_id
,
print_batch_step
):
tic
=
time
.
time
()
train_dataloader
=
trainer
.
train_dataloader
if
trainer
.
use_dali
else
trainer
.
train_dataloader
(
)
for
iter_id
,
batch
in
enumerate
(
train_dataloader
):
if
iter_id
>=
trainer
.
max_iter
:
break
if
iter_id
==
5
:
for
key
in
trainer
.
time_info
:
trainer
.
time_info
[
key
].
reset
()
trainer
.
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
if
trainer
.
use_dali
:
batch
=
[
paddle
.
to_tensor
(
batch
[
0
][
'data'
]),
paddle
.
to_tensor
(
batch
[
0
][
'label'
])
]
batch_size
=
batch
[
0
].
shape
[
0
]
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
trainer
.
global_step
+=
1
# image input
if
trainer
.
amp
:
with
paddle
.
amp
.
auto_cast
(
custom_black_list
=
{
"flatten_contiguous_range"
,
"greater_than"
}):
out
=
forward
(
trainer
,
batch
)
loss_dict
=
trainer
.
train_loss_func
(
out
,
batch
[
1
])
else
:
out
=
forward
(
trainer
,
batch
)
# calc loss
if
trainer
.
config
[
"DataLoader"
][
"Train"
][
"dataset"
].
get
(
"batch_transform_ops"
,
None
):
loss_dict
=
trainer
.
train_loss_func
(
out
,
batch
[
1
:])
else
:
loss_dict
=
trainer
.
train_loss_func
(
out
,
batch
[
1
])
# step opt and lr
if
trainer
.
amp
:
scaled
=
trainer
.
scaler
.
scale
(
loss_dict
[
"loss"
])
scaled
.
backward
()
trainer
.
scaler
.
minimize
(
trainer
.
optimizer
,
scaled
)
else
:
loss_dict
[
"loss"
].
backward
()
trainer
.
optimizer
.
step
()
trainer
.
optimizer
.
clear_grad
()
trainer
.
lr_sch
.
step
()
# below code just for logging
# update metric_for_logger
update_metric
(
trainer
,
out
,
batch
,
batch_size
)
# update_loss_for_logger
update_loss
(
trainer
,
loss_dict
,
batch_size
)
trainer
.
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
if
iter_id
%
print_batch_step
==
0
:
log_info
(
trainer
,
batch_size
,
epoch_id
,
iter_id
)
tic
=
time
.
time
()
def
forward
(
trainer
,
batch
):
if
trainer
.
eval_mode
==
"classification"
:
return
trainer
.
model
(
batch
[
0
])
else
:
return
trainer
.
model
(
batch
[
0
],
batch
[
1
])
ppcls/engine/train/utils.py
0 → 100644
浏览文件 @
35d2be15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
import
datetime
from
ppcls.utils
import
logger
from
ppcls.utils.misc
import
AverageMeter
def
update_metric
(
trainer
,
out
,
batch
,
batch_size
):
# calc metric
if
trainer
.
train_metric_func
is
not
None
:
metric_dict
=
trainer
.
train_metric_func
(
out
,
batch
[
-
1
])
for
key
in
metric_dict
:
if
key
not
in
trainer
.
output_info
:
trainer
.
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
trainer
.
output_info
[
key
].
update
(
metric_dict
[
key
].
numpy
()[
0
],
batch_size
)
def
update_loss
(
trainer
,
loss_dict
,
batch_size
):
# update_output_info
for
key
in
loss_dict
:
if
key
not
in
trainer
.
output_info
:
trainer
.
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
trainer
.
output_info
[
key
].
update
(
loss_dict
[
key
].
numpy
()[
0
],
batch_size
)
def
log_info
(
trainer
,
batch_size
,
epoch_id
,
iter_id
):
lr_msg
=
"lr: {:.5f}"
.
format
(
trainer
.
lr_sch
.
get_lr
())
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
trainer
.
output_info
[
key
].
avg
)
for
key
in
trainer
.
output_info
])
time_msg
=
"s, "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
trainer
.
time_info
[
key
].
avg
)
for
key
in
trainer
.
time_info
])
ips_msg
=
"ips: {:.5f} images/sec"
.
format
(
batch_size
/
trainer
.
time_info
[
"batch_cost"
].
avg
)
eta_sec
=
((
trainer
.
config
[
"Global"
][
"epochs"
]
-
epoch_id
+
1
)
*
len
(
trainer
.
train_dataloader
)
-
iter_id
)
*
trainer
.
time_info
[
"batch_cost"
].
avg
eta_msg
=
"eta: {:s}"
.
format
(
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_sec
))))
logger
.
info
(
"[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}"
.
format
(
epoch_id
,
trainer
.
config
[
"Global"
][
"epochs"
],
iter_id
,
len
(
trainer
.
train_dataloader
),
lr_msg
,
metric_msg
,
time_msg
,
ips_msg
,
eta_msg
))
logger
.
scaler
(
name
=
"lr"
,
value
=
trainer
.
lr_sch
.
get_lr
(),
step
=
trainer
.
global_step
,
writer
=
trainer
.
vdl_writer
)
for
key
in
trainer
.
output_info
:
logger
.
scaler
(
name
=
"train_{}"
.
format
(
key
),
value
=
trainer
.
output_info
[
key
].
avg
,
step
=
trainer
.
global_step
,
writer
=
trainer
.
vdl_writer
)
tools/eval.py
浏览文件 @
35d2be15
...
@@ -21,11 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
...
@@ -21,11 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../'
)))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../'
)))
from
ppcls.utils
import
config
from
ppcls.utils
import
config
from
ppcls.engine.
trainer
import
Trainer
from
ppcls.engine.
engine
import
Engine
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
config
.
parse_args
()
args
=
config
.
parse_args
()
config
=
config
.
get_config
(
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
trainer
=
Trainer
(
config
,
mode
=
"eval"
)
engine
=
Engine
(
config
,
mode
=
"eval"
)
trainer
.
eval
()
engine
.
eval
()
tools/export_model.py
浏览文件 @
35d2be15
...
@@ -24,82 +24,11 @@ import paddle
...
@@ -24,82 +24,11 @@ import paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
from
ppcls.utils
import
config
from
ppcls.utils
import
config
from
ppcls.utils.logger
import
init_logger
from
ppcls.engine.engine
import
Engine
from
ppcls.utils.config
import
print_config
from
ppcls.arch
import
build_model
,
RecModel
,
DistillationModel
from
ppcls.utils.save_load
import
load_dygraph_pretrain
from
ppcls.arch.gears.identity_head
import
IdentityHead
class
ExportModel
(
nn
.
Layer
):
"""
ExportModel: add softmax onto the model
"""
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
base_model
=
build_model
(
config
)
# we should choose a final model to export
if
isinstance
(
self
.
base_model
,
DistillationModel
):
self
.
infer_model_name
=
config
[
"infer_model_name"
]
else
:
self
.
infer_model_name
=
None
self
.
infer_output_key
=
config
.
get
(
"infer_output_key"
,
None
)
if
self
.
infer_output_key
==
"features"
and
isinstance
(
self
.
base_model
,
RecModel
):
self
.
base_model
.
head
=
IdentityHead
()
if
config
.
get
(
"infer_add_softmax"
,
True
):
self
.
softmax
=
nn
.
Softmax
(
axis
=-
1
)
else
:
self
.
softmax
=
None
def
eval
(
self
):
self
.
training
=
False
for
layer
in
self
.
sublayers
():
layer
.
training
=
False
layer
.
eval
()
def
forward
(
self
,
x
):
x
=
self
.
base_model
(
x
)
if
isinstance
(
x
,
list
):
x
=
x
[
0
]
if
self
.
infer_model_name
is
not
None
:
x
=
x
[
self
.
infer_model_name
]
if
self
.
infer_output_key
is
not
None
:
x
=
x
[
self
.
infer_output_key
]
if
self
.
softmax
is
not
None
:
x
=
self
.
softmax
(
x
)
return
x
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
config
.
parse_args
()
args
=
config
.
parse_args
()
config
=
config
.
get_config
(
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
log_file
=
os
.
path
.
join
(
config
[
'Global'
][
'output_dir'
],
engine
=
Engine
(
config
,
mode
=
"export"
)
config
[
"Arch"
][
"name"
],
"export.log"
)
engine
.
export
()
init_logger
(
name
=
'root'
,
log_file
=
log_file
)
print_config
(
config
)
# set device
assert
config
[
"Global"
][
"device"
]
in
[
"cpu"
,
"gpu"
,
"xpu"
]
device
=
paddle
.
set_device
(
config
[
"Global"
][
"device"
])
model
=
ExportModel
(
config
[
"Arch"
])
if
config
[
"Global"
][
"pretrained_model"
]
is
not
None
:
load_dygraph_pretrain
(
model
.
base_model
,
config
[
"Global"
][
"pretrained_model"
])
model
.
eval
()
model
=
paddle
.
jit
.
to_static
(
model
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
]
+
config
[
"Global"
][
"image_shape"
],
dtype
=
'float32'
)
])
paddle
.
jit
.
save
(
model
,
os
.
path
.
join
(
config
[
"Global"
][
"save_inference_dir"
],
"inference"
))
tools/infer.py
浏览文件 @
35d2be15
...
@@ -21,12 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
...
@@ -21,12 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../'
)))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../'
)))
from
ppcls.utils
import
config
from
ppcls.utils
import
config
from
ppcls.engine.
trainer
import
Trainer
from
ppcls.engine.
engine
import
Engine
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
config
.
parse_args
()
args
=
config
.
parse_args
()
config
=
config
.
get_config
(
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
trainer
=
Trainer
(
config
,
mode
=
"infer"
)
engine
=
Engine
(
config
,
mode
=
"infer"
)
engine
.
infer
()
trainer
.
infer
()
tools/train.py
浏览文件 @
35d2be15
...
@@ -21,11 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
...
@@ -21,11 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../'
)))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../'
)))
from
ppcls.utils
import
config
from
ppcls.utils
import
config
from
ppcls.engine.
trainer
import
Trainer
from
ppcls.engine.
engine
import
Engine
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
config
.
parse_args
()
args
=
config
.
parse_args
()
config
=
config
.
get_config
(
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
trainer
=
Trainer
(
config
,
mode
=
"train"
)
engine
=
Engine
(
config
,
mode
=
"train"
)
trainer
.
train
()
engine
.
train
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录