提交 b30fd0df 编写于 作者: D dengkaipeng

add logger in load_pretrain_weights

上级 3ab90378
......@@ -143,6 +143,7 @@ class ModelBase(object):
return path
def load_pretrain_params(self, exe, pretrain, prog, place):
logger.info("Load pretrain weights from {}".format(pretrain))
fluid.io.load_params(exe, pretrain, main_program=prog)
def get_config_from_sec(self, sec, item, default=None):
......
......@@ -17,6 +17,9 @@ import paddle.fluid as fluid
from ..model import ModelBase
from .stnet_res_model import StNet_ResNet
import logging
logger = logging.getLogger(__name__)
__all__ = ["STNET"]
......@@ -131,6 +134,7 @@ class STNET(ModelBase):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \
and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name))
logger.info("Load pretrain weights from {}, exclude fc, batch_norm, xception, conv3d layers.".format(pretrain))
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars, main_program=prog)
......
......@@ -18,6 +18,9 @@ from paddle.fluid import ParamAttr
from ..model import ModelBase
from .tsn_res_model import TSN_ResNet
import logging
logger = logging.getLogger(__name__)
__all__ = ["TSN"]
......@@ -133,6 +136,7 @@ class TSN(ModelBase):
def is_parameter(var):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name))
logger.info("Load pretrain weights from {}, exclude fc layer.".format(pretrain))
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars, main_program=prog)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册