未验证 提交 9bce908f 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #1811 from heavengate/add_weight_info

add weights_info for lstm/tsn/stnet
......@@ -147,4 +147,5 @@ class AttentionLSTM(ModelBase):
]
def weights_info(self):
return (None, None)
return ('attention_lstm_youtube8m',
'https://paddlemodels.bj.bcebos.com/video_classification/attention_lstm_youtube8m.tar.gz')
......@@ -128,6 +128,10 @@ class STNET(ModelBase):
def pretrain_info(self):
return ('ResNet50_pretrained', 'https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz')
def weights_info(self):
return ('stnet_kinetics',
'https://paddlemodels.bj.bcebos.com/video_classification/stnet_kinetics.tar.gz')
def load_pretrain_params(self, exe, pretrain, prog, place):
def is_parameter(var):
if isinstance(var, fluid.framework.Parameter):
......
......@@ -132,6 +132,10 @@ class TSN(ModelBase):
def pretrain_info(self):
return ('ResNet50_pretrained', 'https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz')
def weights_info(self):
return ('tsn_kinetics',
'https://paddlemodels.bj.bcebos.com/video_classification/tsn_kinetics.tar.gz')
def load_pretrain_params(self, exe, pretrain, prog, place):
def is_parameter(var):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册