Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
d3e3f733
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
289
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d3e3f733
编写于
8月 28, 2020
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add config
上级
b9c9ed27
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
539 addition
and
155 deletion
+539
-155
dygraph/__init__.py
dygraph/__init__.py
+3
-1
dygraph/configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_40k.yml
...onfigs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_40k.yml
+43
-0
dygraph/cvlibs/manager.py
dygraph/cvlibs/manager.py
+14
-8
dygraph/datasets/ade.py
dygraph/datasets/ade.py
+4
-1
dygraph/datasets/cityscapes.py
dygraph/datasets/cityscapes.py
+4
-1
dygraph/datasets/dataset.py
dygraph/datasets/dataset.py
+5
-1
dygraph/datasets/optic_disc_seg.py
dygraph/datasets/optic_disc_seg.py
+4
-1
dygraph/datasets/voc.py
dygraph/datasets/voc.py
+5
-1
dygraph/models/__init__.py
dygraph/models/__init__.py
+1
-0
dygraph/models/ocrnet.py
dygraph/models/ocrnet.py
+196
-0
dygraph/train.py
dygraph/train.py
+19
-92
dygraph/transforms/transforms.py
dygraph/transforms/transforms.py
+15
-0
dygraph/utils/__init__.py
dygraph/utils/__init__.py
+1
-0
dygraph/utils/config.py
dygraph/utils/config.py
+210
-0
dygraph/val.py
dygraph/val.py
+15
-49
未找到文件。
dygraph/__init__.py
浏览文件 @
d3e3f733
...
@@ -12,4 +12,6 @@
...
@@ -12,4 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
dygraph.models
from
.
import
models
\ No newline at end of file
from
.
import
datasets
from
.
import
transforms
dygraph/configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_40k.yml
0 → 100644
浏览文件 @
d3e3f733
batch_size
:
2
iters
:
40000
train_dataset
:
type
:
Cityscapes
dataset_root
:
datasets/cityscapes
transforms
:
-
type
:
RandomHorizontalFlip
-
type
:
ResizeStepScaling
min_scale_factor
:
0.5
max_scale_factor
:
2.0
scale_step_size
:
0.25
-
type
:
RandomPaddingCrop
crop_size
:
[
1024
,
512
]
-
type
:
Normalize
mode
:
train
val_dataset
:
type
:
Cityscapes
dataset_root
:
datasets/cityscapes
transforms
:
-
type
:
Normalize
mode
:
val
model
:
type
:
ocrnet
backbone
:
type
:
HRNet_W18
pretrained
:
dygraph/pretrained_model/hrnet_w18_ssld/model
num_classes
:
19
in_channels
:
270
optimizer
:
type
:
sgd
learning_rate
:
value
:
0.01
decay
:
type
:
poly
power
:
0.9
loss
:
type
:
CrossEntropy
dygraph/cvlibs/manager.py
浏览文件 @
d3e3f733
...
@@ -44,19 +44,20 @@ class ComponentManager:
...
@@ -44,19 +44,20 @@ class ComponentManager:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_components_dict
=
dict
()
self
.
_components_dict
=
dict
()
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
_components_dict
)
return
len
(
self
.
_components_dict
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"{}:{}"
.
format
(
self
.
__class__
.
__name__
,
list
(
self
.
_components_dict
.
keys
()))
return
"{}:{}"
.
format
(
self
.
__class__
.
__name__
,
list
(
self
.
_components_dict
.
keys
()))
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
if
item
not
in
self
.
_components_dict
.
keys
():
if
item
not
in
self
.
_components_dict
.
keys
():
raise
KeyError
(
"{} does not exist in the current {}"
.
format
(
item
,
self
))
raise
KeyError
(
"{} does not exist in the current {}"
.
format
(
item
,
self
))
return
self
.
_components_dict
[
item
]
return
self
.
_components_dict
[
item
]
@
property
@
property
def
components_dict
(
self
):
def
components_dict
(
self
):
return
self
.
_components_dict
return
self
.
_components_dict
...
@@ -74,7 +75,9 @@ class ComponentManager:
...
@@ -74,7 +75,9 @@ class ComponentManager:
# Currently only support class or function type
# Currently only support class or function type
if
not
(
inspect
.
isclass
(
component
)
or
inspect
.
isfunction
(
component
)):
if
not
(
inspect
.
isclass
(
component
)
or
inspect
.
isfunction
(
component
)):
raise
TypeError
(
"Expect class/function type, but received {}"
.
format
(
type
(
component
)))
raise
TypeError
(
"Expect class/function type, but received {}"
.
format
(
type
(
component
)))
# Obtain the internal name of the component
# Obtain the internal name of the component
component_name
=
component
.
__name__
component_name
=
component
.
__name__
...
@@ -92,7 +95,7 @@ class ComponentManager:
...
@@ -92,7 +95,7 @@ class ComponentManager:
Args:
Args:
components (function | class | list | tuple): support three types of components
components (function | class | list | tuple): support three types of components
Returns:
Returns:
None
None
"""
"""
...
@@ -104,8 +107,11 @@ class ComponentManager:
...
@@ -104,8 +107,11 @@ class ComponentManager:
else
:
else
:
component
=
components
component
=
components
self
.
_add_single_component
(
component
)
self
.
_add_single_component
(
component
)
return
components
return
components
MODELS
=
ComponentManager
()
MODELS
=
ComponentManager
()
BACKBONES
=
ComponentManager
()
BACKBONES
=
ComponentManager
()
\ No newline at end of file
DATASETS
=
ComponentManager
()
TRANSFORMS
=
ComponentManager
()
dygraph/datasets/ade.py
浏览文件 @
d3e3f733
...
@@ -19,11 +19,14 @@ from PIL import Image
...
@@ -19,11 +19,14 @@ from PIL import Image
from
.dataset
import
Dataset
from
.dataset
import
Dataset
from
dygraph.utils.download
import
download_file_and_uncompress
from
dygraph.utils.download
import
download_file_and_uncompress
from
dygraph.cvlibs
import
manager
from
dygraph.transforms
import
Compose
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
URL
=
"http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
URL
=
"http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
@
manager
.
DATASETS
.
add_component
class
ADE20K
(
Dataset
):
class
ADE20K
(
Dataset
):
"""ADE20K dataset `http://sceneparsing.csail.mit.edu/`.
"""ADE20K dataset `http://sceneparsing.csail.mit.edu/`.
Args:
Args:
...
@@ -39,7 +42,7 @@ class ADE20K(Dataset):
...
@@ -39,7 +42,7 @@ class ADE20K(Dataset):
transforms
=
None
,
transforms
=
None
,
download
=
True
):
download
=
True
):
self
.
dataset_root
=
dataset_root
self
.
dataset_root
=
dataset_root
self
.
transforms
=
transforms
self
.
transforms
=
Compose
(
transforms
)
self
.
mode
=
mode
self
.
mode
=
mode
self
.
file_list
=
list
()
self
.
file_list
=
list
()
self
.
num_classes
=
150
self
.
num_classes
=
150
...
...
dygraph/datasets/cityscapes.py
浏览文件 @
d3e3f733
...
@@ -16,8 +16,11 @@ import os
...
@@ -16,8 +16,11 @@ import os
import
glob
import
glob
from
.dataset
import
Dataset
from
.dataset
import
Dataset
from
dygraph.cvlibs
import
manager
from
dygraph.transforms
import
Compose
@
manager
.
DATASETS
.
add_component
class
Cityscapes
(
Dataset
):
class
Cityscapes
(
Dataset
):
"""Cityscapes dataset `https://www.cityscapes-dataset.com/`.
"""Cityscapes dataset `https://www.cityscapes-dataset.com/`.
The folder structure is as follow:
The folder structure is as follow:
...
@@ -42,7 +45,7 @@ class Cityscapes(Dataset):
...
@@ -42,7 +45,7 @@ class Cityscapes(Dataset):
def
__init__
(
self
,
dataset_root
,
transforms
=
None
,
mode
=
'train'
):
def
__init__
(
self
,
dataset_root
,
transforms
=
None
,
mode
=
'train'
):
self
.
dataset_root
=
dataset_root
self
.
dataset_root
=
dataset_root
self
.
transforms
=
transforms
self
.
transforms
=
Compose
(
transforms
)
self
.
file_list
=
list
()
self
.
file_list
=
list
()
self
.
mode
=
mode
self
.
mode
=
mode
self
.
num_classes
=
19
self
.
num_classes
=
19
...
...
dygraph/datasets/dataset.py
浏览文件 @
d3e3f733
...
@@ -17,8 +17,12 @@ import os
...
@@ -17,8 +17,12 @@ import os
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
from
dygraph.cvlibs
import
manager
from
dygraph.transforms
import
Compose
@
manager
.
DATASETS
.
add_component
class
Dataset
(
fluid
.
io
.
Dataset
):
class
Dataset
(
fluid
.
io
.
Dataset
):
"""Pass in a custom dataset that conforms to the format.
"""Pass in a custom dataset that conforms to the format.
...
@@ -52,7 +56,7 @@ class Dataset(fluid.io.Dataset):
...
@@ -52,7 +56,7 @@ class Dataset(fluid.io.Dataset):
separator
=
' '
,
separator
=
' '
,
transforms
=
None
):
transforms
=
None
):
self
.
dataset_root
=
dataset_root
self
.
dataset_root
=
dataset_root
self
.
transforms
=
transforms
self
.
transforms
=
Compose
(
transforms
)
self
.
file_list
=
list
()
self
.
file_list
=
list
()
self
.
mode
=
mode
self
.
mode
=
mode
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
...
...
dygraph/datasets/optic_disc_seg.py
浏览文件 @
d3e3f733
...
@@ -16,11 +16,14 @@ import os
...
@@ -16,11 +16,14 @@ import os
from
.dataset
import
Dataset
from
.dataset
import
Dataset
from
dygraph.utils.download
import
download_file_and_uncompress
from
dygraph.utils.download
import
download_file_and_uncompress
from
dygraph.cvlibs
import
manager
from
dygraph.transforms
import
Compose
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
URL
=
"https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
URL
=
"https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
@
manager
.
DATASETS
.
add_component
class
OpticDiscSeg
(
Dataset
):
class
OpticDiscSeg
(
Dataset
):
def
__init__
(
self
,
def
__init__
(
self
,
dataset_root
=
None
,
dataset_root
=
None
,
...
@@ -28,7 +31,7 @@ class OpticDiscSeg(Dataset):
...
@@ -28,7 +31,7 @@ class OpticDiscSeg(Dataset):
mode
=
'train'
,
mode
=
'train'
,
download
=
True
):
download
=
True
):
self
.
dataset_root
=
dataset_root
self
.
dataset_root
=
dataset_root
self
.
transforms
=
transforms
self
.
transforms
=
Compose
(
transforms
)
self
.
file_list
=
list
()
self
.
file_list
=
list
()
self
.
mode
=
mode
self
.
mode
=
mode
self
.
num_classes
=
2
self
.
num_classes
=
2
...
...
dygraph/datasets/voc.py
浏览文件 @
d3e3f733
...
@@ -13,13 +13,17 @@
...
@@ -13,13 +13,17 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
from
.dataset
import
Dataset
from
.dataset
import
Dataset
from
dygraph.utils.download
import
download_file_and_uncompress
from
dygraph.utils.download
import
download_file_and_uncompress
from
dygraph.cvlibs
import
manager
from
dygraph.transforms
import
Compose
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
URL
=
"http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
URL
=
"http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
@
manager
.
DATASETS
.
add_component
class
PascalVOC
(
Dataset
):
class
PascalVOC
(
Dataset
):
"""Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset,
"""Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset,
please run the voc_augment.py in tools.
please run the voc_augment.py in tools.
...
@@ -36,7 +40,7 @@ class PascalVOC(Dataset):
...
@@ -36,7 +40,7 @@ class PascalVOC(Dataset):
transforms
=
None
,
transforms
=
None
,
download
=
True
):
download
=
True
):
self
.
dataset_root
=
dataset_root
self
.
dataset_root
=
dataset_root
self
.
transforms
=
transforms
self
.
transforms
=
Compose
(
transforms
)
self
.
mode
=
mode
self
.
mode
=
mode
self
.
file_list
=
list
()
self
.
file_list
=
list
()
self
.
num_classes
=
21
self
.
num_classes
=
21
...
...
dygraph/models/__init__.py
浏览文件 @
d3e3f733
...
@@ -17,3 +17,4 @@ from .unet import UNet
...
@@ -17,3 +17,4 @@ from .unet import UNet
from
.deeplab
import
*
from
.deeplab
import
*
from
.fcn
import
*
from
.fcn
import
*
from
.pspnet
import
*
from
.pspnet
import
*
from
.ocrnet
import
*
dygraph/models/ocrnet.py
0 → 100644
浏览文件 @
d3e3f733
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Sequential
,
Conv2D
from
dygraph.cvlibs
import
manager
from
dygraph.models.architectures.layer_utils
import
ConvBnRelu
class
SpatialGatherBlock
(
fluid
.
dygraph
.
Layer
):
def
forward
(
self
,
pixels
,
regions
):
n
,
c
,
h
,
w
=
pixels
.
shape
_
,
k
,
_
,
_
=
regions
.
shape
# pixels: from (n, c, h, w) to (n, h*w, c)
pixels
=
fluid
.
layers
.
reshape
(
pixels
,
(
n
,
c
,
h
*
w
))
pixels
=
fluid
.
layers
.
transpose
(
pixels
,
(
0
,
2
,
1
))
# regions: from (n, k, h, w) to (n, k, h*w)
regions
=
fluid
.
layers
.
reshape
(
regions
,
(
n
,
k
,
h
*
w
))
regions
=
fluid
.
layers
.
softmax
(
regions
,
axis
=
2
)
# feats: from (n, k, c) to (n, c, k, 1)
feats
=
fluid
.
layers
.
matmul
(
regions
,
pixels
)
feats
=
fluid
.
layers
.
transpose
(
feats
,
(
0
,
2
,
1
))
feats
=
fluid
.
layers
.
unsqueeze
(
feats
,
axes
=
[
-
1
])
return
feats
class
SpatialOCRModule
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
in_channels
,
key_channels
,
out_channels
,
dropout_rate
=
0.1
):
super
(
SpatialOCRModule
,
self
).
__init__
()
self
.
attention_block
=
ObjectAttentionBlock
(
in_channels
,
key_channels
)
self
.
dropout_rate
=
dropout_rate
self
.
conv1x1
=
Conv2D
(
2
*
in_channels
,
out_channels
,
1
)
def
forward
(
self
,
pixels
,
regions
):
context
=
self
.
attention_block
(
pixels
,
regions
)
feats
=
fluid
.
layers
.
concat
([
context
,
pixels
],
axis
=
1
)
feats
=
self
.
conv1x1
(
feats
)
feats
=
fluid
.
layers
.
dropout
(
feats
,
self
.
dropout_rate
)
return
feats
class
ObjectAttentionBlock
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
in_channels
,
key_channels
):
super
(
ObjectAttentionBlock
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
key_channels
=
key_channels
self
.
f_pixel
=
Sequential
(
ConvBnRelu
(
in_channels
,
key_channels
,
1
),
ConvBnRelu
(
key_channels
,
key_channels
,
1
))
self
.
f_object
=
Sequential
(
ConvBnRelu
(
in_channels
,
key_channels
,
1
),
ConvBnRelu
(
key_channels
,
key_channels
,
1
))
self
.
f_down
=
ConvBnRelu
(
in_channels
,
key_channels
,
1
)
self
.
f_up
=
ConvBnRelu
(
key_channels
,
in_channels
,
1
)
def
forward
(
self
,
x
,
proxy
):
n
,
_
,
h
,
w
=
x
.
shape
# query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
query
=
self
.
f_pixel
(
x
)
query
=
fluid
.
layers
.
reshape
(
query
,
(
n
,
self
.
key_channels
,
-
1
))
query
=
fluid
.
layers
.
transpose
(
query
,
(
0
,
2
,
1
))
# key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
key
=
self
.
f_object
(
proxy
)
key
=
fluid
.
layers
.
reshape
(
key
,
(
n
,
self
.
key_channels
,
-
1
))
# value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
value
=
self
.
f_down
(
proxy
)
value
=
fluid
.
layers
.
reshape
(
value
,
(
n
,
self
.
key_channels
,
-
1
))
value
=
fluid
.
layers
.
transpose
(
value
,
(
0
,
2
,
1
))
# sim_map (n, h1*w1, h2*w2)
sim_map
=
fluid
.
layers
.
matmul
(
query
,
key
)
sim_map
=
(
self
.
key_channels
**-
.
5
)
*
sim_map
sim_map
=
fluid
.
layers
.
softmax
(
sim_map
,
axis
=-
1
)
# context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
context
=
fluid
.
layers
.
matmul
(
sim_map
,
value
)
context
=
fluid
.
layers
.
transpose
(
context
,
(
0
,
2
,
1
))
context
=
fluid
.
layers
.
reshape
(
context
,
(
n
,
self
.
key_channels
,
h
,
w
))
context
=
self
.
f_up
(
context
)
return
context
@
manager
.
MODELS
.
add_component
class
OCRNet
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_classes
,
in_channels
,
backbone
,
ocr_mid_channels
=
512
,
ocr_key_channels
=
256
,
ignore_index
=
255
):
super
(
OCRNet
,
self
).
__init__
()
self
.
ignore_index
=
ignore_index
self
.
num_classes
=
num_classes
self
.
EPS
=
1e-5
self
.
backbone
=
backbone
self
.
spatial_gather
=
SpatialGatherBlock
()
self
.
spatial_ocr
=
SpatialOCRModule
(
ocr_mid_channels
,
ocr_key_channels
,
ocr_mid_channels
)
self
.
conv3x3_ocr
=
ConvBnRelu
(
in_channels
,
ocr_mid_channels
,
3
,
padding
=
1
)
self
.
cls_head
=
Conv2D
(
ocr_mid_channels
,
self
.
num_classes
,
1
)
self
.
aux_head
=
Sequential
(
ConvBnRelu
(
in_channels
,
in_channels
,
3
,
padding
=
1
),
Conv2D
(
in_channels
,
self
.
num_classes
,
1
))
def
forward
(
self
,
x
,
label
=
None
):
feats
=
self
.
backbone
(
x
)
soft_regions
=
self
.
aux_head
(
feats
)
pixels
=
self
.
conv3x3_ocr
(
feats
)
object_regions
=
self
.
spatial_gather
(
pixels
,
soft_regions
)
ocr
=
self
.
spatial_ocr
(
pixels
,
object_regions
)
logit
=
self
.
cls_head
(
ocr
)
logit
=
fluid
.
layers
.
resize_bilinear
(
logit
,
x
.
shape
[
2
:])
if
self
.
training
:
soft_regions
=
fluid
.
layers
.
resize_bilinear
(
soft_regions
,
x
.
shape
[
2
:])
cls_loss
=
self
.
_get_loss
(
logit
,
label
)
aux_loss
=
self
.
_get_loss
(
soft_regions
,
label
)
return
cls_loss
+
0.4
*
aux_loss
score_map
=
fluid
.
layers
.
softmax
(
logit
,
axis
=
1
)
score_map
=
fluid
.
layers
.
transpose
(
score_map
,
[
0
,
2
,
3
,
1
])
pred
=
fluid
.
layers
.
argmax
(
score_map
,
axis
=
3
)
pred
=
fluid
.
layers
.
unsqueeze
(
pred
,
axes
=
[
3
])
return
pred
,
score_map
def
_get_loss
(
self
,
logit
,
label
):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit
=
fluid
.
layers
.
transpose
(
logit
,
[
0
,
2
,
3
,
1
])
label
=
fluid
.
layers
.
transpose
(
label
,
[
0
,
2
,
3
,
1
])
mask
=
label
!=
self
.
ignore_index
mask
=
fluid
.
layers
.
cast
(
mask
,
'float32'
)
loss
,
probs
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logit
,
label
,
ignore_index
=
self
.
ignore_index
,
return_softmax
=
True
,
axis
=-
1
)
loss
=
loss
*
mask
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
/
(
fluid
.
layers
.
mean
(
mask
)
+
self
.
EPS
)
label
.
stop_gradient
=
True
mask
.
stop_gradient
=
True
return
avg_loss
dygraph/train.py
浏览文件 @
d3e3f733
...
@@ -17,78 +17,36 @@ import argparse
...
@@ -17,78 +17,36 @@ import argparse
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
dygraph.datasets
import
DATASETS
import
dygraph
import
dygraph.transforms
as
T
from
dygraph.cvlibs
import
manager
from
dygraph.cvlibs
import
manager
from
dygraph.utils
import
get_environ_info
from
dygraph.utils
import
get_environ_info
from
dygraph.utils
import
logger
from
dygraph.utils
import
logger
from
dygraph.utils
import
Config
from
dygraph.core
import
train
from
dygraph.core
import
train
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Model training'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Model training'
)
# params of model
parser
.
add_argument
(
'--model_name'
,
dest
=
'model_name'
,
help
=
'Model type for training, which is one of {}'
.
format
(
str
(
list
(
manager
.
MODELS
.
components_dict
.
keys
()))),
type
=
str
,
default
=
'UNet'
)
# params of dataset
parser
.
add_argument
(
'--dataset'
,
dest
=
'dataset'
,
help
=
"The dataset you want to train, which is one of {}"
.
format
(
str
(
list
(
DATASETS
.
keys
()))),
type
=
str
,
default
=
'OpticDiscSeg'
)
parser
.
add_argument
(
'--dataset_root'
,
dest
=
'dataset_root'
,
help
=
"dataset root directory"
,
type
=
str
,
default
=
None
)
# params of training
# params of training
parser
.
add_argument
(
parser
.
add_argument
(
"--input_size"
,
"--config"
,
dest
=
"cfg"
,
help
=
"The config file."
,
default
=
None
,
type
=
str
)
dest
=
"input_size"
,
help
=
"The image size for net inputs."
,
nargs
=
2
,
default
=
[
512
,
512
],
type
=
int
)
parser
.
add_argument
(
parser
.
add_argument
(
'--iters'
,
'--iters'
,
dest
=
'iters'
,
dest
=
'iters'
,
help
=
'iters for training'
,
help
=
'iters for training'
,
type
=
int
,
type
=
int
,
default
=
10000
)
default
=
None
)
parser
.
add_argument
(
parser
.
add_argument
(
'--batch_size'
,
'--batch_size'
,
dest
=
'batch_size'
,
dest
=
'batch_size'
,
help
=
'Mini batch size of one gpu or cpu'
,
help
=
'Mini batch size of one gpu or cpu'
,
type
=
int
,
type
=
int
,
default
=
2
)
default
=
None
)
parser
.
add_argument
(
parser
.
add_argument
(
'--learning_rate'
,
'--learning_rate'
,
dest
=
'learning_rate'
,
dest
=
'learning_rate'
,
help
=
'Learning rate'
,
help
=
'Learning rate'
,
type
=
float
,
type
=
float
,
default
=
0.01
)
parser
.
add_argument
(
'--pretrained_model'
,
dest
=
'pretrained_model'
,
help
=
'The path of pretrained model'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--resume_model'
,
dest
=
'resume_model'
,
help
=
'The path of resume model'
,
type
=
str
,
default
=
None
)
default
=
None
)
parser
.
add_argument
(
parser
.
add_argument
(
'--save_interval_iters'
,
'--save_interval_iters'
,
...
@@ -139,59 +97,28 @@ def main(args):
...
@@ -139,59 +97,28 @@ def main(args):
if
env_info
[
'Paddle compiled with cuda'
]
and
env_info
[
'GPUs used'
]
\
if
env_info
[
'Paddle compiled with cuda'
]
and
env_info
[
'GPUs used'
]
\
else
fluid
.
CPUPlace
()
else
fluid
.
CPUPlace
()
if
args
.
dataset
not
in
DATASETS
:
raise
Exception
(
'`--dataset` is invalid. it should be one of {}'
.
format
(
str
(
list
(
DATASETS
.
keys
()))))
dataset
=
DATASETS
[
args
.
dataset
]
with
fluid
.
dygraph
.
guard
(
places
):
with
fluid
.
dygraph
.
guard
(
places
):
# Creat dataset reader
if
not
args
.
cfg
:
train_transforms
=
T
.
Compose
([
raise
RuntimeError
(
'No configuration file specified.'
)
T
.
Resize
(
args
.
input_size
),
T
.
RandomHorizontalFlip
(),
T
.
Normalize
()
])
train_dataset
=
dataset
(
dataset_root
=
args
.
dataset_root
,
transforms
=
train_transforms
,
mode
=
'train'
)
eval_dataset
=
None
cfg
=
Config
(
args
.
cfg
)
if
args
.
do_eval
:
train_dataset
=
cfg
.
train_dataset
eval_transforms
=
T
.
Compose
(
if
not
train_dataset
:
[
T
.
Resize
(
args
.
input_size
),
raise
RuntimeError
(
T
.
Normalize
()])
'The training dataset is not specified in the configuration file.'
eval_dataset
=
dataset
(
)
dataset_root
=
args
.
dataset_root
,
transforms
=
eval_transforms
,
mode
=
'val'
)
model
=
manager
.
MODELS
[
args
.
model_name
](
num_classes
=
train_dataset
.
num_classes
,
pretrained_model
=
args
.
pretrained_model
)
# Creat optimizer
val_dataset
=
cfg
.
val_dataset
if
args
.
do_eval
else
None
# todo, may less one than len(loader)
num_iters_each_epoch
=
len
(
train_dataset
)
//
(
args
.
batch_size
*
ParallelEnv
().
nranks
)
lr_decay
=
fluid
.
layers
.
polynomial_decay
(
args
.
learning_rate
,
args
.
iters
,
end_learning_rate
=
0
,
power
=
0.9
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
lr_decay
,
momentum
=
0.9
,
parameter_list
=
model
.
parameters
(),
regularization
=
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
4e-5
))
train
(
train
(
model
,
cfg
.
model
,
train_dataset
,
train_dataset
,
places
=
places
,
places
=
places
,
eval_dataset
=
e
val_dataset
,
eval_dataset
=
val_dataset
,
optimizer
=
optimizer
,
optimizer
=
cfg
.
optimizer
,
save_dir
=
args
.
save_dir
,
save_dir
=
args
.
save_dir
,
iters
=
args
.
iters
,
iters
=
cfg
.
iters
,
batch_size
=
args
.
batch_size
,
batch_size
=
cfg
.
batch_size
,
resume_model
=
args
.
resume_model
,
save_interval_iters
=
args
.
save_interval_iters
,
save_interval_iters
=
args
.
save_interval_iters
,
log_iters
=
args
.
log_iters
,
log_iters
=
args
.
log_iters
,
num_classes
=
train_dataset
.
num_classes
,
num_classes
=
train_dataset
.
num_classes
,
...
...
dygraph/transforms/transforms.py
浏览文件 @
d3e3f733
...
@@ -21,8 +21,10 @@ from PIL import Image
...
@@ -21,8 +21,10 @@ from PIL import Image
import
cv2
import
cv2
from
.functional
import
*
from
.functional
import
*
from
dygraph.cvlibs
import
manager
@
manager
.
TRANSFORMS
.
add_component
class
Compose
:
class
Compose
:
def
__init__
(
self
,
transforms
,
to_rgb
=
True
):
def
__init__
(
self
,
transforms
,
to_rgb
=
True
):
if
not
isinstance
(
transforms
,
list
):
if
not
isinstance
(
transforms
,
list
):
...
@@ -58,6 +60,7 @@ class Compose:
...
@@ -58,6 +60,7 @@ class Compose:
return
(
im
,
im_info
,
label
)
return
(
im
,
im_info
,
label
)
@
manager
.
TRANSFORMS
.
add_component
class
RandomHorizontalFlip
:
class
RandomHorizontalFlip
:
def
__init__
(
self
,
prob
=
0.5
):
def
__init__
(
self
,
prob
=
0.5
):
self
.
prob
=
prob
self
.
prob
=
prob
...
@@ -73,6 +76,7 @@ class RandomHorizontalFlip:
...
@@ -73,6 +76,7 @@ class RandomHorizontalFlip:
return
(
im
,
im_info
,
label
)
return
(
im
,
im_info
,
label
)
@
manager
.
TRANSFORMS
.
add_component
class
RandomVerticalFlip
:
class
RandomVerticalFlip
:
def
__init__
(
self
,
prob
=
0.1
):
def
__init__
(
self
,
prob
=
0.1
):
self
.
prob
=
prob
self
.
prob
=
prob
...
@@ -88,6 +92,7 @@ class RandomVerticalFlip:
...
@@ -88,6 +92,7 @@ class RandomVerticalFlip:
return
(
im
,
im_info
,
label
)
return
(
im
,
im_info
,
label
)
@
manager
.
TRANSFORMS
.
add_component
class
Resize
:
class
Resize
:
# The interpolation mode
# The interpolation mode
interp_dict
=
{
interp_dict
=
{
...
@@ -137,6 +142,7 @@ class Resize:
...
@@ -137,6 +142,7 @@ class Resize:
return
(
im
,
im_info
,
label
)
return
(
im
,
im_info
,
label
)
@
manager
.
TRANSFORMS
.
add_component
class
ResizeByLong
:
class
ResizeByLong
:
def
__init__
(
self
,
long_size
):
def
__init__
(
self
,
long_size
):
self
.
long_size
=
long_size
self
.
long_size
=
long_size
...
@@ -156,6 +162,7 @@ class ResizeByLong:
...
@@ -156,6 +162,7 @@ class ResizeByLong:
return
(
im
,
im_info
,
label
)
return
(
im
,
im_info
,
label
)
@
manager
.
TRANSFORMS
.
add_component
class
ResizeRangeScaling
:
class
ResizeRangeScaling
:
def
__init__
(
self
,
min_value
=
400
,
max_value
=
600
):
def
__init__
(
self
,
min_value
=
400
,
max_value
=
600
):
if
min_value
>
max_value
:
if
min_value
>
max_value
:
...
@@ -181,6 +188,7 @@ class ResizeRangeScaling:
...
@@ -181,6 +188,7 @@ class ResizeRangeScaling:
return
(
im
,
im_info
,
label
)
return
(
im
,
im_info
,
label
)
@
manager
.
TRANSFORMS
.
add_component
class
ResizeStepScaling
:
class
ResizeStepScaling
:
def
__init__
(
self
,
def
__init__
(
self
,
min_scale_factor
=
0.75
,
min_scale_factor
=
0.75
,
...
@@ -224,6 +232,7 @@ class ResizeStepScaling:
...
@@ -224,6 +232,7 @@ class ResizeStepScaling:
return
(
im
,
im_info
,
label
)
return
(
im
,
im_info
,
label
)
@
manager
.
TRANSFORMS
.
add_component
class
Normalize
:
class
Normalize
:
def
__init__
(
self
,
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
]):
def
__init__
(
self
,
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
]):
self
.
mean
=
mean
self
.
mean
=
mean
...
@@ -245,6 +254,7 @@ class Normalize:
...
@@ -245,6 +254,7 @@ class Normalize:
return
(
im
,
im_info
,
label
)
return
(
im
,
im_info
,
label
)
@
manager
.
TRANSFORMS
.
add_component
class
Padding
:
class
Padding
:
def
__init__
(
self
,
def
__init__
(
self
,
target_size
,
target_size
,
...
@@ -305,6 +315,7 @@ class Padding:
...
@@ -305,6 +315,7 @@ class Padding:
return
(
im
,
im_info
,
label
)
return
(
im
,
im_info
,
label
)
@
manager
.
TRANSFORMS
.
add_component
class
RandomPaddingCrop
:
class
RandomPaddingCrop
:
def
__init__
(
self
,
def
__init__
(
self
,
crop_size
=
512
,
crop_size
=
512
,
...
@@ -378,6 +389,7 @@ class RandomPaddingCrop:
...
@@ -378,6 +389,7 @@ class RandomPaddingCrop:
return
(
im
,
im_info
,
label
)
return
(
im
,
im_info
,
label
)
@
manager
.
TRANSFORMS
.
add_component
class
RandomBlur
:
class
RandomBlur
:
def
__init__
(
self
,
prob
=
0.1
):
def
__init__
(
self
,
prob
=
0.1
):
self
.
prob
=
prob
self
.
prob
=
prob
...
@@ -404,6 +416,7 @@ class RandomBlur:
...
@@ -404,6 +416,7 @@ class RandomBlur:
return
(
im
,
im_info
,
label
)
return
(
im
,
im_info
,
label
)
@
manager
.
TRANSFORMS
.
add_component
class
RandomRotation
:
class
RandomRotation
:
def
__init__
(
self
,
def
__init__
(
self
,
max_rotation
=
15
,
max_rotation
=
15
,
...
@@ -451,6 +464,7 @@ class RandomRotation:
...
@@ -451,6 +464,7 @@ class RandomRotation:
return
(
im
,
im_info
,
label
)
return
(
im
,
im_info
,
label
)
@
manager
.
TRANSFORMS
.
add_component
class
RandomScaleAspect
:
class
RandomScaleAspect
:
def
__init__
(
self
,
min_scale
=
0.5
,
aspect_ratio
=
0.33
):
def
__init__
(
self
,
min_scale
=
0.5
,
aspect_ratio
=
0.33
):
self
.
min_scale
=
min_scale
self
.
min_scale
=
min_scale
...
@@ -492,6 +506,7 @@ class RandomScaleAspect:
...
@@ -492,6 +506,7 @@ class RandomScaleAspect:
return
(
im
,
im_info
,
label
)
return
(
im
,
im_info
,
label
)
@
manager
.
TRANSFORMS
.
add_component
class
RandomDistort
:
class
RandomDistort
:
def
__init__
(
self
,
def
__init__
(
self
,
brightness_range
=
0.5
,
brightness_range
=
0.5
,
...
...
dygraph/utils/__init__.py
浏览文件 @
d3e3f733
...
@@ -18,3 +18,4 @@ from .metrics import ConfusionMatrix
...
@@ -18,3 +18,4 @@ from .metrics import ConfusionMatrix
from
.utils
import
*
from
.utils
import
*
from
.timer
import
Timer
,
calculate_eta
from
.timer
import
Timer
,
calculate_eta
from
.get_environ_info
import
get_environ_info
from
.get_environ_info
import
get_environ_info
from
.config
import
Config
dygraph/utils/config.py
0 → 100644
浏览文件 @
d3e3f733
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
codecs
import
os
from
typing
import
Any
,
Callable
import
yaml
import
paddle.fluid
as
fluid
import
dygraph.cvlibs.manager
as
manager
class
Config
(
object
):
'''
Training config.
Args:
path(str) : the path of config file, supports yaml format only
'''
def
__init__
(
self
,
path
:
str
):
if
not
os
.
path
.
exists
(
path
):
raise
FileNotFoundError
(
'File {} does not exist'
.
format
(
path
))
if
path
.
endswith
(
'yml'
)
or
path
.
endswith
(
'yaml'
):
self
.
_parse_from_yaml
(
path
)
else
:
raise
RuntimeError
(
'Config file should in yaml format!'
)
def
_parse_from_yaml
(
self
,
path
:
str
):
'''Parse a yaml file and build config'''
with
codecs
.
open
(
path
,
'r'
,
'utf-8'
)
as
file
:
dic
=
yaml
.
load
(
file
,
Loader
=
yaml
.
FullLoader
)
self
.
_build
(
dic
)
def
_build
(
self
,
dic
:
dict
):
'''Build config from dictionary'''
dic
=
dic
.
copy
()
self
.
_batch_size
=
dic
.
get
(
'batch_size'
,
1
)
self
.
_iters
=
dic
.
get
(
'iters'
)
if
'model'
not
in
dic
:
raise
RuntimeError
()
self
.
_model_cfg
=
dic
[
'model'
]
self
.
_model
=
None
self
.
_train_dataset
=
dic
.
get
(
'train_dataset'
)
self
.
_val_dataset
=
dic
.
get
(
'val_dataset'
)
self
.
_learning_rate_cfg
=
dic
.
get
(
'learning_rate'
,
{})
self
.
_learning_rate
=
self
.
_learning_rate_cfg
.
get
(
'value'
)
self
.
_decay
=
self
.
_learning_rate_cfg
.
get
(
'decay'
,
{
'type'
:
'poly'
,
'power'
:
0.9
})
self
.
_loss_cfg
=
dic
.
get
(
'loss'
,
{})
self
.
_optimizer_cfg
=
dic
.
get
(
'optimizer'
,
{})
def
update
(
self
,
learning_rate
:
float
=
None
,
batch_size
:
int
=
None
,
iters
:
int
=
None
):
'''Update config'''
if
learning_rate
:
self
.
_learning_rate
=
learning_rate
if
batch_size
:
self
.
_batch_size
=
batch_size
if
iters
:
self
.
_iters
=
iters
@
property
def
batch_size
(
self
)
->
int
:
return
self
.
_batch_size
@
property
def
iters
(
self
)
->
int
:
if
not
self
.
_iters
:
raise
RuntimeError
(
'No iters specified in the configuration file.'
)
return
self
.
_iters
@
property
def
learning_rate
(
self
)
->
float
:
if
not
self
.
_learning_rate
:
raise
RuntimeError
(
'No learning rate specified in the configuration file.'
)
if
self
.
decay_type
==
'poly'
:
lr
=
self
.
_learning_rate
args
=
self
.
decay_args
args
.
setdefault
(
'decay_steps'
,
self
.
iters
)
return
fluid
.
layers
.
polynomial_decay
(
lr
,
**
args
)
else
:
raise
RuntimeError
(
'Only poly decay support.'
)
@
property
def
optimizer
(
self
)
->
fluid
.
optimizer
.
Optimizer
:
if
self
.
optimizer_type
==
'sgd'
:
lr
=
self
.
learning_rate
args
=
self
.
optimizer_args
args
.
setdefault
(
'momentum'
,
0.9
)
return
fluid
.
optimizer
.
Momentum
(
lr
,
parameter_list
=
self
.
model
.
parameters
(),
**
args
)
else
:
raise
RuntimeError
(
'Only sgd optimizer support.'
)
@
property
def
optimizer_type
(
self
)
->
str
:
otype
=
self
.
_optimizer_cfg
.
get
(
'type'
)
if
not
otype
:
raise
RuntimeError
(
'No optimizer type specified in the configuration file.'
)
return
otype
@
property
def
optimizer_args
(
self
)
->
dict
:
args
=
self
.
_optimizer_cfg
.
copy
()
args
.
pop
(
'type'
)
return
args
@
property
def
decay_type
(
self
)
->
str
:
return
self
.
_decay
[
'type'
]
@
property
def
decay_args
(
self
)
->
dict
:
args
=
self
.
_decay
.
copy
()
args
.
pop
(
'type'
)
return
args
@
property
def
loss_type
(
self
)
->
str
:
...
@
property
def
loss_args
(
self
)
->
dict
:
args
=
self
.
_loss_cfg
.
copy
()
args
.
pop
(
'type'
)
return
args
@
property
def
model
(
self
)
->
Callable
:
if
not
self
.
_model
:
self
.
_model
=
self
.
_load_object
(
self
.
_model_cfg
)
return
self
.
_model
@
property
def
train_dataset
(
self
)
->
Any
:
if
not
self
.
_train_dataset
:
return
None
return
self
.
_load_object
(
self
.
_train_dataset
)
@
property
def
val_dataset
(
self
)
->
Any
:
if
not
self
.
_val_dataset
:
return
None
return
self
.
_load_object
(
self
.
_val_dataset
)
def
_load_component
(
self
,
com_name
:
str
)
->
Any
:
com_list
=
[
manager
.
MODELS
,
manager
.
BACKBONES
,
manager
.
DATASETS
,
manager
.
TRANSFORMS
]
for
com
in
com_list
:
if
com_name
in
com
.
components_dict
:
return
com
[
com_name
]
else
:
raise
RuntimeError
(
'The specified component was not found {}.'
.
format
(
com_name
))
def
_load_object
(
self
,
cfg
:
dict
)
->
Any
:
cfg
=
cfg
.
copy
()
if
'type'
not
in
cfg
:
raise
RuntimeError
(
'No object information in {}.'
.
format
(
cfg
))
component
=
self
.
_load_component
(
cfg
.
pop
(
'type'
))
params
=
{}
for
key
,
val
in
cfg
.
items
():
if
self
.
_is_meta_type
(
val
):
params
[
key
]
=
self
.
_load_object
(
val
)
elif
isinstance
(
val
,
list
):
params
[
key
]
=
[
self
.
_load_object
(
item
)
if
self
.
_is_meta_type
(
item
)
else
item
for
item
in
val
]
else
:
params
[
key
]
=
val
return
component
(
**
params
)
def
_is_meta_type
(
self
,
item
:
Any
)
->
bool
:
return
isinstance
(
item
,
dict
)
and
'type'
in
item
dygraph/val.py
浏览文件 @
d3e3f733
...
@@ -17,48 +17,19 @@ import argparse
...
@@ -17,48 +17,19 @@ import argparse
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
dygraph.datasets
import
DATASETS
import
dygraph
import
dygraph.transforms
as
T
from
dygraph.cvlibs
import
manager
from
dygraph.cvlibs
import
manager
from
dygraph.utils
import
get_environ_info
from
dygraph.utils
import
get_environ_info
from
dygraph.utils
import
Config
from
dygraph.core
import
evaluate
from
dygraph.core
import
evaluate
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Model evaluation'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Model evaluation'
)
# params of model
parser
.
add_argument
(
'--model_name'
,
dest
=
'model_name'
,
help
=
'Model type for evaluation, which is one of {}'
.
format
(
str
(
list
(
manager
.
MODELS
.
components_dict
.
keys
()))),
type
=
str
,
default
=
'UNet'
)
# params of dataset
parser
.
add_argument
(
'--dataset'
,
dest
=
'dataset'
,
help
=
"The dataset you want to evaluation, which is one of {}"
.
format
(
str
(
list
(
DATASETS
.
keys
()))),
type
=
str
,
default
=
'OpticDiscSeg'
)
parser
.
add_argument
(
'--dataset_root'
,
dest
=
'dataset_root'
,
help
=
"dataset root directory"
,
type
=
str
,
default
=
None
)
# params of evaluate
# params of evaluate
parser
.
add_argument
(
parser
.
add_argument
(
"--input_size"
,
"--config"
,
dest
=
"cfg"
,
help
=
"The config file."
,
default
=
None
,
type
=
str
)
dest
=
"input_size"
,
help
=
"The image size for net inputs."
,
nargs
=
2
,
default
=
[
512
,
512
],
type
=
int
)
parser
.
add_argument
(
parser
.
add_argument
(
'--model_dir'
,
'--model_dir'
,
dest
=
'model_dir'
,
dest
=
'model_dir'
,
...
@@ -75,26 +46,21 @@ def main(args):
...
@@ -75,26 +46,21 @@ def main(args):
if
env_info
[
'Paddle compiled with cuda'
]
and
env_info
[
'GPUs used'
]
\
if
env_info
[
'Paddle compiled with cuda'
]
and
env_info
[
'GPUs used'
]
\
else
fluid
.
CPUPlace
()
else
fluid
.
CPUPlace
()
if
args
.
dataset
not
in
DATASETS
:
raise
Exception
(
'`--dataset` is invalid. it should be one of {}'
.
format
(
str
(
list
(
DATASETS
.
keys
()))))
dataset
=
DATASETS
[
args
.
dataset
]
with
fluid
.
dygraph
.
guard
(
places
):
with
fluid
.
dygraph
.
guard
(
places
):
eval_transforms
=
T
.
Compose
([
T
.
Resize
(
args
.
input_size
),
T
.
Normalize
()])
if
not
args
.
cfg
:
eval_dataset
=
dataset
(
raise
RuntimeError
(
'No configuration file specified.'
)
dataset_root
=
args
.
dataset_root
,
transforms
=
eval_transforms
,
cfg
=
Config
(
args
.
cfg
)
mode
=
'val'
)
val_dataset
=
cfg
.
val_dataset
if
not
val_dataset
:
model
=
manager
.
MODELS
[
args
.
model_name
]
(
raise
RuntimeError
(
num_classes
=
eval_dataset
.
num_classes
)
'The verification dataset is not specified in the configuration file.'
)
evaluate
(
evaluate
(
model
,
cfg
.
model
,
e
val_dataset
,
val_dataset
,
model_dir
=
args
.
model_dir
,
model_dir
=
args
.
model_dir
,
num_classes
=
e
val_dataset
.
num_classes
)
num_classes
=
val_dataset
.
num_classes
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录