提交 3ab90378 编写于 作者: D dengkaipeng

fix tsn pretrain load ResNet50

上级 921a0ef3
......@@ -125,3 +125,14 @@ class TSN(ModelBase):
return self.feature_input if self.mode == 'infer' else self.feature_input + [
self.label_input
]
def pretrain_info(self):
return ('ResNet50_pretrained', 'https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz')
def load_pretrain_params(self, exe, pretrain, prog, place):
def is_parameter(var):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name))
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars, main_program=prog)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册