Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
10ee66c7
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
10ee66c7
编写于
6月 23, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vdl
上级
11e1c767
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
30 addition
and
8 deletion
+30
-8
dygraph/train.py
dygraph/train.py
+29
-8
dygraph/val.py
dygraph/val.py
+1
-0
未找到文件。
dygraph/train.py
浏览文件 @
10ee66c7
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
import
argparse
import
argparse
import
os
import
os
import
time
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
...
@@ -119,6 +118,11 @@ def parse_args():
...
@@ -119,6 +118,11 @@ def parse_args():
help
=
'Display logging information at every log_steps'
,
help
=
'Display logging information at every log_steps'
,
default
=
10
,
default
=
10
,
type
=
int
)
type
=
int
)
parser
.
add_argument
(
'--use_vdl'
,
dest
=
'use_vdl'
,
help
=
'Whether to record the data during training to VisualDL'
,
action
=
'store_true'
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -136,7 +140,8 @@ def train(model,
...
@@ -136,7 +140,8 @@ def train(model,
save_interval_epochs
=
1
,
save_interval_epochs
=
1
,
log_steps
=
10
,
log_steps
=
10
,
num_classes
=
None
,
num_classes
=
None
,
num_workers
=
8
):
num_workers
=
8
,
use_vdl
=
False
):
ignore_index
=
model
.
ignore_index
ignore_index
=
model
.
ignore_index
nranks
=
ParallelEnv
().
nranks
nranks
=
ParallelEnv
().
nranks
...
@@ -165,10 +170,16 @@ def train(model,
...
@@ -165,10 +170,16 @@ def train(model,
return_list
=
True
,
return_list
=
True
,
)
)
if
use_vdl
:
from
visualdl
import
LogWriter
log_writer
=
LogWriter
(
save_dir
)
timer
=
Timer
()
timer
=
Timer
()
timer
.
start
()
timer
.
start
()
steps_per_epoch
=
len
(
batch_sampler
)
avg_loss
=
0.0
avg_loss
=
0.0
steps_per_epoch
=
len
(
batch_sampler
)
total_steps
=
steps_per_epoch
*
(
num_epochs
-
start_epoch
)
num_steps
=
0
for
epoch
in
range
(
start_epoch
,
num_epochs
):
for
epoch
in
range
(
start_epoch
,
num_epochs
):
for
step
,
data
in
enumerate
(
loader
):
for
step
,
data
in
enumerate
(
loader
):
images
=
data
[
0
]
images
=
data
[
0
]
...
@@ -185,17 +196,21 @@ def train(model,
...
@@ -185,17 +196,21 @@ def train(model,
model
.
clear_gradients
()
model
.
clear_gradients
()
avg_loss
+=
loss
.
numpy
()[
0
]
avg_loss
+=
loss
.
numpy
()[
0
]
lr
=
optimizer
.
current_step_lr
()
lr
=
optimizer
.
current_step_lr
()
if
step
%
log_steps
==
0
:
num_steps
+=
1
if
num_steps
%
log_steps
==
0
:
avg_loss
/=
log_steps
avg_loss
/=
log_steps
time_step
=
timer
.
elapsed_time
()
/
log_steps
time_step
=
timer
.
elapsed_time
()
/
log_steps
remain_step
=
(
num_epochs
-
epoch
)
*
steps_per_epoch
-
step
-
1
remain_step
s
=
total_steps
-
num_steps
logging
.
info
(
logging
.
info
(
"[TRAIN] Epoch={}/{}, Step={}/{}, loss={:.4f}, lr={:.6f}, sec/step={:.4f} | ETA {}"
"[TRAIN] Epoch={}/{}, Step={}/{}, loss={:.4f}, lr={:.6f}, sec/step={:.4f} | ETA {}"
.
format
(
epoch
+
1
,
num_epochs
,
step
+
1
,
steps_per_epoch
,
.
format
(
epoch
+
1
,
num_epochs
,
step
+
1
,
steps_per_epoch
,
avg_loss
,
lr
,
time_step
,
avg_loss
,
lr
,
time_step
,
calculate_eta
(
remain_step
,
time_step
)))
calculate_eta
(
remain_step
s
,
time_step
)))
avg_loss
=
0.0
avg_loss
=
0.0
timer
.
restart
()
timer
.
restart
()
if
use_vdl
:
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
num_steps
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
,
num_steps
)
if
((
epoch
+
1
)
%
save_interval_epochs
==
0
if
((
epoch
+
1
)
%
save_interval_epochs
==
0
or
epoch
==
num_epochs
-
1
)
and
ParallelEnv
().
local_rank
==
0
:
or
epoch
==
num_epochs
-
1
)
and
ParallelEnv
().
local_rank
==
0
:
...
@@ -209,7 +224,7 @@ def train(model,
...
@@ -209,7 +224,7 @@ def train(model,
os
.
path
.
join
(
current_save_dir
,
'model'
))
os
.
path
.
join
(
current_save_dir
,
'model'
))
if
eval_dataset
is
not
None
:
if
eval_dataset
is
not
None
:
evaluate
(
mean_iou
,
mean_acc
=
evaluate
(
model
,
model
,
eval_dataset
,
eval_dataset
,
places
=
places
,
places
=
places
,
...
@@ -218,6 +233,11 @@ def train(model,
...
@@ -218,6 +233,11 @@ def train(model,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
ignore_index
=
ignore_index
,
ignore_index
=
ignore_index
,
epoch_id
=
epoch
+
1
)
epoch_id
=
epoch
+
1
)
if
use_vdl
:
log_writer
.
add_scalar
(
'Evaluate/mean_iou'
,
mean_iou
,
num_steps
)
log_writer
.
add_scalar
(
'Evaluate/mean_acc'
,
mean_acc
,
num_steps
)
model
.
train
()
model
.
train
()
...
@@ -283,7 +303,8 @@ def main(args):
...
@@ -283,7 +303,8 @@ def main(args):
save_interval_epochs
=
args
.
save_interval_epochs
,
save_interval_epochs
=
args
.
save_interval_epochs
,
log_steps
=
args
.
log_steps
,
log_steps
=
args
.
log_steps
,
num_classes
=
train_dataset
.
num_classes
,
num_classes
=
train_dataset
.
num_classes
,
num_workers
=
args
.
num_workers
)
num_workers
=
args
.
num_workers
,
use_vdl
=
args
.
use_vdl
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
dygraph/val.py
浏览文件 @
10ee66c7
...
@@ -131,6 +131,7 @@ def evaluate(model,
...
@@ -131,6 +131,7 @@ def evaluate(model,
logging
.
info
(
"[EVAL] Category IoU: "
+
str
(
category_iou
))
logging
.
info
(
"[EVAL] Category IoU: "
+
str
(
category_iou
))
logging
.
info
(
"[EVAL] Category Acc: "
+
str
(
category_acc
))
logging
.
info
(
"[EVAL] Category Acc: "
+
str
(
category_acc
))
logging
.
info
(
"[EVAL] Kappa:{:.4f} "
.
format
(
conf_mat
.
kappa
()))
logging
.
info
(
"[EVAL] Kappa:{:.4f} "
.
format
(
conf_mat
.
kappa
()))
return
miou
,
macc
def
main
(
args
):
def
main
(
args
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录