Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
3e0ecbcb
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3e0ecbcb
编写于
8月 26, 2020
作者:
W
wuzewu
提交者:
GitHub
8月 26, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #360 from wuyefeilin/dygraph
上级
96b1dfa1
f4ee7706
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
389 addition
and
352 deletion
+389
-352
dygraph/README.md
dygraph/README.md
+13
-5
dygraph/benchmark/deeplabv3p.py
dygraph/benchmark/deeplabv3p.py
+6
-9
dygraph/benchmark/hrnet.py
dygraph/benchmark/hrnet.py
+6
-8
dygraph/core/train.py
dygraph/core/train.py
+1
-4
dygraph/infer.py
dygraph/infer.py
+4
-7
dygraph/models/__init__.py
dygraph/models/__init__.py
+1
-33
dygraph/models/architectures/__init__.py
dygraph/models/architectures/__init__.py
+19
-0
dygraph/models/architectures/hrnet.py
dygraph/models/architectures/hrnet.py
+59
-271
dygraph/models/fcn.py
dygraph/models/fcn.py
+233
-0
dygraph/models/unet.py
dygraph/models/unet.py
+40
-4
dygraph/train.py
dygraph/train.py
+3
-4
dygraph/val.py
dygraph/val.py
+4
-7
未找到文件。
dygraph/README.md
浏览文件 @
3e0ecbcb
# 动态图执行
## 下载及添加路径
```
git clone https://github.com/PaddlePaddle/PaddleSeg
cd PaddleSeg
export PYTHONPATH=$PYTHONPATH:`pwd`
cd dygraph
```
## 训练
```
python3 train.py --model_name
UN
et \
python3 train.py --model_name
un
et \
--dataset OpticDiscSeg \
--input_size 192 192 \
--
num_epoch
s 10 \
--save_interval_
epoch
s 1 \
--
iter
s 10 \
--save_interval_
iter
s 1 \
--do_eval \
--save_dir output
```
## 评估
```
python3 val.py --model_name
UN
et \
python3 val.py --model_name
un
et \
--dataset OpticDiscSeg \
--input_size 192 192 \
--model_dir output/best_model
...
...
@@ -21,7 +29,7 @@ python3 val.py --model_name UNet \
## 预测
```
python3 infer.py --model_name
UN
et \
python3 infer.py --model_name
un
et \
--dataset OpticDiscSeg \
--model_dir output/best_model \
--input_size 192 192
...
...
dygraph/benchmark/deeplabv3p.py
浏览文件 @
3e0ecbcb
...
...
@@ -19,7 +19,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from
dygraph.datasets
import
DATASETS
import
dygraph.transforms
as
T
from
dygraph.models
import
MODELS
#from dygraph.models import MODELS
from
dygraph.cvlibs
import
manager
from
dygraph.utils
import
get_environ_info
from
dygraph.utils
import
logger
from
dygraph.core
import
train
...
...
@@ -33,7 +34,7 @@ def parse_args():
'--model_name'
,
dest
=
'model_name'
,
help
=
'Model type for training, which is one of {}'
.
format
(
str
(
list
(
MODELS
.
keys
()))),
str
(
list
(
manager
.
MODELS
.
components_dict
.
keys
()))),
type
=
str
,
default
=
'UNet'
)
...
...
@@ -161,18 +162,15 @@ def main(args):
eval_dataset
=
None
if
args
.
do_eval
:
eval_transforms
=
T
.
Compose
(
[
T
.
Resize
(
args
.
input_size
),
[
T
.
Padding
((
2049
,
1025
)
),
T
.
Normalize
()])
eval_dataset
=
dataset
(
dataset_root
=
args
.
dataset_root
,
transforms
=
eval_transforms
,
mode
=
'val'
)
if
args
.
model_name
not
in
MODELS
:
raise
Exception
(
'`--model_name` is invalid. it should be one of {}'
.
format
(
str
(
list
(
MODELS
.
keys
()))))
model
=
MODELS
[
args
.
model_name
](
num_classes
=
train_dataset
.
num_classes
)
model
=
manager
.
MODELS
[
args
.
model_name
](
num_classes
=
train_dataset
.
num_classes
)
# Creat optimizer
# todo, may less one than len(loader)
...
...
@@ -195,7 +193,6 @@ def main(args):
save_dir
=
args
.
save_dir
,
iters
=
args
.
iters
,
batch_size
=
args
.
batch_size
,
pretrained_model
=
args
.
pretrained_model
,
resume_model
=
args
.
resume_model
,
save_interval_iters
=
args
.
save_interval_iters
,
log_iters
=
args
.
log_iters
,
...
...
dygraph/benchmark/hrnet.py
浏览文件 @
3e0ecbcb
...
...
@@ -19,7 +19,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from
dygraph.datasets
import
DATASETS
import
dygraph.transforms
as
T
from
dygraph.models
import
MODELS
#from dygraph.models import MODELS
from
dygraph.cvlibs
import
manager
from
dygraph.utils
import
get_environ_info
from
dygraph.utils
import
logger
from
dygraph.core
import
train
...
...
@@ -33,7 +34,7 @@ def parse_args():
'--model_name'
,
dest
=
'model_name'
,
help
=
'Model type for training, which is one of {}'
.
format
(
str
(
list
(
MODELS
.
keys
()))),
str
(
list
(
manager
.
MODELS
.
components_dict
.
keys
()))),
type
=
str
,
default
=
'UNet'
)
...
...
@@ -166,11 +167,9 @@ def main(args):
transforms
=
eval_transforms
,
mode
=
'val'
)
if
args
.
model_name
not
in
MODELS
:
raise
Exception
(
'`--model_name` is invalid. it should be one of {}'
.
format
(
str
(
list
(
MODELS
.
keys
()))))
model
=
MODELS
[
args
.
model_name
](
num_classes
=
train_dataset
.
num_classes
)
model
=
manager
.
MODELS
[
args
.
model_name
](
num_classes
=
train_dataset
.
num_classes
,
pretrained_model
=
args
.
pretrained_model
)
# Creat optimizer
# todo, may less one than len(loader)
...
...
@@ -193,7 +192,6 @@ def main(args):
save_dir
=
args
.
save_dir
,
iters
=
args
.
iters
,
batch_size
=
args
.
batch_size
,
pretrained_model
=
args
.
pretrained_model
,
resume_model
=
args
.
resume_model
,
save_interval_iters
=
args
.
save_interval_iters
,
log_iters
=
args
.
log_iters
,
...
...
dygraph/core/train.py
浏览文件 @
3e0ecbcb
...
...
@@ -34,7 +34,6 @@ def train(model,
save_dir
=
'output'
,
iters
=
10000
,
batch_size
=
2
,
pretrained_model
=
None
,
resume_model
=
None
,
save_interval_iters
=
1000
,
log_iters
=
10
,
...
...
@@ -47,8 +46,6 @@ def train(model,
start_iter
=
0
if
resume_model
is
not
None
:
start_iter
=
resume
(
model
,
optimizer
,
resume_model
)
elif
pretrained_model
is
not
None
:
load_pretrained_model
(
model
,
pretrained_model
)
if
not
os
.
path
.
isdir
(
save_dir
):
if
os
.
path
.
exists
(
save_dir
):
...
...
@@ -126,7 +123,6 @@ def train(model,
log_writer
.
add_scalar
(
'Train/reader_cost'
,
avg_train_reader_cost
,
iter
)
avg_loss
=
0.0
timer
.
restart
()
if
(
iter
%
save_interval_iters
==
0
or
iter
==
iters
)
and
ParallelEnv
().
local_rank
==
0
:
...
...
@@ -162,5 +158,6 @@ def train(model,
log_writer
.
add_scalar
(
'Evaluate/mIoU'
,
mean_iou
,
iter
)
log_writer
.
add_scalar
(
'Evaluate/aAcc'
,
avg_acc
,
iter
)
model
.
train
()
timer
.
restart
()
if
use_vdl
:
log_writer
.
close
()
dygraph/infer.py
浏览文件 @
3e0ecbcb
...
...
@@ -19,7 +19,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from
dygraph.datasets
import
DATASETS
import
dygraph.transforms
as
T
from
dygraph.
models
import
MODELS
from
dygraph.
cvlibs
import
manager
from
dygraph.utils
import
get_environ_info
from
dygraph.core
import
infer
...
...
@@ -32,7 +32,7 @@ def parse_args():
'--model_name'
,
dest
=
'model_name'
,
help
=
'Model type for testing, which is one of {}'
.
format
(
str
(
list
(
MODELS
.
keys
()))),
str
(
list
(
manager
.
MODELS
.
components_dict
.
keys
()))),
type
=
str
,
default
=
'UNet'
)
...
...
@@ -99,11 +99,8 @@ def main(args):
transforms
=
test_transforms
,
mode
=
'test'
)
if
args
.
model_name
not
in
MODELS
:
raise
Exception
(
'`--model_name` is invalid. it should be one of {}'
.
format
(
str
(
list
(
MODELS
.
keys
()))))
model
=
MODELS
[
args
.
model_name
](
num_classes
=
test_dataset
.
num_classes
)
model
=
manager
.
MODELS
[
args
.
model_name
](
num_classes
=
test_dataset
.
num_classes
)
infer
(
model
,
...
...
dygraph/models/__init__.py
浏览文件 @
3e0ecbcb
...
...
@@ -14,37 +14,5 @@
from
.architectures
import
*
from
.unet
import
UNet
from
.hrnet
import
*
from
.deeplab
import
*
# MODELS = {
# "UNet": UNet,
# "HRNet_W18_Small_V1": HRNet_W18_Small_V1,
# "HRNet_W18_Small_V2": HRNet_W18_Small_V2,
# "HRNet_W18": HRNet_W18,
# "HRNet_W30": HRNet_W30,
# "HRNet_W32": HRNet_W32,
# "HRNet_W40": HRNet_W40,
# "HRNet_W44": HRNet_W44,
# "HRNet_W48": HRNet_W48,
# "HRNet_W60": HRNet_W48,
# "HRNet_W64": HRNet_W64,
# "SE_HRNet_W18_Small_V1": SE_HRNet_W18_Small_V1,
# "SE_HRNet_W18_Small_V2": SE_HRNet_W18_Small_V2,
# "SE_HRNet_W18": SE_HRNet_W18,
# "SE_HRNet_W30": SE_HRNet_W30,
# "SE_HRNet_W32": SE_HRNet_W30,
# "SE_HRNet_W40": SE_HRNet_W40,
# "SE_HRNet_W44": SE_HRNet_W44,
# "SE_HRNet_W48": SE_HRNet_W48,
# "SE_HRNet_W60": SE_HRNet_W60,
# "SE_HRNet_W64": SE_HRNet_W64,
# "DeepLabV3P": DeepLabV3P,
# "deeplabv3p_resnet101_vd": deeplabv3p_resnet101_vd,
# "deeplabv3p_resnet101_vd_os8": deeplabv3p_resnet101_vd_os8,
# "deeplabv3p_resnet50_vd": deeplabv3p_resnet50_vd,
# "deeplabv3p_resnet50_vd_os8": deeplabv3p_resnet50_vd_os8,
# "deeplabv3p_xception65_deeplab": deeplabv3p_xception65_deeplab,
# "deeplabv3p_mobilenetv3_large": deeplabv3p_mobilenetv3_large,
# "deeplabv3p_mobilenetv3_small": deeplabv3p_mobilenetv3_small
# }
from
.fcn
import
*
dygraph/models/architectures/__init__.py
0 → 100644
浏览文件 @
3e0ecbcb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
layer_utils
from
.hrnet
import
*
from
.resnet_vd
import
*
from
.xception_deeplab
import
*
from
.mobilenetv3
import
*
dygraph/models/hrnet.py
→
dygraph/models/
architectures/
hrnet.py
浏览文件 @
3e0ecbcb
...
...
@@ -20,20 +20,38 @@ from paddle.fluid.param_attr import ParamAttr
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.fluid.initializer
import
Normal
from
paddle.fluid.dygraph
import
SyncBatchNorm
as
BatchNorm
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
dygraph.cvlibs
import
manager
__all__
=
[
"HRNet_W18_Small_V1"
,
"HRNet_W18_Small_V2"
,
"HRNet_W18"
,
"HRNet_W30"
,
"HRNet_W32"
,
"HRNet_W40"
,
"HRNet_W44"
,
"HRNet_W48"
,
"HRNet_W60"
,
"HRNet_W64"
,
"SE_HRNet_W18_Small_V1"
,
"SE_HRNet_W18_Small_V2"
,
"SE_HRNet_W18"
,
"SE_HRNet_W30"
,
"SE_HRNet_W32"
,
"SE_HRNet_W40"
,
"SE_HRNet_W44"
,
"SE_HRNet_W48"
,
"SE_HRNet_W60"
,
"SE_HRNet_W64"
"HRNet_W32"
,
"HRNet_W40"
,
"HRNet_W44"
,
"HRNet_W48"
,
"HRNet_W60"
,
"HRNet_W64"
]
class
HRNet
(
fluid
.
dygraph
.
Layer
):
"""
HRNet:Deep High-Resolution Representation Learning for Visual Recognition
https://arxiv.org/pdf/1908.07919.pdf.
Args:
stage1_num_modules (int): number of modules for stage1. Default 1.
stage1_num_blocks (list): number of blocks per module for stage1. Default [4].
stage1_num_channels (list): number of channels per branch for stage1. Default [64].
stage2_num_modules (int): number of modules for stage2. Default 1.
stage2_num_blocks (list): number of blocks per module for stage2. Default [4, 4]
stage2_num_channels (list): number of channels per branch for stage2. Default [18, 36].
stage3_num_modules (int): number of modules for stage3. Default 4.
stage3_num_blocks (list): number of blocks per module for stage3. Default [4, 4, 4]
stage3_num_channels (list): number of channels per branch for stage3. Default [18, 36, 72].
stage4_num_modules (int): number of modules for stage4. Default 3.
stage4_num_blocks (list): number of blocks per module for stage4. Default [4, 4, 4, 4]
stage4_num_channels (list): number of channels per branch for stage4. Default [18, 36, 72. 144].
has_se (bool): whether to use Squeeze-and-Excitation module. Default False.
"""
def
__init__
(
self
,
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
...
...
@@ -46,11 +64,9 @@ class HRNet(fluid.dygraph.Layer):
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
18
,
36
,
72
,
144
],
has_se
=
False
,
ignore_index
=
255
):
has_se
=
False
):
super
(
HRNet
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
stage1_num_modules
=
stage1_num_modules
self
.
stage1_num_blocks
=
stage1_num_blocks
self
.
stage1_num_channels
=
stage1_num_channels
...
...
@@ -64,8 +80,6 @@ class HRNet(fluid.dygraph.Layer):
self
.
stage4_num_blocks
=
stage4_num_blocks
self
.
stage4_num_channels
=
stage4_num_channels
self
.
has_se
=
has_se
self
.
ignore_index
=
ignore_index
self
.
EPS
=
1e-5
self
.
conv_layer1_1
=
ConvBNLayer
(
num_channels
=
3
,
...
...
@@ -112,6 +126,7 @@ class HRNet(fluid.dygraph.Layer):
num_modules
=
self
.
stage3_num_modules
,
num_blocks
=
self
.
stage3_num_blocks
,
num_filters
=
self
.
stage3_num_channels
,
has_se
=
self
.
has_se
,
name
=
"st3"
)
self
.
tr3
=
TransitionLayer
(
...
...
@@ -123,24 +138,9 @@ class HRNet(fluid.dygraph.Layer):
num_modules
=
self
.
stage4_num_modules
,
num_blocks
=
self
.
stage4_num_blocks
,
num_filters
=
self
.
stage4_num_channels
,
has_se
=
self
.
has_se
,
name
=
"st4"
)
last_inp_channels
=
sum
(
self
.
stage4_num_channels
)
self
.
conv_last_2
=
ConvBNLayer
(
num_channels
=
last_inp_channels
,
num_filters
=
last_inp_channels
,
filter_size
=
1
,
stride
=
1
,
name
=
'conv-2'
)
self
.
conv_last_1
=
Conv2D
(
num_channels
=
last_inp_channels
,
num_filters
=
self
.
num_classes
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
param_attr
=
ParamAttr
(
initializer
=
Normal
(
scale
=
0.001
),
name
=
'conv-1_weights'
))
def
forward
(
self
,
x
,
label
=
None
,
mode
=
'train'
):
input_shape
=
x
.
shape
[
2
:]
conv1
=
self
.
conv_layer1_1
(
x
)
...
...
@@ -162,40 +162,8 @@ class HRNet(fluid.dygraph.Layer):
x2
=
fluid
.
layers
.
resize_bilinear
(
st4
[
2
],
out_shape
=
(
x0_h
,
x0_w
))
x3
=
fluid
.
layers
.
resize_bilinear
(
st4
[
3
],
out_shape
=
(
x0_h
,
x0_w
))
x
=
fluid
.
layers
.
concat
([
st4
[
0
],
x1
,
x2
,
x3
],
axis
=
1
)
x
=
self
.
conv_last_2
(
x
)
logit
=
self
.
conv_last_1
(
x
)
logit
=
fluid
.
layers
.
resize_bilinear
(
logit
,
input_shape
)
if
self
.
training
:
if
label
is
None
:
raise
Exception
(
'Label is need during training'
)
return
self
.
_get_loss
(
logit
,
label
)
else
:
score_map
=
fluid
.
layers
.
softmax
(
logit
,
axis
=
1
)
score_map
=
fluid
.
layers
.
transpose
(
score_map
,
[
0
,
2
,
3
,
1
])
pred
=
fluid
.
layers
.
argmax
(
score_map
,
axis
=
3
)
pred
=
fluid
.
layers
.
unsqueeze
(
pred
,
axes
=
[
3
])
return
pred
,
score_map
def
_get_loss
(
self
,
logit
,
label
):
logit
=
fluid
.
layers
.
transpose
(
logit
,
[
0
,
2
,
3
,
1
])
label
=
fluid
.
layers
.
transpose
(
label
,
[
0
,
2
,
3
,
1
])
mask
=
label
!=
self
.
ignore_index
mask
=
fluid
.
layers
.
cast
(
mask
,
'float32'
)
loss
,
probs
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logit
,
label
,
ignore_index
=
self
.
ignore_index
,
return_softmax
=
True
,
axis
=-
1
)
loss
=
loss
*
mask
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
/
(
fluid
.
layers
.
mean
(
mask
)
+
self
.
EPS
)
label
.
stop_gradient
=
True
mask
.
stop_gradient
=
True
return
avg_loss
return
x
class
ConvBNLayer
(
fluid
.
dygraph
.
Layer
):
...
...
@@ -698,189 +666,9 @@ class LastClsOut(fluid.dygraph.Layer):
return
outs
def
HRNet_W18_Small_V1
(
num_classes
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
1
],
stage1_num_channels
=
[
32
],
stage2_num_modules
=
1
,
stage2_num_blocks
=
[
2
,
2
],
stage2_num_channels
=
[
16
,
32
],
stage3_num_modules
=
1
,
stage3_num_blocks
=
[
2
,
2
,
2
],
stage3_num_channels
=
[
16
,
32
,
64
],
stage4_num_modules
=
1
,
stage4_num_blocks
=
[
2
,
2
,
2
,
2
],
stage4_num_channels
=
[
16
,
32
,
64
,
128
])
return
model
def
HRNet_W18_Small_V2
(
num_classes
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
2
],
stage1_num_channels
=
[
64
],
stage2_num_modules
=
1
,
stage2_num_blocks
=
[
2
,
2
],
stage2_num_channels
=
[
18
,
36
],
stage3_num_modules
=
1
,
stage3_num_blocks
=
[
2
,
2
,
2
],
stage3_num_channels
=
[
18
,
36
,
72
],
stage4_num_modules
=
1
,
stage4_num_blocks
=
[
2
,
2
,
2
,
2
],
stage4_num_channels
=
[
18
,
36
,
72
,
144
])
return
model
def
HRNet_W18
(
num_classes
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
stage2_num_modules
=
1
,
stage2_num_blocks
=
[
4
,
4
],
stage2_num_channels
=
[
18
,
36
],
stage3_num_modules
=
4
,
stage3_num_blocks
=
[
4
,
4
,
4
],
stage3_num_channels
=
[
18
,
36
,
72
],
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
18
,
36
,
72
,
144
])
return
model
def
HRNet_W30
(
num_classes
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
stage2_num_modules
=
1
,
stage2_num_blocks
=
[
4
,
4
],
stage2_num_channels
=
[
30
,
60
],
stage3_num_modules
=
4
,
stage3_num_blocks
=
[
4
,
4
,
4
],
stage3_num_channels
=
[
30
,
60
,
120
],
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
30
,
60
,
120
,
240
])
return
model
def
HRNet_W32
(
num_classes
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
stage2_num_modules
=
1
,
stage2_num_blocks
=
[
4
,
4
],
stage2_num_channels
=
[
32
,
64
],
stage3_num_modules
=
4
,
stage3_num_blocks
=
[
4
,
4
,
4
],
stage3_num_channels
=
[
32
,
64
,
128
],
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
32
,
64
,
128
,
256
])
return
model
def
HRNet_W40
(
num_classes
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
stage2_num_modules
=
1
,
stage2_num_blocks
=
[
4
,
4
],
stage2_num_channels
=
[
40
,
80
],
stage3_num_modules
=
4
,
stage3_num_blocks
=
[
4
,
4
,
4
],
stage3_num_channels
=
[
40
,
80
,
160
],
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
40
,
80
,
160
,
320
])
return
model
def
HRNet_W44
(
num_classes
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
stage2_num_modules
=
1
,
stage2_num_blocks
=
[
4
,
4
],
stage2_num_channels
=
[
44
,
88
],
stage3_num_modules
=
4
,
stage3_num_blocks
=
[
4
,
4
,
4
],
stage3_num_channels
=
[
44
,
88
,
176
],
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
44
,
88
,
176
,
352
])
return
model
def
HRNet_W48
(
num_classes
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
stage2_num_modules
=
1
,
stage2_num_blocks
=
[
4
,
4
],
stage2_num_channels
=
[
48
,
96
],
stage3_num_modules
=
4
,
stage3_num_blocks
=
[
4
,
4
,
4
],
stage3_num_channels
=
[
48
,
96
,
192
],
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
48
,
96
,
192
,
384
])
return
model
def
HRNet_W60
(
num_classes
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
stage2_num_modules
=
1
,
stage2_num_blocks
=
[
4
,
4
],
stage2_num_channels
=
[
60
,
120
],
stage3_num_modules
=
4
,
stage3_num_blocks
=
[
4
,
4
,
4
],
stage3_num_channels
=
[
60
,
120
,
240
],
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
60
,
120
,
240
,
480
])
return
model
def
HRNet_W64
(
num_classes
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
stage2_num_modules
=
1
,
stage2_num_blocks
=
[
4
,
4
],
stage2_num_channels
=
[
64
,
128
],
stage3_num_modules
=
4
,
stage3_num_blocks
=
[
4
,
4
,
4
],
stage3_num_channels
=
[
64
,
128
,
256
],
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
64
,
128
,
256
,
512
])
return
model
def
SE_HRNet_W18_Small_V1
(
num_classes
):
@
manager
.
BACKBONES
.
add_component
def
HRNet_W18_Small_V1
(
**
kwargs
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
1
],
stage1_num_channels
=
[
32
],
...
...
@@ -893,13 +681,13 @@ def SE_HRNet_W18_Small_V1(num_classes):
stage4_num_modules
=
1
,
stage4_num_blocks
=
[
2
,
2
,
2
,
2
],
stage4_num_channels
=
[
16
,
32
,
64
,
128
],
has_se
=
True
)
**
kwargs
)
return
model
def
SE_HRNet_W18_Small_V2
(
num_classes
):
@
manager
.
BACKBONES
.
add_component
def
HRNet_W18_Small_V2
(
**
kwargs
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
2
],
stage1_num_channels
=
[
64
],
...
...
@@ -912,13 +700,13 @@ def SE_HRNet_W18_Small_V2(num_classes):
stage4_num_modules
=
1
,
stage4_num_blocks
=
[
2
,
2
,
2
,
2
],
stage4_num_channels
=
[
18
,
36
,
72
,
144
],
has_se
=
True
)
**
kwargs
)
return
model
def
SE_HRNet_W18
(
num_classes
):
@
manager
.
BACKBONES
.
add_component
def
HRNet_W18
(
**
kwargs
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
...
...
@@ -931,13 +719,13 @@ def SE_HRNet_W18(num_classes):
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
18
,
36
,
72
,
144
],
has_se
=
True
)
**
kwargs
)
return
model
def
SE_HRNet_W30
(
num_classes
):
@
manager
.
BACKBONES
.
add_component
def
HRNet_W30
(
**
kwargs
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
...
...
@@ -950,13 +738,13 @@ def SE_HRNet_W30(num_classes):
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
30
,
60
,
120
,
240
],
has_se
=
True
)
**
kwargs
)
return
model
def
SE_HRNet_W32
(
num_classes
):
@
manager
.
BACKBONES
.
add_component
def
HRNet_W32
(
**
kwargs
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
...
...
@@ -969,13 +757,13 @@ def SE_HRNet_W32(num_classes):
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
32
,
64
,
128
,
256
],
has_se
=
True
)
**
kwargs
)
return
model
def
SE_HRNet_W40
(
num_classes
):
@
manager
.
BACKBONES
.
add_component
def
HRNet_W40
(
**
kwargs
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
...
...
@@ -988,13 +776,13 @@ def SE_HRNet_W40(num_classes):
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
40
,
80
,
160
,
320
],
has_se
=
True
)
**
kwargs
)
return
model
def
SE_HRNet_W44
(
num_classes
):
@
manager
.
BACKBONES
.
add_component
def
HRNet_W44
(
**
kwargs
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
...
...
@@ -1007,13 +795,13 @@ def SE_HRNet_W44(num_classes):
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
44
,
88
,
176
,
352
],
has_se
=
True
)
**
kwargs
)
return
model
def
SE_HRNet_W48
(
num_classes
):
@
manager
.
BACKBONES
.
add_component
def
HRNet_W48
(
**
kwargs
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
...
...
@@ -1026,13 +814,13 @@ def SE_HRNet_W48(num_classes):
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
48
,
96
,
192
,
384
],
has_se
=
True
)
**
kwargs
)
return
model
def
SE_HRNet_W60
(
num_classes
):
@
manager
.
BACKBONES
.
add_component
def
HRNet_W60
(
**
kwargs
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
...
...
@@ -1045,13 +833,13 @@ def SE_HRNet_W60(num_classes):
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
60
,
120
,
240
,
480
],
has_se
=
True
)
**
kwargs
)
return
model
def
SE_HRNet_W64
(
num_classes
):
@
manager
.
BACKBONES
.
add_component
def
HRNet_W64
(
**
kwargs
):
model
=
HRNet
(
num_classes
=
num_classes
,
stage1_num_modules
=
1
,
stage1_num_blocks
=
[
4
],
stage1_num_channels
=
[
64
],
...
...
@@ -1064,5 +852,5 @@ def SE_HRNet_W64(num_classes):
stage4_num_modules
=
3
,
stage4_num_blocks
=
[
4
,
4
,
4
,
4
],
stage4_num_channels
=
[
64
,
128
,
256
,
512
],
has_se
=
True
)
**
kwargs
)
return
model
dygraph/models/fcn.py
0 → 100644
浏览文件 @
3e0ecbcb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
os
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.fluid.initializer
import
Normal
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
dygraph.cvlibs
import
manager
from
dygraph
import
utils
__all__
=
[
"fcn_hrnet_w18_small_v1"
,
"fcn_hrnet_w18_small_v2"
,
"fcn_hrnet_w18"
,
"fcn_hrnet_w30"
,
"fcn_hrnet_w32"
,
"fcn_hrnet_w40"
,
"fcn_hrnet_w44"
,
"fcn_hrnet_w48"
,
"fcn_hrnet_w60"
,
"fcn_hrnet_w64"
]
class
FCN
(
fluid
.
dygraph
.
Layer
):
"""
Fully Convolutional Networks for Semantic Segmentation.
https://arxiv.org/abs/1411.4038
Args:
backbone (str): backbone name,
num_classes (int): the unique number of target classes.
in_channels (int): the channels of input feature maps.
channels (int): channels after conv layer before the last one.
pretrained_model (str): the path of pretrained model.
ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255.
"""
def
__init__
(
self
,
backbone
,
num_classes
,
in_channels
,
channels
=
None
,
pretrained_model
=
None
,
ignore_index
=
255
,
**
kwargs
):
super
(
FCN
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
ignore_index
=
ignore_index
self
.
EPS
=
1e-5
if
channels
is
None
:
channels
=
in_channels
self
.
backbone
=
manager
.
BACKBONES
[
backbone
](
**
kwargs
)
self
.
conv_last_2
=
ConvBNLayer
(
num_channels
=
in_channels
,
num_filters
=
channels
,
filter_size
=
1
,
stride
=
1
,
name
=
'conv-2'
)
self
.
conv_last_1
=
Conv2D
(
num_channels
=
channels
,
num_filters
=
self
.
num_classes
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
param_attr
=
ParamAttr
(
initializer
=
Normal
(
scale
=
0.001
),
name
=
'conv-1_weights'
))
self
.
init_weight
(
pretrained_model
)
def
forward
(
self
,
x
,
label
=
None
,
mode
=
'train'
):
input_shape
=
x
.
shape
[
2
:]
x
=
self
.
backbone
(
x
)
x
=
self
.
conv_last_2
(
x
)
logit
=
self
.
conv_last_1
(
x
)
logit
=
fluid
.
layers
.
resize_bilinear
(
logit
,
input_shape
)
if
self
.
training
:
if
label
is
None
:
raise
Exception
(
'Label is need during training'
)
return
self
.
_get_loss
(
logit
,
label
)
else
:
score_map
=
fluid
.
layers
.
softmax
(
logit
,
axis
=
1
)
score_map
=
fluid
.
layers
.
transpose
(
score_map
,
[
0
,
2
,
3
,
1
])
pred
=
fluid
.
layers
.
argmax
(
score_map
,
axis
=
3
)
pred
=
fluid
.
layers
.
unsqueeze
(
pred
,
axes
=
[
3
])
return
pred
,
score_map
def
init_weight
(
self
,
pretrained_model
=
None
):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
"""
if
pretrained_model
is
not
None
:
if
os
.
path
.
exists
(
pretrained_model
):
utils
.
load_pretrained_model
(
self
.
backbone
,
pretrained_model
)
utils
.
load_pretrained_model
(
self
,
pretrained_model
)
else
:
raise
Exception
(
'Pretrained model is not found: {}'
.
format
(
pretrained_model
))
def
_get_loss
(
self
,
logit
,
label
):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit
=
fluid
.
layers
.
transpose
(
logit
,
[
0
,
2
,
3
,
1
])
label
=
fluid
.
layers
.
transpose
(
label
,
[
0
,
2
,
3
,
1
])
mask
=
label
!=
self
.
ignore_index
mask
=
fluid
.
layers
.
cast
(
mask
,
'float32'
)
loss
,
probs
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logit
,
label
,
ignore_index
=
self
.
ignore_index
,
return_softmax
=
True
,
axis
=-
1
)
loss
=
loss
*
mask
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
/
(
fluid
.
layers
.
mean
(
mask
)
+
self
.
EPS
)
label
.
stop_gradient
=
True
mask
.
stop_gradient
=
True
return
avg_loss
class
ConvBNLayer
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
act
=
"relu"
,
name
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
_conv
=
Conv2D
(
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
param_attr
=
ParamAttr
(
initializer
=
Normal
(
scale
=
0.001
),
name
=
name
+
"_weights"
),
bias_attr
=
False
)
bn_name
=
name
+
'_bn'
self
.
_batch_norm
=
BatchNorm
(
num_filters
,
weight_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
,
initializer
=
fluid
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
,
initializer
=
fluid
.
initializer
.
Constant
(
0.0
)))
self
.
act
=
act
def
forward
(
self
,
input
):
y
=
self
.
_conv
(
input
)
y
=
self
.
_batch_norm
(
y
)
if
self
.
act
==
'relu'
:
y
=
fluid
.
layers
.
relu
(
y
)
return
y
@
manager
.
MODELS
.
add_component
def
fcn_hrnet_w18_small_v1
(
*
args
,
**
kwargs
):
return
FCN
(
backbone
=
'HRNet_W18_Small_V1'
,
in_channels
=
240
,
**
kwargs
)
@
manager
.
MODELS
.
add_component
def
fcn_hrnet_w18_small_v2
(
*
args
,
**
kwargs
):
return
FCN
(
backbone
=
'HRNet_W18_Small_V2'
,
in_channels
=
270
,
**
kwargs
)
@
manager
.
MODELS
.
add_component
def
fcn_hrnet_w18
(
*
args
,
**
kwargs
):
return
FCN
(
backbone
=
'HRNet_W18'
,
in_channels
=
270
,
**
kwargs
)
@
manager
.
MODELS
.
add_component
def
fcn_hrnet_w30
(
*
args
,
**
kwargs
):
return
FCN
(
backbone
=
'HRNet_W30'
,
in_channels
=
450
,
**
kwargs
)
@
manager
.
MODELS
.
add_component
def
fcn_hrnet_w32
(
*
args
,
**
kwargs
):
return
FCN
(
backbone
=
'HRNet_W32'
,
in_channels
=
480
,
**
kwargs
)
@
manager
.
MODELS
.
add_component
def
fcn_hrnet_w40
(
*
args
,
**
kwargs
):
return
FCN
(
backbone
=
'HRNet_W40'
,
in_channels
=
600
,
**
kwargs
)
@
manager
.
MODELS
.
add_component
def
fcn_hrnet_w44
(
*
args
,
**
kwargs
):
return
FCN
(
backbone
=
'HRNet_W44'
,
in_channels
=
660
,
**
kwargs
)
@
manager
.
MODELS
.
add_component
def
fcn_hrnet_w48
(
*
args
,
**
kwargs
):
return
FCN
(
backbone
=
'HRNet_W48'
,
in_channels
=
720
,
**
kwargs
)
@
manager
.
MODELS
.
add_component
def
fcn_hrnet_w60
(
*
args
,
**
kwargs
):
return
FCN
(
backbone
=
'HRNet_W60'
,
in_channels
=
900
,
**
kwargs
)
@
manager
.
MODELS
.
add_component
def
fcn_hrnet_w64
(
*
args
,
**
kwargs
):
return
FCN
(
backbone
=
'HRNet_W64'
,
in_channels
=
960
,
**
kwargs
)
dygraph/models/unet.py
浏览文件 @
3e0ecbcb
...
...
@@ -12,13 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Conv2D
,
Pool2D
from
paddle.fluid.dygraph
import
SyncBatchNorm
as
BatchNorm
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
dygraph.cvlibs
import
manager
from
dygraph
import
utils
class
UNet
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_classes
,
ignore_index
=
255
):
"""
U-Net: Convolutional Networks for Biomedical Image Segmentation.
https://arxiv.org/abs/1505.04597
Args:
num_classes (int): the unique number of target classes.
pretrained_model (str): the path of pretrained model.
ignore_index (int): the value of ground-truth mask would be ignored while computing loss or doing evaluation. Default 255.
"""
def
__init__
(
self
,
num_classes
,
pretrained_model
=
None
,
ignore_index
=
255
):
super
(
UNet
,
self
).
__init__
()
self
.
encode
=
UnetEncoder
()
self
.
decode
=
UnetDecode
()
...
...
@@ -26,6 +41,8 @@ class UNet(fluid.dygraph.Layer):
self
.
ignore_index
=
ignore_index
self
.
EPS
=
1e-5
self
.
init_weight
(
pretrained_model
)
def
forward
(
self
,
x
,
label
=
None
):
encode_data
,
short_cuts
=
self
.
encode
(
x
)
decode_data
=
self
.
decode
(
encode_data
,
short_cuts
)
...
...
@@ -39,6 +56,20 @@ class UNet(fluid.dygraph.Layer):
pred
=
fluid
.
layers
.
unsqueeze
(
pred
,
axes
=
[
3
])
return
pred
,
score_map
def
init_weight
(
self
,
pretrained_model
=
None
):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
"""
if
pretrained_model
is
not
None
:
if
os
.
path
.
exists
(
pretrained_model
):
utils
.
load_pretrained_model
(
self
.
backbone
,
pretrained_model
)
utils
.
load_pretrained_model
(
self
,
pretrained_model
)
else
:
raise
Exception
(
'Pretrained model is not found: {}'
.
format
(
pretrained_model
))
def
_get_loss
(
self
,
logit
,
label
):
logit
=
fluid
.
layers
.
transpose
(
logit
,
[
0
,
2
,
3
,
1
])
label
=
fluid
.
layers
.
transpose
(
label
,
[
0
,
2
,
3
,
1
])
...
...
@@ -108,14 +139,14 @@ class DoubleConv(fluid.dygraph.Layer):
filter_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
bn0
=
BatchNorm
(
num_
channels
=
num_
filters
)
self
.
bn0
=
BatchNorm
(
num_filters
)
self
.
conv1
=
Conv2D
(
num_channels
=
num_filters
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
bn1
=
BatchNorm
(
num_
channels
=
num_
filters
)
self
.
bn1
=
BatchNorm
(
num_filters
)
def
forward
(
self
,
x
):
x
=
self
.
conv0
(
x
)
...
...
@@ -166,3 +197,8 @@ class GetLogit(fluid.dygraph.Layer):
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
x
@
manager
.
MODELS
.
add_component
def
unet
(
*
args
,
**
kwargs
):
return
UNet
(
*
args
,
**
kwargs
)
dygraph/train.py
浏览文件 @
3e0ecbcb
...
...
@@ -19,7 +19,6 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from
dygraph.datasets
import
DATASETS
import
dygraph.transforms
as
T
#from dygraph.models import MODELS
from
dygraph.cvlibs
import
manager
from
dygraph.utils
import
get_environ_info
from
dygraph.utils
import
logger
...
...
@@ -167,8 +166,9 @@ def main(args):
transforms
=
eval_transforms
,
mode
=
'val'
)
model
=
manager
.
MODELS
[
args
.
model_name
](
num_classes
=
train_dataset
.
num_classes
)
model
=
manager
.
MODELS
[
args
.
model_name
](
num_classes
=
train_dataset
.
num_classes
,
pretrained_model
=
args
.
pretrained_model
)
# Creat optimizer
# todo, may less one than len(loader)
...
...
@@ -191,7 +191,6 @@ def main(args):
save_dir
=
args
.
save_dir
,
iters
=
args
.
iters
,
batch_size
=
args
.
batch_size
,
pretrained_model
=
args
.
pretrained_model
,
resume_model
=
args
.
resume_model
,
save_interval_iters
=
args
.
save_interval_iters
,
log_iters
=
args
.
log_iters
,
...
...
dygraph/val.py
浏览文件 @
3e0ecbcb
...
...
@@ -19,7 +19,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from
dygraph.datasets
import
DATASETS
import
dygraph.transforms
as
T
from
dygraph.
models
import
MODELS
from
dygraph.
cvlibs
import
manager
from
dygraph.utils
import
get_environ_info
from
dygraph.core
import
evaluate
...
...
@@ -32,7 +32,7 @@ def parse_args():
'--model_name'
,
dest
=
'model_name'
,
help
=
'Model type for evaluation, which is one of {}'
.
format
(
str
(
list
(
MODELS
.
keys
()))),
str
(
list
(
manager
.
MODELS
.
components_dict
.
keys
()))),
type
=
str
,
default
=
'UNet'
)
...
...
@@ -87,11 +87,8 @@ def main(args):
transforms
=
eval_transforms
,
mode
=
'val'
)
if
args
.
model_name
not
in
MODELS
:
raise
Exception
(
'`--model_name` is invalid. it should be one of {}'
.
format
(
str
(
list
(
MODELS
.
keys
()))))
model
=
MODELS
[
args
.
model_name
](
num_classes
=
eval_dataset
.
num_classes
)
model
=
manager
.
MODELS
[
args
.
model_name
](
num_classes
=
eval_dataset
.
num_classes
)
evaluate
(
model
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录