提交 3b3a69b7 编写于 作者: W wuzewu

Download pretrained model from url

上级 8f77b383
...@@ -17,12 +17,12 @@ import os ...@@ -17,12 +17,12 @@ import os
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import paddleseg.env as segenv
from .dataset import Dataset from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip" URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
...@@ -61,8 +61,8 @@ class ADE20K(Dataset): ...@@ -61,8 +61,8 @@ class ADE20K(Dataset):
"`dataset_root` not set and auto download disabled.") "`dataset_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress( self.dataset_root = download_file_and_uncompress(
url=URL, url=URL,
savepath=DATA_HOME, savepath=segenv.DATA_HOME,
extrapath=DATA_HOME, extrapath=segenv.DATA_HOME,
extraname='ADEChallengeData2016') extraname='ADEChallengeData2016')
elif not os.path.exists(self.dataset_root): elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format( raise Exception('there is not `dataset_root`: {}.'.format(
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
import os import os
import paddleseg.env as segenv
from .dataset import Dataset from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip" URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
...@@ -49,7 +49,7 @@ class OpticDiscSeg(Dataset): ...@@ -49,7 +49,7 @@ class OpticDiscSeg(Dataset):
raise Exception( raise Exception(
"`data_root` not set and auto download disabled.") "`data_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress( self.dataset_root = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME) url=URL, savepath=segenv.DATA_HOME, extrapath=segenv.DATA_HOME)
elif not os.path.exists(self.dataset_root): elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format( raise Exception('there is not `dataset_root`: {}.'.format(
self.dataset_root)) self.dataset_root))
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
import os import os
import paddleseg.env as segenv
from .dataset import Dataset from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar" URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
...@@ -59,8 +59,8 @@ class PascalVOC(Dataset): ...@@ -59,8 +59,8 @@ class PascalVOC(Dataset):
"`dataset_root` not set and auto download disabled.") "`dataset_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress( self.dataset_root = download_file_and_uncompress(
url=URL, url=URL,
savepath=DATA_HOME, savepath=segenv.DATA_HOME,
extrapath=DATA_HOME, extrapath=segenv.DATA_HOME,
extraname='VOCdevkit') extraname='VOCdevkit')
elif not os.path.exists(self.dataset_root): elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format( raise Exception('there is not `dataset_root`: {}.'.format(
......
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from paddleseg.utils import logger
def _get_user_home():
return os.path.expanduser('~')
def _get_seg_home():
if 'SEG_HOME' in os.environ:
home_path = os.environ['SEG_HOME']
if os.path.exists(home_path):
if os.path.isdir(home_path):
return home_path
else:
logger.warning('SEG_HOME {} is a file!'.format(home_path))
else:
return home_path
return os.path.join(_get_user_home(), '.paddleseg')
def _get_sub_home(directory):
home = os.path.join(_get_seg_home(), directory)
if not os.path.exists(home):
os.makedirs(home)
return home
USER_HOME = _get_user_home()
SEG_HOME = _get_seg_home()
DATA_HOME = _get_sub_home('dataset')
TMP_HOME = _get_sub_home('tmp')
PRETRAINED_MODEL_HOME = _get_sub_home('pretrained_model')
...@@ -12,13 +12,28 @@ ...@@ -12,13 +12,28 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import os import os
import numpy as np import numpy as np
import math import math
import cv2 import cv2
import tempfile
import paddle.fluid as fluid import paddle.fluid as fluid
from urllib.parse import urlparse, unquote
from . import logger import filelock
import paddleseg.env as segenv
from paddleseg.utils import logger
from paddleseg.utils.download import download_file_and_uncompress
@contextlib.contextmanager
def generate_tempdir(directory: str = None, **kwargs):
'''Generate a temporary directory'''
directory = segenv.TMP_HOME if not directory else directory
with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir:
yield _dir
def seconds_to_hms(seconds): def seconds_to_hms(seconds):
...@@ -32,6 +47,18 @@ def seconds_to_hms(seconds): ...@@ -32,6 +47,18 @@ def seconds_to_hms(seconds):
def load_pretrained_model(model, pretrained_model): def load_pretrained_model(model, pretrained_model):
if pretrained_model is not None: if pretrained_model is not None:
logger.info('Load pretrained model from {}'.format(pretrained_model)) logger.info('Load pretrained model from {}'.format(pretrained_model))
# download pretrained model from url
if urlparse(pretrained_model).netloc:
pretrained_model = unquote(pretrained_model)
savename = pretrained_model.split('/')[-1].split('.')[0]
with generate_tempdir() as _dir:
with filelock.FileLock(os.path.join(segenv.TMP_HOME, savename)):
pretrained_model = download_file_and_uncompress(
pretrained_model,
savepath=_dir,
extrapath=segenv.PRETRAINED_MODEL_HOME,
extraname=savename)
if os.path.exists(pretrained_model): if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model') ckpt_path = os.path.join(pretrained_model, 'model')
try: try:
......
...@@ -112,9 +112,10 @@ def main(args): ...@@ -112,9 +112,10 @@ def main(args):
val_dataset = cfg.val_dataset if args.do_eval else None val_dataset = cfg.val_dataset if args.do_eval else None
losses = cfg.loss losses = cfg.loss
print('---------------Config Information---------------') msg = '\n---------------Config Information---------------\n'
print(cfg) msg += str(cfg)
print('------------------------------------------------') msg += '------------------------------------------------'
logger.info(msg)
train( train(
cfg.model, cfg.model,
......
...@@ -19,7 +19,7 @@ from paddle.distributed import ParallelEnv ...@@ -19,7 +19,7 @@ from paddle.distributed import ParallelEnv
import paddleseg import paddleseg
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg.utils import get_environ_info, Config from paddleseg.utils import get_environ_info, Config, logger
from paddleseg.core import evaluate from paddleseg.core import evaluate
...@@ -56,9 +56,10 @@ def main(args): ...@@ -56,9 +56,10 @@ def main(args):
'The verification dataset is not specified in the configuration file.' 'The verification dataset is not specified in the configuration file.'
) )
print('---------------Config Information---------------') msg = '\n---------------Config Information---------------\n'
print(cfg) msg += str(cfg)
print('------------------------------------------------') msg += '------------------------------------------------'
logger.info(msg)
evaluate( evaluate(
cfg.model, cfg.model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册