From 20aab067b97597ac77d2d9ec4f830299dd67206b Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 1 Mar 2019 10:23:03 +0000 Subject: [PATCH] add weights_info for lstm/tsn/stnet --- fluid/PaddleCV/video/models/attention_lstm/attention_lstm.py | 3 ++- fluid/PaddleCV/video/models/stnet/stnet.py | 4 ++++ fluid/PaddleCV/video/models/tsn/tsn.py | 4 ++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/fluid/PaddleCV/video/models/attention_lstm/attention_lstm.py b/fluid/PaddleCV/video/models/attention_lstm/attention_lstm.py index 88bb6f33..5d28dc47 100755 --- a/fluid/PaddleCV/video/models/attention_lstm/attention_lstm.py +++ b/fluid/PaddleCV/video/models/attention_lstm/attention_lstm.py @@ -147,4 +147,5 @@ class AttentionLSTM(ModelBase): ] def weights_info(self): - return (None, None) + return ('attention_lstm_youtube8m', + 'https://paddlemodels.bj.bcebos.com/video_classification/attention_lstm_youtube8m.tar.gz') diff --git a/fluid/PaddleCV/video/models/stnet/stnet.py b/fluid/PaddleCV/video/models/stnet/stnet.py index e20ad0bd..c408aa08 100644 --- a/fluid/PaddleCV/video/models/stnet/stnet.py +++ b/fluid/PaddleCV/video/models/stnet/stnet.py @@ -128,6 +128,10 @@ class STNET(ModelBase): def pretrain_info(self): return ('ResNet50_pretrained', 'https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz') + def weights_info(self): + return ('stnet_kinetics', + 'https://paddlemodels.bj.bcebos.com/video_classification/stnet_kinetics.tar.gz') + def load_pretrain_params(self, exe, pretrain, prog, place): def is_parameter(var): if isinstance(var, fluid.framework.Parameter): diff --git a/fluid/PaddleCV/video/models/tsn/tsn.py b/fluid/PaddleCV/video/models/tsn/tsn.py index 5bc8aba3..82fdb327 100644 --- a/fluid/PaddleCV/video/models/tsn/tsn.py +++ b/fluid/PaddleCV/video/models/tsn/tsn.py @@ -132,6 +132,10 @@ class TSN(ModelBase): def pretrain_info(self): return ('ResNet50_pretrained', 'https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz') + def weights_info(self): + return ('tsn_kinetics', + 'https://paddlemodels.bj.bcebos.com/video_classification/tsn_kinetics.tar.gz') + def load_pretrain_params(self, exe, pretrain, prog, place): def is_parameter(var): return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) -- GitLab