未验证 提交 4d78bac7 编写于 作者: S SunGaofeng 提交者: GitHub

Merge pull request #1757 from heavengate/fix_tsn_pretrain

fix tsn pretrain load ResNet50
...@@ -143,6 +143,7 @@ class ModelBase(object): ...@@ -143,6 +143,7 @@ class ModelBase(object):
return path return path
def load_pretrain_params(self, exe, pretrain, prog, place): def load_pretrain_params(self, exe, pretrain, prog, place):
logger.info("Load pretrain weights from {}".format(pretrain))
fluid.io.load_params(exe, pretrain, main_program=prog) fluid.io.load_params(exe, pretrain, main_program=prog)
def get_config_from_sec(self, sec, item, default=None): def get_config_from_sec(self, sec, item, default=None):
......
...@@ -17,6 +17,9 @@ import paddle.fluid as fluid ...@@ -17,6 +17,9 @@ import paddle.fluid as fluid
from ..model import ModelBase from ..model import ModelBase
from .stnet_res_model import StNet_ResNet from .stnet_res_model import StNet_ResNet
import logging
logger = logging.getLogger(__name__)
__all__ = ["STNET"] __all__ = ["STNET"]
...@@ -131,6 +134,7 @@ class STNET(ModelBase): ...@@ -131,6 +134,7 @@ class STNET(ModelBase):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \ return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \
and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name)) and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name))
logger.info("Load pretrain weights from {}, exclude fc, batch_norm, xception, conv3d layers.".format(pretrain))
vars = filter(is_parameter, prog.list_vars()) vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars, main_program=prog) fluid.io.load_vars(exe, pretrain, vars=vars, main_program=prog)
......
...@@ -18,6 +18,9 @@ from paddle.fluid import ParamAttr ...@@ -18,6 +18,9 @@ from paddle.fluid import ParamAttr
from ..model import ModelBase from ..model import ModelBase
from .tsn_res_model import TSN_ResNet from .tsn_res_model import TSN_ResNet
import logging
logger = logging.getLogger(__name__)
__all__ = ["TSN"] __all__ = ["TSN"]
...@@ -125,3 +128,15 @@ class TSN(ModelBase): ...@@ -125,3 +128,15 @@ class TSN(ModelBase):
return self.feature_input if self.mode == 'infer' else self.feature_input + [ return self.feature_input if self.mode == 'infer' else self.feature_input + [
self.label_input self.label_input
] ]
def pretrain_info(self):
return ('ResNet50_pretrained', 'https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz')
def load_pretrain_params(self, exe, pretrain, prog, place):
def is_parameter(var):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name))
logger.info("Load pretrain weights from {}, exclude fc layer.".format(pretrain))
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars, main_program=prog)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册