From 3b3a69b74daa91093780fa2214665460b70bd2e6 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Tue, 22 Sep 2020 11:43:06 +0800 Subject: [PATCH] Download pretrained model from url --- dygraph/paddleseg/datasets/ade.py | 6 +-- dygraph/paddleseg/datasets/optic_disc_seg.py | 4 +- dygraph/paddleseg/datasets/voc.py | 6 +-- dygraph/paddleseg/env.py | 50 ++++++++++++++++++++ dygraph/paddleseg/utils/utils.py | 29 +++++++++++- dygraph/train.py | 7 +-- dygraph/val.py | 9 ++-- 7 files changed, 95 insertions(+), 16 deletions(-) create mode 100644 dygraph/paddleseg/env.py diff --git a/dygraph/paddleseg/datasets/ade.py b/dygraph/paddleseg/datasets/ade.py index 66147398..db560ac7 100644 --- a/dygraph/paddleseg/datasets/ade.py +++ b/dygraph/paddleseg/datasets/ade.py @@ -17,12 +17,12 @@ import os import numpy as np from PIL import Image +import paddleseg.env as segenv from .dataset import Dataset from paddleseg.utils.download import download_file_and_uncompress from paddleseg.cvlibs import manager from paddleseg.transforms import Compose -DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip" @@ -61,8 +61,8 @@ class ADE20K(Dataset): "`dataset_root` not set and auto download disabled.") self.dataset_root = download_file_and_uncompress( url=URL, - savepath=DATA_HOME, - extrapath=DATA_HOME, + savepath=segenv.DATA_HOME, + extrapath=segenv.DATA_HOME, extraname='ADEChallengeData2016') elif not os.path.exists(self.dataset_root): raise Exception('there is not `dataset_root`: {}.'.format( diff --git a/dygraph/paddleseg/datasets/optic_disc_seg.py b/dygraph/paddleseg/datasets/optic_disc_seg.py index 6c1dedde..d9e0a21e 100644 --- a/dygraph/paddleseg/datasets/optic_disc_seg.py +++ b/dygraph/paddleseg/datasets/optic_disc_seg.py @@ -14,12 +14,12 @@ import os +import paddleseg.env as segenv from .dataset import Dataset from paddleseg.utils.download import download_file_and_uncompress from paddleseg.cvlibs import manager from paddleseg.transforms import Compose -DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip" @@ -49,7 +49,7 @@ class OpticDiscSeg(Dataset): raise Exception( "`data_root` not set and auto download disabled.") self.dataset_root = download_file_and_uncompress( - url=URL, savepath=DATA_HOME, extrapath=DATA_HOME) + url=URL, savepath=segenv.DATA_HOME, extrapath=segenv.DATA_HOME) elif not os.path.exists(self.dataset_root): raise Exception('there is not `dataset_root`: {}.'.format( self.dataset_root)) diff --git a/dygraph/paddleseg/datasets/voc.py b/dygraph/paddleseg/datasets/voc.py index c6ac4b6a..6c1c488c 100644 --- a/dygraph/paddleseg/datasets/voc.py +++ b/dygraph/paddleseg/datasets/voc.py @@ -14,12 +14,12 @@ import os +import paddleseg.env as segenv from .dataset import Dataset from paddleseg.utils.download import download_file_and_uncompress from paddleseg.cvlibs import manager from paddleseg.transforms import Compose -DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar" @@ -59,8 +59,8 @@ class PascalVOC(Dataset): "`dataset_root` not set and auto download disabled.") self.dataset_root = download_file_and_uncompress( url=URL, - savepath=DATA_HOME, - extrapath=DATA_HOME, + savepath=segenv.DATA_HOME, + extrapath=segenv.DATA_HOME, extraname='VOCdevkit') elif not os.path.exists(self.dataset_root): raise Exception('there is not `dataset_root`: {}.'.format( diff --git a/dygraph/paddleseg/env.py b/dygraph/paddleseg/env.py new file mode 100644 index 00000000..3bb0930a --- /dev/null +++ b/dygraph/paddleseg/env.py @@ -0,0 +1,50 @@ +# coding:utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil + +from paddleseg.utils import logger + + +def _get_user_home(): + return os.path.expanduser('~') + + +def _get_seg_home(): + if 'SEG_HOME' in os.environ: + home_path = os.environ['SEG_HOME'] + if os.path.exists(home_path): + if os.path.isdir(home_path): + return home_path + else: + logger.warning('SEG_HOME {} is a file!'.format(home_path)) + else: + return home_path + return os.path.join(_get_user_home(), '.paddleseg') + + +def _get_sub_home(directory): + home = os.path.join(_get_seg_home(), directory) + if not os.path.exists(home): + os.makedirs(home) + return home + + +USER_HOME = _get_user_home() +SEG_HOME = _get_seg_home() +DATA_HOME = _get_sub_home('dataset') +TMP_HOME = _get_sub_home('tmp') +PRETRAINED_MODEL_HOME = _get_sub_home('pretrained_model') diff --git a/dygraph/paddleseg/utils/utils.py b/dygraph/paddleseg/utils/utils.py index 0b7d8716..02f7d3b7 100644 --- a/dygraph/paddleseg/utils/utils.py +++ b/dygraph/paddleseg/utils/utils.py @@ -12,13 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import os import numpy as np import math import cv2 +import tempfile import paddle.fluid as fluid +from urllib.parse import urlparse, unquote -from . import logger +import filelock + +import paddleseg.env as segenv +from paddleseg.utils import logger +from paddleseg.utils.download import download_file_and_uncompress + + +@contextlib.contextmanager +def generate_tempdir(directory: str = None, **kwargs): + '''Generate a temporary directory''' + directory = segenv.TMP_HOME if not directory else directory + with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir: + yield _dir def seconds_to_hms(seconds): @@ -32,6 +47,18 @@ def seconds_to_hms(seconds): def load_pretrained_model(model, pretrained_model): if pretrained_model is not None: logger.info('Load pretrained model from {}'.format(pretrained_model)) + # download pretrained model from url + if urlparse(pretrained_model).netloc: + pretrained_model = unquote(pretrained_model) + savename = pretrained_model.split('/')[-1].split('.')[0] + with generate_tempdir() as _dir: + with filelock.FileLock(os.path.join(segenv.TMP_HOME, savename)): + pretrained_model = download_file_and_uncompress( + pretrained_model, + savepath=_dir, + extrapath=segenv.PRETRAINED_MODEL_HOME, + extraname=savename) + if os.path.exists(pretrained_model): ckpt_path = os.path.join(pretrained_model, 'model') try: diff --git a/dygraph/train.py b/dygraph/train.py index caa95833..861c301f 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -112,9 +112,10 @@ def main(args): val_dataset = cfg.val_dataset if args.do_eval else None losses = cfg.loss - print('---------------Config Information---------------') - print(cfg) - print('------------------------------------------------') + msg = '\n---------------Config Information---------------\n' + msg += str(cfg) + msg += '------------------------------------------------' + logger.info(msg) train( cfg.model, diff --git a/dygraph/val.py b/dygraph/val.py index 46ae4ffb..4d168dbf 100644 --- a/dygraph/val.py +++ b/dygraph/val.py @@ -19,7 +19,7 @@ from paddle.distributed import ParallelEnv import paddleseg from paddleseg.cvlibs import manager -from paddleseg.utils import get_environ_info, Config +from paddleseg.utils import get_environ_info, Config, logger from paddleseg.core import evaluate @@ -56,9 +56,10 @@ def main(args): 'The verification dataset is not specified in the configuration file.' ) - print('---------------Config Information---------------') - print(cfg) - print('------------------------------------------------') + msg = '\n---------------Config Information---------------\n' + msg += str(cfg) + msg += '------------------------------------------------' + logger.info(msg) evaluate( cfg.model, -- GitLab