提交 c4da736d 编写于 作者: M michaelowenliu

add load_entire_model

上级 4b5665d0
...@@ -12,13 +12,28 @@ ...@@ -12,13 +12,28 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import os import os
import numpy as np import numpy as np
import math import math
import cv2 import cv2
import tempfile
import paddle.fluid as fluid import paddle.fluid as fluid
from urllib.parse import urlparse, unquote
from . import logger import filelock
import paddleseg.env as segenv
from paddleseg.utils import logger
from paddleseg.utils.download import download_file_and_uncompress
@contextlib.contextmanager
def generate_tempdir(directory: str = None, **kwargs):
'''Generate a temporary directory'''
directory = segenv.TMP_HOME if not directory else directory
with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir:
yield _dir
def seconds_to_hms(seconds): def seconds_to_hms(seconds):
...@@ -38,12 +53,25 @@ def load_entire_model(model, pretrained): ...@@ -38,12 +53,25 @@ def load_entire_model(model, pretrained):
raise Exception('Pretrained model is not found: {}'.format( raise Exception('Pretrained model is not found: {}'.format(
pretrained)) pretrained))
else: else:
logger.warning('Not all pretrained parameters of {} to load, '\ logger.warning('Not all pretrained params of {} to load, '\
'training from scratch or a pretrained backbone'.format(model.__class__.__name__)) 'training from scratch or a pretrained backbone'.format(model.__class__.__name__))
def load_pretrained_model(model, pretrained_model): def load_pretrained_model(model, pretrained_model):
if pretrained_model is not None: if pretrained_model is not None:
logger.info('Load pretrained model from {}'.format(pretrained_model)) logger.info('Load pretrained model from {}'.format(pretrained_model))
# download pretrained model from url
if urlparse(pretrained_model).netloc:
pretrained_model = unquote(pretrained_model)
savename = pretrained_model.split('/')[-1].split('.')[0]
with generate_tempdir() as _dir:
with filelock.FileLock(os.path.join(segenv.TMP_HOME, savename)):
pretrained_model = download_file_and_uncompress(
pretrained_model,
savepath=_dir,
extrapath=segenv.PRETRAINED_MODEL_HOME,
extraname=savename)
if os.path.exists(pretrained_model): if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model') ckpt_path = os.path.join(pretrained_model, 'model')
try: try:
...@@ -67,7 +95,7 @@ def load_pretrained_model(model, pretrained_model): ...@@ -67,7 +95,7 @@ def load_pretrained_model(model, pretrained_model):
model_state_dict[k] = para_state_dict[k] model_state_dict[k] = para_state_dict[k]
num_params_loaded += 1 num_params_loaded += 1
model.set_dict(model_state_dict) model.set_dict(model_state_dict)
logger.info("There are {}/{} varaibles are loaded.".format( logger.info("There are {}/{} variables are loaded.".format(
num_params_loaded, len(model_state_dict))) num_params_loaded, len(model_state_dict)))
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册