Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
3b3a69b7
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3b3a69b7
编写于
9月 22, 2020
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Download pretrained model from url
上级
8f77b383
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
95 addition
and
16 deletion
+95
-16
dygraph/paddleseg/datasets/ade.py
dygraph/paddleseg/datasets/ade.py
+3
-3
dygraph/paddleseg/datasets/optic_disc_seg.py
dygraph/paddleseg/datasets/optic_disc_seg.py
+2
-2
dygraph/paddleseg/datasets/voc.py
dygraph/paddleseg/datasets/voc.py
+3
-3
dygraph/paddleseg/env.py
dygraph/paddleseg/env.py
+50
-0
dygraph/paddleseg/utils/utils.py
dygraph/paddleseg/utils/utils.py
+28
-1
dygraph/train.py
dygraph/train.py
+4
-3
dygraph/val.py
dygraph/val.py
+5
-4
未找到文件。
dygraph/paddleseg/datasets/ade.py
浏览文件 @
3b3a69b7
...
@@ -17,12 +17,12 @@ import os
...
@@ -17,12 +17,12 @@ import os
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
import
paddleseg.env
as
segenv
from
.dataset
import
Dataset
from
.dataset
import
Dataset
from
paddleseg.utils.download
import
download_file_and_uncompress
from
paddleseg.utils.download
import
download_file_and_uncompress
from
paddleseg.cvlibs
import
manager
from
paddleseg.cvlibs
import
manager
from
paddleseg.transforms
import
Compose
from
paddleseg.transforms
import
Compose
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
URL
=
"http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
URL
=
"http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
...
@@ -61,8 +61,8 @@ class ADE20K(Dataset):
...
@@ -61,8 +61,8 @@ class ADE20K(Dataset):
"`dataset_root` not set and auto download disabled."
)
"`dataset_root` not set and auto download disabled."
)
self
.
dataset_root
=
download_file_and_uncompress
(
self
.
dataset_root
=
download_file_and_uncompress
(
url
=
URL
,
url
=
URL
,
savepath
=
DATA_HOME
,
savepath
=
segenv
.
DATA_HOME
,
extrapath
=
DATA_HOME
,
extrapath
=
segenv
.
DATA_HOME
,
extraname
=
'ADEChallengeData2016'
)
extraname
=
'ADEChallengeData2016'
)
elif
not
os
.
path
.
exists
(
self
.
dataset_root
):
elif
not
os
.
path
.
exists
(
self
.
dataset_root
):
raise
Exception
(
'there is not `dataset_root`: {}.'
.
format
(
raise
Exception
(
'there is not `dataset_root`: {}.'
.
format
(
...
...
dygraph/paddleseg/datasets/optic_disc_seg.py
浏览文件 @
3b3a69b7
...
@@ -14,12 +14,12 @@
...
@@ -14,12 +14,12 @@
import
os
import
os
import
paddleseg.env
as
segenv
from
.dataset
import
Dataset
from
.dataset
import
Dataset
from
paddleseg.utils.download
import
download_file_and_uncompress
from
paddleseg.utils.download
import
download_file_and_uncompress
from
paddleseg.cvlibs
import
manager
from
paddleseg.cvlibs
import
manager
from
paddleseg.transforms
import
Compose
from
paddleseg.transforms
import
Compose
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
URL
=
"https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
URL
=
"https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
...
@@ -49,7 +49,7 @@ class OpticDiscSeg(Dataset):
...
@@ -49,7 +49,7 @@ class OpticDiscSeg(Dataset):
raise
Exception
(
raise
Exception
(
"`data_root` not set and auto download disabled."
)
"`data_root` not set and auto download disabled."
)
self
.
dataset_root
=
download_file_and_uncompress
(
self
.
dataset_root
=
download_file_and_uncompress
(
url
=
URL
,
savepath
=
DATA_HOME
,
extrapath
=
DATA_HOME
)
url
=
URL
,
savepath
=
segenv
.
DATA_HOME
,
extrapath
=
segenv
.
DATA_HOME
)
elif
not
os
.
path
.
exists
(
self
.
dataset_root
):
elif
not
os
.
path
.
exists
(
self
.
dataset_root
):
raise
Exception
(
'there is not `dataset_root`: {}.'
.
format
(
raise
Exception
(
'there is not `dataset_root`: {}.'
.
format
(
self
.
dataset_root
))
self
.
dataset_root
))
...
...
dygraph/paddleseg/datasets/voc.py
浏览文件 @
3b3a69b7
...
@@ -14,12 +14,12 @@
...
@@ -14,12 +14,12 @@
import
os
import
os
import
paddleseg.env
as
segenv
from
.dataset
import
Dataset
from
.dataset
import
Dataset
from
paddleseg.utils.download
import
download_file_and_uncompress
from
paddleseg.utils.download
import
download_file_and_uncompress
from
paddleseg.cvlibs
import
manager
from
paddleseg.cvlibs
import
manager
from
paddleseg.transforms
import
Compose
from
paddleseg.transforms
import
Compose
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
URL
=
"http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
URL
=
"http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
...
@@ -59,8 +59,8 @@ class PascalVOC(Dataset):
...
@@ -59,8 +59,8 @@ class PascalVOC(Dataset):
"`dataset_root` not set and auto download disabled."
)
"`dataset_root` not set and auto download disabled."
)
self
.
dataset_root
=
download_file_and_uncompress
(
self
.
dataset_root
=
download_file_and_uncompress
(
url
=
URL
,
url
=
URL
,
savepath
=
DATA_HOME
,
savepath
=
segenv
.
DATA_HOME
,
extrapath
=
DATA_HOME
,
extrapath
=
segenv
.
DATA_HOME
,
extraname
=
'VOCdevkit'
)
extraname
=
'VOCdevkit'
)
elif
not
os
.
path
.
exists
(
self
.
dataset_root
):
elif
not
os
.
path
.
exists
(
self
.
dataset_root
):
raise
Exception
(
'there is not `dataset_root`: {}.'
.
format
(
raise
Exception
(
'there is not `dataset_root`: {}.'
.
format
(
...
...
dygraph/paddleseg/env.py
0 → 100644
浏览文件 @
3b3a69b7
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
shutil
from
paddleseg.utils
import
logger
def
_get_user_home
():
return
os
.
path
.
expanduser
(
'~'
)
def
_get_seg_home
():
if
'SEG_HOME'
in
os
.
environ
:
home_path
=
os
.
environ
[
'SEG_HOME'
]
if
os
.
path
.
exists
(
home_path
):
if
os
.
path
.
isdir
(
home_path
):
return
home_path
else
:
logger
.
warning
(
'SEG_HOME {} is a file!'
.
format
(
home_path
))
else
:
return
home_path
return
os
.
path
.
join
(
_get_user_home
(),
'.paddleseg'
)
def
_get_sub_home
(
directory
):
home
=
os
.
path
.
join
(
_get_seg_home
(),
directory
)
if
not
os
.
path
.
exists
(
home
):
os
.
makedirs
(
home
)
return
home
USER_HOME
=
_get_user_home
()
SEG_HOME
=
_get_seg_home
()
DATA_HOME
=
_get_sub_home
(
'dataset'
)
TMP_HOME
=
_get_sub_home
(
'tmp'
)
PRETRAINED_MODEL_HOME
=
_get_sub_home
(
'pretrained_model'
)
dygraph/paddleseg/utils/utils.py
浏览文件 @
3b3a69b7
...
@@ -12,13 +12,28 @@
...
@@ -12,13 +12,28 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
contextlib
import
os
import
os
import
numpy
as
np
import
numpy
as
np
import
math
import
math
import
cv2
import
cv2
import
tempfile
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
urllib.parse
import
urlparse
,
unquote
from
.
import
logger
import
filelock
import
paddleseg.env
as
segenv
from
paddleseg.utils
import
logger
from
paddleseg.utils.download
import
download_file_and_uncompress
@
contextlib
.
contextmanager
def
generate_tempdir
(
directory
:
str
=
None
,
**
kwargs
):
'''Generate a temporary directory'''
directory
=
segenv
.
TMP_HOME
if
not
directory
else
directory
with
tempfile
.
TemporaryDirectory
(
dir
=
directory
,
**
kwargs
)
as
_dir
:
yield
_dir
def
seconds_to_hms
(
seconds
):
def
seconds_to_hms
(
seconds
):
...
@@ -32,6 +47,18 @@ def seconds_to_hms(seconds):
...
@@ -32,6 +47,18 @@ def seconds_to_hms(seconds):
def
load_pretrained_model
(
model
,
pretrained_model
):
def
load_pretrained_model
(
model
,
pretrained_model
):
if
pretrained_model
is
not
None
:
if
pretrained_model
is
not
None
:
logger
.
info
(
'Load pretrained model from {}'
.
format
(
pretrained_model
))
logger
.
info
(
'Load pretrained model from {}'
.
format
(
pretrained_model
))
# download pretrained model from url
if
urlparse
(
pretrained_model
).
netloc
:
pretrained_model
=
unquote
(
pretrained_model
)
savename
=
pretrained_model
.
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
]
with
generate_tempdir
()
as
_dir
:
with
filelock
.
FileLock
(
os
.
path
.
join
(
segenv
.
TMP_HOME
,
savename
)):
pretrained_model
=
download_file_and_uncompress
(
pretrained_model
,
savepath
=
_dir
,
extrapath
=
segenv
.
PRETRAINED_MODEL_HOME
,
extraname
=
savename
)
if
os
.
path
.
exists
(
pretrained_model
):
if
os
.
path
.
exists
(
pretrained_model
):
ckpt_path
=
os
.
path
.
join
(
pretrained_model
,
'model'
)
ckpt_path
=
os
.
path
.
join
(
pretrained_model
,
'model'
)
try
:
try
:
...
...
dygraph/train.py
浏览文件 @
3b3a69b7
...
@@ -112,9 +112,10 @@ def main(args):
...
@@ -112,9 +112,10 @@ def main(args):
val_dataset
=
cfg
.
val_dataset
if
args
.
do_eval
else
None
val_dataset
=
cfg
.
val_dataset
if
args
.
do_eval
else
None
losses
=
cfg
.
loss
losses
=
cfg
.
loss
print
(
'---------------Config Information---------------'
)
msg
=
'
\n
---------------Config Information---------------
\n
'
print
(
cfg
)
msg
+=
str
(
cfg
)
print
(
'------------------------------------------------'
)
msg
+=
'------------------------------------------------'
logger
.
info
(
msg
)
train
(
train
(
cfg
.
model
,
cfg
.
model
,
...
...
dygraph/val.py
浏览文件 @
3b3a69b7
...
@@ -19,7 +19,7 @@ from paddle.distributed import ParallelEnv
...
@@ -19,7 +19,7 @@ from paddle.distributed import ParallelEnv
import
paddleseg
import
paddleseg
from
paddleseg.cvlibs
import
manager
from
paddleseg.cvlibs
import
manager
from
paddleseg.utils
import
get_environ_info
,
Config
from
paddleseg.utils
import
get_environ_info
,
Config
,
logger
from
paddleseg.core
import
evaluate
from
paddleseg.core
import
evaluate
...
@@ -56,9 +56,10 @@ def main(args):
...
@@ -56,9 +56,10 @@ def main(args):
'The verification dataset is not specified in the configuration file.'
'The verification dataset is not specified in the configuration file.'
)
)
print
(
'---------------Config Information---------------'
)
msg
=
'
\n
---------------Config Information---------------
\n
'
print
(
cfg
)
msg
+=
str
(
cfg
)
print
(
'------------------------------------------------'
)
msg
+=
'------------------------------------------------'
logger
.
info
(
msg
)
evaluate
(
evaluate
(
cfg
.
model
,
cfg
.
model
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录