Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
ae0b221d
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ae0b221d
编写于
7月 01, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add hrnet
上级
8949ec49
变更
5
展开全部
显示空白变更内容
内联
并排
Showing
5 changed file
with
1111 addition
and
14 deletion
+1111
-14
dygraph/infer.py
dygraph/infer.py
+12
-4
dygraph/models/__init__.py
dygraph/models/__init__.py
+25
-0
dygraph/models/hrnet.py
dygraph/models/hrnet.py
+1050
-0
dygraph/train.py
dygraph/train.py
+12
-6
dygraph/val.py
dygraph/val.py
+12
-4
未找到文件。
dygraph/infer.py
浏览文件 @
ae0b221d
...
...
@@ -24,7 +24,7 @@ import tqdm
from
datasets
import
OpticDiscSeg
,
Cityscapes
import
transforms
as
T
import
models
from
models
import
MODELS
import
utils
import
utils.logging
as
logging
from
utils
import
get_environ_info
...
...
@@ -37,7 +37,12 @@ def parse_args():
parser
.
add_argument
(
'--model_name'
,
dest
=
'model_name'
,
help
=
"Model type for traing, which is one of ('UNet')"
,
help
=
'Model type for testing, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")'
,
type
=
str
,
default
=
'UNet'
)
...
...
@@ -146,8 +151,11 @@ def main(args):
test_transforms
=
T
.
Compose
([
T
.
Resize
(
args
.
input_size
),
T
.
Normalize
()])
test_dataset
=
dataset
(
transforms
=
test_transforms
,
mode
=
'test'
)
if
args
.
model_name
==
'UNet'
:
model
=
models
.
UNet
(
num_classes
=
test_dataset
.
num_classes
)
if
args
.
model_name
not
in
MODELS
:
raise
Exception
(
'--model_name is invalid. it should be one of {}'
.
format
(
str
(
list
(
MODELS
.
keys
()))))
model
=
MODELS
[
args
.
model_name
](
num_classes
=
test_dataset
.
num_classes
)
infer
(
model
,
...
...
dygraph/models/__init__.py
浏览文件 @
ae0b221d
...
...
@@ -13,3 +13,28 @@
# limitations under the License.
from
.unet
import
UNet
from
.hrnet
import
*
MODELS
=
{
"UNet"
:
UNet
,
"HRNet_W18_Small_V1"
:
HRNet_W18_Small_V1
,
"HRNet_W18_Small_V2"
:
HRNet_W18_Small_V2
,
"HRNet_W18"
:
HRNet_W18
,
"HRNet_W30"
:
HRNet_W30
,
"HRNet_W32"
:
HRNet_W32
,
"HRNet_W40"
:
HRNet_W40
,
"HRNet_W44"
:
HRNet_W44
,
"HRNet_W48"
:
HRNet_W48
,
"HRNet_W60"
:
HRNet_W48
,
"HRNet_W64"
:
HRNet_W64
,
"SE_HRNet_W18_Small_V1"
:
SE_HRNet_W18_Small_V1
,
"SE_HRNet_W18_Small_V2"
:
SE_HRNet_W18_Small_V2
,
"SE_HRNet_W18"
:
SE_HRNet_W18
,
"SE_HRNet_W30"
:
SE_HRNet_W30
,
"SE_HRNet_W32"
:
SE_HRNet_W30
,
"SE_HRNet_W40"
:
SE_HRNet_W40
,
"SE_HRNet_W44"
:
SE_HRNet_W44
,
"SE_HRNet_W48"
:
SE_HRNet_W48
,
"SE_HRNet_W60"
:
SE_HRNet_W60
,
"SE_HRNet_W64"
:
SE_HRNet_W64
}
dygraph/models/hrnet.py
0 → 100644
浏览文件 @
ae0b221d
此差异已折叠。
点击以展开。
dygraph/train.py
浏览文件 @
ae0b221d
...
...
@@ -22,7 +22,7 @@ from paddle.incubate.hapi.distributed import DistributedBatchSampler
from
datasets
import
OpticDiscSeg
,
Cityscapes
import
transforms
as
T
import
models
from
models
import
MODELS
import
utils.logging
as
logging
from
utils
import
get_environ_info
from
utils
import
load_pretrained_model
...
...
@@ -38,7 +38,12 @@ def parse_args():
parser
.
add_argument
(
'--model_name'
,
dest
=
'model_name'
,
help
=
"Model type for traing, which is one of ('UNet')"
,
help
=
'Model type for training, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")'
,
type
=
str
,
default
=
'UNet'
)
...
...
@@ -181,7 +186,6 @@ def train(model,
total_steps
=
steps_per_epoch
*
(
num_epochs
-
start_epoch
)
num_steps
=
0
best_mean_iou
=
-
1.0
best_model_epoch
=
1
for
epoch
in
range
(
start_epoch
,
num_epochs
):
for
step
,
data
in
enumerate
(
loader
):
images
=
data
[
0
]
...
...
@@ -286,9 +290,11 @@ def main(args):
T
.
Normalize
()])
eval_dataset
=
dataset
(
transforms
=
eval_transforms
,
mode
=
'eval'
)
if
args
.
model_name
==
'UNet'
:
model
=
models
.
UNet
(
num_classes
=
train_dataset
.
num_classes
,
ignore_index
=
255
)
if
args
.
model_name
not
in
MODELS
:
raise
Exception
(
'--model_name is invalid. it should be one of {}'
.
format
(
str
(
list
(
MODELS
.
keys
()))))
model
=
MODELS
[
args
.
model_name
](
num_classes
=
train_dataset
.
num_classes
)
# Creat optimizer
# todo, may less one than len(loader)
...
...
dygraph/val.py
浏览文件 @
ae0b221d
...
...
@@ -25,7 +25,7 @@ from paddle.fluid.dataloader import BatchSampler
from
datasets
import
OpticDiscSeg
,
Cityscapes
import
transforms
as
T
import
models
from
models
import
MODELS
import
utils.logging
as
logging
from
utils
import
get_environ_info
from
utils
import
ConfusionMatrix
...
...
@@ -39,7 +39,12 @@ def parse_args():
parser
.
add_argument
(
'--model_name'
,
dest
=
'model_name'
,
help
=
"Model type for evaluation, which is one of ('UNet')"
,
help
=
'Model type for evaluation, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", '
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")'
,
type
=
str
,
default
=
'UNet'
)
...
...
@@ -153,8 +158,11 @@ def main(args):
eval_transforms
=
T
.
Compose
([
T
.
Resize
(
args
.
input_size
),
T
.
Normalize
()])
eval_dataset
=
dataset
(
transforms
=
eval_transforms
,
mode
=
'eval'
)
if
args
.
model_name
==
'UNet'
:
model
=
models
.
UNet
(
num_classes
=
eval_dataset
.
num_classes
)
if
args
.
model_name
not
in
MODELS
:
raise
Exception
(
'--model_name is invalid. it should be one of {}'
.
format
(
str
(
list
(
MODELS
.
keys
()))))
model
=
MODELS
[
args
.
model_name
](
num_classes
=
eval_dataset
.
num_classes
)
evaluate
(
model
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录