未验证 提交 9fdbfa49 编写于 作者: K kinghuin 提交者: GitHub

rm wget dependency, del commented code, add unitest (#613)

* rm wget dependency, del commented code
上级 5aa513e3
......@@ -51,7 +51,7 @@ print(results)
## 依赖
paddlepaddle >= 1.6.2
paddlepaddle >= 1.7.2
paddlehub >= 1.6.0
......
......@@ -17,12 +17,11 @@ from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import os
import io
import logging
import numpy as np
import json
from videotag_tsn_lstm.resource.metrics.youtube8m import eval_util as youtube8m_metrics
logger = logging.getLogger(__name__)
......
......@@ -12,7 +12,7 @@
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import logging
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
......@@ -20,10 +20,8 @@ from paddle.fluid import ParamAttr
from ..model import ModelBase
from .lstm_attention import LSTMAttentionModel
import logging
logger = logging.getLogger(__name__)
__all__ = ["AttentionLSTM"]
logger = logging.getLogger(__name__)
class AttentionLSTM(ModelBase):
......@@ -51,7 +49,6 @@ class AttentionLSTM(ModelBase):
self.feature_input.append(
fluid.data(
shape=[None, dim], lod_level=1, dtype='float32', name=name))
# self.label_input = None
if use_dataloader:
assert self.mode != 'infer', \
'dataloader is not recommendated when infer, please set use_dataloader to be false.'
......@@ -138,15 +135,6 @@ class AttentionLSTM(ModelBase):
)
def load_pretrain_params(self, exe, pretrain, prog, place):
#def is_parameter(var):
# return isinstance(var, fluid.framework.Parameter)
#params_list = list(filter(is_parameter, prog.list_vars()))
#for param in params_list:
# print(param.name)
#assert False, "stop here"
logger.info(
"Load pretrain weights from {}, exclude fc layer.".format(pretrain))
......@@ -159,18 +147,3 @@ class AttentionLSTM(ModelBase):
'Delete {} from pretrained parameters. Do not load it'.
format(name))
fluid.set_program_state(prog, state_dict)
# def load_test_weights(self, exe, weights, prog):
# def is_parameter(var):
# return isinstance(var, fluid.framework.Parameter)
# params_list = list(filter(is_parameter, prog.list_vars()))
# state_dict = np.load(weights)
# for p in params_list:
# if p.name in state_dict.keys():
# logger.info('########### load param {} from file'.format(p.name))
# else:
# logger.info('----------- param {} not in file'.format(p.name))
# fluid.set_program_state(prog, state_dict)
# fluid.save(prog, './weights/attention_lstm')
......@@ -11,10 +11,8 @@
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
class LSTMAttentionModel(object):
......@@ -39,15 +37,6 @@ class LSTMAttentionModel(object):
initializer=fluid.initializer.NormalInitializer(scale=0.0)),
name='rgb_fc')
#lstm_forward_fc = fluid.layers.fc(
# input=input_fc,
# size=self.lstm_size * 4,
# act=None,
# bias_attr=ParamAttr(
# regularizer=fluid.regularizer.L2Decay(0.0),
# initializer=fluid.initializer.NormalInitializer(scale=0.0)),
# name='rgb_fc_forward')
lstm_forward_fc = fluid.layers.fc(
input=input_fc,
size=self.lstm_size * 4,
......@@ -61,15 +50,6 @@ class LSTMAttentionModel(object):
is_reverse=False,
name='rgb_lstm_forward')
#lsmt_backward_fc = fluid.layers.fc(
# input=input_fc,
# size=self.lstm_size * 4,
# act=None,
# bias_attr=ParamAttr(
# regularizer=fluid.regularizer.L2Decay(0.0),
# initializer=fluid.initializer.NormalInitializer(scale=0.0)),
# name='rgb_fc_backward')
lsmt_backward_fc = fluid.layers.fc(
input=input_fc,
size=self.lstm_size * 4,
......@@ -91,15 +71,6 @@ class LSTMAttentionModel(object):
dropout_prob=self.drop_rate,
is_test=(not is_training))
#lstm_weight = fluid.layers.fc(
# input=lstm_dropout,
# size=1,
# act='sequence_softmax',
# bias_attr=ParamAttr(
# regularizer=fluid.regularizer.L2Decay(0.0),
# initializer=fluid.initializer.NormalInitializer(scale=0.0)),
# name='rgb_weight')
lstm_weight = fluid.layers.fc(
input=lstm_dropout,
size=1,
......
......@@ -13,7 +13,6 @@
#limitations under the License.
import os
import wget
import logging
try:
from configparser import ConfigParser
......@@ -21,7 +20,6 @@ except:
from ConfigParser import ConfigParser
import paddle.fluid as fluid
from .utils import download, AttrDict
WEIGHT_DIR = os.path.join(os.path.expanduser('~'), '.paddle', 'weights')
......@@ -103,21 +101,6 @@ class ModelBase(object):
"get model weight default path and download url"
raise NotImplementError(self, self.weights_info)
def get_weights(self):
"get model weight file path, download weight from Paddle if not exist"
path, url = self.weights_info()
path = os.path.join(WEIGHT_DIR, path)
if not os.path.isdir(WEIGHT_DIR):
logger.info('{} not exists, will be created automatically.'.format(
WEIGHT_DIR))
os.makedirs(WEIGHT_DIR)
if os.path.exists(path):
return path
logger.info("Download weights of {} from {}".format(self.name, url))
wget.download(url, path)
return path
def dataloader(self):
return self.dataloader
......@@ -129,25 +112,6 @@ class ModelBase(object):
"get pretrain base model directory"
return (None, None)
def get_pretrain_weights(self):
"get model weight file path, download weight from Paddle if not exist"
path, url = self.pretrain_info()
if not path:
return None
path = os.path.join(WEIGHT_DIR, path)
if not os.path.isdir(WEIGHT_DIR):
logger.info('{} not exists, will be created automatically.'.format(
WEIGHT_DIR))
os.makedirs(WEIGHT_DIR)
if os.path.exists(path):
return path
logger.info("Download pretrain weights of {} from {}".format(
self.name, url))
download(url, path)
return path
def load_pretrain_params(self, exe, pretrain, prog, place):
logger.info("Load pretrain weights from {}".format(pretrain))
state_dict = fluid.load_program_state(pretrain)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import wget
import tarfile
__all__ = ['decompress', 'download', 'AttrDict']
def decompress(path):
t = tarfile.open(path)
t.extractall(path=os.path.split(path)[0])
t.close()
os.remove(path)
def download(url, path):
weight_dir = os.path.split(path)[0]
if not os.path.exists(weight_dir):
os.makedirs(weight_dir)
path = path + ".tar.gz"
wget.download(url, path)
decompress(path)
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
......@@ -11,23 +11,22 @@
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import cv2
import math
import random
import functools
import logging
try:
import cPickle as pickle
from cStringIO import StringIO
except ImportError:
import pickle
from io import BytesIO
import numpy as np
import paddle
from PIL import Image, ImageEnhance
import logging
import cv2
import numpy as np
from PIL import Image
from .reader_utils import DataReader
......
......@@ -12,11 +12,6 @@
#See the License for the specific language governing permissions and
#limitations under the License.
import pickle
import cv2
import numpy as np
import random
class ReaderNotFoundError(Exception):
"Error: reader not found"
......
......@@ -12,9 +12,10 @@
#See the License for the specific language governing permissions and
#limitations under the License.
import yaml
from .utility import AttrDict
import logging
from .utility import AttrDict
logger = logging.getLogger(__name__)
CONFIG_SECS = [
......
......@@ -14,13 +14,12 @@
import os
import sys
import logging
import time
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import profiler
import logging
import shutil
logger = logging.getLogger(__name__)
......
......@@ -16,7 +16,7 @@ import os
import sys
import signal
import logging
import paddle
import paddle.fluid as fluid
__all__ = ['AttrDict']
......
# coding=utf-8
import unittest
import paddlehub as hub
class TestVideoTag(unittest.TestCase):
def setUp(self):
"Call setUp() to prepare environment\n"
self.module = hub.Module(name='videotag_tsn_lstm')
self.test_video = [
"../video_dataset/classification/1.mp4",
"../video_dataset/classification/2.mp4"
]
def test_classification(self):
default_expect1 = {
'训练': 0.9771281480789185,
'蹲': 0.9389840960502625,
'杠铃': 0.8554490804672241,
'健身房': 0.8479971885681152
}
default_expect2 = {'舞蹈': 0.8504238724708557}
for use_gpu in [True, False]:
for threshold in [0.5, 0.9]:
for top_k in [10, 1]:
expect1 = {}
expect2 = {}
for key, value in default_expect1.items():
if value >= threshold:
expect1[key] = value
if len(expect1.keys()) >= top_k:
break
for key, value in default_expect2.items():
if value >= threshold:
expect2[key] = value
if len(expect2.keys()) >= top_k:
break
results = self.module.classify(
paths=self.test_video,
use_gpu=use_gpu,
threshold=threshold,
top_k=top_k)
for result in results:
if '1.mp4' in result['path']:
self.assertEqual(result['prediction'], expect1)
else:
self.assertEqual(result['prediction'], expect2)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册