Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
2e6dfa44
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2e6dfa44
编写于
6月 15, 2021
作者:
L
littletomatodonkey
提交者:
GitHub
6月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix logger (#840)
* fix logger * fix trainer for int64 on windows
上级
e4c4ec76
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
109 addition
and
105 deletion
+109
-105
ppcls/data/__init__.py
ppcls/data/__init__.py
+5
-10
ppcls/engine/trainer.py
ppcls/engine/trainer.py
+21
-11
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+1
-1
ppcls/optimizer/__init__.py
ppcls/optimizer/__init__.py
+3
-3
ppcls/utils/config.py
ppcls/utils/config.py
+4
-8
ppcls/utils/logger.py
ppcls/utils/logger.py
+65
-50
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+4
-19
tools/eval.py
tools/eval.py
+2
-1
tools/infer.py
tools/infer.py
+2
-1
tools/train.py
tools/train.py
+2
-1
未找到文件。
ppcls/data/__init__.py
浏览文件 @
2e6dfa44
...
@@ -54,12 +54,7 @@ def create_operators(params):
...
@@ -54,12 +54,7 @@ def create_operators(params):
def
build_dataloader
(
config
,
mode
,
device
,
seed
=
None
):
def
build_dataloader
(
config
,
mode
,
device
,
seed
=
None
):
assert
mode
in
[
assert
mode
in
[
'Train'
,
'Eval'
,
'Test'
,
'Gallery'
,
'Query'
'Train'
,
'Eval'
,
'Test'
,
'Gallery'
,
'Query'
],
"Mode should be Train, Eval, Test, Gallery, Query"
],
"Mode should be Train, Eval, Test, Gallery, Query"
# build dataset
# build dataset
config_dataset
=
config
[
mode
][
'dataset'
]
config_dataset
=
config
[
mode
][
'dataset'
]
...
@@ -72,7 +67,7 @@ def build_dataloader(config, mode, device, seed=None):
...
@@ -72,7 +67,7 @@ def build_dataloader(config, mode, device, seed=None):
dataset
=
eval
(
dataset_name
)(
**
config_dataset
)
dataset
=
eval
(
dataset_name
)(
**
config_dataset
)
logger
.
info
(
"build dataset({}) success..."
.
format
(
dataset
))
logger
.
debug
(
"build dataset({}) success..."
.
format
(
dataset
))
# build sampler
# build sampler
config_sampler
=
config
[
mode
][
'sampler'
]
config_sampler
=
config
[
mode
][
'sampler'
]
...
@@ -85,7 +80,7 @@ def build_dataloader(config, mode, device, seed=None):
...
@@ -85,7 +80,7 @@ def build_dataloader(config, mode, device, seed=None):
sampler_name
=
config_sampler
.
pop
(
"name"
)
sampler_name
=
config_sampler
.
pop
(
"name"
)
batch_sampler
=
eval
(
sampler_name
)(
dataset
,
**
config_sampler
)
batch_sampler
=
eval
(
sampler_name
)(
dataset
,
**
config_sampler
)
logger
.
info
(
"build batch_sampler({}) success..."
.
format
(
batch_sampler
))
logger
.
debug
(
"build batch_sampler({}) success..."
.
format
(
batch_sampler
))
# build batch operator
# build batch operator
def
mix_collate_fn
(
batch
):
def
mix_collate_fn
(
batch
):
...
@@ -132,5 +127,5 @@ def build_dataloader(config, mode, device, seed=None):
...
@@ -132,5 +127,5 @@ def build_dataloader(config, mode, device, seed=None):
batch_sampler
=
batch_sampler
,
batch_sampler
=
batch_sampler
,
collate_fn
=
batch_collate_fn
)
collate_fn
=
batch_collate_fn
)
logger
.
info
(
"build data_loader({}) success..."
.
format
(
data_loader
))
logger
.
debug
(
"build data_loader({}) success..."
.
format
(
data_loader
))
return
data_loader
return
data_loader
ppcls/engine/trainer.py
浏览文件 @
2e6dfa44
...
@@ -30,6 +30,8 @@ import paddle.distributed as dist
...
@@ -30,6 +30,8 @@ import paddle.distributed as dist
from
ppcls.utils.check
import
check_gpu
from
ppcls.utils.check
import
check_gpu
from
ppcls.utils.misc
import
AverageMeter
from
ppcls.utils.misc
import
AverageMeter
from
ppcls.utils
import
logger
from
ppcls.utils
import
logger
from
ppcls.utils.logger
import
init_logger
from
ppcls.utils.config
import
print_config
from
ppcls.data
import
build_dataloader
from
ppcls.data
import
build_dataloader
from
ppcls.arch
import
build_model
from
ppcls.arch
import
build_model
from
ppcls.loss
import
build_loss
from
ppcls.loss
import
build_loss
...
@@ -49,6 +51,11 @@ class Trainer(object):
...
@@ -49,6 +51,11 @@ class Trainer(object):
self
.
mode
=
mode
self
.
mode
=
mode
self
.
config
=
config
self
.
config
=
config
self
.
output_dir
=
self
.
config
[
'Global'
][
'output_dir'
]
self
.
output_dir
=
self
.
config
[
'Global'
][
'output_dir'
]
log_file
=
os
.
path
.
join
(
self
.
output_dir
,
self
.
config
[
"Arch"
][
"name"
],
f
"
{
mode
}
.log"
)
init_logger
(
name
=
'root'
,
log_file
=
log_file
)
print_config
(
config
)
# set device
# set device
assert
self
.
config
[
"Global"
][
"device"
]
in
[
"cpu"
,
"gpu"
,
"xpu"
]
assert
self
.
config
[
"Global"
][
"device"
]
in
[
"cpu"
,
"gpu"
,
"xpu"
]
self
.
device
=
paddle
.
set_device
(
self
.
config
[
"Global"
][
"device"
])
self
.
device
=
paddle
.
set_device
(
self
.
config
[
"Global"
][
"device"
])
...
@@ -153,8 +160,8 @@ class Trainer(object):
...
@@ -153,8 +160,8 @@ class Trainer(object):
time_info
[
key
].
reset
()
time_info
[
key
].
reset
()
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
batch_size
=
batch
[
0
].
shape
[
0
]
batch_size
=
batch
[
0
].
shape
[
0
]
batch
[
1
]
=
paddle
.
to_tensor
(
batch
[
1
].
numpy
(
).
astype
(
"int64"
)
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]
).
astype
(
"int64"
)
.
reshape
([
-
1
,
1
]))
global_step
+=
1
global_step
+=
1
# image input
# image input
if
not
self
.
is_rec
:
if
not
self
.
is_rec
:
...
@@ -206,8 +213,9 @@ class Trainer(object):
...
@@ -206,8 +213,9 @@ class Trainer(object):
eta_msg
=
"eta: {:s}"
.
format
(
eta_msg
=
"eta: {:s}"
.
format
(
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_sec
))))
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_sec
))))
logger
.
info
(
logger
.
info
(
"[Train][Epoch {}][Iter: {}/{}]{}, {}, {}, {}, {}"
.
"[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}"
.
format
(
epoch_id
,
iter_id
,
format
(
epoch_id
,
self
.
config
[
"Global"
][
"epochs"
],
iter_id
,
len
(
self
.
train_dataloader
),
lr_msg
,
metric_msg
,
len
(
self
.
train_dataloader
),
lr_msg
,
metric_msg
,
time_msg
,
ips_msg
,
eta_msg
))
time_msg
,
ips_msg
,
eta_msg
))
tic
=
time
.
time
()
tic
=
time
.
time
()
...
@@ -216,8 +224,8 @@ class Trainer(object):
...
@@ -216,8 +224,8 @@ class Trainer(object):
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
for
key
in
output_info
])
])
logger
.
info
(
"[Train][Epoch {}
][Avg]{}"
.
format
(
epoch_id
,
logger
.
info
(
"[Train][Epoch {}
/{}][Avg]{}"
.
format
(
metric_msg
))
epoch_id
,
self
.
config
[
"Global"
][
"epochs"
],
metric_msg
))
output_info
.
clear
()
output_info
.
clear
()
# eval model and save model if possible
# eval model and save model if possible
...
@@ -327,7 +335,7 @@ class Trainer(object):
...
@@ -327,7 +335,7 @@ class Trainer(object):
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
batch_size
=
batch
[
0
].
shape
[
0
]
batch_size
=
batch
[
0
].
shape
[
0
]
batch
[
0
]
=
paddle
.
to_tensor
(
batch
[
0
]).
astype
(
"float32"
)
batch
[
0
]
=
paddle
.
to_tensor
(
batch
[
0
]).
astype
(
"float32"
)
batch
[
1
]
=
paddle
.
to_tensor
(
batch
[
1
]).
reshape
([
-
1
,
1
]
)
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
# image input
# image input
if
self
.
is_rec
:
if
self
.
is_rec
:
out
=
self
.
model
(
batch
[
0
],
batch
[
1
])
out
=
self
.
model
(
batch
[
0
],
batch
[
1
])
...
@@ -438,9 +446,11 @@ class Trainer(object):
...
@@ -438,9 +446,11 @@ class Trainer(object):
for
key
in
metric_tmp
:
for
key
in
metric_tmp
:
if
key
not
in
metric_dict
:
if
key
not
in
metric_dict
:
metric_dict
[
key
]
=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
metric_dict
[
key
]
=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
else
:
else
:
metric_dict
[
key
]
+=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
metric_dict
[
key
]
+=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
metric_info_list
=
[]
metric_info_list
=
[]
for
key
in
metric_dict
:
for
key
in
metric_dict
:
...
@@ -467,10 +477,10 @@ class Trainer(object):
...
@@ -467,10 +477,10 @@ class Trainer(object):
for
idx
,
batch
in
enumerate
(
dataloader
(
for
idx
,
batch
in
enumerate
(
dataloader
(
)):
# load is very time-consuming
)):
# load is very time-consuming
batch
=
[
paddle
.
to_tensor
(
x
)
for
x
in
batch
]
batch
=
[
paddle
.
to_tensor
(
x
)
for
x
in
batch
]
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
])
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
])
.
astype
(
"int64"
)
if
len
(
batch
)
==
3
:
if
len
(
batch
)
==
3
:
has_unique_id
=
True
has_unique_id
=
True
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
])
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
])
.
astype
(
"int64"
)
out
=
self
.
model
(
batch
[
0
],
batch
[
1
])
out
=
self
.
model
(
batch
[
0
],
batch
[
1
])
batch_feas
=
out
[
"features"
]
batch_feas
=
out
[
"features"
]
...
...
ppcls/loss/__init__.py
浏览文件 @
2e6dfa44
...
@@ -52,5 +52,5 @@ class CombinedLoss(nn.Layer):
...
@@ -52,5 +52,5 @@ class CombinedLoss(nn.Layer):
def
build_loss
(
config
):
def
build_loss
(
config
):
module_class
=
CombinedLoss
(
copy
.
deepcopy
(
config
))
module_class
=
CombinedLoss
(
copy
.
deepcopy
(
config
))
logger
.
info
(
"build loss {} success."
.
format
(
module_class
))
logger
.
debug
(
"build loss {} success."
.
format
(
module_class
))
return
module_class
return
module_class
ppcls/optimizer/__init__.py
浏览文件 @
2e6dfa44
...
@@ -45,7 +45,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
...
@@ -45,7 +45,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
# step1 build lr
# step1 build lr
lr
=
build_lr_scheduler
(
config
.
pop
(
'lr'
),
epochs
,
step_each_epoch
)
lr
=
build_lr_scheduler
(
config
.
pop
(
'lr'
),
epochs
,
step_each_epoch
)
logger
.
info
(
"build lr ({}) success.."
.
format
(
lr
))
logger
.
debug
(
"build lr ({}) success.."
.
format
(
lr
))
# step2 build regularization
# step2 build regularization
if
'regularizer'
in
config
and
config
[
'regularizer'
]
is
not
None
:
if
'regularizer'
in
config
and
config
[
'regularizer'
]
is
not
None
:
reg_config
=
config
.
pop
(
'regularizer'
)
reg_config
=
config
.
pop
(
'regularizer'
)
...
@@ -53,7 +53,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
...
@@ -53,7 +53,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
reg
=
getattr
(
paddle
.
regularizer
,
reg_name
)(
**
reg_config
)
reg
=
getattr
(
paddle
.
regularizer
,
reg_name
)(
**
reg_config
)
else
:
else
:
reg
=
None
reg
=
None
logger
.
info
(
"build regularizer ({}) success.."
.
format
(
reg
))
logger
.
debug
(
"build regularizer ({}) success.."
.
format
(
reg
))
# step3 build optimizer
# step3 build optimizer
optim_name
=
config
.
pop
(
'name'
)
optim_name
=
config
.
pop
(
'name'
)
if
'clip_norm'
in
config
:
if
'clip_norm'
in
config
:
...
@@ -65,5 +65,5 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
...
@@ -65,5 +65,5 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
weight_decay
=
reg
,
weight_decay
=
reg
,
grad_clip
=
grad_clip
,
grad_clip
=
grad_clip
,
**
config
)(
parameters
=
parameters
)
**
config
)(
parameters
=
parameters
)
logger
.
info
(
"build optimizer ({}) success.."
.
format
(
optim
))
logger
.
debug
(
"build optimizer ({}) success.."
.
format
(
optim
))
return
optim
,
lr
return
optim
,
lr
ppcls/utils/config.py
浏览文件 @
2e6dfa44
...
@@ -67,18 +67,14 @@ def print_dict(d, delimiter=0):
...
@@ -67,18 +67,14 @@ def print_dict(d, delimiter=0):
placeholder
=
"-"
*
60
placeholder
=
"-"
*
60
for
k
,
v
in
sorted
(
d
.
items
()):
for
k
,
v
in
sorted
(
d
.
items
()):
if
isinstance
(
v
,
dict
):
if
isinstance
(
v
,
dict
):
logger
.
info
(
"{}{} : "
.
format
(
delimiter
*
" "
,
logger
.
info
(
"{}{} : "
.
format
(
delimiter
*
" "
,
k
))
logger
.
coloring
(
k
,
"HEADER"
)))
print_dict
(
v
,
delimiter
+
4
)
print_dict
(
v
,
delimiter
+
4
)
elif
isinstance
(
v
,
list
)
and
len
(
v
)
>=
1
and
isinstance
(
v
[
0
],
dict
):
elif
isinstance
(
v
,
list
)
and
len
(
v
)
>=
1
and
isinstance
(
v
[
0
],
dict
):
logger
.
info
(
"{}{} : "
.
format
(
delimiter
*
" "
,
logger
.
info
(
"{}{} : "
.
format
(
delimiter
*
" "
,
k
))
logger
.
coloring
(
str
(
k
),
"HEADER"
)))
for
value
in
v
:
for
value
in
v
:
print_dict
(
value
,
delimiter
+
4
)
print_dict
(
value
,
delimiter
+
4
)
else
:
else
:
logger
.
info
(
"{}{} : {}"
.
format
(
delimiter
*
" "
,
logger
.
info
(
"{}{} : {}"
.
format
(
delimiter
*
" "
,
k
,
v
))
logger
.
coloring
(
k
,
"HEADER"
),
logger
.
coloring
(
v
,
"OKGREEN"
)))
if
k
.
isupper
():
if
k
.
isupper
():
logger
.
info
(
placeholder
)
logger
.
info
(
placeholder
)
...
@@ -175,7 +171,7 @@ def override_config(config, options=None):
...
@@ -175,7 +171,7 @@ def override_config(config, options=None):
return
config
return
config
def
get_config
(
fname
,
overrides
=
None
,
show
=
Tru
e
):
def
get_config
(
fname
,
overrides
=
None
,
show
=
Fals
e
):
"""
"""
Read config from file
Read config from file
"""
"""
...
...
ppcls/utils/logger.py
浏览文件 @
2e6dfa44
...
@@ -12,70 +12,86 @@
...
@@ -12,70 +12,86 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
logging
import
os
import
os
import
datetime
import
sys
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"%(asctime)s %(levelname)s: %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
)
def
time_zone
(
sec
,
fmt
):
real_time
=
datetime
.
datetime
.
now
()
return
real_time
.
timetuple
()
logging
.
Formatter
.
converter
=
time_zone
_logger
=
logging
.
getLogger
(
__name__
)
Color
=
{
import
logging
'RED'
:
'
\033
[31m'
,
import
datetime
'HEADER'
:
'
\033
[35m'
,
# deep purple
import
paddle.distributed
as
dist
'PURPLE'
:
'
\033
[95m'
,
# purple
'OKBLUE'
:
'
\033
[94m'
,
_logger
=
None
'OKGREEN'
:
'
\033
[92m'
,
'WARNING'
:
'
\033
[93m'
,
'FAIL'
:
'
\033
[91m'
,
def
init_logger
(
name
=
'root'
,
log_file
=
None
,
log_level
=
logging
.
INFO
):
'ENDC'
:
'
\033
[0m'
"""Initialize and get a logger by name.
}
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
def
coloring
(
message
,
color
=
"OKGREEN"
):
added. If `log_file` is specified a FileHandler will also be added.
assert
color
in
Color
.
keys
()
Args:
if
os
.
environ
.
get
(
'PADDLECLAS_COLORING'
,
False
):
name (str): Logger name.
return
Color
[
color
]
+
str
(
message
)
+
Color
[
"ENDC"
]
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
global
_logger
assert
_logger
is
None
,
"logger should not be initialized twice or more."
_logger
=
logging
.
getLogger
(
name
)
formatter
=
logging
.
Formatter
(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s'
,
datefmt
=
"%Y/%m/%d %H:%M:%S"
)
stream_handler
=
logging
.
StreamHandler
(
stream
=
sys
.
stdout
)
stream_handler
.
setFormatter
(
formatter
)
_logger
.
addHandler
(
stream_handler
)
if
log_file
is
not
None
and
dist
.
get_rank
()
==
0
:
log_file_folder
=
os
.
path
.
split
(
log_file
)[
0
]
os
.
makedirs
(
log_file_folder
,
exist_ok
=
True
)
file_handler
=
logging
.
FileHandler
(
log_file
,
'a'
)
file_handler
.
setFormatter
(
formatter
)
_logger
.
addHandler
(
file_handler
)
if
dist
.
get_rank
()
==
0
:
_logger
.
setLevel
(
log_level
)
else
:
else
:
return
message
_logger
.
setLevel
(
logging
.
ERROR
)
def
anti_fleet
(
log
):
def
log_at_trainer0
(
log
):
"""
"""
logs will print multi-times when calling Fleet API.
logs will print multi-times when calling Fleet API.
Only display single log and ignore the others.
Only display single log and ignore the others.
"""
"""
def
wrapper
(
fmt
,
*
args
):
def
wrapper
(
fmt
,
*
args
):
if
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
)
)
==
0
:
if
dist
.
get_rank
(
)
==
0
:
log
(
fmt
,
*
args
)
log
(
fmt
,
*
args
)
return
wrapper
return
wrapper
@
anti_fleet
@
log_at_trainer0
def
info
(
fmt
,
*
args
):
def
info
(
fmt
,
*
args
):
_logger
.
info
(
fmt
,
*
args
)
_logger
.
info
(
fmt
,
*
args
)
@
anti_fleet
@
log_at_trainer0
def
debug
(
fmt
,
*
args
):
_logger
.
debug
(
fmt
,
*
args
)
@
log_at_trainer0
def
warning
(
fmt
,
*
args
):
def
warning
(
fmt
,
*
args
):
_logger
.
warning
(
coloring
(
fmt
,
"RED"
)
,
*
args
)
_logger
.
warning
(
fmt
,
*
args
)
@
anti_fleet
@
log_at_trainer0
def
error
(
fmt
,
*
args
):
def
error
(
fmt
,
*
args
):
_logger
.
error
(
coloring
(
fmt
,
"FAIL"
)
,
*
args
)
_logger
.
error
(
fmt
,
*
args
)
def
scaler
(
name
,
value
,
step
,
writer
):
def
scaler
(
name
,
value
,
step
,
writer
):
...
@@ -108,8 +124,7 @@ def advertise():
...
@@ -108,8 +124,7 @@ def advertise():
website
=
"https://github.com/PaddlePaddle/PaddleClas"
website
=
"https://github.com/PaddlePaddle/PaddleClas"
AD_LEN
=
6
+
len
(
max
([
copyright
,
ad
,
website
],
key
=
len
))
AD_LEN
=
6
+
len
(
max
([
copyright
,
ad
,
website
],
key
=
len
))
info
(
info
(
"
\n
{0}
\n
{1}
\n
{2}
\n
{3}
\n
{4}
\n
{5}
\n
{6}
\n
{7}
\n
"
.
format
(
coloring
(
"
\n
{0}
\n
{1}
\n
{2}
\n
{3}
\n
{4}
\n
{5}
\n
{6}
\n
{7}
\n
"
.
format
(
"="
*
(
AD_LEN
+
4
),
"="
*
(
AD_LEN
+
4
),
"=={}=="
.
format
(
copyright
.
center
(
AD_LEN
)),
"=={}=="
.
format
(
copyright
.
center
(
AD_LEN
)),
"="
*
(
AD_LEN
+
4
),
"="
*
(
AD_LEN
+
4
),
...
@@ -117,4 +132,4 @@ def advertise():
...
@@ -117,4 +132,4 @@ def advertise():
"=={}=="
.
format
(
ad
.
center
(
AD_LEN
)),
"=={}=="
.
format
(
ad
.
center
(
AD_LEN
)),
"=={}=="
.
format
(
' '
*
AD_LEN
),
"=={}=="
.
format
(
' '
*
AD_LEN
),
"=={}=="
.
format
(
website
.
center
(
AD_LEN
)),
"=={}=="
.
format
(
website
.
center
(
AD_LEN
)),
"="
*
(
AD_LEN
+
4
),
),
"RED"
))
"="
*
(
AD_LEN
+
4
),
))
ppcls/utils/save_load.py
浏览文件 @
2e6dfa44
...
@@ -115,19 +115,6 @@ def init_model(config, net, optimizer=None):
...
@@ -115,19 +115,6 @@ def init_model(config, net, optimizer=None):
pretrained_model
),
"HEADER"
))
pretrained_model
),
"HEADER"
))
def
_save_student_model
(
net
,
model_prefix
):
"""
save student model if the net is the network contains student
"""
student_model_prefix
=
model_prefix
+
"_student.pdparams"
if
hasattr
(
net
,
"_layers"
):
net
=
net
.
_layers
if
hasattr
(
net
,
"student"
):
paddle
.
save
(
net
.
student
.
state_dict
(),
student_model_prefix
)
logger
.
info
(
"Already save student model in {}"
.
format
(
student_model_prefix
))
def
save_model
(
net
,
def
save_model
(
net
,
optimizer
,
optimizer
,
metric_info
,
metric_info
,
...
@@ -141,11 +128,9 @@ def save_model(net,
...
@@ -141,11 +128,9 @@ def save_model(net,
return
return
model_path
=
os
.
path
.
join
(
model_path
,
model_name
)
model_path
=
os
.
path
.
join
(
model_path
,
model_name
)
_mkdir_if_not_exist
(
model_path
)
_mkdir_if_not_exist
(
model_path
)
model_prefix
=
os
.
path
.
join
(
model_path
,
prefix
)
model_path
=
os
.
path
.
join
(
model_path
,
prefix
)
_save_student_model
(
net
,
model_prefix
)
paddle
.
save
(
net
.
state_dict
(),
model_p
refix
+
".pdparams"
)
paddle
.
save
(
net
.
state_dict
(),
model_p
ath
+
".pdparams"
)
paddle
.
save
(
optimizer
.
state_dict
(),
model_p
refix
+
".pdopt"
)
paddle
.
save
(
optimizer
.
state_dict
(),
model_p
ath
+
".pdopt"
)
paddle
.
save
(
metric_info
,
model_p
refix
+
".pdstates"
)
paddle
.
save
(
metric_info
,
model_p
ath
+
".pdstates"
)
logger
.
info
(
"Already save model in {}"
.
format
(
model_path
))
logger
.
info
(
"Already save model in {}"
.
format
(
model_path
))
tools/eval.py
浏览文件 @
2e6dfa44
...
@@ -25,6 +25,7 @@ from ppcls.engine.trainer import Trainer
...
@@ -25,6 +25,7 @@ from ppcls.engine.trainer import Trainer
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
config
.
parse_args
()
args
=
config
.
parse_args
()
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
True
)
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
trainer
=
Trainer
(
config
,
mode
=
"eval"
)
trainer
=
Trainer
(
config
,
mode
=
"eval"
)
trainer
.
eval
()
trainer
.
eval
()
tools/infer.py
浏览文件 @
2e6dfa44
...
@@ -25,7 +25,8 @@ from ppcls.engine.trainer import Trainer
...
@@ -25,7 +25,8 @@ from ppcls.engine.trainer import Trainer
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
config
.
parse_args
()
args
=
config
.
parse_args
()
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
True
)
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
trainer
=
Trainer
(
config
,
mode
=
"infer"
)
trainer
=
Trainer
(
config
,
mode
=
"infer"
)
trainer
.
infer
()
trainer
.
infer
()
tools/train.py
浏览文件 @
2e6dfa44
...
@@ -25,6 +25,7 @@ from ppcls.engine.trainer import Trainer
...
@@ -25,6 +25,7 @@ from ppcls.engine.trainer import Trainer
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
config
.
parse_args
()
args
=
config
.
parse_args
()
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
True
)
config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
False
)
trainer
=
Trainer
(
config
,
mode
=
"train"
)
trainer
=
Trainer
(
config
,
mode
=
"train"
)
trainer
.
train
()
trainer
.
train
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录