Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
ca0448fd
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ca0448fd
编写于
9月 21, 2020
作者:
W
wuzewu
提交者:
GitHub
9月 21, 2020
浏览文件
操作
浏览文件
下载
差异文件
update api to 2.0-beta
上级
0e666bfe
4231b3f7
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
186 addition
and
258 deletion
+186
-258
dygraph/configs/fcn_hrnet/fcn_hrnetw18_cityscapes_1024x512_100k.yml
...nfigs/fcn_hrnet/fcn_hrnetw18_cityscapes_1024x512_100k.yml
+3
-0
dygraph/configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_40k.yml
...onfigs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_40k.yml
+4
-35
dygraph/paddleseg/core/train.py
dygraph/paddleseg/core/train.py
+25
-18
dygraph/paddleseg/models/backbones/resnet_vd.py
dygraph/paddleseg/models/backbones/resnet_vd.py
+65
-117
dygraph/paddleseg/models/fcn.py
dygraph/paddleseg/models/fcn.py
+26
-28
dygraph/paddleseg/utils/config.py
dygraph/paddleseg/utils/config.py
+10
-6
dygraph/train.py
dygraph/train.py
+35
-36
dygraph/val.py
dygraph/val.py
+18
-18
未找到文件。
dygraph/configs/fcn_hrnet/fcn_hrnetw18_cityscapes_1024x512_100k.yml
浏览文件 @
ca0448fd
...
@@ -7,3 +7,6 @@ model:
...
@@ -7,3 +7,6 @@ model:
num_classes
:
19
num_classes
:
19
backbone_channels
:
[
270
]
backbone_channels
:
[
270
]
backbone_pretrained
:
pretrained_model/hrnet_w18_imagenet
backbone_pretrained
:
pretrained_model/hrnet_w18_imagenet
optimizer
:
weight_decay
:
0.0005
dygraph/configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_40k.yml
浏览文件 @
ca0448fd
_base_
:
'
../_base_/cityscapes.yml'
batch_size
:
2
batch_size
:
2
iters
:
40000
iters
:
40000
train_dataset
:
type
:
Cityscapes
dataset_root
:
data/cityscapes
transforms
:
-
type
:
ResizeStepScaling
min_scale_factor
:
0.5
max_scale_factor
:
2.0
scale_step_size
:
0.25
-
type
:
RandomPaddingCrop
crop_size
:
[
1024
,
512
]
-
type
:
RandomHorizontalFlip
-
type
:
Normalize
mode
:
train
val_dataset
:
type
:
Cityscapes
dataset_root
:
data/cityscapes
transforms
:
-
type
:
Normalize
mode
:
val
model
:
model
:
type
:
OCRNet
type
:
OCRNet
backbone
:
backbone
:
type
:
HRNet_W18
type
:
HRNet_W18
backbone_pretrianed
:
None
num_classes
:
19
num_classes
:
19
in_channels
:
270
backbone_channels
:
[
270
]
backbone_pretrained
:
pretrained_model/hrnet_w18_imagenet
model_pretrained
:
None
model_pretrained
:
None
optimizer
:
type
:
sgd
learning_rate
:
value
:
0.01
decay
:
type
:
poly
power
:
0.9
loss
:
type
:
CrossEntropy
dygraph/paddleseg/core/train.py
浏览文件 @
ca0448fd
...
@@ -15,11 +15,10 @@
...
@@ -15,11 +15,10 @@
import
os
import
os
import
paddle
import
paddle
import
paddle.fluid
as
fluid
from
paddle.distributed
import
ParallelEnv
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.distributed
import
init_parallel_env
from
paddle.fluid.io
import
DataLoader
# from paddle.incubate.hapi.distributed import DistributedBatchSampler
from
paddle.io
import
DistributedBatchSampler
from
paddle.io
import
DistributedBatchSampler
from
paddle.io
import
DataLoader
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
import
paddleseg.utils.logger
as
logger
import
paddleseg.utils.logger
as
logger
...
@@ -79,11 +78,14 @@ def train(model,
...
@@ -79,11 +78,14 @@ def train(model,
os
.
makedirs
(
save_dir
)
os
.
makedirs
(
save_dir
)
if
nranks
>
1
:
if
nranks
>
1
:
strategy
=
fluid
.
dygraph
.
prepare_context
()
# Initialize parallel training environment.
ddp_model
=
fluid
.
dygraph
.
DataParallel
(
model
,
strategy
)
init_parallel_env
()
strategy
=
paddle
.
distributed
.
prepare_context
()
ddp_model
=
paddle
.
DataParallel
(
model
,
strategy
)
batch_sampler
=
DistributedBatchSampler
(
batch_sampler
=
DistributedBatchSampler
(
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
True
,
drop_last
=
True
)
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
True
,
drop_last
=
True
)
loader
=
DataLoader
(
loader
=
DataLoader
(
train_dataset
,
train_dataset
,
batch_sampler
=
batch_sampler
,
batch_sampler
=
batch_sampler
,
...
@@ -117,7 +119,6 @@ def train(model,
...
@@ -117,7 +119,6 @@ def train(model,
if
nranks
>
1
:
if
nranks
>
1
:
logits
=
ddp_model
(
images
)
logits
=
ddp_model
(
images
)
loss
=
loss_computation
(
logits
,
labels
,
losses
)
loss
=
loss_computation
(
logits
,
labels
,
losses
)
# loss = ddp_model(images, labels)
# apply_collective_grads sum grads over multiple gpus.
# apply_collective_grads sum grads over multiple gpus.
loss
=
ddp_model
.
scale_loss
(
loss
)
loss
=
ddp_model
.
scale_loss
(
loss
)
loss
.
backward
()
loss
.
backward
()
...
@@ -127,10 +128,17 @@ def train(model,
...
@@ -127,10 +128,17 @@ def train(model,
loss
=
loss_computation
(
logits
,
labels
,
losses
)
loss
=
loss_computation
(
logits
,
labels
,
losses
)
# loss = model(images, labels)
# loss = model(images, labels)
loss
.
backward
()
loss
.
backward
()
optimizer
.
minimize
(
loss
)
# optimizer.minimize(loss)
optimizer
.
step
()
if
isinstance
(
optimizer
.
_learning_rate
,
paddle
.
optimizer
.
_LRScheduler
):
optimizer
.
_learning_rate
.
step
()
model
.
clear_gradients
()
model
.
clear_gradients
()
# Sum loss over all ranks
if
nranks
>
1
:
paddle
.
distributed
.
all_reduce
(
loss
)
avg_loss
+=
loss
.
numpy
()[
0
]
avg_loss
+=
loss
.
numpy
()[
0
]
lr
=
optimizer
.
current_step
_lr
()
lr
=
optimizer
.
get
_lr
()
train_batch_cost
+=
timer
.
elapsed_time
()
train_batch_cost
+=
timer
.
elapsed_time
()
if
(
iter
)
%
log_iters
==
0
and
ParallelEnv
().
local_rank
==
0
:
if
(
iter
)
%
log_iters
==
0
and
ParallelEnv
().
local_rank
==
0
:
avg_loss
/=
log_iters
avg_loss
/=
log_iters
...
@@ -143,10 +151,10 @@ def train(model,
...
@@ -143,10 +151,10 @@ def train(model,
logger
.
info
(
logger
.
info
(
"[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
"[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
.
format
((
iter
-
1
)
//
iters_per_epoch
+
1
,
iter
,
iters
,
.
format
((
iter
-
1
)
//
iters_per_epoch
+
1
,
iter
,
iters
,
avg_loss
*
nranks
,
lr
,
avg_train_batch_cost
,
avg_loss
,
lr
,
avg_train_batch_cost
,
avg_train_reader_cost
,
eta
))
avg_train_reader_cost
,
eta
))
if
use_vdl
:
if
use_vdl
:
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
*
nranks
,
iter
)
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
iter
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
,
iter
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
,
iter
)
log_writer
.
add_scalar
(
'Train/batch_cost'
,
log_writer
.
add_scalar
(
'Train/batch_cost'
,
avg_train_batch_cost
,
iter
)
avg_train_batch_cost
,
iter
)
...
@@ -160,10 +168,10 @@ def train(model,
...
@@ -160,10 +168,10 @@ def train(model,
"iter_{}"
.
format
(
iter
))
"iter_{}"
.
format
(
iter
))
if
not
os
.
path
.
isdir
(
current_save_dir
):
if
not
os
.
path
.
isdir
(
current_save_dir
):
os
.
makedirs
(
current_save_dir
)
os
.
makedirs
(
current_save_dir
)
fluid
.
save_dygraph
(
model
.
state_dict
(),
paddle
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
current_save_dir
,
'model'
))
os
.
path
.
join
(
current_save_dir
,
'model'
))
fluid
.
save_dygraph
(
optimizer
.
state_dict
(),
paddle
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
current_save_dir
,
'model'
))
os
.
path
.
join
(
current_save_dir
,
'model'
))
if
eval_dataset
is
not
None
:
if
eval_dataset
is
not
None
:
mean_iou
,
avg_acc
=
evaluate
(
mean_iou
,
avg_acc
=
evaluate
(
...
@@ -177,9 +185,8 @@ def train(model,
...
@@ -177,9 +185,8 @@ def train(model,
best_mean_iou
=
mean_iou
best_mean_iou
=
mean_iou
best_model_iter
=
iter
best_model_iter
=
iter
best_model_dir
=
os
.
path
.
join
(
save_dir
,
"best_model"
)
best_model_dir
=
os
.
path
.
join
(
save_dir
,
"best_model"
)
fluid
.
save_dygraph
(
paddle
.
save
(
model
.
state_dict
(),
model
.
state_dict
(),
os
.
path
.
join
(
best_model_dir
,
'model'
))
os
.
path
.
join
(
best_model_dir
,
'model'
))
logger
.
info
(
logger
.
info
(
'Current evaluated best model in eval_dataset is iter_{}, miou={:4f}'
'Current evaluated best model in eval_dataset is iter_{}, miou={:4f}'
.
format
(
best_model_iter
,
best_mean_iou
))
.
format
(
best_model_iter
,
best_mean_iou
))
...
...
dygraph/paddleseg/models/backbones/resnet_vd.py
浏览文件 @
ca0448fd
...
@@ -21,11 +21,11 @@ import math
...
@@ -21,11 +21,11 @@ import math
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.nn
as
nn
from
paddle.fluid.param_attr
import
ParamAttr
import
paddle.nn.functional
as
F
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
,
Dropout
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
paddle.nn
import
Conv2d
,
Linear
,
Dropout
from
paddle.nn
import
AdaptiveAvgPool2d
,
MaxPool2d
,
AvgPool2d
from
paddleseg.utils
import
utils
from
paddleseg.utils
import
utils
from
paddleseg.models.common
import
layer_libs
,
activation
from
paddleseg.models.common
import
layer_libs
,
activation
...
@@ -36,12 +36,12 @@ __all__ = [
...
@@ -36,12 +36,12 @@ __all__ = [
]
]
class
ConvBNLayer
(
fluid
.
dygraph
.
Layer
):
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
def
__init__
(
self
,
self
,
num
_channels
,
in
_channels
,
num_filter
s
,
out_channel
s
,
filter
_size
,
kernel
_size
,
stride
=
1
,
stride
=
1
,
dilation
=
1
,
dilation
=
1
,
groups
=
1
,
groups
=
1
,
...
@@ -52,31 +52,22 @@ class ConvBNLayer(fluid.dygraph.Layer):
...
@@ -52,31 +52,22 @@ class ConvBNLayer(fluid.dygraph.Layer):
super
(
ConvBNLayer
,
self
).
__init__
()
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
is_vd_mode
=
is_vd_mode
self
.
is_vd_mode
=
is_vd_mode
self
.
_pool2d_avg
=
Pool2D
(
self
.
_pool2d_avg
=
AvgPool2d
(
pool_size
=
2
,
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
ceil_mode
=
True
)
pool_stride
=
2
,
self
.
_conv
=
Conv2d
(
pool_padding
=
0
,
in_channels
=
in_channels
,
pool_type
=
'avg'
,
out_channels
=
out_channels
,
ceil_mode
=
True
)
kernel_size
=
kernel_size
,
self
.
_conv
=
Conv2D
(
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
stride
=
stride
,
padding
=
(
filter
_size
-
1
)
//
2
if
dilation
==
1
else
0
,
padding
=
(
kernel
_size
-
1
)
//
2
if
dilation
==
1
else
0
,
dilation
=
dilation
,
dilation
=
dilation
,
groups
=
groups
,
groups
=
groups
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
"_weights"
),
bias_attr
=
False
)
bias_attr
=
False
)
if
name
==
"conv1"
:
if
name
==
"conv1"
:
bn_name
=
"bn_"
+
name
bn_name
=
"bn_"
+
name
else
:
else
:
bn_name
=
"bn"
+
name
[
3
:]
bn_name
=
"bn"
+
name
[
3
:]
self
.
_batch_norm
=
BatchNorm
(
self
.
_batch_norm
=
BatchNorm
(
out_channels
)
num_filters
,
weight_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
))
self
.
_act_op
=
activation
.
Activation
(
act
=
act
)
self
.
_act_op
=
activation
.
Activation
(
act
=
act
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
...
@@ -89,10 +80,10 @@ class ConvBNLayer(fluid.dygraph.Layer):
...
@@ -89,10 +80,10 @@ class ConvBNLayer(fluid.dygraph.Layer):
return
y
return
y
class
BottleneckBlock
(
fluid
.
dygraph
.
Layer
):
class
BottleneckBlock
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
num
_channels
,
in
_channels
,
num_filter
s
,
out_channel
s
,
stride
,
stride
,
shortcut
=
True
,
shortcut
=
True
,
if_first
=
False
,
if_first
=
False
,
...
@@ -101,34 +92,34 @@ class BottleneckBlock(fluid.dygraph.Layer):
...
@@ -101,34 +92,34 @@ class BottleneckBlock(fluid.dygraph.Layer):
super
(
BottleneckBlock
,
self
).
__init__
()
super
(
BottleneckBlock
,
self
).
__init__
()
self
.
conv0
=
ConvBNLayer
(
self
.
conv0
=
ConvBNLayer
(
num_channels
=
num
_channels
,
in_channels
=
in
_channels
,
num_filters
=
num_filter
s
,
out_channels
=
out_channel
s
,
filter
_size
=
1
,
kernel
_size
=
1
,
act
=
'relu'
,
act
=
'relu'
,
name
=
name
+
"_branch2a"
)
name
=
name
+
"_branch2a"
)
self
.
dilation
=
dilation
self
.
dilation
=
dilation
self
.
conv1
=
ConvBNLayer
(
self
.
conv1
=
ConvBNLayer
(
num_channels
=
num_filter
s
,
in_channels
=
out_channel
s
,
num_filters
=
num_filter
s
,
out_channels
=
out_channel
s
,
filter
_size
=
3
,
kernel
_size
=
3
,
stride
=
stride
,
stride
=
stride
,
act
=
'relu'
,
act
=
'relu'
,
dilation
=
dilation
,
dilation
=
dilation
,
name
=
name
+
"_branch2b"
)
name
=
name
+
"_branch2b"
)
self
.
conv2
=
ConvBNLayer
(
self
.
conv2
=
ConvBNLayer
(
num_channels
=
num_filter
s
,
in_channels
=
out_channel
s
,
num_filters
=
num_filter
s
*
4
,
out_channels
=
out_channel
s
*
4
,
filter
_size
=
1
,
kernel
_size
=
1
,
act
=
None
,
act
=
None
,
name
=
name
+
"_branch2c"
)
name
=
name
+
"_branch2c"
)
if
not
shortcut
:
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
self
.
short
=
ConvBNLayer
(
num_channels
=
num
_channels
,
in_channels
=
in
_channels
,
num_filters
=
num_filter
s
*
4
,
out_channels
=
out_channel
s
*
4
,
filter
_size
=
1
,
kernel
_size
=
1
,
stride
=
1
,
stride
=
1
,
is_vd_mode
=
False
if
if_first
or
stride
==
1
else
True
,
is_vd_mode
=
False
if
if_first
or
stride
==
1
else
True
,
name
=
name
+
"_branch1"
)
name
=
name
+
"_branch1"
)
...
@@ -142,8 +133,7 @@ class BottleneckBlock(fluid.dygraph.Layer):
...
@@ -142,8 +133,7 @@ class BottleneckBlock(fluid.dygraph.Layer):
# If given dilation rate > 1, using corresponding padding
# If given dilation rate > 1, using corresponding padding
if
self
.
dilation
>
1
:
if
self
.
dilation
>
1
:
padding
=
self
.
dilation
padding
=
self
.
dilation
y
=
fluid
.
layers
.
pad
(
y
=
F
.
pad
(
y
,
[
0
,
0
,
0
,
0
,
padding
,
padding
,
padding
,
padding
])
y
,
[
0
,
0
,
0
,
0
,
padding
,
padding
,
padding
,
padding
])
#####################################################################
#####################################################################
conv1
=
self
.
conv1
(
y
)
conv1
=
self
.
conv1
(
y
)
conv2
=
self
.
conv2
(
conv1
)
conv2
=
self
.
conv2
(
conv1
)
...
@@ -153,15 +143,14 @@ class BottleneckBlock(fluid.dygraph.Layer):
...
@@ -153,15 +143,14 @@ class BottleneckBlock(fluid.dygraph.Layer):
else
:
else
:
short
=
self
.
short
(
inputs
)
short
=
self
.
short
(
inputs
)
y
=
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
conv2
)
y
=
paddle
.
elementwise_add
(
x
=
short
,
y
=
conv2
,
act
=
'relu'
)
layer_helper
=
LayerHelper
(
self
.
full_name
(),
act
=
'relu'
)
return
y
return
layer_helper
.
append_activation
(
y
)
class
BasicBlock
(
fluid
.
dygraph
.
Layer
):
class
BasicBlock
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
num
_channels
,
in
_channels
,
num_filter
s
,
out_channel
s
,
stride
,
stride
,
shortcut
=
True
,
shortcut
=
True
,
if_first
=
False
,
if_first
=
False
,
...
@@ -169,24 +158,24 @@ class BasicBlock(fluid.dygraph.Layer):
...
@@ -169,24 +158,24 @@ class BasicBlock(fluid.dygraph.Layer):
super
(
BasicBlock
,
self
).
__init__
()
super
(
BasicBlock
,
self
).
__init__
()
self
.
stride
=
stride
self
.
stride
=
stride
self
.
conv0
=
ConvBNLayer
(
self
.
conv0
=
ConvBNLayer
(
num_channels
=
num
_channels
,
in_channels
=
in
_channels
,
num_filters
=
num_filter
s
,
out_channels
=
out_channel
s
,
filter
_size
=
3
,
kernel
_size
=
3
,
stride
=
stride
,
stride
=
stride
,
act
=
'relu'
,
act
=
'relu'
,
name
=
name
+
"_branch2a"
)
name
=
name
+
"_branch2a"
)
self
.
conv1
=
ConvBNLayer
(
self
.
conv1
=
ConvBNLayer
(
num_channels
=
num_filter
s
,
in_channels
=
out_channel
s
,
num_filters
=
num_filter
s
,
out_channels
=
out_channel
s
,
filter
_size
=
3
,
kernel
_size
=
3
,
act
=
None
,
act
=
None
,
name
=
name
+
"_branch2b"
)
name
=
name
+
"_branch2b"
)
if
not
shortcut
:
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
self
.
short
=
ConvBNLayer
(
num_channels
=
num
_channels
,
in_channels
=
in
_channels
,
num_filters
=
num_filter
s
,
out_channels
=
out_channel
s
,
filter
_size
=
1
,
kernel
_size
=
1
,
stride
=
1
,
stride
=
1
,
is_vd_mode
=
False
if
if_first
else
True
,
is_vd_mode
=
False
if
if_first
else
True
,
name
=
name
+
"_branch1"
)
name
=
name
+
"_branch1"
)
...
@@ -201,13 +190,12 @@ class BasicBlock(fluid.dygraph.Layer):
...
@@ -201,13 +190,12 @@ class BasicBlock(fluid.dygraph.Layer):
short
=
inputs
short
=
inputs
else
:
else
:
short
=
self
.
short
(
inputs
)
short
=
self
.
short
(
inputs
)
y
=
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
conv1
)
y
=
paddle
.
elementwise_add
(
x
=
short
,
y
=
conv1
,
act
=
'relu'
)
layer_helper
=
LayerHelper
(
self
.
full_name
(),
act
=
'relu'
)
return
y
return
layer_helper
.
append_activation
(
y
)
class
ResNet_vd
(
fluid
.
dygraph
.
Layer
):
class
ResNet_vd
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
backbone_pretrained
=
None
,
backbone_pretrained
=
None
,
layers
=
50
,
layers
=
50
,
...
@@ -243,28 +231,27 @@ class ResNet_vd(fluid.dygraph.Layer):
...
@@ -243,28 +231,27 @@ class ResNet_vd(fluid.dygraph.Layer):
dilation_dict
=
{
3
:
2
}
dilation_dict
=
{
3
:
2
}
self
.
conv1_1
=
ConvBNLayer
(
self
.
conv1_1
=
ConvBNLayer
(
num
_channels
=
3
,
in
_channels
=
3
,
num_filter
s
=
32
,
out_channel
s
=
32
,
filter
_size
=
3
,
kernel
_size
=
3
,
stride
=
2
,
stride
=
2
,
act
=
'relu'
,
act
=
'relu'
,
name
=
"conv1_1"
)
name
=
"conv1_1"
)
self
.
conv1_2
=
ConvBNLayer
(
self
.
conv1_2
=
ConvBNLayer
(
num
_channels
=
32
,
in
_channels
=
32
,
num_filter
s
=
32
,
out_channel
s
=
32
,
filter
_size
=
3
,
kernel
_size
=
3
,
stride
=
1
,
stride
=
1
,
act
=
'relu'
,
act
=
'relu'
,
name
=
"conv1_2"
)
name
=
"conv1_2"
)
self
.
conv1_3
=
ConvBNLayer
(
self
.
conv1_3
=
ConvBNLayer
(
num
_channels
=
32
,
in
_channels
=
32
,
num_filter
s
=
64
,
out_channel
s
=
64
,
filter
_size
=
3
,
kernel
_size
=
3
,
stride
=
1
,
stride
=
1
,
act
=
'relu'
,
act
=
'relu'
,
name
=
"conv1_3"
)
name
=
"conv1_3"
)
self
.
pool2d_max
=
Pool2D
(
self
.
pool2d_max
=
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
1
,
pool_type
=
'max'
)
# self.block_list = []
# self.block_list = []
self
.
stage_list
=
[]
self
.
stage_list
=
[]
...
@@ -296,9 +283,9 @@ class ResNet_vd(fluid.dygraph.Layer):
...
@@ -296,9 +283,9 @@ class ResNet_vd(fluid.dygraph.Layer):
bottleneck_block
=
self
.
add_sublayer
(
bottleneck_block
=
self
.
add_sublayer
(
'bb_%d_%d'
%
(
block
,
i
),
'bb_%d_%d'
%
(
block
,
i
),
BottleneckBlock
(
BottleneckBlock
(
num
_channels
=
num_channels
[
block
]
in
_channels
=
num_channels
[
block
]
if
i
==
0
else
num_filters
[
block
]
*
4
,
if
i
==
0
else
num_filters
[
block
]
*
4
,
num_filter
s
=
num_filters
[
block
],
out_channel
s
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
stride
=
2
if
i
==
0
and
block
!=
0
and
dilation_rate
==
1
else
1
,
and
dilation_rate
==
1
else
1
,
shortcut
=
shortcut
,
shortcut
=
shortcut
,
...
@@ -318,9 +305,9 @@ class ResNet_vd(fluid.dygraph.Layer):
...
@@ -318,9 +305,9 @@ class ResNet_vd(fluid.dygraph.Layer):
basic_block
=
self
.
add_sublayer
(
basic_block
=
self
.
add_sublayer
(
'bb_%d_%d'
%
(
block
,
i
),
'bb_%d_%d'
%
(
block
,
i
),
BasicBlock
(
BasicBlock
(
num
_channels
=
num_channels
[
block
]
in
_channels
=
num_channels
[
block
]
if
i
==
0
else
num_filters
[
block
],
if
i
==
0
else
num_filters
[
block
],
num_filter
s
=
num_filters
[
block
],
out_channel
s
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
shortcut
=
shortcut
,
shortcut
=
shortcut
,
if_first
=
block
==
i
==
0
,
if_first
=
block
==
i
==
0
,
...
@@ -329,23 +316,6 @@ class ResNet_vd(fluid.dygraph.Layer):
...
@@ -329,23 +316,6 @@ class ResNet_vd(fluid.dygraph.Layer):
shortcut
=
True
shortcut
=
True
self
.
stage_list
.
append
(
block_list
)
self
.
stage_list
.
append
(
block_list
)
self
.
pool2d_avg
=
Pool2D
(
pool_size
=
7
,
pool_type
=
'avg'
,
global_pooling
=
True
)
self
.
pool2d_avg_channels
=
num_channels
[
-
1
]
*
2
stdv
=
1.0
/
math
.
sqrt
(
self
.
pool2d_avg_channels
*
1.0
)
self
.
out
=
Linear
(
self
.
pool2d_avg_channels
,
class_dim
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Uniform
(
-
stdv
,
stdv
),
name
=
"fc_0.w_0"
),
bias_attr
=
ParamAttr
(
name
=
"fc_0.b_0"
))
self
.
init_weight
(
backbone_pretrained
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
y
=
self
.
conv1_1
(
inputs
)
y
=
self
.
conv1_1
(
inputs
)
y
=
self
.
conv1_2
(
y
)
y
=
self
.
conv1_2
(
y
)
...
@@ -357,34 +327,12 @@ class ResNet_vd(fluid.dygraph.Layer):
...
@@ -357,34 +327,12 @@ class ResNet_vd(fluid.dygraph.Layer):
for
i
,
stage
in
enumerate
(
self
.
stage_list
):
for
i
,
stage
in
enumerate
(
self
.
stage_list
):
for
j
,
block
in
enumerate
(
stage
):
for
j
,
block
in
enumerate
(
stage
):
y
=
block
(
y
)
y
=
block
(
y
)
#print("stage {} block {}".format(i+1, j+1), y.shape)
feat_list
.
append
(
y
)
feat_list
.
append
(
y
)
y
=
self
.
pool2d_avg
(
y
)
return
feat_list
y
=
fluid
.
layers
.
reshape
(
y
,
shape
=
[
-
1
,
self
.
pool2d_avg_channels
])
y
=
self
.
out
(
y
)
return
y
,
feat_list
# def init_weight(self, pretrained_model=None):
# if pretrained_model is not None:
# if os.path.exists(pretrained_model):
# utils.load_pretrained_model(self, pretrained_model)
def
init_weight
(
self
,
pretrained_model
=
None
):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if
pretrained_model
is
not
None
:
if
os
.
path
.
exists
(
pretrained_model
):
utils
.
load_pretrained_model
(
self
,
pretrained_model
)
else
:
raise
Exception
(
'Pretrained model is not found: {}'
.
format
(
pretrained_model
))
@
manager
.
BACKBONES
.
add_component
def
ResNet18_vd
(
**
args
):
def
ResNet18_vd
(
**
args
):
model
=
ResNet_vd
(
layers
=
18
,
**
args
)
model
=
ResNet_vd
(
layers
=
18
,
**
args
)
return
model
return
model
...
...
dygraph/paddleseg/models/fcn.py
浏览文件 @
ca0448fd
...
@@ -16,17 +16,16 @@ import math
...
@@ -16,17 +16,16 @@ import math
import
os
import
os
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.nn
as
nn
from
paddle.fluid.param_attr
import
ParamAttr
import
paddle.nn.functional
as
F
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.nn
import
Conv2d
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.fluid.initializer
import
Normal
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
paddleseg.cvlibs
import
manager
from
paddleseg.cvlibs
import
manager
from
paddleseg
import
utils
from
paddleseg
import
utils
from
paddleseg.cvlibs
import
param_init
from
paddleseg.cvlibs
import
param_init
from
paddleseg.utils
import
logger
from
paddleseg.utils
import
logger
from
paddleseg.models.common
import
layer_libs
,
activation
__all__
=
[
__all__
=
[
"fcn_hrnet_w18_small_v1"
,
"fcn_hrnet_w18_small_v2"
,
"fcn_hrnet_w18"
,
"fcn_hrnet_w18_small_v1"
,
"fcn_hrnet_w18_small_v2"
,
"fcn_hrnet_w18"
,
...
@@ -36,7 +35,7 @@ __all__ = [
...
@@ -36,7 +35,7 @@ __all__ = [
@
manager
.
MODELS
.
add_component
@
manager
.
MODELS
.
add_component
class
FCN
(
fluid
.
dygraph
.
Layer
):
class
FCN
(
nn
.
Layer
):
"""
"""
Fully Convolutional Networks for Semantic Segmentation.
Fully Convolutional Networks for Semantic Segmentation.
https://arxiv.org/abs/1411.4038
https://arxiv.org/abs/1411.4038
...
@@ -70,18 +69,18 @@ class FCN(fluid.dygraph.Layer):
...
@@ -70,18 +69,18 @@ class FCN(fluid.dygraph.Layer):
self
.
model_pretrained
=
model_pretrained
self
.
model_pretrained
=
model_pretrained
self
.
backbone_indices
=
backbone_indices
self
.
backbone_indices
=
backbone_indices
if
channels
is
None
:
if
channels
is
None
:
channels
=
backbone_channels
[
backbone_indices
[
0
]
]
channels
=
backbone_channels
[
0
]
self
.
backbone
=
backbone
self
.
backbone
=
backbone
self
.
conv_last_2
=
ConvBNLayer
(
self
.
conv_last_2
=
ConvBNLayer
(
num_channels
=
backbone_channels
[
backbone_indices
[
0
]
],
in_channels
=
backbone_channels
[
0
],
num_filter
s
=
channels
,
out_channel
s
=
channels
,
filter
_size
=
1
,
kernel
_size
=
1
,
stride
=
1
)
stride
=
1
)
self
.
conv_last_1
=
Conv2
D
(
self
.
conv_last_1
=
Conv2
d
(
num
_channels
=
channels
,
in
_channels
=
channels
,
num_filter
s
=
self
.
num_classes
,
out_channel
s
=
self
.
num_classes
,
filter
_size
=
1
,
kernel
_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
)
padding
=
0
)
if
self
.
training
:
if
self
.
training
:
...
@@ -93,7 +92,7 @@ class FCN(fluid.dygraph.Layer):
...
@@ -93,7 +92,7 @@ class FCN(fluid.dygraph.Layer):
x
=
fea_list
[
self
.
backbone_indices
[
0
]]
x
=
fea_list
[
self
.
backbone_indices
[
0
]]
x
=
self
.
conv_last_2
(
x
)
x
=
self
.
conv_last_2
(
x
)
logit
=
self
.
conv_last_1
(
x
)
logit
=
self
.
conv_last_1
(
x
)
logit
=
fluid
.
layers
.
resize_bilinear
(
logit
,
input_shape
)
logit
=
F
.
resize_bilinear
(
logit
,
input_shape
)
return
[
logit
]
return
[
logit
]
def
init_weight
(
self
):
def
init_weight
(
self
):
...
@@ -125,32 +124,31 @@ class FCN(fluid.dygraph.Layer):
...
@@ -125,32 +124,31 @@ class FCN(fluid.dygraph.Layer):
logger
.
warning
(
'No pretrained model to load, train from scratch'
)
logger
.
warning
(
'No pretrained model to load, train from scratch'
)
class
ConvBNLayer
(
fluid
.
dygraph
.
Layer
):
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
num
_channels
,
in
_channels
,
num_filter
s
,
out_channel
s
,
filter
_size
,
kernel
_size
,
stride
=
1
,
stride
=
1
,
groups
=
1
,
groups
=
1
,
act
=
"relu"
):
act
=
"relu"
):
super
(
ConvBNLayer
,
self
).
__init__
()
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
_conv
=
Conv2
D
(
self
.
_conv
=
Conv2
d
(
num_channels
=
num
_channels
,
in_channels
=
in
_channels
,
num_filters
=
num_filter
s
,
out_channels
=
out_channel
s
,
filter_size
=
filter
_size
,
kernel_size
=
kernel
_size
,
stride
=
stride
,
stride
=
stride
,
padding
=
(
filter
_size
-
1
)
//
2
,
padding
=
(
kernel
_size
-
1
)
//
2
,
groups
=
groups
,
groups
=
groups
,
bias_attr
=
False
)
bias_attr
=
False
)
self
.
_batch_norm
=
BatchNorm
(
num_filter
s
)
self
.
_batch_norm
=
BatchNorm
(
out_channel
s
)
self
.
act
=
act
self
.
act
=
act
ivation
.
Activation
(
act
=
act
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
y
=
self
.
_conv
(
input
)
y
=
self
.
_conv
(
input
)
y
=
self
.
_batch_norm
(
y
)
y
=
self
.
_batch_norm
(
y
)
if
self
.
act
==
'relu'
:
y
=
self
.
act
(
y
)
y
=
fluid
.
layers
.
relu
(
y
)
return
y
return
y
...
...
dygraph/paddleseg/utils/config.py
浏览文件 @
ca0448fd
...
@@ -15,11 +15,14 @@
...
@@ -15,11 +15,14 @@
import
codecs
import
codecs
import
os
import
os
from
typing
import
Any
,
Callable
from
typing
import
Any
,
Callable
import
pprint
import
yaml
import
yaml
import
paddle.fluid
as
fluid
import
paddle
import
paddle.nn.functional
as
F
import
paddleseg.cvlibs.manager
as
manager
import
paddleseg.cvlibs.manager
as
manager
from
paddleseg.utils
import
logger
class
Config
(
object
):
class
Config
(
object
):
...
@@ -36,7 +39,7 @@ class Config(object):
...
@@ -36,7 +39,7 @@ class Config(object):
if
path
.
endswith
(
'yml'
)
or
path
.
endswith
(
'yaml'
):
if
path
.
endswith
(
'yml'
)
or
path
.
endswith
(
'yaml'
):
dic
=
self
.
_parse_from_yaml
(
path
)
dic
=
self
.
_parse_from_yaml
(
path
)
print
(
dic
)
logger
.
info
(
'
\n
'
+
pprint
.
pformat
(
dic
)
)
self
.
_build
(
dic
)
self
.
_build
(
dic
)
else
:
else
:
raise
RuntimeError
(
'Config file should in yaml format!'
)
raise
RuntimeError
(
'Config file should in yaml format!'
)
...
@@ -127,18 +130,19 @@ class Config(object):
...
@@ -127,18 +130,19 @@ class Config(object):
lr
=
self
.
_learning_rate
lr
=
self
.
_learning_rate
args
=
self
.
decay_args
args
=
self
.
decay_args
args
.
setdefault
(
'decay_steps'
,
self
.
iters
)
args
.
setdefault
(
'decay_steps'
,
self
.
iters
)
return
fluid
.
layers
.
polynomial_decay
(
lr
,
**
args
)
args
.
setdefault
(
'end_lr'
,
0
)
return
paddle
.
optimizer
.
PolynomialLR
(
lr
,
**
args
)
else
:
else
:
raise
RuntimeError
(
'Only poly decay support.'
)
raise
RuntimeError
(
'Only poly decay support.'
)
@
property
@
property
def
optimizer
(
self
)
->
fluid
.
optimizer
.
Optimizer
:
def
optimizer
(
self
)
->
paddle
.
optimizer
.
Optimizer
:
if
self
.
optimizer_type
==
'sgd'
:
if
self
.
optimizer_type
==
'sgd'
:
lr
=
self
.
learning_rate
lr
=
self
.
learning_rate
args
=
self
.
optimizer_args
args
=
self
.
optimizer_args
args
.
setdefault
(
'momentum'
,
0.9
)
args
.
setdefault
(
'momentum'
,
0.9
)
return
fluid
.
optimizer
.
Momentum
(
return
paddle
.
optimizer
.
Momentum
(
lr
,
parameter
_list
=
self
.
model
.
parameters
(),
**
args
)
lr
,
parameter
s
=
self
.
model
.
parameters
(),
**
args
)
else
:
else
:
raise
RuntimeError
(
'Only sgd optimizer support.'
)
raise
RuntimeError
(
'Only sgd optimizer support.'
)
...
...
dygraph/train.py
浏览文件 @
ca0448fd
...
@@ -14,8 +14,8 @@
...
@@ -14,8 +14,8 @@
import
argparse
import
argparse
import
paddle
.fluid
as
fluid
import
paddle
from
paddle.
fluid.dygraph.parallel
import
ParallelEnv
from
paddle.
distributed
import
ParallelEnv
import
paddleseg
import
paddleseg
from
paddleseg.cvlibs
import
manager
from
paddleseg.cvlibs
import
manager
...
@@ -91,41 +91,40 @@ def main(args):
...
@@ -91,41 +91,40 @@ def main(args):
[
'-'
*
48
])
[
'-'
*
48
])
logger
.
info
(
info
)
logger
.
info
(
info
)
places
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
places
=
paddle
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
if
env_info
[
'Paddle compiled with cuda'
]
and
env_info
[
'GPUs used'
]
\
if
env_info
[
'Paddle compiled with cuda'
]
and
env_info
[
'GPUs used'
]
\
else
fluid
.
CPUPlace
()
else
paddle
.
CPUPlace
()
with
fluid
.
dygraph
.
guard
(
places
):
paddle
.
disable_static
(
places
)
if
not
args
.
cfg
:
if
not
args
.
cfg
:
raise
RuntimeError
(
'No configuration file specified.'
)
raise
RuntimeError
(
'No configuration file specified.'
)
cfg
=
Config
(
args
.
cfg
)
cfg
=
Config
(
args
.
cfg
)
train_dataset
=
cfg
.
train_dataset
train_dataset
=
cfg
.
train_dataset
if
not
train_dataset
:
if
not
train_dataset
:
raise
RuntimeError
(
raise
RuntimeError
(
'The training dataset is not specified in the configuration file.'
'The training dataset is not specified in the configuration file.'
)
)
val_dataset
=
cfg
.
val_dataset
if
args
.
do_eval
else
None
val_dataset
=
cfg
.
val_dataset
if
args
.
do_eval
else
None
losses
=
cfg
.
loss
losses
=
cfg
.
loss
train
(
train
(
cfg
.
model
,
cfg
.
model
,
train_dataset
,
train_dataset
,
places
=
places
,
places
=
places
,
eval_dataset
=
val_dataset
,
eval_dataset
=
val_dataset
,
optimizer
=
cfg
.
optimizer
,
optimizer
=
cfg
.
optimizer
,
save_dir
=
args
.
save_dir
,
save_dir
=
args
.
save_dir
,
iters
=
cfg
.
iters
,
iters
=
cfg
.
iters
,
batch_size
=
cfg
.
batch_size
,
batch_size
=
cfg
.
batch_size
,
save_interval_iters
=
args
.
save_interval_iters
,
save_interval_iters
=
args
.
save_interval_iters
,
log_iters
=
args
.
log_iters
,
log_iters
=
args
.
log_iters
,
num_classes
=
train_dataset
.
num_classes
,
num_classes
=
train_dataset
.
num_classes
,
num_workers
=
args
.
num_workers
,
num_workers
=
args
.
num_workers
,
use_vdl
=
args
.
use_vdl
,
use_vdl
=
args
.
use_vdl
,
losses
=
losses
,
losses
=
losses
,
ignore_index
=
losses
[
'types'
][
0
].
ignore_index
)
ignore_index
=
losses
[
'types'
][
0
].
ignore_index
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
dygraph/val.py
浏览文件 @
ca0448fd
...
@@ -14,8 +14,8 @@
...
@@ -14,8 +14,8 @@
import
argparse
import
argparse
import
paddle
.fluid
as
fluid
import
paddle
from
paddle.
fluid.dygraph.parallel
import
ParallelEnv
from
paddle.
distributed
import
ParallelEnv
import
paddleseg
import
paddleseg
from
paddleseg.cvlibs
import
manager
from
paddleseg.cvlibs
import
manager
...
@@ -41,25 +41,25 @@ def parse_args():
...
@@ -41,25 +41,25 @@ def parse_args():
def
main
(
args
):
def
main
(
args
):
env_info
=
get_environ_info
()
env_info
=
get_environ_info
()
places
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
places
=
paddle
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
if
env_info
[
'Paddle compiled with cuda'
]
and
env_info
[
'GPUs used'
]
\
if
env_info
[
'Paddle compiled with cuda'
]
and
env_info
[
'GPUs used'
]
\
else
fluid
.
CPUPlace
()
else
paddle
.
CPUPlace
()
with
fluid
.
dygraph
.
guard
(
places
):
paddle
.
disable_static
(
places
)
if
not
args
.
cfg
:
if
not
args
.
cfg
:
raise
RuntimeError
(
'No configuration file specified.'
)
raise
RuntimeError
(
'No configuration file specified.'
)
cfg
=
Config
(
args
.
cfg
)
cfg
=
Config
(
args
.
cfg
)
val_dataset
=
cfg
.
val_dataset
val_dataset
=
cfg
.
val_dataset
if
not
val_dataset
:
if
not
val_dataset
:
raise
RuntimeError
(
raise
RuntimeError
(
'The verification dataset is not specified in the configuration file.'
'The verification dataset is not specified in the configuration file.'
)
)
evaluate
(
evaluate
(
cfg
.
model
,
cfg
.
model
,
val_dataset
,
val_dataset
,
model_dir
=
args
.
model_dir
,
model_dir
=
args
.
model_dir
,
num_classes
=
val_dataset
.
num_classes
)
num_classes
=
val_dataset
.
num_classes
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录