未验证 提交 4d78bac7 编写于 作者: S SunGaofeng 提交者: GitHub

Merge pull request #1757 from heavengate/fix_tsn_pretrain

fix tsn pretrain load ResNet50
......@@ -143,6 +143,7 @@ class ModelBase(object):
return path
def load_pretrain_params(self, exe, pretrain, prog, place):
logger.info("Load pretrain weights from {}".format(pretrain))
fluid.io.load_params(exe, pretrain, main_program=prog)
def get_config_from_sec(self, sec, item, default=None):
......
......@@ -17,6 +17,9 @@ import paddle.fluid as fluid
from ..model import ModelBase
from .stnet_res_model import StNet_ResNet
import logging
logger = logging.getLogger(__name__)
__all__ = ["STNET"]
......@@ -131,6 +134,7 @@ class STNET(ModelBase):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \
and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name))
logger.info("Load pretrain weights from {}, exclude fc, batch_norm, xception, conv3d layers.".format(pretrain))
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars, main_program=prog)
......
......@@ -18,6 +18,9 @@ from paddle.fluid import ParamAttr
from ..model import ModelBase
from .tsn_res_model import TSN_ResNet
import logging
logger = logging.getLogger(__name__)
__all__ = ["TSN"]
......@@ -125,3 +128,15 @@ class TSN(ModelBase):
return self.feature_input if self.mode == 'infer' else self.feature_input + [
self.label_input
]
def pretrain_info(self):
return ('ResNet50_pretrained', 'https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz')
def load_pretrain_params(self, exe, pretrain, prog, place):
def is_parameter(var):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name))
logger.info("Load pretrain weights from {}, exclude fc layer.".format(pretrain))
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars, main_program=prog)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册