Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
70f3fab4
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
70f3fab4
编写于
7月 07, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'dygraph' of
https://github.com/wuyefeilin/PaddleSeg
into dygraph
上级
e1f3b7e5
b7f59d54
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
84 addition
and
71 deletion
+84
-71
dygraph/infer.py
dygraph/infer.py
+2
-6
dygraph/models/hrnet.py
dygraph/models/hrnet.py
+2
-1
dygraph/models/unet.py
dygraph/models/unet.py
+2
-1
dygraph/train.py
dygraph/train.py
+6
-9
dygraph/transforms/transforms.py
dygraph/transforms/transforms.py
+70
-48
dygraph/val.py
dygraph/val.py
+2
-6
未找到文件。
dygraph/infer.py
浏览文件 @
70f3fab4
...
...
@@ -37,12 +37,8 @@ def parse_args():
parser
.
add_argument
(
'--model_name'
,
dest
=
'model_name'
,
help
=
'Model type for testing, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")'
,
help
=
'Model type for testing, which is one of {}'
.
format
(
str
(
list
(
MODELS
.
keys
()))),
type
=
str
,
default
=
'UNet'
)
...
...
dygraph/models/hrnet.py
浏览文件 @
70f3fab4
...
...
@@ -18,7 +18,8 @@ import paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
BatchNorm
,
Linear
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.fluid.dygraph
import
SyncBatchNorm
as
BatchNorm
__all__
=
[
"HRNet_W18_Small_V1"
,
"HRNet_W18_Small_V2"
,
"HRNet_W18"
,
"HRNet_W30"
,
...
...
dygraph/models/unet.py
浏览文件 @
70f3fab4
...
...
@@ -13,7 +13,8 @@
# limitations under the License.
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Conv2D
,
BatchNorm
,
Pool2D
from
paddle.fluid.dygraph
import
Conv2D
,
Pool2D
from
paddle.fluid.dygraph
import
SyncBatchNorm
as
BatchNorm
class
UNet
(
fluid
.
dygraph
.
Layer
):
...
...
dygraph/train.py
浏览文件 @
70f3fab4
...
...
@@ -38,12 +38,8 @@ def parse_args():
parser
.
add_argument
(
'--model_name'
,
dest
=
'model_name'
,
help
=
'Model type for training, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")'
,
help
=
'Model type for training, which is one of {}'
.
format
(
str
(
list
(
MODELS
.
keys
()))),
type
=
str
,
default
=
'UNet'
)
...
...
@@ -186,6 +182,7 @@ def train(model,
total_steps
=
steps_per_epoch
*
(
num_epochs
-
start_epoch
)
num_steps
=
0
best_mean_iou
=
-
1.0
best_model_epoch
=
-
1
for
epoch
in
range
(
start_epoch
,
num_epochs
):
for
step
,
data
in
enumerate
(
loader
):
images
=
data
[
0
]
...
...
@@ -245,9 +242,9 @@ def train(model,
best_model_dir
=
os
.
path
.
join
(
save_dir
,
"best_model"
)
fluid
.
save_dygraph
(
model
.
state_dict
(),
os
.
path
.
join
(
best_model_dir
,
'model'
))
logging
.
info
(
'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
.
format
(
best_model_epoch
,
best_mean_iou
))
logging
.
info
(
'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
.
format
(
best_model_epoch
,
best_mean_iou
))
if
use_vdl
:
log_writer
.
add_scalar
(
'Evaluate/mean_iou'
,
mean_iou
,
...
...
dygraph/transforms/transforms.py
浏览文件 @
70f3fab4
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -33,6 +34,7 @@ class Compose:
ValueError: transforms元素个数小于1。
"""
def
__init__
(
self
,
transforms
,
to_rgb
=
True
):
if
not
isinstance
(
transforms
,
list
):
raise
TypeError
(
'The transforms must be a list!'
)
...
...
@@ -86,6 +88,7 @@ class RandomHorizontalFlip:
prob (float): 随机水平翻转的概率。默认值为0.5。
"""
def
__init__
(
self
,
prob
=
0.5
):
self
.
prob
=
prob
...
...
@@ -117,6 +120,7 @@ class RandomVerticalFlip:
Args:
prob (float): 随机垂直翻转的概率。默认值为0.1。
"""
def
__init__
(
self
,
prob
=
0.1
):
self
.
prob
=
prob
...
...
@@ -233,6 +237,7 @@ class ResizeByLong:
Args:
long_size (int): resize后图像的长边大小。
"""
def
__init__
(
self
,
long_size
):
self
.
long_size
=
long_size
...
...
@@ -274,6 +279,7 @@ class ResizeRangeScaling:
Raises:
ValueError: min_value大于max_value
"""
def
__init__
(
self
,
min_value
=
400
,
max_value
=
600
):
if
min_value
>
max_value
:
raise
ValueError
(
'min_value must be less than max_value, '
...
...
@@ -321,6 +327,7 @@ class ResizeStepScaling:
Raises:
ValueError: min_scale_factor大于max_scale_factor
"""
def
__init__
(
self
,
min_scale_factor
=
0.75
,
max_scale_factor
=
1.25
,
...
...
@@ -386,6 +393,7 @@ class Normalize:
Raises:
ValueError: mean或std不是list对象。std包含0。
"""
def
__init__
(
self
,
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
]):
self
.
mean
=
mean
self
.
std
=
std
...
...
@@ -431,6 +439,7 @@ class Padding:
TypeError: target_size不是int|list|tuple。
ValueError: target_size为list|tuple时元素个数不等于2。
"""
def
__init__
(
self
,
target_size
,
im_padding_value
=
[
127.5
,
127.5
,
127.5
],
...
...
@@ -483,21 +492,23 @@ class Padding:
'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
.
format
(
im_width
,
im_height
,
target_width
,
target_height
))
else
:
im
=
cv2
.
copyMakeBorder
(
im
,
0
,
pad_height
,
0
,
pad_width
,
cv2
.
BORDER_CONSTANT
,
value
=
self
.
im_padding_value
)
im
=
cv2
.
copyMakeBorder
(
im
,
0
,
pad_height
,
0
,
pad_width
,
cv2
.
BORDER_CONSTANT
,
value
=
self
.
im_padding_value
)
if
label
is
not
None
:
label
=
cv2
.
copyMakeBorder
(
label
,
0
,
pad_height
,
0
,
pad_width
,
cv2
.
BORDER_CONSTANT
,
value
=
self
.
label_padding_value
)
label
=
cv2
.
copyMakeBorder
(
label
,
0
,
pad_height
,
0
,
pad_width
,
cv2
.
BORDER_CONSTANT
,
value
=
self
.
label_padding_value
)
if
label
is
None
:
return
(
im
,
im_info
)
else
:
...
...
@@ -516,6 +527,7 @@ class RandomPaddingCrop:
TypeError: crop_size不是int/list/tuple。
ValueError: target_size为list/tuple时元素个数不等于2。
"""
def
__init__
(
self
,
crop_size
=
512
,
im_padding_value
=
[
127.5
,
127.5
,
127.5
],
...
...
@@ -564,21 +576,23 @@ class RandomPaddingCrop:
pad_height
=
max
(
crop_height
-
img_height
,
0
)
pad_width
=
max
(
crop_width
-
img_width
,
0
)
if
(
pad_height
>
0
or
pad_width
>
0
):
im
=
cv2
.
copyMakeBorder
(
im
,
0
,
pad_height
,
0
,
pad_width
,
cv2
.
BORDER_CONSTANT
,
value
=
self
.
im_padding_value
)
im
=
cv2
.
copyMakeBorder
(
im
,
0
,
pad_height
,
0
,
pad_width
,
cv2
.
BORDER_CONSTANT
,
value
=
self
.
im_padding_value
)
if
label
is
not
None
:
label
=
cv2
.
copyMakeBorder
(
label
,
0
,
pad_height
,
0
,
pad_width
,
cv2
.
BORDER_CONSTANT
,
value
=
self
.
label_padding_value
)
label
=
cv2
.
copyMakeBorder
(
label
,
0
,
pad_height
,
0
,
pad_width
,
cv2
.
BORDER_CONSTANT
,
value
=
self
.
label_padding_value
)
img_height
=
im
.
shape
[
0
]
img_width
=
im
.
shape
[
1
]
...
...
@@ -586,11 +600,11 @@ class RandomPaddingCrop:
h_off
=
np
.
random
.
randint
(
img_height
-
crop_height
+
1
)
w_off
=
np
.
random
.
randint
(
img_width
-
crop_width
+
1
)
im
=
im
[
h_off
:(
crop_height
+
h_off
),
w_off
:(
w_off
+
crop_width
),
:]
im
=
im
[
h_off
:(
crop_height
+
h_off
),
w_off
:(
w_off
+
crop_width
),
:]
if
label
is
not
None
:
label
=
label
[
h_off
:(
crop_height
+
h_off
),
w_off
:(
w_off
+
crop_width
)]
label
=
label
[
h_off
:(
crop_height
+
h_off
),
w_off
:(
w_off
+
crop_width
)]
if
label
is
None
:
return
(
im
,
im_info
)
else
:
...
...
@@ -603,6 +617,7 @@ class RandomBlur:
Args:
prob (float): 图像模糊概率。默认为0.1。
"""
def
__init__
(
self
,
prob
=
0.1
):
self
.
prob
=
prob
...
...
@@ -650,6 +665,7 @@ class RandomRotation:
label_padding_value (int): 标注图像padding的值。默认为255。
"""
def
__init__
(
self
,
max_rotation
=
15
,
im_padding_value
=
[
127.5
,
127.5
,
127.5
],
...
...
@@ -686,18 +702,20 @@ class RandomRotation:
r
[
0
,
2
]
+=
(
nw
/
2
)
-
cx
r
[
1
,
2
]
+=
(
nh
/
2
)
-
cy
dsize
=
(
nw
,
nh
)
im
=
cv2
.
warpAffine
(
im
,
r
,
dsize
=
dsize
,
flags
=
cv2
.
INTER_LINEAR
,
borderMode
=
cv2
.
BORDER_CONSTANT
,
borderValue
=
self
.
im_padding_value
)
label
=
cv2
.
warpAffine
(
label
,
r
,
dsize
=
dsize
,
flags
=
cv2
.
INTER_NEAREST
,
borderMode
=
cv2
.
BORDER_CONSTANT
,
borderValue
=
self
.
label_padding_value
)
im
=
cv2
.
warpAffine
(
im
,
r
,
dsize
=
dsize
,
flags
=
cv2
.
INTER_LINEAR
,
borderMode
=
cv2
.
BORDER_CONSTANT
,
borderValue
=
self
.
im_padding_value
)
label
=
cv2
.
warpAffine
(
label
,
r
,
dsize
=
dsize
,
flags
=
cv2
.
INTER_NEAREST
,
borderMode
=
cv2
.
BORDER_CONSTANT
,
borderValue
=
self
.
label_padding_value
)
if
label
is
None
:
return
(
im
,
im_info
)
...
...
@@ -713,6 +731,7 @@ class RandomScaleAspect:
min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。
aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。
"""
def
__init__
(
self
,
min_scale
=
0.5
,
aspect_ratio
=
0.33
):
self
.
min_scale
=
min_scale
self
.
aspect_ratio
=
aspect_ratio
...
...
@@ -751,10 +770,12 @@ class RandomScaleAspect:
im
=
im
[
h1
:(
h1
+
dh
),
w1
:(
w1
+
dw
),
:]
label
=
label
[
h1
:(
h1
+
dh
),
w1
:(
w1
+
dw
)]
im
=
cv2
.
resize
(
im
,
(
img_width
,
img_height
),
interpolation
=
cv2
.
INTER_LINEAR
)
label
=
cv2
.
resize
(
label
,
(
img_width
,
img_height
),
interpolation
=
cv2
.
INTER_NEAREST
)
im
=
cv2
.
resize
(
im
,
(
img_width
,
img_height
),
interpolation
=
cv2
.
INTER_LINEAR
)
label
=
cv2
.
resize
(
label
,
(
img_width
,
img_height
),
interpolation
=
cv2
.
INTER_NEAREST
)
break
if
label
is
None
:
return
(
im
,
im_info
)
...
...
@@ -778,6 +799,7 @@ class RandomDistort:
hue_range (int): 色调因子的范围。默认为18。
hue_prob (float): 随机调整色调的概率。默认为0.5。
"""
def
__init__
(
self
,
brightness_range
=
0.5
,
brightness_prob
=
0.5
,
...
...
dygraph/val.py
浏览文件 @
70f3fab4
...
...
@@ -39,12 +39,8 @@ def parse_args():
parser
.
add_argument
(
'--model_name'
,
dest
=
'model_name'
,
help
=
'Model type for evaluation, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")'
,
help
=
'Model type for evaluation, which is one of {}'
.
format
(
str
(
list
(
MODELS
.
keys
()))),
type
=
str
,
default
=
'UNet'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录