Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
29a5e832
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
289
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
29a5e832
编写于
6月 01, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update structure
上级
3e90faaa
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
459 addition
and
531 deletion
+459
-531
dygraph/models/__init__.py
dygraph/models/__init__.py
+15
-15
dygraph/models/unet.py
dygraph/models/unet.py
+216
-267
dygraph/nets/__init__.py
dygraph/nets/__init__.py
+0
-15
dygraph/nets/unet.py
dygraph/nets/unet.py
+0
-234
dygraph/train.py
dygraph/train.py
+228
-0
未找到文件。
dygraph/models/__init__.py
浏览文件 @
29a5e832
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.unet
import
UNet
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.unet
import
UNet
dygraph/models/unet.py
浏览文件 @
29a5e832
...
...
@@ -13,101 +13,48 @@
# limitations under the License.
from
__future__
import
absolute_import
import
paddle.fluid
as
fluid
import
os
from
os
import
path
as
osp
import
numpy
as
np
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
OrderedDict
import
copy
import
math
import
time
import
tqdm
import
cv2
import
yaml
import
shutil
from
paddle.fluid.dygraph.base
import
to_variable
import
utils
import
utils.logging
as
logging
from
utils
import
seconds_to_hms
from
utils
import
ConfusionMatrix
from
utils
import
get_environ_info
import
nets
import
transforms
as
T
def
dict2str
(
dict_input
):
out
=
''
for
k
,
v
in
dict_input
.
items
():
try
:
v
=
round
(
float
(
v
),
6
)
except
:
pass
out
=
out
+
'{}={}, '
.
format
(
k
,
v
)
return
out
.
strip
(
', '
)
class
UNet
(
object
):
# DeepLab mobilenet
def
__init__
(
self
,
num_classes
=
2
,
upsample_mode
=
'bilinear'
,
ignore_index
=
255
):
self
.
num_classes
=
num_classes
self
.
upsample_mode
=
upsample_mode
self
.
ignore_index
=
ignore_index
self
.
labels
=
None
self
.
env_info
=
get_environ_info
()
if
self
.
env_info
[
'place'
]
==
'cpu'
:
self
.
places
=
fluid
.
CPUPlace
()
else
:
self
.
places
=
fluid
.
CUDAPlace
(
0
)
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Conv2D
,
BatchNorm
,
Pool2D
import
contextlib
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
)
name_scope
=
""
def
build_model
(
self
):
self
.
model
=
nets
.
UNet
(
self
.
num_classes
,
self
.
upsample_mode
)
@
contextlib
.
contextmanager
def
scope
(
name
):
global
name_scope
bk
=
name_scope
name_scope
=
name_scope
+
name
+
'/'
yield
name_scope
=
bk
def
arrange_transform
(
self
,
transforms
,
mode
=
'train'
):
arrange_transform
=
T
.
ArrangeSegmenter
if
type
(
transforms
.
transforms
[
-
1
]).
__name__
.
startswith
(
'Arrange'
):
transforms
.
transforms
[
-
1
]
=
arrange_transform
(
mode
=
mode
)
class
UNet
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_classes
,
upsample_mode
=
'bilinear'
,
ignore_index
=
255
):
super
().
__init__
()
self
.
encode
=
Encoder
()
self
.
decode
=
Decode
(
upsample_mode
=
upsample_mode
)
self
.
get_logit
=
GetLogit
(
64
,
num_classes
)
self
.
ignore_index
=
ignore_index
def
forward
(
self
,
x
,
label
,
mode
=
'train'
):
encode_data
,
short_cuts
=
self
.
encode
(
x
)
decode_data
=
self
.
decode
(
encode_data
,
short_cuts
)
logit
=
self
.
get_logit
(
decode_data
)
if
mode
==
'train'
:
return
self
.
_get_loss
(
logit
,
label
)
else
:
transforms
.
transforms
.
append
(
arrange_transform
(
mode
=
mode
))
def
load_model
(
self
,
model_dir
):
ckpt_path
=
osp
.
join
(
model_dir
,
'model'
)
para_state_dict
,
opti_state_dict
=
fluid
.
load_dygraph
(
ckpt_path
)
self
.
model
.
set_dict
(
para_state_dict
)
def
save_model
(
self
,
state_dict
,
save_dir
):
if
not
osp
.
isdir
(
save_dir
):
if
osp
.
exists
(
save_dir
):
os
.
remove
(
save_dir
)
os
.
makedirs
(
save_dir
)
fluid
.
save_dygraph
(
state_dict
,
osp
.
join
(
save_dir
,
'model'
))
def
default_optimizer
(
self
,
learning_rate
,
num_epochs
,
num_steps_each_epoch
,
parameter_list
=
None
,
lr_decay_power
=
0.9
,
regularization_coeff
=
4e-5
):
decay_step
=
num_epochs
*
num_steps_each_epoch
lr_decay
=
fluid
.
layers
.
polynomial_decay
(
learning_rate
,
decay_step
,
end_learning_rate
=
0
,
power
=
lr_decay_power
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
lr_decay
,
momentum
=
0.9
,
parameter_list
=
parameter_list
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
regularization_coeff
))
return
optimizer
logit
=
fluid
.
layers
.
softmax
(
logit
,
axis
=
1
)
logit
=
fluid
.
layers
.
transpose
(
logit
,
[
0
,
2
,
3
,
1
])
pred
=
fluid
.
layers
.
argmax
(
logit
,
axis
=
3
)
pred
=
fluid
.
layers
.
unsqueeze
(
pred
,
axes
=
[
3
])
return
pred
,
logit
def
_get_loss
(
self
,
logit
,
label
):
mask
=
label
!=
self
.
ignore_index
...
...
@@ -126,181 +73,183 @@ class UNet(object):
mask
.
stop_gradient
=
True
return
avg_loss
def
train
(
self
,
num_epochs
,
train_dataset
,
train_batch_size
=
2
,
eval_dataset
=
None
,
save_interval_epochs
=
1
,
log_interval_steps
=
2
,
save_dir
=
'output'
,
pretrained_weights
=
None
,
resume_weights
=
None
,
optimizer
=
None
,
learning_rate
=
0.01
,
lr_decay_power
=
0.9
,
regularization_coeff
=
4e-5
,
use_vdl
=
False
):
self
.
labels
=
train_dataset
.
labels
self
.
train_transforms
=
train_dataset
.
transforms
self
.
train_init
=
locals
()
self
.
begin_epoch
=
0
if
optimizer
is
None
:
num_steps_each_epoch
=
train_dataset
.
num_samples
//
train_batch_size
optimizer
=
self
.
default_optimizer
(
learning_rate
=
learning_rate
,
num_epochs
=
num_epochs
,
num_steps_each_epoch
=
num_steps_each_epoch
,
parameter_list
=
self
.
model
.
parameters
(),
lr_decay_power
=
lr_decay_power
,
regularization_coeff
=
regularization_coeff
)
# to do: 预训练模型加载, resume
if
self
.
begin_epoch
>=
num_epochs
:
raise
ValueError
(
(
"begin epoch[{}] is larger than num_epochs[{}]"
).
format
(
self
.
begin_epoch
,
num_epochs
))
if
not
osp
.
isdir
(
save_dir
):
if
osp
.
exists
(
save_dir
):
os
.
remove
(
save_dir
)
os
.
makedirs
(
save_dir
)
# add arrange op to transforms
self
.
arrange_transform
(
transforms
=
train_dataset
.
transforms
,
mode
=
'train'
)
if
eval_dataset
is
not
None
:
self
.
eval_transforms
=
eval_dataset
.
transforms
self
.
test_transforms
=
copy
.
deepcopy
(
eval_dataset
.
transforms
)
data_generator
=
train_dataset
.
generator
(
batch_size
=
train_batch_size
,
drop_last
=
True
)
total_num_steps
=
math
.
floor
(
train_dataset
.
num_samples
/
train_batch_size
)
for
i
in
range
(
self
.
begin_epoch
,
num_epochs
):
for
step
,
data
in
enumerate
(
data_generator
()):
images
=
np
.
array
([
d
[
0
]
for
d
in
data
])
labels
=
np
.
array
([
d
[
1
]
for
d
in
data
]).
astype
(
'int64'
)
images
=
to_variable
(
images
)
labels
=
to_variable
(
labels
)
logit
=
self
.
model
(
images
)
loss
=
self
.
_get_loss
(
logit
,
labels
)
loss
.
backward
()
optimizer
.
minimize
(
loss
)
print
(
"[TRAIN] Epoch={}/{}, Step={}/{}, loss={}"
.
format
(
i
+
1
,
num_epochs
,
step
+
1
,
total_num_steps
,
loss
.
numpy
()))
if
(
i
+
1
)
%
save_interval_epochs
==
0
or
i
==
num_epochs
-
1
:
current_save_dir
=
osp
.
join
(
save_dir
,
"epoch_{}"
.
format
(
i
+
1
))
if
not
osp
.
isdir
(
current_save_dir
):
os
.
makedirs
(
current_save_dir
)
self
.
save_model
(
self
.
model
.
state_dict
(),
current_save_dir
)
if
eval_dataset
is
not
None
:
self
.
model
.
eval
()
self
.
evaluate
(
eval_dataset
,
batch_size
=
train_batch_size
)
self
.
model
.
train
()
def
evaluate
(
self
,
eval_dataset
,
batch_size
=
1
,
epoch_id
=
None
):
"""评估。
Args:
eval_dataset (paddlex.datasets): 评估数据读取器。
batch_size (int): 评估时的batch大小。默认1。
epoch_id (int): 当前评估模型所在的训练轮数。
return_details (bool): 是否返回详细信息。默认False。
Returns:
dict: 当return_details为False时,返回dict。包含关键字:'miou'、'category_iou'、'macc'、
'category_acc'和'kappa',分别表示平均iou、各类别iou、平均准确率、各类别准确率和kappa系数。
tuple (metrics, eval_details):当return_details为True时,增加返回dict (eval_details),
包含关键字:'confusion_matrix',表示评估的混淆矩阵。
"""
self
.
model
.
eval
()
self
.
arrange_transform
(
transforms
=
eval_dataset
.
transforms
,
mode
=
'train'
)
total_steps
=
math
.
ceil
(
eval_dataset
.
num_samples
*
1.0
/
batch_size
)
conf_mat
=
ConfusionMatrix
(
self
.
num_classes
,
streaming
=
True
)
data_generator
=
eval_dataset
.
generator
(
batch_size
=
batch_size
,
drop_last
=
False
)
logging
.
info
(
"Start to evaluating(total_samples={}, total_steps={})..."
.
format
(
eval_dataset
.
num_samples
,
total_steps
))
for
step
,
data
in
tqdm
.
tqdm
(
enumerate
(
data_generator
()),
total
=
total_steps
):
images
=
np
.
array
([
d
[
0
]
for
d
in
data
])
labels
=
np
.
array
([
d
[
1
]
for
d
in
data
])
images
=
to_variable
(
images
)
logit
=
self
.
model
(
images
)
pred
=
fluid
.
layers
.
argmax
(
logit
,
axis
=
1
)
pred
=
fluid
.
layers
.
unsqueeze
(
pred
,
axes
=
[
3
])
pred
=
pred
.
numpy
()
mask
=
labels
!=
self
.
ignore_index
conf_mat
.
calculate
(
pred
=
pred
,
label
=
labels
,
ignore
=
mask
)
_
,
iou
=
conf_mat
.
mean_iou
()
logging
.
debug
(
"[EVAL] Epoch={}, Step={}/{}, iou={}"
.
format
(
epoch_id
,
step
+
1
,
total_steps
,
iou
))
category_iou
,
miou
=
conf_mat
.
mean_iou
()
category_acc
,
macc
=
conf_mat
.
accuracy
()
metrics
=
OrderedDict
(
zip
([
'miou'
,
'category_iou'
,
'macc'
,
'category_acc'
,
'kappa'
],
[
miou
,
category_iou
,
macc
,
category_acc
,
conf_mat
.
kappa
()]))
logging
.
info
(
'[EVAL] Finished, Epoch={}, {} .'
.
format
(
epoch_id
,
dict2str
(
metrics
)))
return
metrics
def
predict
(
self
,
im_file
,
transforms
=
None
):
"""预测。
Args:
img_file(str|np.ndarray): 预测图像。
transforms(paddlex.cv.transforms): 数据预处理操作。
Returns:
dict: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,
像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)
"""
if
isinstance
(
im_file
,
str
):
if
not
osp
.
exists
(
im_file
):
raise
ValueError
(
'The Image file does not exist: {}'
.
format
(
im_file
))
if
transforms
is
None
and
not
hasattr
(
self
,
'test_transforms'
):
raise
Exception
(
"transforms need to be defined, now is None."
)
if
transforms
is
not
None
:
self
.
arrange_transform
(
transforms
=
transforms
,
mode
=
'test'
)
im
,
im_info
=
transforms
(
im_file
)
class
Encoder
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
with
scope
(
'encode'
):
with
scope
(
'block1'
):
self
.
double_conv
=
DoubleConv
(
3
,
64
)
with
scope
(
'block1'
):
self
.
down1
=
Down
(
64
,
128
)
with
scope
(
'block2'
):
self
.
down2
=
Down
(
128
,
256
)
with
scope
(
'block3'
):
self
.
down3
=
Down
(
256
,
512
)
with
scope
(
'block4'
):
self
.
down4
=
Down
(
512
,
512
)
def
forward
(
self
,
x
):
short_cuts
=
[]
x
=
self
.
double_conv
(
x
)
short_cuts
.
append
(
x
)
x
=
self
.
down1
(
x
)
short_cuts
.
append
(
x
)
x
=
self
.
down2
(
x
)
short_cuts
.
append
(
x
)
x
=
self
.
down3
(
x
)
short_cuts
.
append
(
x
)
x
=
self
.
down4
(
x
)
return
x
,
short_cuts
class
Decode
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
upsample_mode
=
'bilinear'
):
super
().
__init__
()
with
scope
(
'decode'
):
with
scope
(
'decode1'
):
self
.
up1
=
Up
(
512
,
256
,
upsample_mode
)
with
scope
(
'decode2'
):
self
.
up2
=
Up
(
256
,
128
,
upsample_mode
)
with
scope
(
'decode3'
):
self
.
up3
=
Up
(
128
,
64
,
upsample_mode
)
with
scope
(
'decode4'
):
self
.
up4
=
Up
(
64
,
64
,
upsample_mode
)
def
forward
(
self
,
x
,
short_cuts
):
x
=
self
.
up1
(
x
,
short_cuts
[
3
])
x
=
self
.
up2
(
x
,
short_cuts
[
2
])
x
=
self
.
up3
(
x
,
short_cuts
[
1
])
x
=
self
.
up4
(
x
,
short_cuts
[
0
])
return
x
class
GetLogit
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
class
DoubleConv
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
):
super
().
__init__
()
with
scope
(
'conv0'
):
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'weights'
,
regularizer
=
regularizer
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.33
))
self
.
conv0
=
Conv2D
(
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
param_attr
=
param_attr
)
self
.
bn0
=
BatchNorm
(
num_channels
=
num_filters
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'gamma'
,
regularizer
=
regularizer
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'beta'
,
regularizer
=
regularizer
),
moving_mean_name
=
name_scope
+
'moving_mean'
,
moving_variance_name
=
name_scope
+
'moving_variance'
)
with
scope
(
'conv1'
):
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'weights'
,
regularizer
=
regularizer
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.33
))
self
.
conv1
=
Conv2D
(
num_channels
=
num_filters
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
param_attr
=
param_attr
)
self
.
bn1
=
BatchNorm
(
num_channels
=
num_filters
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'gamma'
,
regularizer
=
regularizer
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'beta'
,
regularizer
=
regularizer
),
moving_mean_name
=
name_scope
+
'moving_mean'
,
moving_variance_name
=
name_scope
+
'moving_variance'
)
def
forward
(
self
,
x
):
x
=
self
.
conv0
(
x
)
x
=
self
.
bn0
(
x
)
x
=
fluid
.
layers
.
relu
(
x
)
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
fluid
.
layers
.
relu
(
x
)
return
x
class
Down
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
):
super
().
__init__
()
with
scope
(
"down"
):
self
.
max_pool
=
Pool2D
(
pool_size
=
2
,
pool_type
=
'max'
,
pool_stride
=
2
,
pool_padding
=
0
)
self
.
double_conv
=
DoubleConv
(
num_channels
,
num_filters
)
def
forward
(
self
,
x
):
x
=
self
.
max_pool
(
x
)
x
=
self
.
double_conv
(
x
)
return
x
class
Up
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
upsample_mode
):
super
().
__init__
()
self
.
upsample_mode
=
upsample_mode
with
scope
(
'up'
):
if
upsample_mode
==
'bilinear'
:
self
.
double_conv
=
DoubleConv
(
2
*
num_channels
,
num_filters
)
if
not
upsample_mode
==
'bilinear'
:
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'weights'
,
regularizer
=
regularizer
,
initializer
=
fluid
.
initializer
.
XavierInitializer
(),
)
self
.
deconv
=
fluid
.
dygraph
.
Conv2DTranspose
(
num_channels
=
num_channels
,
num_filters
=
num_filters
//
2
,
filter_size
=
2
,
stride
=
2
,
padding
=
0
,
param_attr
=
param_attr
)
self
.
double_conv
=
DoubleConv
(
num_channels
+
num_filters
//
2
,
num_filters
)
def
forward
(
self
,
x
,
short_cut
):
if
self
.
upsample_mode
==
'bilinear'
:
short_cut_shape
=
fluid
.
layers
.
shape
(
short_cut
)
x
=
fluid
.
layers
.
resize_bilinear
(
x
,
short_cut_shape
[
2
:])
else
:
self
.
arrange_transform
(
transforms
=
self
.
test_transforms
,
mode
=
'test'
)
im
,
im_info
=
self
.
test_transforms
(
im_file
)
im
=
np
.
expand_dims
(
im
,
axis
=
0
)
im
=
to_variable
(
im
)
logit
=
self
.
model
(
im
)
logit
=
fluid
.
layers
.
softmax
(
logit
)
pred
=
fluid
.
layers
.
argmax
(
logit
,
axis
=
1
)
logit
=
logit
.
numpy
()
pred
=
pred
.
numpy
()
logit
=
np
.
squeeze
(
logit
)
logit
=
np
.
transpose
(
logit
,
(
1
,
2
,
0
))
pred
=
np
.
squeeze
(
pred
).
astype
(
'uint8'
)
keys
=
list
(
im_info
.
keys
())
print
(
pred
.
shape
,
logit
.
shape
)
for
k
in
keys
[::
-
1
]:
if
k
==
'shape_before_resize'
:
h
,
w
=
im_info
[
k
][
0
],
im_info
[
k
][
1
]
pred
=
cv2
.
resize
(
pred
,
(
w
,
h
),
cv2
.
INTER_NEAREST
)
logit
=
cv2
.
resize
(
logit
,
(
w
,
h
),
cv2
.
INTER_LINEAR
)
elif
k
==
'shape_before_padding'
:
h
,
w
=
im_info
[
k
][
0
],
im_info
[
k
][
1
]
pred
=
pred
[
0
:
h
,
0
:
w
]
logit
=
logit
[
0
:
h
,
0
:
w
,
:]
return
{
'label_map'
:
pred
,
'score_map'
:
logit
}
x
=
self
.
deconv
(
x
)
x
=
fluid
.
layers
.
concat
([
x
,
short_cut
],
axis
=
1
)
x
=
self
.
double_conv
(
x
)
return
x
class
GetLogit
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_classes
):
super
().
__init__
()
with
scope
(
'logit'
):
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'weights'
,
regularizer
=
regularizer
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.01
)
)
self
.
conv
=
Conv2D
(
num_channels
=
num_channels
,
num_filters
=
num_classes
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
param_attr
=
param_attr
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
x
dygraph/nets/__init__.py
已删除
100644 → 0
浏览文件 @
3e90faaa
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.unet
import
UNet
dygraph/nets/unet.py
已删除
100644 → 0
浏览文件 @
3e90faaa
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
OrderedDict
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Conv2D
,
BatchNorm
,
Pool2D
import
contextlib
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
)
name_scope
=
""
@
contextlib
.
contextmanager
def
scope
(
name
):
global
name_scope
bk
=
name_scope
name_scope
=
name_scope
+
name
+
'/'
yield
name_scope
=
bk
class
UNet
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_classes
,
upsample_mode
=
'bilinear'
,
):
super
().
__init__
()
self
.
encode
=
Encoder
()
self
.
decode
=
Decode
(
upsample_mode
=
upsample_mode
)
self
.
get_logit
=
GetLogit
(
64
,
num_classes
)
def
forward
(
self
,
x
):
encode_data
,
short_cuts
=
self
.
encode
(
x
)
decode_data
=
self
.
decode
(
encode_data
,
short_cuts
)
logit
=
self
.
get_logit
(
decode_data
)
return
logit
class
Encoder
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
with
scope
(
'encode'
):
with
scope
(
'block1'
):
self
.
double_conv
=
DoubleConv
(
3
,
64
)
with
scope
(
'block1'
):
self
.
down1
=
Down
(
64
,
128
)
with
scope
(
'block2'
):
self
.
down2
=
Down
(
128
,
256
)
with
scope
(
'block3'
):
self
.
down3
=
Down
(
256
,
512
)
with
scope
(
'block4'
):
self
.
down4
=
Down
(
512
,
512
)
def
forward
(
self
,
x
):
short_cuts
=
[]
x
=
self
.
double_conv
(
x
)
short_cuts
.
append
(
x
)
x
=
self
.
down1
(
x
)
short_cuts
.
append
(
x
)
x
=
self
.
down2
(
x
)
short_cuts
.
append
(
x
)
x
=
self
.
down3
(
x
)
short_cuts
.
append
(
x
)
x
=
self
.
down4
(
x
)
return
x
,
short_cuts
class
Decode
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
upsample_mode
=
'bilinear'
):
super
().
__init__
()
with
scope
(
'decode'
):
with
scope
(
'decode1'
):
self
.
up1
=
Up
(
512
,
256
,
upsample_mode
)
with
scope
(
'decode2'
):
self
.
up2
=
Up
(
256
,
128
,
upsample_mode
)
with
scope
(
'decode3'
):
self
.
up3
=
Up
(
128
,
64
,
upsample_mode
)
with
scope
(
'decode4'
):
self
.
up4
=
Up
(
64
,
64
,
upsample_mode
)
def
forward
(
self
,
x
,
short_cuts
):
x
=
self
.
up1
(
x
,
short_cuts
[
3
])
x
=
self
.
up2
(
x
,
short_cuts
[
2
])
x
=
self
.
up3
(
x
,
short_cuts
[
1
])
x
=
self
.
up4
(
x
,
short_cuts
[
0
])
return
x
class
GetLogit
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
class
DoubleConv
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
):
super
().
__init__
()
with
scope
(
'conv0'
):
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'weights'
,
regularizer
=
regularizer
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.33
))
self
.
conv0
=
Conv2D
(
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
param_attr
=
param_attr
)
self
.
bn0
=
BatchNorm
(
num_channels
=
num_filters
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'gamma'
,
regularizer
=
regularizer
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'beta'
,
regularizer
=
regularizer
),
moving_mean_name
=
name_scope
+
'moving_mean'
,
moving_variance_name
=
name_scope
+
'moving_variance'
)
with
scope
(
'conv1'
):
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'weights'
,
regularizer
=
regularizer
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.33
))
self
.
conv1
=
Conv2D
(
num_channels
=
num_filters
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
param_attr
=
param_attr
)
self
.
bn1
=
BatchNorm
(
num_channels
=
num_filters
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'gamma'
,
regularizer
=
regularizer
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'beta'
,
regularizer
=
regularizer
),
moving_mean_name
=
name_scope
+
'moving_mean'
,
moving_variance_name
=
name_scope
+
'moving_variance'
)
def
forward
(
self
,
x
):
x
=
self
.
conv0
(
x
)
x
=
self
.
bn0
(
x
)
x
=
fluid
.
layers
.
relu
(
x
)
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
fluid
.
layers
.
relu
(
x
)
return
x
class
Down
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
):
super
().
__init__
()
with
scope
(
"down"
):
self
.
max_pool
=
Pool2D
(
pool_size
=
2
,
pool_type
=
'max'
,
pool_stride
=
2
,
pool_padding
=
0
)
self
.
double_conv
=
DoubleConv
(
num_channels
,
num_filters
)
def
forward
(
self
,
x
):
x
=
self
.
max_pool
(
x
)
x
=
self
.
double_conv
(
x
)
return
x
class
Up
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
upsample_mode
):
super
().
__init__
()
self
.
upsample_mode
=
upsample_mode
with
scope
(
'up'
):
if
upsample_mode
==
'bilinear'
:
self
.
double_conv
=
DoubleConv
(
2
*
num_channels
,
num_filters
)
if
not
upsample_mode
==
'bilinear'
:
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'weights'
,
regularizer
=
regularizer
,
initializer
=
fluid
.
initializer
.
XavierInitializer
(),
)
self
.
deconv
=
fluid
.
dygraph
.
Conv2DTranspose
(
num_channels
=
num_channels
,
num_filters
=
num_filters
//
2
,
filter_size
=
2
,
stride
=
2
,
padding
=
0
,
param_attr
=
param_attr
)
self
.
double_conv
=
DoubleConv
(
num_channels
+
num_filters
//
2
,
num_filters
)
def
forward
(
self
,
x
,
short_cut
):
if
self
.
upsample_mode
==
'bilinear'
:
short_cut_shape
=
fluid
.
layers
.
shape
(
short_cut
)
x
=
fluid
.
layers
.
resize_bilinear
(
x
,
short_cut_shape
[
2
:])
else
:
x
=
self
.
deconv
(
x
)
x
=
fluid
.
layers
.
concat
([
x
,
short_cut
],
axis
=
1
)
x
=
self
.
double_conv
(
x
)
return
x
class
GetLogit
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_classes
):
super
().
__init__
()
with
scope
(
'logit'
):
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'weights'
,
regularizer
=
regularizer
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.01
))
self
.
conv
=
Conv2D
(
num_channels
=
num_channels
,
num_filters
=
num_classes
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
param_attr
=
param_attr
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
x
dygraph/train.py
0 → 100644
浏览文件 @
29a5e832
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
os.path
as
osp
from
paddle.fluid.dygraph.base
import
to_variable
import
numpy
as
np
import
paddle.fluid
as
fluid
from
datasets.dataset
import
Dataset
import
transforms
as
T
import
models
import
utils.logging
as
logging
from
utils
import
get_environ_info
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Model training'
)
# params of model
parser
.
add_argument
(
'--model_name'
,
dest
=
'model_name'
,
help
=
"Model type for traing, which is one of ('UNet')"
,
type
=
str
,
default
=
'UNet'
)
# params of dataset
parser
.
add_argument
(
'--data_dir'
,
dest
=
'data_dir'
,
help
=
'The root directory of dataset'
,
type
=
str
)
parser
.
add_argument
(
'--train_list'
,
dest
=
'train_list'
,
help
=
'Train list file of dataset'
,
type
=
str
)
parser
.
add_argument
(
'--val_list'
,
dest
=
'val_list'
,
help
=
'Val list file of dataset'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--num_classes'
,
dest
=
'num_classes'
,
help
=
'Number of classes'
,
type
=
int
,
default
=
2
)
# params of training
parser
.
add_argument
(
"--input_size"
,
dest
=
"input_size"
,
help
=
"The image size for net inputs."
,
nargs
=
2
,
default
=
[
512
,
512
],
type
=
int
)
parser
.
add_argument
(
'--num_epochs'
,
dest
=
'num_epochs'
,
help
=
'Number epochs for training'
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
'--batch_size'
,
dest
=
'batch_size'
,
help
=
'Mini batch size'
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
'--learning_rate'
,
dest
=
'learning_rate'
,
help
=
'Learning rate'
,
type
=
float
,
default
=
0.01
)
parser
.
add_argument
(
'--pretrained_model'
,
dest
=
'pretrained_model'
,
help
=
'The path of pretrianed weight'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--save_interval_epochs'
,
dest
=
'save_interval_epochs'
,
help
=
'The interval epochs for save a model snapshot'
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
'--save_dir'
,
dest
=
'save_dir'
,
help
=
'The directory for saving the model snapshot'
,
type
=
str
,
default
=
'./output'
)
return
parser
.
parse_args
()
def
train
(
model
,
train_dataset
,
eval_dataset
=
None
,
optimizer
=
None
,
save_dir
=
'output'
,
num_epochs
=
100
,
batch_size
=
2
,
pretrained_model
=
None
,
save_interval_epochs
=
1
):
if
not
osp
.
isdir
(
save_dir
):
if
osp
.
exists
(
save_dir
):
os
.
remove
(
save_dir
)
os
.
makedirs
(
save_dir
)
data_generator
=
train_dataset
.
generator
(
batch_size
=
batch_size
,
drop_last
=
True
)
num_steps_each_epoch
=
train_dataset
.
num_samples
//
args
.
batch_size
for
epoch
in
range
(
num_epochs
):
for
step
,
data
in
enumerate
(
data_generator
()):
images
=
np
.
array
([
d
[
0
]
for
d
in
data
])
labels
=
np
.
array
([
d
[
1
]
for
d
in
data
]).
astype
(
'int64'
)
images
=
to_variable
(
images
)
labels
=
to_variable
(
labels
)
loss
=
model
(
images
,
labels
,
mode
=
'train'
)
loss
.
backward
()
optimizer
.
minimize
(
loss
)
logging
.
info
(
"[TRAIN] Epoch={}/{}, Step={}/{}, loss={}"
.
format
(
epoch
+
1
,
num_epochs
,
step
+
1
,
num_steps_each_epoch
,
loss
.
numpy
()))
if
(
epoch
+
1
)
%
save_interval_epochs
==
0
or
num_steps_each_epoch
==
num_epochs
-
1
:
current_save_dir
=
osp
.
join
(
save_dir
,
"epoch_{}"
.
format
(
epoch
+
1
))
if
not
osp
.
isdir
(
current_save_dir
):
os
.
makedirs
(
current_save_dir
)
fluid
.
save_dygraph
(
model
.
state_dict
(),
osp
.
join
(
current_save_dir
,
'model'
))
# if eval_dataset is not None:
# model.eval()
# evaluate(eval_dataset, batch_size=train_batch_size)
# model.train()
def
arrange_transform
(
transforms
,
mode
=
'train'
):
arrange_transform
=
T
.
ArrangeSegmenter
if
type
(
transforms
.
transforms
[
-
1
]).
__name__
.
startswith
(
'Arrange'
):
transforms
.
transforms
[
-
1
]
=
arrange_transform
(
mode
=
mode
)
else
:
transforms
.
transforms
.
append
(
arrange_transform
(
mode
=
mode
))
def
main
(
args
):
# Creat dataset reader
train_transforms
=
T
.
Compose
(
[
T
.
Resize
(
args
.
input_size
),
T
.
RandomHorizontalFlip
(),
T
.
Normalize
()])
arrange_transform
(
train_transforms
,
mode
=
'train'
)
train_dataset
=
Dataset
(
data_dir
=
args
.
data_dir
,
file_list
=
args
.
train_list
,
transforms
=
train_transforms
,
num_workers
=
'auto'
,
buffer_size
=
100
,
parallel_method
=
'thread'
,
shuffle
=
True
)
if
args
.
val_list
is
not
None
:
eval_transforms
=
T
.
Compose
([
T
.
Resize
(
args
.
input_size
),
T
.
Normalize
()])
arrange_transform
(
train_transforms
,
mode
=
'eval'
)
eval_dataset
=
Dataset
(
data_dir
=
args
.
data_dir
,
file_list
=
args
.
val_list
,
transforms
=
eval_transforms
,
num_workers
=
'auto'
,
buffer_size
=
100
,
parallel_method
=
'thread'
,
shuffle
=
False
)
if
args
.
model_name
==
'UNet'
:
model
=
models
.
UNet
(
num_classes
=
args
.
num_classes
)
# Creat optimizer
num_steps_each_epoch
=
train_dataset
.
num_samples
//
args
.
batch_size
decay_step
=
args
.
num_epochs
*
num_steps_each_epoch
lr_decay
=
fluid
.
layers
.
polynomial_decay
(
args
.
learning_rate
,
decay_step
,
end_learning_rate
=
0
,
power
=
0.9
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
lr_decay
,
momentum
=
0.9
,
parameter_list
=
model
.
parameters
(),
regularization
=
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
4e-5
))
train
(
model
,
train_dataset
,
eval_dataset
,
optimizer
,
save_dir
=
args
.
save_dir
,
num_epochs
=
args
.
num_epochs
,
batch_size
=
args
.
batch_size
,
pretrained_model
=
args
.
pretrained_model
,
save_interval_epochs
=
args
.
save_interval_epochs
)
if
__name__
==
'__main__'
:
args
=
parse_args
()
env_info
=
get_environ_info
()
if
env_info
[
'place'
]
==
'cpu'
:
places
=
fluid
.
CPUPlace
()
else
:
places
=
fluid
.
CUDAPlace
(
0
)
with
fluid
.
dygraph
.
guard
(
places
):
main
(
args
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录