Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
40ed988d
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
40ed988d
编写于
9月 21, 2020
作者:
M
michaelowenliu
提交者:
GitHub
9月 21, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #389 from michaelowenliu/develop
add callbacks in core/train
上级
ca0448fd
73289709
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
661 addition
and
112 deletion
+661
-112
dygraph/paddleseg/core/seg_train.py
dygraph/paddleseg/core/seg_train.py
+162
-0
dygraph/paddleseg/cvlibs/callbacks.py
dygraph/paddleseg/cvlibs/callbacks.py
+277
-0
dygraph/paddleseg/models/ann.py
dygraph/paddleseg/models/ann.py
+5
-41
dygraph/paddleseg/models/common/layer_libs.py
dygraph/paddleseg/models/common/layer_libs.py
+0
-4
dygraph/paddleseg/models/common/pyramid_pool.py
dygraph/paddleseg/models/common/pyramid_pool.py
+0
-9
dygraph/paddleseg/models/deeplab.py
dygraph/paddleseg/models/deeplab.py
+8
-16
dygraph/paddleseg/models/fast_scnn.py
dygraph/paddleseg/models/fast_scnn.py
+4
-20
dygraph/paddleseg/models/gcnet.py
dygraph/paddleseg/models/gcnet.py
+5
-14
dygraph/paddleseg/models/pspnet.py
dygraph/paddleseg/models/pspnet.py
+0
-8
dygraph/paddleseg/utils/progbar.py
dygraph/paddleseg/utils/progbar.py
+200
-0
未找到文件。
dygraph/paddleseg/core/seg_train.py
0 → 100644
浏览文件 @
40ed988d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
import
paddle.nn.functional
as
F
import
paddleseg.utils.logger
as
logger
from
paddleseg.utils
import
load_pretrained_model
from
paddleseg.utils
import
resume
from
paddleseg.utils
import
Timer
,
calculate_eta
from
paddleseg.core.val
import
evaluate
from
paddleseg.cvlibs
import
callbacks
def
check_logits_losses
(
logits
,
losses
):
len_logits
=
len
(
logits
)
len_losses
=
len
(
losses
[
'types'
])
if
len_logits
!=
len_losses
:
raise
RuntimeError
(
'The length of logits should equal to the types of loss config: {} != {}.'
.
format
(
len_logits
,
len_losses
))
def
loss_computation
(
logits
,
label
,
losses
):
check_logits_losses
(
logits
,
losses
)
loss
=
0
for
i
in
range
(
len
(
logits
)):
logit
=
logits
[
i
]
if
logit
.
shape
[
-
2
:]
!=
label
.
shape
[
-
2
:]:
logit
=
F
.
resize_bilinear
(
logit
,
label
.
shape
[
-
2
:])
loss_i
=
losses
[
'types'
][
i
](
logit
,
label
)
loss
+=
losses
[
'coef'
][
i
]
*
loss_i
return
loss
def
seg_train
(
model
,
train_dataset
,
places
=
None
,
val_dataset
=
None
,
losses
=
None
,
optimizer
=
None
,
save_dir
=
'output'
,
iters
=
10000
,
batch_size
=
2
,
resume_model
=
None
,
save_interval_iters
=
1000
,
log_iters
=
10
,
num_workers
=
8
):
nranks
=
ParallelEnv
().
nranks
start_iter
=
0
if
resume_model
is
not
None
:
start_iter
=
resume
(
model
,
optimizer
,
resume_model
)
if
nranks
>
1
:
strategy
=
fluid
.
dygraph
.
prepare_context
()
ddp_model
=
fluid
.
dygraph
.
DataParallel
(
model
,
strategy
)
batch_sampler
=
DistributedBatchSampler
(
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
True
,
drop_last
=
True
)
loader
=
DataLoader
(
train_dataset
,
batch_sampler
=
batch_sampler
,
places
=
places
,
num_workers
=
num_workers
,
return_list
=
True
,
)
out_labels
=
[
"loss"
,
"reader_cost"
,
"batch_cost"
]
base_logger
=
callbacks
.
BaseLogger
(
period
=
log_iters
)
train_logger
=
callbacks
.
TrainLogger
(
log_freq
=
log_iters
)
model_ckpt
=
callbacks
.
ModelCheckpoint
(
save_dir
,
save_params_only
=
False
,
period
=
save_interval_iters
)
vdl
=
callbacks
.
VisualDL
(
log_dir
=
os
.
path
.
join
(
save_dir
,
"log"
))
cbks_list
=
[
base_logger
,
train_logger
,
model_ckpt
,
vdl
]
cbks
=
callbacks
.
CallbackList
(
cbks_list
)
cbks
.
set_model
(
model
)
cbks
.
set_optimizer
(
optimizer
)
cbks
.
set_params
({
"batch_size"
:
batch_size
,
"total_iters"
:
iters
,
"log_iters"
:
log_iters
,
"verbose"
:
1
,
"do_validation"
:
True
,
"metrics"
:
out_labels
,
"iters_per_epoch"
:
len
(
batch_sampler
)
})
logs
=
{}
logs
=
{
key
:
0.0
for
key
in
out_labels
}
timer
=
Timer
()
timer
.
start
()
############## 1 ################
cbks
.
on_train_begin
(
logs
)
#################################
iter
=
start_iter
while
iter
<
iters
:
for
data
in
loader
:
iter
+=
1
if
iter
>
iters
:
break
logs
[
"reader_cost"
]
=
timer
.
elapsed_time
()
############## 2 ################
cbks
.
on_iter_begin
(
iter
,
logs
)
#################################
images
=
data
[
0
]
labels
=
data
[
1
].
astype
(
'int64'
)
if
nranks
>
1
:
logits
=
ddp_model
(
images
)
loss
=
loss_computation
(
logits
,
labels
,
losses
)
# apply_collective_grads sum grads over multiple gpus.
loss
=
ddp_model
.
scale_loss
(
loss
)
loss
.
backward
()
ddp_model
.
apply_collective_grads
()
else
:
logits
=
model
(
images
)
loss
=
loss_computation
(
logits
,
labels
,
losses
)
loss
.
backward
()
optimizer
.
step
()
optimizer
.
_learning_rate
.
step
()
model
.
clear_gradients
()
logs
[
'loss'
]
=
loss
.
numpy
()[
0
]
logs
[
"batch_cost"
]
=
timer
.
elapsed_time
()
############## 3 ################
cbks
.
on_iter_end
(
iter
,
logs
)
#################################
timer
.
restart
()
############### 4 ###############
cbks
.
on_train_end
(
logs
)
#################################
\ No newline at end of file
dygraph/paddleseg/cvlibs/callbacks.py
0 → 100644
浏览文件 @
40ed988d
# -*- encoding: utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
time
import
numpy
as
np
import
paddle
from
paddle.distributed.parallel
import
ParallelEnv
from
visualdl
import
LogWriter
from
paddleseg.utils.progbar
import
Progbar
import
paddleseg.utils.logger
as
logger
class
CallbackList
(
object
):
"""Container abstracting a list of callbacks.
# Arguments
callbacks: List of `Callback` instances.
"""
def
__init__
(
self
,
callbacks
=
None
):
callbacks
=
callbacks
or
[]
self
.
callbacks
=
[
c
for
c
in
callbacks
]
def
append
(
self
,
callback
):
self
.
callbacks
.
append
(
callback
)
def
set_params
(
self
,
params
):
for
callback
in
self
.
callbacks
:
callback
.
set_params
(
params
)
def
set_model
(
self
,
model
):
for
callback
in
self
.
callbacks
:
callback
.
set_model
(
model
)
def
set_optimizer
(
self
,
optimizer
):
for
callback
in
self
.
callbacks
:
callback
.
set_optimizer
(
optimizer
)
def
on_iter_begin
(
self
,
iter
,
logs
=
None
):
"""Called right before processing a batch.
"""
logs
=
logs
or
{}
for
callback
in
self
.
callbacks
:
callback
.
on_iter_begin
(
iter
,
logs
)
self
.
_t_enter_iter
=
time
.
time
()
def
on_iter_end
(
self
,
iter
,
logs
=
None
):
"""Called at the end of a batch.
"""
logs
=
logs
or
{}
for
callback
in
self
.
callbacks
:
callback
.
on_iter_end
(
iter
,
logs
)
self
.
_t_exit_iter
=
time
.
time
()
def
on_train_begin
(
self
,
logs
=
None
):
"""Called at the beginning of training.
"""
logs
=
logs
or
{}
for
callback
in
self
.
callbacks
:
callback
.
on_train_begin
(
logs
)
def
on_train_end
(
self
,
logs
=
None
):
"""Called at the end of training.
"""
logs
=
logs
or
{}
for
callback
in
self
.
callbacks
:
callback
.
on_train_end
(
logs
)
def
__iter__
(
self
):
return
iter
(
self
.
callbacks
)
class
Callback
(
object
):
"""Abstract base class used to build new callbacks.
"""
def
__init__
(
self
):
self
.
validation_data
=
None
def
set_params
(
self
,
params
):
self
.
params
=
params
def
set_model
(
self
,
model
):
self
.
model
=
model
def
set_optimizer
(
self
,
optimizer
):
self
.
optimizer
=
optimizer
def
on_iter_begin
(
self
,
iter
,
logs
=
None
):
pass
def
on_iter_end
(
self
,
iter
,
logs
=
None
):
pass
def
on_train_begin
(
self
,
logs
=
None
):
pass
def
on_train_end
(
self
,
logs
=
None
):
pass
class
BaseLogger
(
Callback
):
def
__init__
(
self
,
period
=
10
):
super
(
BaseLogger
,
self
).
__init__
()
self
.
period
=
period
def
_reset
(
self
):
self
.
totals
=
{}
def
on_train_begin
(
self
,
logs
=
None
):
self
.
totals
=
{}
def
on_iter_end
(
self
,
iter
,
logs
=
None
):
logs
=
logs
or
{}
#(iter - 1) // iters_per_epoch + 1
for
k
,
v
in
logs
.
items
():
if
k
in
self
.
totals
.
keys
():
self
.
totals
[
k
]
+=
v
else
:
self
.
totals
[
k
]
=
v
if
iter
%
self
.
period
==
0
and
ParallelEnv
().
local_rank
==
0
:
for
k
in
self
.
totals
:
logs
[
k
]
=
self
.
totals
[
k
]
/
self
.
period
self
.
_reset
()
class
TrainLogger
(
Callback
):
def
__init__
(
self
,
log_freq
=
10
):
self
.
log_freq
=
log_freq
def
_calculate_eta
(
self
,
remaining_iters
,
speed
):
if
remaining_iters
<
0
:
remaining_iters
=
0
remaining_time
=
int
(
remaining_iters
*
speed
)
result
=
"{:0>2}:{:0>2}:{:0>2}"
arr
=
[]
for
i
in
range
(
2
,
-
1
,
-
1
):
arr
.
append
(
int
(
remaining_time
/
60
**
i
))
remaining_time
%=
60
**
i
return
result
.
format
(
*
arr
)
def
on_iter_end
(
self
,
iter
,
logs
=
None
):
if
iter
%
self
.
log_freq
==
0
and
ParallelEnv
().
local_rank
==
0
:
total_iters
=
self
.
params
[
"total_iters"
]
iters_per_epoch
=
self
.
params
[
"iters_per_epoch"
]
remaining_iters
=
total_iters
-
iter
eta
=
self
.
_calculate_eta
(
remaining_iters
,
logs
[
"batch_cost"
])
current_epoch
=
(
iter
-
1
)
//
self
.
params
[
"iters_per_epoch"
]
+
1
loss
=
logs
[
"loss"
]
lr
=
self
.
optimizer
.
get_lr
()
batch_cost
=
logs
[
"batch_cost"
]
reader_cost
=
logs
[
"reader_cost"
]
logger
.
info
(
"[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
.
format
(
current_epoch
,
iter
,
total_iters
,
loss
,
lr
,
batch_cost
,
reader_cost
,
eta
))
class
ProgbarLogger
(
Callback
):
def
__init__
(
self
):
super
(
ProgbarLogger
,
self
).
__init__
()
def
on_train_begin
(
self
,
logs
=
None
):
self
.
verbose
=
self
.
params
[
"verbose"
]
self
.
total_iters
=
self
.
params
[
"total_iters"
]
self
.
target
=
self
.
params
[
"total_iters"
]
self
.
progbar
=
Progbar
(
target
=
self
.
target
,
verbose
=
self
.
verbose
)
self
.
seen
=
0
self
.
log_values
=
[]
def
on_iter_begin
(
self
,
iter
,
logs
=
None
):
#self.seen = 0
if
self
.
seen
<
self
.
target
:
self
.
log_values
=
[]
def
on_iter_end
(
self
,
iter
,
logs
=
None
):
logs
=
logs
or
{}
self
.
seen
+=
1
for
k
in
self
.
params
[
'metrics'
]:
if
k
in
logs
:
self
.
log_values
.
append
((
k
,
logs
[
k
]))
#if self.verbose and self.seen < self.target and ParallelEnv.local_rank == 0:
#print(self.log_values)
if
self
.
seen
<
self
.
target
:
self
.
progbar
.
update
(
self
.
seen
,
self
.
log_values
)
class
ModelCheckpoint
(
Callback
):
def
__init__
(
self
,
save_dir
,
monitor
=
"miou"
,
save_best_only
=
False
,
save_params_only
=
True
,
mode
=
"max"
,
period
=
1
):
super
(
ModelCheckpoint
,
self
).
__init__
()
self
.
monitor
=
monitor
self
.
save_dir
=
save_dir
self
.
save_best_only
=
save_best_only
self
.
save_params_only
=
save_params_only
self
.
period
=
period
self
.
iters_since_last_save
=
0
if
mode
==
"min"
:
self
.
monitor_op
=
np
.
less
self
.
best
=
np
.
Inf
elif
mode
==
"max"
:
self
.
monitor_op
=
np
.
greater
self
.
best
=
-
np
.
Inf
else
:
raise
RuntimeError
(
"mode is not either
\"
min
\"
or
\"
max
\"
!"
)
def
on_train_begin
(
self
,
logs
=
None
):
self
.
verbose
=
self
.
params
[
"verbose"
]
save_dir
=
self
.
save_dir
if
not
os
.
path
.
isdir
(
save_dir
):
if
os
.
path
.
exists
(
save_dir
):
os
.
remove
(
save_dir
)
os
.
makedirs
(
save_dir
)
def
on_iter_end
(
self
,
iter
,
logs
=
None
):
logs
=
logs
or
{}
self
.
iters_since_last_save
+=
1
current_save_dir
=
os
.
path
.
join
(
self
.
save_dir
,
"iter_{}"
.
format
(
iter
))
current_save_dir
=
os
.
path
.
abspath
(
current_save_dir
)
#if self.iters_since_last_save % self.period and ParallelEnv().local_rank == 0:
#self.iters_since_last_save = 0
if
iter
%
self
.
period
==
0
and
ParallelEnv
().
local_rank
==
0
:
if
self
.
verbose
>
0
:
print
(
"iter {iter_num}: saving model to {path}"
.
format
(
iter_num
=
iter
,
path
=
current_save_dir
))
filepath
=
os
.
path
.
join
(
current_save_dir
,
'model'
)
paddle
.
save
(
self
.
model
.
state_dict
(),
filepath
)
if
not
self
.
save_params_only
:
paddle
.
save
(
self
.
optimizer
.
state_dict
(),
filepath
)
class
VisualDL
(
Callback
):
def
__init__
(
self
,
log_dir
=
"./log"
,
freq
=
1
):
super
(
VisualDL
,
self
).
__init__
()
self
.
log_dir
=
log_dir
self
.
freq
=
freq
def
on_train_begin
(
self
,
logs
=
None
):
self
.
writer
=
LogWriter
(
self
.
log_dir
)
def
on_iter_end
(
self
,
iter
,
logs
=
None
):
logs
=
logs
or
{}
if
iter
%
self
.
freq
==
0
and
ParallelEnv
().
local_rank
==
0
:
for
k
,
v
in
logs
.
items
():
self
.
writer
.
add_scalar
(
"Train/{}"
.
format
(
k
),
v
,
iter
)
self
.
writer
.
flush
()
def
on_train_end
(
self
,
logs
=
None
):
self
.
writer
.
close
()
\ No newline at end of file
dygraph/paddleseg/models/ann.py
浏览文件 @
40ed988d
...
...
@@ -35,30 +35,20 @@ class ANN(nn.Layer):
It mainly consists of AFNB and APNB modules.
Args:
num_classes (int): the unique number of target classes.
backbone (Paddle.nn.Layer): backbone network, currently support Resnet50/101.
model_pretrained (str): the path of pretrained model. Defaullt to None.
backbone_indices (tuple): two values in the tuple indicte the indices of output of backbone.
the first index will be taken as low-level features; the second one will be
taken as high-level features in AFNB module. Usually backbone consists of four
downsampling stage, and return an output of each stage, so we set default (2, 3),
which means taking feature map of the third stage and the fourth stage in backbone.
the first index will be taken as low-level features; the second one will be
taken as high-level features in AFNB module. Usually backbone consists of four
downsampling stage, and return an output of each stage, so we set default (2, 3),
which means taking feature map of the third stage and the fourth stage in backbone.
backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
key_value_channels (int): the key and value channels of self-attention map in both AFNB and APNB modules.
Default to 256.
Default to 256.
inter_channels (int): both input and output channels of APNB modules.
psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
enable_auxiliary_loss (bool): a bool values indictes whether adding auxiliary loss. Default to True.
"""
def
__init__
(
self
,
...
...
@@ -156,21 +146,13 @@ class AFNB(nn.Layer):
Args:
low_in_channels (int): low-level-feature channels.
high_in_channels (int): high-level-feature channels.
out_channels (int): out channels of AFNB module.
key_channels (int): the key channels in self-attention block.
value_channels (int): the value channels in self-attention block.
dropout_prob (float): the dropout rate of output.
sizes (tuple): the number of AFNB modules. Default to ([1]).
psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
"""
def
__init__
(
self
,
...
...
@@ -214,19 +196,12 @@ class APNB(nn.Layer):
Args:
in_channels (int): the input channels of APNB module.
out_channels (int): out channels of APNB module.
key_channels (int): the key channels in self-attention block.
value_channels (int): the value channels in self-attention block.
dropout_prob (float): the dropout rate of output.
sizes (tuple): the number of AFNB modules. Default to ([1]).
psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
"""
def
__init__
(
self
,
...
...
@@ -279,17 +254,11 @@ class SelfAttentionBlock_AFNB(nn.Layer):
Args:
low_in_channels (int): low-level-feature channels.
high_in_channels (int): high-level-feature channels.
key_channels (int): the key channels in self-attention block.
value_channels (int): the value channels in self-attention block.
out_channels (int): out channels of AFNB module.
scale (int): pooling size. Defaut to 1.
psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
"""
...
...
@@ -366,15 +335,10 @@ class SelfAttentionBlock_APNB(nn.Layer):
Args:
in_channels (int): the input channels of APNB module.
out_channels (int): out channels of APNB module.
key_channels (int): the key channels in self-attention block.
value_channels (int): the value channels in self-attention block.
scale (int): pooling size. Defaut to 1.
psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
"""
...
...
dygraph/paddleseg/models/common/layer_libs.py
浏览文件 @
40ed988d
...
...
@@ -18,7 +18,6 @@ from paddle import nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
Conv2d
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
paddle.nn.layer
import
activation
class
ConvBnRelu
(
nn
.
Layer
):
...
...
@@ -94,11 +93,8 @@ class AuxLayer(nn.Layer):
Args:
in_channels (int): the number of input channels.
inter_channels (int): intermediate channels.
out_channels (int): the number of output channels, which is usually num_classes.
dropout_prob (float): the droput rate. Default to 0.1.
"""
...
...
dygraph/paddleseg/models/common/pyramid_pool.py
浏览文件 @
40ed988d
...
...
@@ -28,15 +28,10 @@ class ASPPModule(nn.Layer):
Args:
aspp_ratios (tuple): the dilation rate using in ASSP module.
in_channels (int): the number of input channels.
out_channels (int): the number of output channels.
sep_conv (bool): if using separable conv in ASPP module.
image_pooling: if augmented with image-level features.
"""
def
__init__
(
self
,
...
...
@@ -106,11 +101,8 @@ class PPModule(nn.Layer):
Args:
in_channels (int): the number of intput channels to pyramid pooling module.
out_channels (int): the number of output channels after pyramid pooling module.
bin_sizes (tuple): the out size of pooled feature maps. Default to (1,2,3,6).
dim_reduction (bool): a bool value represent if reduing dimention after pooling. Default to True.
"""
...
...
@@ -152,7 +144,6 @@ class PPModule(nn.Layer):
Args:
in_channels (int): the number of intput channels to pyramid pooling module.
size (int): the out size of the pooled layer.
Returns:
...
...
dygraph/paddleseg/models/deeplab.py
浏览文件 @
40ed988d
...
...
@@ -38,25 +38,19 @@ class DeepLabV3P(nn.Layer):
Args:
num_classes (int): the unique number of target classes.
backbone (paddle.nn.Layer): backbone network, currently support Xception65, Resnet101_vd.
model_pretrained (str): the path of pretrained model.
aspp_ratios (tuple): the dilation rate using in ASSP module.
if output_stride=16, aspp_ratios should be set as (1, 6, 12, 18).
if output_stride=8, aspp_ratios is (1, 12, 24, 36).
if output_stride=16, aspp_ratios should be set as (1, 6, 12, 18).
if output_stride=8, aspp_ratios is (1, 12, 24, 36).
backbone_indices (tuple): two values in the tuple indicte the indices of output of backbone.
the first index will be taken as a low-level feature in Deconder component;
the second one will be taken as input of ASPP component.
Usually backbone consists of four downsampling stage, and return an output of
each stage, so we set default (0, 3), which means taking feature map of the first
stage in backbone as low-level feature used in Decoder, and feature map of the fourth
stage as input of ASPP.
the first index will be taken as a low-level feature in Deconder component;
the second one will be taken as input of ASPP component.
Usually backbone consists of four downsampling stage, and return an output of
each stage, so we set default (0, 3), which means taking feature map of the first
stage in backbone as low-level feature used in Decoder, and feature map of the fourth
stage as input of ASPP.
backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
"""
def
__init__
(
self
,
...
...
@@ -118,7 +112,6 @@ class DeepLabV3(nn.Layer):
Args:
Refer to DeepLabV3P above
"""
def
__init__
(
self
,
...
...
@@ -178,7 +171,6 @@ class Decoder(nn.Layer):
Args:
num_classes (int): the number of classes.
in_channels (int): the number of input channels in decoder module.
"""
...
...
dygraph/paddleseg/models/fast_scnn.py
浏览文件 @
40ed988d
...
...
@@ -15,7 +15,7 @@
import
paddle.nn.functional
as
F
from
paddle
import
nn
from
paddleseg.cvlibs
import
manager
from
paddleseg.models.common
import
layer_libs
from
paddleseg.models.common
import
layer_libs
,
pyramid_pool
@
manager
.
MODELS
.
add_component
...
...
@@ -33,12 +33,9 @@ class FastSCNN(nn.Layer):
Args:
num_classes (int): the unique number of target classes. Default to 2.
model_pretrained (str): the path of pretrained model. Defaullt to None.
enable_auxiliary_loss (bool): a bool values indictes whether adding auxiliary loss.
if true, auxiliary loss will be added after LearningToDownsample module, where the weight is 0.4. Default to False.
if true, auxiliary loss will be added after LearningToDownsample module, where the weight is 0.4. Default to False.
"""
def
__init__
(
self
,
...
...
@@ -55,7 +52,7 @@ class FastSCNN(nn.Layer):
self
.
classifier
=
Classifier
(
128
,
num_classes
)
if
enable_auxiliary_loss
:
self
.
auxlayer
=
model_util
s
.
AuxLayer
(
64
,
32
,
num_classes
)
self
.
auxlayer
=
layer_lib
s
.
AuxLayer
(
64
,
32
,
num_classes
)
self
.
enable_auxiliary_loss
=
enable_auxiliary_loss
...
...
@@ -101,9 +98,7 @@ class LearningToDownsample(nn.Layer):
Args:
dw_channels1 (int): the input channels of the first sep conv. Default to 32.
dw_channels2 (int): the input channels of the second sep conv. Default to 48.
out_channels (int): the output channels of LearningToDownsample module. Default to 64.
"""
...
...
@@ -141,13 +136,9 @@ class GlobalFeatureExtractor(nn.Layer):
Args:
in_channels (int): the number of input channels to the module. Default to 64.
block_channels (tuple): a tuple represents output channels of each bottleneck block. Default to (64, 96, 128).
out_channels (int): the number of output channels of the module. Default to 128.
expansion (int): the expansion factor in bottleneck. Default to 6.
num_blocks (tuple): it indicates the repeat time of each bottleneck. Default to (3, 3, 3).
"""
...
...
@@ -169,7 +160,7 @@ class GlobalFeatureExtractor(nn.Layer):
block_channels
[
2
],
num_blocks
[
2
],
expansion
,
1
)
self
.
ppm
=
model_utils
.
PPModule
(
self
.
ppm
=
pyramid_pool
.
PPModule
(
block_channels
[
2
],
out_channels
,
dim_reduction
=
True
)
def
_make_layer
(
self
,
...
...
@@ -199,11 +190,8 @@ class LinearBottleneck(nn.Layer):
Args:
in_channels (int): the number of input channels to bottleneck block.
out_channels (int): the number of output channels of bottleneck block.
expansion (int). the expansion factor in bottleneck. Default to 6.
stride (int). the stride used in depth-wise conv.
"""
...
...
@@ -257,9 +245,7 @@ class FeatureFusionModule(nn.Layer):
Args:
high_in_channels (int): the channels of high-resolution feature (output of LearningToDownsample).
low_in_channels (int). the channels of low-resolution feature (output of GlobalFeatureExtractor).
out_channels (int). the output channels of this module.
"""
...
...
@@ -309,9 +295,7 @@ class Classifier(nn.Layer):
Args:
input_channels (int): the input channels to this module.
num_classes (int). the unique number of target classes.
"""
def
__init__
(
self
,
input_channels
,
num_classes
):
...
...
dygraph/paddleseg/models/gcnet.py
浏览文件 @
40ed988d
...
...
@@ -32,28 +32,19 @@ class GCNet(nn.Layer):
(https://arxiv.org/pdf/1904.11492.pdf)
Args:
num_classes (int): the unique number of target classes.
backbone (Paddle.nn.Layer): backbone network, currently support Resnet50/101.
model_pretrained (str): the path of pretrained model. Defaullt to None.
backbone_indices (tuple): two values in the tuple indicte the indices of output of backbone.
the first index will be taken as a deep-supervision feature in auxiliary layer;
the second one will be taken as input of GlobalContextBlock. Usually backbone
consists of four downsampling stage, and return an output of each stage, so we
set default (2, 3), which means taking feature map of the third stage (res4b22)
and the fourth stage (res5c) in backbone.
the first index will be taken as a deep-supervision feature in auxiliary layer;
the second one will be taken as input of GlobalContextBlock. Usually backbone
consists of four downsampling stage, and return an output of each stage, so we
set default (2, 3), which means taking feature map of the third stage (res4b22)
and the fourth stage (res5c) in backbone.
backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
gc_channels (int): input channels to Global Context Block. Default to 512.
ratio (float): it indictes the ratio of attention channels and gc_channels. Default to 1/4.
enable_auxiliary_loss (bool): a bool values indictes whether adding auxiliary loss. Default to True.
"""
def
__init__
(
self
,
...
...
dygraph/paddleseg/models/pspnet.py
浏览文件 @
40ed988d
...
...
@@ -33,26 +33,18 @@ class PSPNet(nn.Layer):
Args:
num_classes (int): the unique number of target classes.
backbone (Paddle.nn.Layer): backbone network, currently support Resnet50/101.
model_pretrained (str): the path of pretrained model. Defaullt to None.
backbone_indices (tuple): two values in the tuple indicte the indices of output of backbone.
the first index will be taken as a deep-supervision feature in auxiliary layer;
the second one will be taken as input of Pyramid Pooling Module (PPModule).
Usually backbone consists of four downsampling stage, and return an output of
each stage, so we set default (2, 3), which means taking feature map of the third
stage (res4b22) in backbone, and feature map of the fourth stage (res5c) as input of PPModule.
backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
pp_out_channels (int): output channels after Pyramid Pooling Module. Default to 1024.
bin_sizes (tuple): the out size of pooled feature maps. Default to (1,2,3,6).
enable_auxiliary_loss (bool): a bool values indictes whether adding auxiliary loss. Default to True.
"""
def
__init__
(
self
,
...
...
dygraph/paddleseg/utils/progbar.py
0 → 100644
浏览文件 @
40ed988d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
time
import
numpy
as
np
class
Progbar
(
object
):
"""Displays a progress bar.
refers to https://github.com/keras-team/keras/blob/keras-2/keras/utils/generic_utils.py
Arguments:
target: Total number of steps expected, None if unknown.
width: Progress bar width on screen.
verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
stateful_metrics: Iterable of string names of metrics that should *not* be
averaged over time. Metrics in this list will be displayed as-is. All
others will be averaged by the progbar before display.
interval: Minimum visual progress update interval (in seconds).
unit_name: Display name for step counts (usually "step" or "sample").
"""
def
__init__
(
self
,
target
,
width
=
30
,
verbose
=
1
,
interval
=
0.05
,
stateful_metrics
=
None
,
unit_name
=
'step'
):
self
.
target
=
target
self
.
width
=
width
self
.
verbose
=
verbose
self
.
interval
=
interval
self
.
unit_name
=
unit_name
if
stateful_metrics
:
self
.
stateful_metrics
=
set
(
stateful_metrics
)
else
:
self
.
stateful_metrics
=
set
()
self
.
_dynamic_display
=
((
hasattr
(
sys
.
stdout
,
'isatty'
)
and
sys
.
stdout
.
isatty
())
or
'ipykernel'
in
sys
.
modules
or
'posix'
in
sys
.
modules
or
'PYCHARM_HOSTED'
in
os
.
environ
)
self
.
_total_width
=
0
self
.
_seen_so_far
=
0
# We use a dict + list to avoid garbage collection
# issues found in OrderedDict
self
.
_values
=
{}
self
.
_values_order
=
[]
self
.
_start
=
time
.
time
()
self
.
_last_update
=
0
def
update
(
self
,
current
,
values
=
None
,
finalize
=
None
):
"""Updates the progress bar.
Arguments:
current: Index of current step.
values: List of tuples: `(name, value_for_last_step)`. If `name` is in
`stateful_metrics`, `value_for_last_step` will be displayed as-is.
Else, an average of the metric over time will be displayed.
finalize: Whether this is the last update for the progress bar. If
`None`, defaults to `current >= self.target`.
"""
if
finalize
is
None
:
if
self
.
target
is
None
:
finalize
=
False
else
:
finalize
=
current
>=
self
.
target
values
=
values
or
[]
for
k
,
v
in
values
:
if
k
not
in
self
.
_values_order
:
self
.
_values_order
.
append
(
k
)
if
k
not
in
self
.
stateful_metrics
:
# In the case that progress bar doesn't have a target value in the first
# epoch, both on_batch_end and on_epoch_end will be called, which will
# cause 'current' and 'self._seen_so_far' to have the same value. Force
# the minimal value to 1 here, otherwise stateful_metric will be 0s.
value_base
=
max
(
current
-
self
.
_seen_so_far
,
1
)
if
k
not
in
self
.
_values
:
self
.
_values
[
k
]
=
[
v
*
value_base
,
value_base
]
else
:
self
.
_values
[
k
][
0
]
+=
v
*
value_base
self
.
_values
[
k
][
1
]
+=
value_base
else
:
# Stateful metrics output a numeric value. This representation
# means "take an average from a single value" but keeps the
# numeric formatting.
self
.
_values
[
k
]
=
[
v
,
1
]
self
.
_seen_so_far
=
current
now
=
time
.
time
()
info
=
' - %.0fs'
%
(
now
-
self
.
_start
)
if
self
.
verbose
==
1
:
if
now
-
self
.
_last_update
<
self
.
interval
and
not
finalize
:
return
prev_total_width
=
self
.
_total_width
if
self
.
_dynamic_display
:
sys
.
stdout
.
write
(
'
\b
'
*
prev_total_width
)
sys
.
stdout
.
write
(
'
\r
'
)
else
:
sys
.
stdout
.
write
(
'
\n
'
)
if
self
.
target
is
not
None
:
numdigits
=
int
(
np
.
log10
(
self
.
target
))
+
1
bar
=
(
'%'
+
str
(
numdigits
)
+
'd/%d ['
)
%
(
current
,
self
.
target
)
prog
=
float
(
current
)
/
self
.
target
prog_width
=
int
(
self
.
width
*
prog
)
if
prog_width
>
0
:
bar
+=
(
'='
*
(
prog_width
-
1
))
if
current
<
self
.
target
:
bar
+=
'>'
else
:
bar
+=
'='
bar
+=
(
'.'
*
(
self
.
width
-
prog_width
))
bar
+=
']'
else
:
bar
=
'%7d/Unknown'
%
current
self
.
_total_width
=
len
(
bar
)
sys
.
stdout
.
write
(
bar
)
if
current
:
time_per_unit
=
(
now
-
self
.
_start
)
/
current
else
:
time_per_unit
=
0
if
self
.
target
is
None
or
finalize
:
if
time_per_unit
>=
1
or
time_per_unit
==
0
:
info
+=
' %.0fs/%s'
%
(
time_per_unit
,
self
.
unit_name
)
elif
time_per_unit
>=
1e-3
:
info
+=
' %.0fms/%s'
%
(
time_per_unit
*
1e3
,
self
.
unit_name
)
else
:
info
+=
' %.0fus/%s'
%
(
time_per_unit
*
1e6
,
self
.
unit_name
)
else
:
eta
=
time_per_unit
*
(
self
.
target
-
current
)
if
eta
>
3600
:
eta_format
=
'%d:%02d:%02d'
%
(
eta
//
3600
,
(
eta
%
3600
)
//
60
,
eta
%
60
)
elif
eta
>
60
:
eta_format
=
'%d:%02d'
%
(
eta
//
60
,
eta
%
60
)
else
:
eta_format
=
'%ds'
%
eta
info
=
' - ETA: %s'
%
eta_format
for
k
in
self
.
_values_order
:
info
+=
' - %s:'
%
k
if
isinstance
(
self
.
_values
[
k
],
list
):
avg
=
np
.
mean
(
self
.
_values
[
k
][
0
]
/
max
(
1
,
self
.
_values
[
k
][
1
]))
if
abs
(
avg
)
>
1e-3
:
info
+=
' %.4f'
%
avg
else
:
info
+=
' %.4e'
%
avg
else
:
info
+=
' %s'
%
self
.
_values
[
k
]
self
.
_total_width
+=
len
(
info
)
if
prev_total_width
>
self
.
_total_width
:
info
+=
(
' '
*
(
prev_total_width
-
self
.
_total_width
))
if
finalize
:
info
+=
'
\n
'
sys
.
stdout
.
write
(
info
)
sys
.
stdout
.
flush
()
elif
self
.
verbose
==
2
:
if
finalize
:
numdigits
=
int
(
np
.
log10
(
self
.
target
))
+
1
count
=
(
'%'
+
str
(
numdigits
)
+
'd/%d'
)
%
(
current
,
self
.
target
)
info
=
count
+
info
for
k
in
self
.
_values_order
:
info
+=
' - %s:'
%
k
avg
=
np
.
mean
(
self
.
_values
[
k
][
0
]
/
max
(
1
,
self
.
_values
[
k
][
1
]))
if
avg
>
1e-3
:
info
+=
' %.4f'
%
avg
else
:
info
+=
' %.4e'
%
avg
info
+=
'
\n
'
sys
.
stdout
.
write
(
info
)
sys
.
stdout
.
flush
()
self
.
_last_update
=
now
def
add
(
self
,
n
,
values
=
None
):
self
.
update
(
self
.
_seen_so_far
+
n
,
values
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录