diff --git a/dygraph/paddleseg/utils/utils.py b/dygraph/paddleseg/utils/utils.py index 02f7d3b7f9b9f3c8f2674e8ee0dcf7f52730b236..8b4a731ae0312b9eb1748f51a6046db834519d3e 100644 --- a/dygraph/paddleseg/utils/utils.py +++ b/dygraph/paddleseg/utils/utils.py @@ -12,28 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import os import numpy as np import math import cv2 -import tempfile import paddle.fluid as fluid -from urllib.parse import urlparse, unquote -import filelock - -import paddleseg.env as segenv -from paddleseg.utils import logger -from paddleseg.utils.download import download_file_and_uncompress - - -@contextlib.contextmanager -def generate_tempdir(directory: str = None, **kwargs): - '''Generate a temporary directory''' - directory = segenv.TMP_HOME if not directory else directory - with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir: - yield _dir +from . import logger def seconds_to_hms(seconds): @@ -44,21 +29,21 @@ def seconds_to_hms(seconds): return hms_str +def load_entire_model(model, pretrained): + + if pretrained is not None: + if os.path.exists(pretrained): + load_pretrained_model(model, pretrained) + else: + raise Exception('Pretrained model is not found: {}'.format( + pretrained)) + else: + logger.warning('Not all pretrained parameters of {} to load, '\ + 'training from scratch or a pretrained backbone'.format(model.__class__.__name__)) + def load_pretrained_model(model, pretrained_model): if pretrained_model is not None: logger.info('Load pretrained model from {}'.format(pretrained_model)) - # download pretrained model from url - if urlparse(pretrained_model).netloc: - pretrained_model = unquote(pretrained_model) - savename = pretrained_model.split('/')[-1].split('.')[0] - with generate_tempdir() as _dir: - with filelock.FileLock(os.path.join(segenv.TMP_HOME, savename)): - pretrained_model = download_file_and_uncompress( - pretrained_model, - savepath=_dir, - extrapath=segenv.PRETRAINED_MODEL_HOME, - extraname=savename) - if os.path.exists(pretrained_model): ckpt_path = os.path.join(pretrained_model, 'model') try: