Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
b885a6d1
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b885a6d1
编写于
5月 19, 2020
作者:
J
jiangjiajun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
revert base.py
上级
a532ecc2
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
22 addition
and
28 deletion
+22
-28
paddlex/cv/models/base.py
paddlex/cv/models/base.py
+22
-28
未找到文件。
paddlex/cv/models/base.py
浏览文件 @
b885a6d1
...
...
@@ -79,9 +79,9 @@ class BaseAPI:
return
int
(
batch_size
//
len
(
self
.
places
))
else
:
raise
Exception
(
"Please support correct batch_size,
\
which can be divided by available cards({}) in {}"
.
format
(
paddlex
.
env_info
[
'num'
],
paddlex
.
env_info
[
'place'
]))
which can be divided by available cards({}) in {}"
.
format
(
paddlex
.
env_info
[
'num'
],
paddlex
.
env_info
[
'place'
]))
def
build_program
(
self
):
# 构建训练网络
...
...
@@ -141,7 +141,7 @@ class BaseAPI:
from
.slim.post_quantization
import
PaddleXPostTrainingQuantization
except
:
raise
Exception
(
"Model Quantization is not available, try to upgrade your paddlepaddle>=1.
8
.0"
"Model Quantization is not available, try to upgrade your paddlepaddle>=1.
7
.0"
)
is_use_cache_file
=
True
if
cache_dir
is
None
:
...
...
@@ -209,8 +209,8 @@ class BaseAPI:
paddlex
.
utils
.
utils
.
load_pretrain_weights
(
self
.
exe
,
self
.
train_prog
,
resume_checkpoint
,
resume
=
True
)
if
not
osp
.
exists
(
osp
.
join
(
resume_checkpoint
,
"model.yml"
)):
raise
Exception
(
"There's not model.yml in {}"
.
format
(
resume_checkpoint
))
raise
Exception
(
"There's not model.yml in {}"
.
format
(
resume_checkpoint
))
with
open
(
osp
.
join
(
resume_checkpoint
,
"model.yml"
))
as
f
:
info
=
yaml
.
load
(
f
.
read
(),
Loader
=
yaml
.
Loader
)
self
.
completed_epochs
=
info
[
'completed_epochs'
]
...
...
@@ -361,8 +361,8 @@ class BaseAPI:
# 模型保存成功的标志
open
(
osp
.
join
(
save_dir
,
'.success'
),
'w'
).
close
()
logging
.
info
(
"Model for inference deploy saved in {}."
.
format
(
save_dir
))
logging
.
info
(
"Model for inference deploy saved in {}."
.
format
(
save_dir
))
def
train_loop
(
self
,
num_epochs
,
...
...
@@ -376,8 +376,7 @@ class BaseAPI:
early_stop
=
False
,
early_stop_patience
=
5
):
if
train_dataset
.
num_samples
<
train_batch_size
:
raise
Exception
(
'The amount of training datset must be larger than batch size.'
)
raise
Exception
(
'The amount of training datset must be larger than batch size.'
)
if
not
osp
.
isdir
(
save_dir
):
if
osp
.
exists
(
save_dir
):
os
.
remove
(
save_dir
)
...
...
@@ -415,8 +414,8 @@ class BaseAPI:
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
total_num_steps
=
math
.
floor
(
train_dataset
.
num_samples
/
train_batch_size
)
total_num_steps
=
math
.
floor
(
train_dataset
.
num_samples
/
train_batch_size
)
num_steps
=
0
time_stat
=
list
()
time_train_one_epoch
=
None
...
...
@@ -430,8 +429,8 @@ class BaseAPI:
if
self
.
model_type
==
'detector'
:
eval_batch_size
=
self
.
_get_single_card_bs
(
train_batch_size
)
if
eval_dataset
is
not
None
:
total_num_steps_eval
=
math
.
ceil
(
eval_dataset
.
num_samples
/
eval_batch_size
)
total_num_steps_eval
=
math
.
ceil
(
eval_dataset
.
num_samples
/
eval_batch_size
)
if
use_vdl
:
# VisualDL component
...
...
@@ -473,9 +472,7 @@ class BaseAPI:
if
use_vdl
:
for
k
,
v
in
step_metrics
.
items
():
log_writer
.
add_scalar
(
'Metrics/Training(Step): {}'
.
format
(
k
),
v
,
num_steps
)
log_writer
.
add_scalar
(
'Metrics/Training(Step): {}'
.
format
(
k
),
v
,
num_steps
)
# 估算剩余时间
avg_step_time
=
np
.
mean
(
time_stat
)
...
...
@@ -483,12 +480,11 @@ class BaseAPI:
eta
=
(
num_epochs
-
i
-
1
)
*
time_train_one_epoch
+
(
total_num_steps
-
step
-
1
)
*
avg_step_time
else
:
eta
=
((
num_epochs
-
i
)
*
total_num_steps
-
step
-
1
)
*
avg_step_time
eta
=
((
num_epochs
-
i
)
*
total_num_steps
-
step
-
1
)
*
avg_step_time
if
time_eval_one_epoch
is
not
None
:
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
)
*
time_eval_one_epoch
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
)
*
time_eval_one_epoch
else
:
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
...
...
@@ -498,11 +494,10 @@ class BaseAPI:
logging
.
info
(
"[TRAIN] Epoch={}/{}, Step={}/{}, {}, time_each_step={}s, eta={}"
.
format
(
i
+
1
,
num_epochs
,
step
+
1
,
total_num_steps
,
dict2str
(
step_metrics
),
round
(
avg_step_time
,
2
),
eta_str
))
dict2str
(
step_metrics
),
round
(
avg_step_time
,
2
),
eta_str
))
train_metrics
=
OrderedDict
(
zip
(
list
(
self
.
train_outputs
.
keys
()),
np
.
mean
(
records
,
axis
=
0
)))
zip
(
list
(
self
.
train_outputs
.
keys
()),
np
.
mean
(
records
,
axis
=
0
)))
logging
.
info
(
'[TRAIN] Epoch {} finished, {} .'
.
format
(
i
+
1
,
dict2str
(
train_metrics
)))
time_train_one_epoch
=
time
.
time
()
-
epoch_start_time
...
...
@@ -538,8 +533,7 @@ class BaseAPI:
if
isinstance
(
v
,
np
.
ndarray
):
if
v
.
size
>
1
:
continue
log_writer
.
add_scalar
(
"Metrics/Eval(Epoch): {}"
.
format
(
k
),
v
,
i
+
1
)
log_writer
.
add_scalar
(
"Metrics/Eval(Epoch): {}"
.
format
(
k
),
v
,
i
+
1
)
self
.
save_model
(
save_dir
=
current_save_dir
)
time_eval_one_epoch
=
time
.
time
()
-
eval_epoch_start_time
eval_epoch_start_time
=
time
.
time
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录