提交 8e0e4e39 编写于 作者: M michaelowenliu

add load_entire_model

上级 9efa0289
......@@ -12,28 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import numpy as np
import math
import cv2
import tempfile
import paddle.fluid as fluid
from urllib.parse import urlparse, unquote
import filelock
import paddleseg.env as segenv
from paddleseg.utils import logger
from paddleseg.utils.download import download_file_and_uncompress
@contextlib.contextmanager
def generate_tempdir(directory: str = None, **kwargs):
'''Generate a temporary directory'''
directory = segenv.TMP_HOME if not directory else directory
with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir:
yield _dir
from . import logger
def seconds_to_hms(seconds):
......@@ -44,21 +29,21 @@ def seconds_to_hms(seconds):
return hms_str
def load_entire_model(model, pretrained):
if pretrained is not None:
if os.path.exists(pretrained):
load_pretrained_model(model, pretrained)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained))
else:
logger.warning('Not all pretrained parameters of {} to load, '\
'training from scratch or a pretrained backbone'.format(model.__class__.__name__))
def load_pretrained_model(model, pretrained_model):
if pretrained_model is not None:
logger.info('Load pretrained model from {}'.format(pretrained_model))
# download pretrained model from url
if urlparse(pretrained_model).netloc:
pretrained_model = unquote(pretrained_model)
savename = pretrained_model.split('/')[-1].split('.')[0]
with generate_tempdir() as _dir:
with filelock.FileLock(os.path.join(segenv.TMP_HOME, savename)):
pretrained_model = download_file_and_uncompress(
pretrained_model,
savepath=_dir,
extrapath=segenv.PRETRAINED_MODEL_HOME,
extraname=savename)
if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model')
try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册