提交 3b3a69b7 编写于 作者: W wuzewu

Download pretrained model from url

上级 8f77b383
......@@ -17,12 +17,12 @@ import os
import numpy as np
from PIL import Image
import paddleseg.env as segenv
from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
......@@ -61,8 +61,8 @@ class ADE20K(Dataset):
"`dataset_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL,
savepath=DATA_HOME,
extrapath=DATA_HOME,
savepath=segenv.DATA_HOME,
extrapath=segenv.DATA_HOME,
extraname='ADEChallengeData2016')
elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format(
......
......@@ -14,12 +14,12 @@
import os
import paddleseg.env as segenv
from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
......@@ -49,7 +49,7 @@ class OpticDiscSeg(Dataset):
raise Exception(
"`data_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME)
url=URL, savepath=segenv.DATA_HOME, extrapath=segenv.DATA_HOME)
elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format(
self.dataset_root))
......
......@@ -14,12 +14,12 @@
import os
import paddleseg.env as segenv
from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
......@@ -59,8 +59,8 @@ class PascalVOC(Dataset):
"`dataset_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL,
savepath=DATA_HOME,
extrapath=DATA_HOME,
savepath=segenv.DATA_HOME,
extrapath=segenv.DATA_HOME,
extraname='VOCdevkit')
elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format(
......
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from paddleseg.utils import logger
def _get_user_home():
return os.path.expanduser('~')
def _get_seg_home():
if 'SEG_HOME' in os.environ:
home_path = os.environ['SEG_HOME']
if os.path.exists(home_path):
if os.path.isdir(home_path):
return home_path
else:
logger.warning('SEG_HOME {} is a file!'.format(home_path))
else:
return home_path
return os.path.join(_get_user_home(), '.paddleseg')
def _get_sub_home(directory):
home = os.path.join(_get_seg_home(), directory)
if not os.path.exists(home):
os.makedirs(home)
return home
USER_HOME = _get_user_home()
SEG_HOME = _get_seg_home()
DATA_HOME = _get_sub_home('dataset')
TMP_HOME = _get_sub_home('tmp')
PRETRAINED_MODEL_HOME = _get_sub_home('pretrained_model')
......@@ -12,13 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import numpy as np
import math
import cv2
import tempfile
import paddle.fluid as fluid
from urllib.parse import urlparse, unquote
from . import logger
import filelock
import paddleseg.env as segenv
from paddleseg.utils import logger
from paddleseg.utils.download import download_file_and_uncompress
@contextlib.contextmanager
def generate_tempdir(directory: str = None, **kwargs):
'''Generate a temporary directory'''
directory = segenv.TMP_HOME if not directory else directory
with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir:
yield _dir
def seconds_to_hms(seconds):
......@@ -32,6 +47,18 @@ def seconds_to_hms(seconds):
def load_pretrained_model(model, pretrained_model):
if pretrained_model is not None:
logger.info('Load pretrained model from {}'.format(pretrained_model))
# download pretrained model from url
if urlparse(pretrained_model).netloc:
pretrained_model = unquote(pretrained_model)
savename = pretrained_model.split('/')[-1].split('.')[0]
with generate_tempdir() as _dir:
with filelock.FileLock(os.path.join(segenv.TMP_HOME, savename)):
pretrained_model = download_file_and_uncompress(
pretrained_model,
savepath=_dir,
extrapath=segenv.PRETRAINED_MODEL_HOME,
extraname=savename)
if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model')
try:
......
......@@ -112,9 +112,10 @@ def main(args):
val_dataset = cfg.val_dataset if args.do_eval else None
losses = cfg.loss
print('---------------Config Information---------------')
print(cfg)
print('------------------------------------------------')
msg = '\n---------------Config Information---------------\n'
msg += str(cfg)
msg += '------------------------------------------------'
logger.info(msg)
train(
cfg.model,
......
......@@ -19,7 +19,7 @@ from paddle.distributed import ParallelEnv
import paddleseg
from paddleseg.cvlibs import manager
from paddleseg.utils import get_environ_info, Config
from paddleseg.utils import get_environ_info, Config, logger
from paddleseg.core import evaluate
......@@ -56,9 +56,10 @@ def main(args):
'The verification dataset is not specified in the configuration file.'
)
print('---------------Config Information---------------')
print(cfg)
print('------------------------------------------------')
msg = '\n---------------Config Information---------------\n'
msg += str(cfg)
msg += '------------------------------------------------'
logger.info(msg)
evaluate(
cfg.model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册