提交 c4da736d 编写于 作者: M michaelowenliu

add load_entire_model

上级 4b5665d0
......@@ -12,13 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import numpy as np
import math
import cv2
import tempfile
import paddle.fluid as fluid
from urllib.parse import urlparse, unquote
from . import logger
import filelock
import paddleseg.env as segenv
from paddleseg.utils import logger
from paddleseg.utils.download import download_file_and_uncompress
@contextlib.contextmanager
def generate_tempdir(directory: str = None, **kwargs):
'''Generate a temporary directory'''
directory = segenv.TMP_HOME if not directory else directory
with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir:
yield _dir
def seconds_to_hms(seconds):
......@@ -38,12 +53,25 @@ def load_entire_model(model, pretrained):
raise Exception('Pretrained model is not found: {}'.format(
pretrained))
else:
logger.warning('Not all pretrained parameters of {} to load, '\
logger.warning('Not all pretrained params of {} to load, '\
'training from scratch or a pretrained backbone'.format(model.__class__.__name__))
def load_pretrained_model(model, pretrained_model):
if pretrained_model is not None:
logger.info('Load pretrained model from {}'.format(pretrained_model))
# download pretrained model from url
if urlparse(pretrained_model).netloc:
pretrained_model = unquote(pretrained_model)
savename = pretrained_model.split('/')[-1].split('.')[0]
with generate_tempdir() as _dir:
with filelock.FileLock(os.path.join(segenv.TMP_HOME, savename)):
pretrained_model = download_file_and_uncompress(
pretrained_model,
savepath=_dir,
extrapath=segenv.PRETRAINED_MODEL_HOME,
extraname=savename)
if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model')
try:
......@@ -67,7 +95,7 @@ def load_pretrained_model(model, pretrained_model):
model_state_dict[k] = para_state_dict[k]
num_params_loaded += 1
model.set_dict(model_state_dict)
logger.info("There are {}/{} varaibles are loaded.".format(
logger.info("There are {}/{} variables are loaded.".format(
num_params_loaded, len(model_state_dict)))
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册