Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
29a5e832
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
289
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
29a5e832
编写于
6月 01, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update structure
上级
3e90faaa
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
459 addition
and
531 deletion
+459
-531
dygraph/models/__init__.py
dygraph/models/__init__.py
+15
-15
dygraph/models/unet.py
dygraph/models/unet.py
+216
-267
dygraph/nets/__init__.py
dygraph/nets/__init__.py
+0
-15
dygraph/nets/unet.py
dygraph/nets/unet.py
+0
-234
dygraph/train.py
dygraph/train.py
+228
-0
未找到文件。
dygraph/models/__init__.py
浏览文件 @
29a5e832
dygraph/models/unet.py
浏览文件 @
29a5e832
...
@@ -13,101 +13,48 @@
...
@@ -13,101 +13,48 @@
# limitations under the License.
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
paddle.fluid
as
fluid
from
__future__
import
division
import
os
from
__future__
import
print_function
from
os
import
path
as
osp
import
numpy
as
np
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
copy
import
math
import
time
import
tqdm
import
cv2
import
yaml
import
shutil
from
paddle.fluid.dygraph.base
import
to_variable
import
utils
import
utils.logging
as
logging
from
utils
import
seconds_to_hms
from
utils
import
ConfusionMatrix
from
utils
import
get_environ_info
import
nets
import
transforms
as
T
def
dict2str
(
dict_input
):
out
=
''
for
k
,
v
in
dict_input
.
items
():
try
:
v
=
round
(
float
(
v
),
6
)
except
:
pass
out
=
out
+
'{}={}, '
.
format
(
k
,
v
)
return
out
.
strip
(
', '
)
class
UNet
(
object
):
# DeepLab mobilenet
def
__init__
(
self
,
num_classes
=
2
,
upsample_mode
=
'bilinear'
,
ignore_index
=
255
):
self
.
num_classes
=
num_classes
self
.
upsample_mode
=
upsample_mode
self
.
ignore_index
=
ignore_index
self
.
labels
=
None
import
paddle.fluid
as
fluid
self
.
env_info
=
get_environ_info
()
from
paddle.fluid.dygraph
import
Conv2D
,
BatchNorm
,
Pool2D
if
self
.
env_info
[
'place'
]
==
'cpu'
:
import
contextlib
self
.
places
=
fluid
.
CPUPlace
()
else
:
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
)
self
.
places
=
fluid
.
CUDAPlace
(
0
)
name_scope
=
""
def
build_model
(
self
):
@
contextlib
.
contextmanager
self
.
model
=
nets
.
UNet
(
self
.
num_classes
,
self
.
upsample_mode
)
def
scope
(
name
):
global
name_scope
bk
=
name_scope
name_scope
=
name_scope
+
name
+
'/'
yield
name_scope
=
bk
def
arrange_transform
(
self
,
transforms
,
mode
=
'train'
):
arrange_transform
=
T
.
ArrangeSegmenter
class
UNet
(
fluid
.
dygraph
.
Layer
):
if
type
(
transforms
.
transforms
[
-
1
]).
__name__
.
startswith
(
'Arrange'
):
def
__init__
(
self
,
num_classes
,
upsample_mode
=
'bilinear'
,
ignore_index
=
255
):
transforms
.
transforms
[
-
1
]
=
arrange_transform
(
mode
=
mode
)
super
().
__init__
()
self
.
encode
=
Encoder
()
self
.
decode
=
Decode
(
upsample_mode
=
upsample_mode
)
self
.
get_logit
=
GetLogit
(
64
,
num_classes
)
self
.
ignore_index
=
ignore_index
def
forward
(
self
,
x
,
label
,
mode
=
'train'
):
encode_data
,
short_cuts
=
self
.
encode
(
x
)
decode_data
=
self
.
decode
(
encode_data
,
short_cuts
)
logit
=
self
.
get_logit
(
decode_data
)
if
mode
==
'train'
:
return
self
.
_get_loss
(
logit
,
label
)
else
:
else
:
transforms
.
transforms
.
append
(
arrange_transform
(
mode
=
mode
))
logit
=
fluid
.
layers
.
softmax
(
logit
,
axis
=
1
)
logit
=
fluid
.
layers
.
transpose
(
logit
,
[
0
,
2
,
3
,
1
])
def
load_model
(
self
,
model_dir
):
pred
=
fluid
.
layers
.
argmax
(
logit
,
axis
=
3
)
ckpt_path
=
osp
.
join
(
model_dir
,
'model'
)
pred
=
fluid
.
layers
.
unsqueeze
(
pred
,
axes
=
[
3
])
para_state_dict
,
opti_state_dict
=
fluid
.
load_dygraph
(
ckpt_path
)
return
pred
,
logit
self
.
model
.
set_dict
(
para_state_dict
)
def
save_model
(
self
,
state_dict
,
save_dir
):
if
not
osp
.
isdir
(
save_dir
):
if
osp
.
exists
(
save_dir
):
os
.
remove
(
save_dir
)
os
.
makedirs
(
save_dir
)
fluid
.
save_dygraph
(
state_dict
,
osp
.
join
(
save_dir
,
'model'
))
def
default_optimizer
(
self
,
learning_rate
,
num_epochs
,
num_steps_each_epoch
,
parameter_list
=
None
,
lr_decay_power
=
0.9
,
regularization_coeff
=
4e-5
):
decay_step
=
num_epochs
*
num_steps_each_epoch
lr_decay
=
fluid
.
layers
.
polynomial_decay
(
learning_rate
,
decay_step
,
end_learning_rate
=
0
,
power
=
lr_decay_power
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
lr_decay
,
momentum
=
0.9
,
parameter_list
=
parameter_list
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
regularization_coeff
))
return
optimizer
def
_get_loss
(
self
,
logit
,
label
):
def
_get_loss
(
self
,
logit
,
label
):
mask
=
label
!=
self
.
ignore_index
mask
=
label
!=
self
.
ignore_index
...
@@ -126,181 +73,183 @@ class UNet(object):
...
@@ -126,181 +73,183 @@ class UNet(object):
mask
.
stop_gradient
=
True
mask
.
stop_gradient
=
True
return
avg_loss
return
avg_loss
def
train
(
self
,
num_epochs
,
class
Encoder
(
fluid
.
dygraph
.
Layer
):
train_dataset
,
def
__init__
(
self
):
train_batch_size
=
2
,
super
().
__init__
()
eval_dataset
=
None
,
with
scope
(
'encode'
):
save_interval_epochs
=
1
,
with
scope
(
'block1'
):
log_interval_steps
=
2
,
self
.
double_conv
=
DoubleConv
(
3
,
64
)
save_dir
=
'output'
,
with
scope
(
'block1'
):
pretrained_weights
=
None
,
self
.
down1
=
Down
(
64
,
128
)
resume_weights
=
None
,
with
scope
(
'block2'
):
optimizer
=
None
,
self
.
down2
=
Down
(
128
,
256
)
learning_rate
=
0.01
,
with
scope
(
'block3'
):
lr_decay_power
=
0.9
,
self
.
down3
=
Down
(
256
,
512
)
regularization_coeff
=
4e-5
,
with
scope
(
'block4'
):
use_vdl
=
False
):
self
.
down4
=
Down
(
512
,
512
)
self
.
labels
=
train_dataset
.
labels
self
.
train_transforms
=
train_dataset
.
transforms
def
forward
(
self
,
x
):
self
.
train_init
=
locals
()
short_cuts
=
[]
self
.
begin_epoch
=
0
x
=
self
.
double_conv
(
x
)
if
optimizer
is
None
:
short_cuts
.
append
(
x
)
num_steps_each_epoch
=
train_dataset
.
num_samples
//
train_batch_size
x
=
self
.
down1
(
x
)
optimizer
=
self
.
default_optimizer
(
short_cuts
.
append
(
x
)
learning_rate
=
learning_rate
,
x
=
self
.
down2
(
x
)
num_epochs
=
num_epochs
,
short_cuts
.
append
(
x
)
num_steps_each_epoch
=
num_steps_each_epoch
,
x
=
self
.
down3
(
x
)
parameter_list
=
self
.
model
.
parameters
(),
short_cuts
.
append
(
x
)
lr_decay_power
=
lr_decay_power
,
x
=
self
.
down4
(
x
)
regularization_coeff
=
regularization_coeff
)
return
x
,
short_cuts
# to do: 预训练模型加载, resume
class
Decode
(
fluid
.
dygraph
.
Layer
):
if
self
.
begin_epoch
>=
num_epochs
:
def
__init__
(
self
,
upsample_mode
=
'bilinear'
):
raise
ValueError
(
super
().
__init__
()
(
"begin epoch[{}] is larger than num_epochs[{}]"
).
format
(
with
scope
(
'decode'
):
self
.
begin_epoch
,
num_epochs
))
with
scope
(
'decode1'
):
self
.
up1
=
Up
(
512
,
256
,
upsample_mode
)
if
not
osp
.
isdir
(
save_dir
):
with
scope
(
'decode2'
):
if
osp
.
exists
(
save_dir
):
self
.
up2
=
Up
(
256
,
128
,
upsample_mode
)
os
.
remove
(
save_dir
)
with
scope
(
'decode3'
):
os
.
makedirs
(
save_dir
)
self
.
up3
=
Up
(
128
,
64
,
upsample_mode
)
with
scope
(
'decode4'
):
# add arrange op to transforms
self
.
up4
=
Up
(
64
,
64
,
upsample_mode
)
self
.
arrange_transform
(
transforms
=
train_dataset
.
transforms
,
mode
=
'train'
)
def
forward
(
self
,
x
,
short_cuts
):
x
=
self
.
up1
(
x
,
short_cuts
[
3
])
if
eval_dataset
is
not
None
:
x
=
self
.
up2
(
x
,
short_cuts
[
2
])
self
.
eval_transforms
=
eval_dataset
.
transforms
x
=
self
.
up3
(
x
,
short_cuts
[
1
])
self
.
test_transforms
=
copy
.
deepcopy
(
eval_dataset
.
transforms
)
x
=
self
.
up4
(
x
,
short_cuts
[
0
])
return
x
data_generator
=
train_dataset
.
generator
(
batch_size
=
train_batch_size
,
drop_last
=
True
)
total_num_steps
=
math
.
floor
(
class
GetLogit
(
fluid
.
dygraph
.
Layer
):
train_dataset
.
num_samples
/
train_batch_size
)
def
__init__
(
self
):
super
().
__init__
()
for
i
in
range
(
self
.
begin_epoch
,
num_epochs
):
for
step
,
data
in
enumerate
(
data_generator
()):
images
=
np
.
array
([
d
[
0
]
for
d
in
data
])
class
DoubleConv
(
fluid
.
dygraph
.
Layer
):
labels
=
np
.
array
([
d
[
1
]
for
d
in
data
]).
astype
(
'int64'
)
def
__init__
(
self
,
num_channels
,
num_filters
):
images
=
to_variable
(
images
)
super
().
__init__
()
labels
=
to_variable
(
labels
)
with
scope
(
'conv0'
):
logit
=
self
.
model
(
images
)
param_attr
=
fluid
.
ParamAttr
(
loss
=
self
.
_get_loss
(
logit
,
labels
)
name
=
name_scope
+
'weights'
,
loss
.
backward
()
regularizer
=
regularizer
,
optimizer
.
minimize
(
loss
)
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
print
(
"[TRAIN] Epoch={}/{}, Step={}/{}, loss={}"
.
format
(
loc
=
0.0
,
scale
=
0.33
))
i
+
1
,
num_epochs
,
step
+
1
,
total_num_steps
,
loss
.
numpy
()))
self
.
conv0
=
Conv2D
(
num_channels
=
num_channels
,
if
(
i
+
1
)
%
save_interval_epochs
==
0
or
i
==
num_epochs
-
1
:
num_filters
=
num_filters
,
current_save_dir
=
osp
.
join
(
save_dir
,
"epoch_{}"
.
format
(
i
+
1
))
filter_size
=
3
,
if
not
osp
.
isdir
(
current_save_dir
):
stride
=
1
,
os
.
makedirs
(
current_save_dir
)
padding
=
1
,
self
.
save_model
(
self
.
model
.
state_dict
(),
current_save_dir
)
param_attr
=
param_attr
)
if
eval_dataset
is
not
None
:
self
.
bn0
=
BatchNorm
(
self
.
model
.
eval
()
num_channels
=
num_filters
,
self
.
evaluate
(
eval_dataset
,
batch_size
=
train_batch_size
)
param_attr
=
fluid
.
ParamAttr
(
self
.
model
.
train
()
name
=
name_scope
+
'gamma'
,
regularizer
=
regularizer
),
bias_attr
=
fluid
.
ParamAttr
(
def
evaluate
(
self
,
eval_dataset
,
batch_size
=
1
,
epoch_id
=
None
):
name
=
name_scope
+
'beta'
,
regularizer
=
regularizer
),
"""评估。
moving_mean_name
=
name_scope
+
'moving_mean'
,
moving_variance_name
=
name_scope
+
'moving_variance'
)
Args:
with
scope
(
'conv1'
):
eval_dataset (paddlex.datasets): 评估数据读取器。
param_attr
=
fluid
.
ParamAttr
(
batch_size (int): 评估时的batch大小。默认1。
name
=
name_scope
+
'weights'
,
epoch_id (int): 当前评估模型所在的训练轮数。
regularizer
=
regularizer
,
return_details (bool): 是否返回详细信息。默认False。
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.33
))
Returns:
self
.
conv1
=
Conv2D
(
dict: 当return_details为False时,返回dict。包含关键字:'miou'、'category_iou'、'macc'、
num_channels
=
num_filters
,
'category_acc'和'kappa',分别表示平均iou、各类别iou、平均准确率、各类别准确率和kappa系数。
num_filters
=
num_filters
,
tuple (metrics, eval_details):当return_details为True时,增加返回dict (eval_details),
filter_size
=
3
,
包含关键字:'confusion_matrix',表示评估的混淆矩阵。
stride
=
1
,
"""
padding
=
1
,
self
.
model
.
eval
()
param_attr
=
param_attr
)
self
.
arrange_transform
(
transforms
=
eval_dataset
.
transforms
,
mode
=
'train'
)
self
.
bn1
=
BatchNorm
(
total_steps
=
math
.
ceil
(
eval_dataset
.
num_samples
*
1.0
/
batch_size
)
num_channels
=
num_filters
,
conf_mat
=
ConfusionMatrix
(
self
.
num_classes
,
streaming
=
True
)
param_attr
=
fluid
.
ParamAttr
(
data_generator
=
eval_dataset
.
generator
(
name
=
name_scope
+
'gamma'
,
regularizer
=
regularizer
),
batch_size
=
batch_size
,
drop_last
=
False
)
bias_attr
=
fluid
.
ParamAttr
(
logging
.
info
(
name
=
name_scope
+
'beta'
,
regularizer
=
regularizer
),
"Start to evaluating(total_samples={}, total_steps={})..."
.
format
(
moving_mean_name
=
name_scope
+
'moving_mean'
,
eval_dataset
.
num_samples
,
total_steps
))
moving_variance_name
=
name_scope
+
'moving_variance'
)
for
step
,
data
in
tqdm
.
tqdm
(
enumerate
(
data_generator
()),
total
=
total_steps
):
def
forward
(
self
,
x
):
images
=
np
.
array
([
d
[
0
]
for
d
in
data
])
x
=
self
.
conv0
(
x
)
labels
=
np
.
array
([
d
[
1
]
for
d
in
data
])
x
=
self
.
bn0
(
x
)
images
=
to_variable
(
images
)
x
=
fluid
.
layers
.
relu
(
x
)
x
=
self
.
conv1
(
x
)
logit
=
self
.
model
(
images
)
x
=
self
.
bn1
(
x
)
pred
=
fluid
.
layers
.
argmax
(
logit
,
axis
=
1
)
x
=
fluid
.
layers
.
relu
(
x
)
pred
=
fluid
.
layers
.
unsqueeze
(
pred
,
axes
=
[
3
])
return
x
pred
=
pred
.
numpy
()
mask
=
labels
!=
self
.
ignore_index
class
Down
(
fluid
.
dygraph
.
Layer
):
conf_mat
.
calculate
(
pred
=
pred
,
label
=
labels
,
ignore
=
mask
)
def
__init__
(
self
,
num_channels
,
num_filters
):
_
,
iou
=
conf_mat
.
mean_iou
()
super
().
__init__
()
with
scope
(
"down"
):
logging
.
debug
(
"[EVAL] Epoch={}, Step={}/{}, iou={}"
.
format
(
self
.
max_pool
=
Pool2D
(
epoch_id
,
step
+
1
,
total_steps
,
iou
))
pool_size
=
2
,
pool_type
=
'max'
,
pool_stride
=
2
,
pool_padding
=
0
)
self
.
double_conv
=
DoubleConv
(
num_channels
,
num_filters
)
category_iou
,
miou
=
conf_mat
.
mean_iou
()
category_acc
,
macc
=
conf_mat
.
accuracy
()
def
forward
(
self
,
x
):
x
=
self
.
max_pool
(
x
)
metrics
=
OrderedDict
(
x
=
self
.
double_conv
(
x
)
zip
([
'miou'
,
'category_iou'
,
'macc'
,
'category_acc'
,
'kappa'
],
return
x
[
miou
,
category_iou
,
macc
,
category_acc
,
conf_mat
.
kappa
()]))
class
Up
(
fluid
.
dygraph
.
Layer
):
logging
.
info
(
'[EVAL] Finished, Epoch={}, {} .'
.
format
(
def
__init__
(
self
,
num_channels
,
num_filters
,
upsample_mode
):
epoch_id
,
dict2str
(
metrics
)))
super
().
__init__
()
return
metrics
self
.
upsample_mode
=
upsample_mode
with
scope
(
'up'
):
def
predict
(
self
,
im_file
,
transforms
=
None
):
if
upsample_mode
==
'bilinear'
:
"""预测。
self
.
double_conv
=
DoubleConv
(
2
*
num_channels
,
num_filters
)
Args:
if
not
upsample_mode
==
'bilinear'
:
img_file(str|np.ndarray): 预测图像。
param_attr
=
fluid
.
ParamAttr
(
transforms(paddlex.cv.transforms): 数据预处理操作。
name
=
name_scope
+
'weights'
,
regularizer
=
regularizer
,
Returns:
initializer
=
fluid
.
initializer
.
XavierInitializer
(),
dict: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,
)
像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)
self
.
deconv
=
fluid
.
dygraph
.
Conv2DTranspose
(
"""
num_channels
=
num_channels
,
if
isinstance
(
im_file
,
str
):
num_filters
=
num_filters
//
2
,
if
not
osp
.
exists
(
im_file
):
filter_size
=
2
,
raise
ValueError
(
stride
=
2
,
'The Image file does not exist: {}'
.
format
(
im_file
))
padding
=
0
,
param_attr
=
param_attr
)
if
transforms
is
None
and
not
hasattr
(
self
,
'test_transforms'
):
self
.
double_conv
=
DoubleConv
(
num_channels
+
num_filters
//
2
,
raise
Exception
(
"transforms need to be defined, now is None."
)
num_filters
)
if
transforms
is
not
None
:
self
.
arrange_transform
(
transforms
=
transforms
,
mode
=
'test'
)
def
forward
(
self
,
x
,
short_cut
):
im
,
im_info
=
transforms
(
im_file
)
if
self
.
upsample_mode
==
'bilinear'
:
short_cut_shape
=
fluid
.
layers
.
shape
(
short_cut
)
x
=
fluid
.
layers
.
resize_bilinear
(
x
,
short_cut_shape
[
2
:])
else
:
else
:
self
.
arrange_transform
(
transforms
=
self
.
test_transforms
,
mode
=
'test'
)
x
=
self
.
deconv
(
x
)
im
,
im_info
=
self
.
test_transforms
(
im_file
)
x
=
fluid
.
layers
.
concat
([
x
,
short_cut
],
axis
=
1
)
im
=
np
.
expand_dims
(
im
,
axis
=
0
)
x
=
self
.
double_conv
(
x
)
im
=
to_variable
(
im
)
return
x
logit
=
self
.
model
(
im
)
logit
=
fluid
.
layers
.
softmax
(
logit
)
pred
=
fluid
.
layers
.
argmax
(
logit
,
axis
=
1
)
class
GetLogit
(
fluid
.
dygraph
.
Layer
):
logit
=
logit
.
numpy
()
def
__init__
(
self
,
num_channels
,
num_classes
):
pred
=
pred
.
numpy
()
super
().
__init__
()
with
scope
(
'logit'
):
logit
=
np
.
squeeze
(
logit
)
param_attr
=
fluid
.
ParamAttr
(
logit
=
np
.
transpose
(
logit
,
(
1
,
2
,
0
))
name
=
name_scope
+
'weights'
,
pred
=
np
.
squeeze
(
pred
).
astype
(
'uint8'
)
regularizer
=
regularizer
,
keys
=
list
(
im_info
.
keys
())
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
print
(
pred
.
shape
,
logit
.
shape
)
loc
=
0.0
,
scale
=
0.01
)
)
for
k
in
keys
[::
-
1
]:
self
.
conv
=
Conv2D
(
if
k
==
'shape_before_resize'
:
num_channels
=
num_channels
,
h
,
w
=
im_info
[
k
][
0
],
im_info
[
k
][
1
]
num_filters
=
num_classes
,
pred
=
cv2
.
resize
(
pred
,
(
w
,
h
),
cv2
.
INTER_NEAREST
)
filter_size
=
3
,
logit
=
cv2
.
resize
(
logit
,
(
w
,
h
),
cv2
.
INTER_LINEAR
)
stride
=
1
,
elif
k
==
'shape_before_padding'
:
padding
=
1
,
h
,
w
=
im_info
[
k
][
0
],
im_info
[
k
][
1
]
param_attr
=
param_attr
)
pred
=
pred
[
0
:
h
,
0
:
w
]
logit
=
logit
[
0
:
h
,
0
:
w
,
:]
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
{
'label_map'
:
pred
,
'score_map'
:
logit
}
return
x
dygraph/nets/__init__.py
已删除
100644 → 0
浏览文件 @
3e90faaa
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.unet
import
UNet
dygraph/nets/unet.py
已删除
100644 → 0
浏览文件 @
3e90faaa
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
OrderedDict
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Conv2D
,
BatchNorm
,
Pool2D
import
contextlib
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
)
name_scope
=
""
@
contextlib
.
contextmanager
def
scope
(
name
):
global
name_scope
bk
=
name_scope
name_scope
=
name_scope
+
name
+
'/'
yield
name_scope
=
bk
class
UNet
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_classes
,
upsample_mode
=
'bilinear'
,
):
super
().
__init__
()
self
.
encode
=
Encoder
()
self
.
decode
=
Decode
(
upsample_mode
=
upsample_mode
)
self
.
get_logit
=
GetLogit
(
64
,
num_classes
)
def
forward
(
self
,
x
):
encode_data
,
short_cuts
=
self
.
encode
(
x
)
decode_data
=
self
.
decode
(
encode_data
,
short_cuts
)
logit
=
self
.
get_logit
(
decode_data
)
return
logit
class
Encoder
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
with
scope
(
'encode'
):
with
scope
(
'block1'
):
self
.
double_conv
=
DoubleConv
(
3
,
64
)
with
scope
(
'block1'
):
self
.
down1
=
Down
(
64
,
128
)
with
scope
(
'block2'
):
self
.
down2
=
Down
(
128
,
256
)
with
scope
(
'block3'
):
self
.
down3
=
Down
(
256
,
512
)
with
scope
(
'block4'
):
self
.
down4
=
Down
(
512
,
512
)
def
forward
(
self
,
x
):
short_cuts
=
[]
x
=
self
.
double_conv
(
x
)
short_cuts
.
append
(
x
)
x
=
self
.
down1
(
x
)
short_cuts
.
append
(
x
)
x
=
self
.
down2
(
x
)
short_cuts
.
append
(
x
)
x
=
self
.
down3
(
x
)
short_cuts
.
append
(
x
)
x
=
self
.
down4
(
x
)
return
x
,
short_cuts
class
Decode
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
upsample_mode
=
'bilinear'
):
super
().
__init__
()
with
scope
(
'decode'
):
with
scope
(
'decode1'
):
self
.
up1
=
Up
(
512
,
256
,
upsample_mode
)
with
scope
(
'decode2'
):
self
.
up2
=
Up
(
256
,
128
,
upsample_mode
)
with
scope
(
'decode3'
):
self
.
up3
=
Up
(
128
,
64
,
upsample_mode
)
with
scope
(
'decode4'
):
self
.
up4
=
Up
(
64
,
64
,
upsample_mode
)
def
forward
(
self
,
x
,
short_cuts
):
x
=
self
.
up1
(
x
,
short_cuts
[
3
])
x
=
self
.
up2
(
x
,
short_cuts
[
2
])
x
=
self
.
up3
(
x
,
short_cuts
[
1
])
x
=
self
.
up4
(
x
,
short_cuts
[
0
])
return
x
class
GetLogit
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
class
DoubleConv
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
):
super
().
__init__
()
with
scope
(
'conv0'
):
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'weights'
,
regularizer
=
regularizer
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.33
))
self
.
conv0
=
Conv2D
(
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
param_attr
=
param_attr
)
self
.
bn0
=
BatchNorm
(
num_channels
=
num_filters
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'gamma'
,
regularizer
=
regularizer
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'beta'
,
regularizer
=
regularizer
),
moving_mean_name
=
name_scope
+
'moving_mean'
,
moving_variance_name
=
name_scope
+
'moving_variance'
)
with
scope
(
'conv1'
):
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'weights'
,
regularizer
=
regularizer
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.33
))
self
.
conv1
=
Conv2D
(
num_channels
=
num_filters
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
param_attr
=
param_attr
)
self
.
bn1
=
BatchNorm
(
num_channels
=
num_filters
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'gamma'
,
regularizer
=
regularizer
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'beta'
,
regularizer
=
regularizer
),
moving_mean_name
=
name_scope
+
'moving_mean'
,
moving_variance_name
=
name_scope
+
'moving_variance'
)
def
forward
(
self
,
x
):
x
=
self
.
conv0
(
x
)
x
=
self
.
bn0
(
x
)
x
=
fluid
.
layers
.
relu
(
x
)
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
fluid
.
layers
.
relu
(
x
)
return
x
class
Down
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
):
super
().
__init__
()
with
scope
(
"down"
):
self
.
max_pool
=
Pool2D
(
pool_size
=
2
,
pool_type
=
'max'
,
pool_stride
=
2
,
pool_padding
=
0
)
self
.
double_conv
=
DoubleConv
(
num_channels
,
num_filters
)
def
forward
(
self
,
x
):
x
=
self
.
max_pool
(
x
)
x
=
self
.
double_conv
(
x
)
return
x
class
Up
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
upsample_mode
):
super
().
__init__
()
self
.
upsample_mode
=
upsample_mode
with
scope
(
'up'
):
if
upsample_mode
==
'bilinear'
:
self
.
double_conv
=
DoubleConv
(
2
*
num_channels
,
num_filters
)
if
not
upsample_mode
==
'bilinear'
:
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'weights'
,
regularizer
=
regularizer
,
initializer
=
fluid
.
initializer
.
XavierInitializer
(),
)
self
.
deconv
=
fluid
.
dygraph
.
Conv2DTranspose
(
num_channels
=
num_channels
,
num_filters
=
num_filters
//
2
,
filter_size
=
2
,
stride
=
2
,
padding
=
0
,
param_attr
=
param_attr
)
self
.
double_conv
=
DoubleConv
(
num_channels
+
num_filters
//
2
,
num_filters
)
def
forward
(
self
,
x
,
short_cut
):
if
self
.
upsample_mode
==
'bilinear'
:
short_cut_shape
=
fluid
.
layers
.
shape
(
short_cut
)
x
=
fluid
.
layers
.
resize_bilinear
(
x
,
short_cut_shape
[
2
:])
else
:
x
=
self
.
deconv
(
x
)
x
=
fluid
.
layers
.
concat
([
x
,
short_cut
],
axis
=
1
)
x
=
self
.
double_conv
(
x
)
return
x
class
GetLogit
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_classes
):
super
().
__init__
()
with
scope
(
'logit'
):
param_attr
=
fluid
.
ParamAttr
(
name
=
name_scope
+
'weights'
,
regularizer
=
regularizer
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.01
))
self
.
conv
=
Conv2D
(
num_channels
=
num_channels
,
num_filters
=
num_classes
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
param_attr
=
param_attr
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
x
dygraph/train.py
0 → 100644
浏览文件 @
29a5e832
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
os.path
as
osp
from
paddle.fluid.dygraph.base
import
to_variable
import
numpy
as
np
import
paddle.fluid
as
fluid
from
datasets.dataset
import
Dataset
import
transforms
as
T
import
models
import
utils.logging
as
logging
from
utils
import
get_environ_info
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Model training'
)
# params of model
parser
.
add_argument
(
'--model_name'
,
dest
=
'model_name'
,
help
=
"Model type for traing, which is one of ('UNet')"
,
type
=
str
,
default
=
'UNet'
)
# params of dataset
parser
.
add_argument
(
'--data_dir'
,
dest
=
'data_dir'
,
help
=
'The root directory of dataset'
,
type
=
str
)
parser
.
add_argument
(
'--train_list'
,
dest
=
'train_list'
,
help
=
'Train list file of dataset'
,
type
=
str
)
parser
.
add_argument
(
'--val_list'
,
dest
=
'val_list'
,
help
=
'Val list file of dataset'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--num_classes'
,
dest
=
'num_classes'
,
help
=
'Number of classes'
,
type
=
int
,
default
=
2
)
# params of training
parser
.
add_argument
(
"--input_size"
,
dest
=
"input_size"
,
help
=
"The image size for net inputs."
,
nargs
=
2
,
default
=
[
512
,
512
],
type
=
int
)
parser
.
add_argument
(
'--num_epochs'
,
dest
=
'num_epochs'
,
help
=
'Number epochs for training'
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
'--batch_size'
,
dest
=
'batch_size'
,
help
=
'Mini batch size'
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
'--learning_rate'
,
dest
=
'learning_rate'
,
help
=
'Learning rate'
,
type
=
float
,
default
=
0.01
)
parser
.
add_argument
(
'--pretrained_model'
,
dest
=
'pretrained_model'
,
help
=
'The path of pretrianed weight'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--save_interval_epochs'
,
dest
=
'save_interval_epochs'
,
help
=
'The interval epochs for save a model snapshot'
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
'--save_dir'
,
dest
=
'save_dir'
,
help
=
'The directory for saving the model snapshot'
,
type
=
str
,
default
=
'./output'
)
return
parser
.
parse_args
()
def
train
(
model
,
train_dataset
,
eval_dataset
=
None
,
optimizer
=
None
,
save_dir
=
'output'
,
num_epochs
=
100
,
batch_size
=
2
,
pretrained_model
=
None
,
save_interval_epochs
=
1
):
if
not
osp
.
isdir
(
save_dir
):
if
osp
.
exists
(
save_dir
):
os
.
remove
(
save_dir
)
os
.
makedirs
(
save_dir
)
data_generator
=
train_dataset
.
generator
(
batch_size
=
batch_size
,
drop_last
=
True
)
num_steps_each_epoch
=
train_dataset
.
num_samples
//
args
.
batch_size
for
epoch
in
range
(
num_epochs
):
for
step
,
data
in
enumerate
(
data_generator
()):
images
=
np
.
array
([
d
[
0
]
for
d
in
data
])
labels
=
np
.
array
([
d
[
1
]
for
d
in
data
]).
astype
(
'int64'
)
images
=
to_variable
(
images
)
labels
=
to_variable
(
labels
)
loss
=
model
(
images
,
labels
,
mode
=
'train'
)
loss
.
backward
()
optimizer
.
minimize
(
loss
)
logging
.
info
(
"[TRAIN] Epoch={}/{}, Step={}/{}, loss={}"
.
format
(
epoch
+
1
,
num_epochs
,
step
+
1
,
num_steps_each_epoch
,
loss
.
numpy
()))
if
(
epoch
+
1
)
%
save_interval_epochs
==
0
or
num_steps_each_epoch
==
num_epochs
-
1
:
current_save_dir
=
osp
.
join
(
save_dir
,
"epoch_{}"
.
format
(
epoch
+
1
))
if
not
osp
.
isdir
(
current_save_dir
):
os
.
makedirs
(
current_save_dir
)
fluid
.
save_dygraph
(
model
.
state_dict
(),
osp
.
join
(
current_save_dir
,
'model'
))
# if eval_dataset is not None:
# model.eval()
# evaluate(eval_dataset, batch_size=train_batch_size)
# model.train()
def
arrange_transform
(
transforms
,
mode
=
'train'
):
arrange_transform
=
T
.
ArrangeSegmenter
if
type
(
transforms
.
transforms
[
-
1
]).
__name__
.
startswith
(
'Arrange'
):
transforms
.
transforms
[
-
1
]
=
arrange_transform
(
mode
=
mode
)
else
:
transforms
.
transforms
.
append
(
arrange_transform
(
mode
=
mode
))
def
main
(
args
):
# Creat dataset reader
train_transforms
=
T
.
Compose
(
[
T
.
Resize
(
args
.
input_size
),
T
.
RandomHorizontalFlip
(),
T
.
Normalize
()])
arrange_transform
(
train_transforms
,
mode
=
'train'
)
train_dataset
=
Dataset
(
data_dir
=
args
.
data_dir
,
file_list
=
args
.
train_list
,
transforms
=
train_transforms
,
num_workers
=
'auto'
,
buffer_size
=
100
,
parallel_method
=
'thread'
,
shuffle
=
True
)
if
args
.
val_list
is
not
None
:
eval_transforms
=
T
.
Compose
([
T
.
Resize
(
args
.
input_size
),
T
.
Normalize
()])
arrange_transform
(
train_transforms
,
mode
=
'eval'
)
eval_dataset
=
Dataset
(
data_dir
=
args
.
data_dir
,
file_list
=
args
.
val_list
,
transforms
=
eval_transforms
,
num_workers
=
'auto'
,
buffer_size
=
100
,
parallel_method
=
'thread'
,
shuffle
=
False
)
if
args
.
model_name
==
'UNet'
:
model
=
models
.
UNet
(
num_classes
=
args
.
num_classes
)
# Creat optimizer
num_steps_each_epoch
=
train_dataset
.
num_samples
//
args
.
batch_size
decay_step
=
args
.
num_epochs
*
num_steps_each_epoch
lr_decay
=
fluid
.
layers
.
polynomial_decay
(
args
.
learning_rate
,
decay_step
,
end_learning_rate
=
0
,
power
=
0.9
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
lr_decay
,
momentum
=
0.9
,
parameter_list
=
model
.
parameters
(),
regularization
=
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
4e-5
))
train
(
model
,
train_dataset
,
eval_dataset
,
optimizer
,
save_dir
=
args
.
save_dir
,
num_epochs
=
args
.
num_epochs
,
batch_size
=
args
.
batch_size
,
pretrained_model
=
args
.
pretrained_model
,
save_interval_epochs
=
args
.
save_interval_epochs
)
if
__name__
==
'__main__'
:
args
=
parse_args
()
env_info
=
get_environ_info
()
if
env_info
[
'place'
]
==
'cpu'
:
places
=
fluid
.
CPUPlace
()
else
:
places
=
fluid
.
CUDAPlace
(
0
)
with
fluid
.
dygraph
.
guard
(
places
):
main
(
args
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录