Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
15f6f581
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
15f6f581
编写于
8月 24, 2021
作者:
D
dongshuilong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor trainer v2
上级
ebde0e13
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
34 addition
and
124 deletion
+34
-124
ppcls/engine/engine.py
ppcls/engine/engine.py
+10
-15
ppcls/engine/evaluation/__init__.py
ppcls/engine/evaluation/__init__.py
+2
-2
ppcls/engine/evaluation/classification.py
ppcls/engine/evaluation/classification.py
+0
-0
ppcls/engine/evaluation/retrieval.py
ppcls/engine/evaluation/retrieval.py
+0
-0
ppcls/engine/train/__init__.py
ppcls/engine/train/__init__.py
+1
-2
ppcls/engine/train/classification.py
ppcls/engine/train/classification.py
+0
-89
ppcls/engine/train/train.py
ppcls/engine/train/train.py
+9
-4
tools/eval.py
tools/eval.py
+3
-3
tools/export_model.py
tools/export_model.py
+3
-3
tools/infer.py
tools/infer.py
+3
-3
tools/train.py
tools/train.py
+3
-3
未找到文件。
ppcls/engine/
cor
e.py
→
ppcls/engine/
engin
e.py
浏览文件 @
15f6f581
...
...
@@ -47,19 +47,18 @@ from ppcls.utils import save_load
from
ppcls.data.utils.get_image_list
import
get_image_list
from
ppcls.data.postprocess
import
build_postprocess
from
ppcls.data
import
create_operators
from
ppcls.engine.train
import
classification_train
,
retrieval_train
from
ppcls.engine
.eval
import
classification_eval
,
retrieval_eval
from
ppcls.engine.train
import
train_epoch
from
ppcls.engine
import
evaluation
from
ppcls.arch.gears.identity_head
import
IdentityHead
class
Cor
e
(
object
):
class
Engin
e
(
object
):
def
__init__
(
self
,
config
,
mode
=
"train"
):
assert
mode
in
[
'train'
,
'eval'
,
'infer'
,
'export'
]
assert
mode
in
[
"train"
,
"eval"
,
"infer"
,
"export"
]
self
.
mode
=
mode
self
.
config
=
config
self
.
eval_mode
=
self
.
config
[
"Global"
].
get
(
"eval_mode"
,
"classification"
)
# init logger
self
.
output_dir
=
self
.
config
[
'Global'
][
'output_dir'
]
log_file
=
os
.
path
.
join
(
self
.
output_dir
,
self
.
config
[
"Arch"
][
"name"
],
...
...
@@ -68,14 +67,10 @@ class Core(object):
print_config
(
config
)
# init train_func and eval_func
if
self
.
eval_mode
==
"classification"
:
self
.
evaler
=
classification_eval
self
.
trainer
=
classification_train
elif
self
.
eval_mode
==
"retrieval"
:
self
.
trainer
=
retrieval_train
self
.
evaler
=
retrieval_eval
else
:
logger
.
warning
(
"Invalid eval mode: {}"
.
format
(
self
.
eval_mode
))
assert
self
.
eval_mode
in
[
"classification"
,
"retrieval"
],
logger
.
error
(
"Invalid eval mode: {}"
.
format
(
self
.
eval_mode
))
self
.
train_epoch_func
=
train_epoch
self
.
eval_func
=
getattr
(
evaluation
,
self
.
eval_mode
+
"_eval"
)
self
.
use_dali
=
self
.
config
[
'Global'
].
get
(
"use_dali"
,
False
)
# for visualdl
...
...
@@ -242,7 +237,7 @@ class Core(object):
self
.
config
[
"Global"
][
"epochs"
]
+
1
):
acc
=
0.0
# for one epoch train
self
.
train
er
(
self
,
epoch_id
,
print_batch_step
)
self
.
train
_epoch_func
(
self
,
epoch_id
,
print_batch_step
)
if
self
.
use_dali
:
self
.
train_dataloader
.
reset
()
...
...
@@ -304,7 +299,7 @@ class Core(object):
def
eval
(
self
,
epoch_id
=
0
):
assert
self
.
mode
in
[
"train"
,
"eval"
]
self
.
model
.
eval
()
eval_result
=
self
.
eval
er
(
self
,
epoch_id
)
eval_result
=
self
.
eval
_func
(
self
,
epoch_id
)
self
.
model
.
train
()
return
eval_result
...
...
ppcls/engine/eval/__init__.py
→
ppcls/engine/eval
uation
/__init__.py
浏览文件 @
15f6f581
...
...
@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
ppcls.engine.eval.classification
import
classification_eval
from
ppcls.engine.eval.retrieval
import
retrieval_eval
from
ppcls.engine.eval
uation
.classification
import
classification_eval
from
ppcls.engine.eval
uation
.retrieval
import
retrieval_eval
ppcls/engine/eval/classification.py
→
ppcls/engine/eval
uation
/classification.py
浏览文件 @
15f6f581
文件已移动
ppcls/engine/eval/retrieval.py
→
ppcls/engine/eval
uation
/retrieval.py
浏览文件 @
15f6f581
文件已移动
ppcls/engine/train/__init__.py
浏览文件 @
15f6f581
...
...
@@ -11,5 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
ppcls.engine.train.classification
import
classification_train
from
ppcls.engine.train.retrieval
import
retrieval_train
from
ppcls.engine.train.train
import
train_epoch
ppcls/engine/train/classification.py
已删除
100644 → 0
浏览文件 @
ebde0e13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
import
datetime
import
os
import
platform
import
sys
import
time
import
numpy
as
np
import
paddle
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../../../'
)))
from
ppcls.utils
import
logger
from
ppcls.utils.misc
import
AverageMeter
from
ppcls.engine.train.utils
import
update_loss
,
update_metric
,
log_info
def
classification_train
(
trainer
,
epoch_id
,
print_batch_step
):
tic
=
time
.
time
()
train_dataloader
=
trainer
.
train_dataloader
if
trainer
.
use_dali
else
trainer
.
train_dataloader
(
)
for
iter_id
,
batch
in
enumerate
(
train_dataloader
):
if
iter_id
>=
trainer
.
max_iter
:
break
if
iter_id
==
5
:
for
key
in
trainer
.
time_info
:
trainer
.
time_info
[
key
].
reset
()
trainer
.
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
if
trainer
.
use_dali
:
batch
=
[
paddle
.
to_tensor
(
batch
[
0
][
'data'
]),
paddle
.
to_tensor
(
batch
[
0
][
'label'
])
]
batch_size
=
batch
[
0
].
shape
[
0
]
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
trainer
.
global_step
+=
1
# image input
if
trainer
.
amp
:
with
paddle
.
amp
.
auto_cast
(
custom_black_list
=
{
"flatten_contiguous_range"
,
"greater_than"
}):
out
=
trainer
.
model
(
batch
[
0
])
loss_dict
=
trainer
.
train_loss_func
(
out
,
batch
[
1
])
else
:
out
=
trainer
.
model
(
batch
[
0
])
# calc loss
if
trainer
.
config
[
"DataLoader"
][
"Train"
][
"dataset"
].
get
(
"batch_transform_ops"
,
None
):
loss_dict
=
trainer
.
train_loss_func
(
out
,
batch
[
1
:])
else
:
loss_dict
=
trainer
.
train_loss_func
(
out
,
batch
[
1
])
# step opt and lr
if
trainer
.
amp
:
scaled
=
trainer
.
scaler
.
scale
(
loss_dict
[
"loss"
])
scaled
.
backward
()
trainer
.
scaler
.
minimize
(
trainer
.
optimizer
,
scaled
)
else
:
loss_dict
[
"loss"
].
backward
()
trainer
.
optimizer
.
step
()
trainer
.
optimizer
.
clear_grad
()
trainer
.
lr_sch
.
step
()
# below code just for logging
# update metric_for_logger
update_metric
(
trainer
,
out
,
batch
,
batch_size
)
# update_loss_for_logger
update_loss
(
trainer
,
loss_dict
,
batch_size
)
trainer
.
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
if
iter_id
%
print_batch_step
==
0
:
log_info
(
trainer
,
batch_size
,
epoch_id
,
iter_id
)
tic
=
time
.
time
()
ppcls/engine/train/
retrieval
.py
→
ppcls/engine/train/
train
.py
浏览文件 @
15f6f581
...
...
@@ -29,7 +29,7 @@ from ppcls.utils.misc import AverageMeter
from
ppcls.engine.train.utils
import
update_loss
,
update_metric
,
log_info
def
retrieval_train
(
trainer
,
epoch_id
,
print_batch_step
):
def
train_epoch
(
trainer
,
epoch_id
,
print_batch_step
):
tic
=
time
.
time
()
train_dataloader
=
trainer
.
train_dataloader
if
trainer
.
use_dali
else
trainer
.
train_dataloader
(
...
...
@@ -55,10 +55,10 @@ def retrieval_train(trainer, epoch_id, print_batch_step):
with
paddle
.
amp
.
auto_cast
(
custom_black_list
=
{
"flatten_contiguous_range"
,
"greater_than"
}):
out
=
trainer
.
model
(
batch
[
0
],
batch
[
1
]
)
out
=
forward
(
trainer
,
batch
)
loss_dict
=
trainer
.
train_loss_func
(
out
,
batch
[
1
])
else
:
out
=
trainer
.
model
(
batch
[
0
],
batch
[
1
]
)
out
=
forward
(
trainer
,
batch
)
# calc loss
if
trainer
.
config
[
"DataLoader"
][
"Train"
][
"dataset"
].
get
(
...
...
@@ -81,10 +81,15 @@ def retrieval_train(trainer, epoch_id, print_batch_step):
# below code just for logging
# update metric_for_logger
update_metric
(
trainer
,
out
,
batch
,
batch_size
)
# update_loss_for_logger
update_loss
(
trainer
,
loss_dict
,
batch_size
)
trainer
.
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
if
iter_id
%
print_batch_step
==
0
:
log_info
(
trainer
,
batch_size
,
epoch_id
,
iter_id
)
tic
=
time
.
time
()
def
forward
(
trainer
,
batch
):
if
trainer
.
eval_mode
==
"classification"
:
return
trainer
.
model
(
batch
[
0
])
else
:
return
trainer
.
model
(
batch
[
0
],
batch
[
1
])
tools/eval.py
浏览文件 @
15f6f581
...
...
@@ -21,11 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../'
)))
from
ppcls.utils
import
config
from
ppcls.engine.
core
import
Cor
e
from
ppcls.engine.
engine
import
Engin
e
if
__name__
==
"__main__"
:
args
=
config
.
parse_args
()
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
e
valer
=
Cor
e
(
config
,
mode
=
"eval"
)
e
valer
.
eval
()
e
ngine
=
Engin
e
(
config
,
mode
=
"eval"
)
e
ngine
.
eval
()
tools/export_model.py
浏览文件 @
15f6f581
...
...
@@ -24,11 +24,11 @@ import paddle
import
paddle.nn
as
nn
from
ppcls.utils
import
config
from
ppcls.engine.
core
import
Cor
e
from
ppcls.engine.
engine
import
Engin
e
if
__name__
==
"__main__"
:
args
=
config
.
parse_args
()
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
e
xporter
=
Cor
e
(
config
,
mode
=
"export"
)
e
xporter
.
export
()
e
ngine
=
Engin
e
(
config
,
mode
=
"export"
)
e
ngine
.
export
()
tools/infer.py
浏览文件 @
15f6f581
...
...
@@ -21,11 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../'
)))
from
ppcls.utils
import
config
from
ppcls.engine.
core
import
Cor
e
from
ppcls.engine.
engine
import
Engin
e
if
__name__
==
"__main__"
:
args
=
config
.
parse_args
()
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
inferer
=
Cor
e
(
config
,
mode
=
"infer"
)
inferer
.
infer
()
engine
=
Engin
e
(
config
,
mode
=
"infer"
)
engine
.
infer
()
tools/train.py
浏览文件 @
15f6f581
...
...
@@ -21,11 +21,11 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../'
)))
from
ppcls.utils
import
config
from
ppcls.engine.
core
import
Cor
e
from
ppcls.engine.
engine
import
Engin
e
if
__name__
==
"__main__"
:
args
=
config
.
parse_args
()
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
trainer
=
Cor
e
(
config
,
mode
=
"train"
)
trainer
.
train
()
engine
=
Engin
e
(
config
,
mode
=
"train"
)
engine
.
train
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录