未验证 提交 fe44893b 编写于 作者: K kinghuin 提交者: GitHub

add videotag_tsn_lstm (#576)

* add videotag_tsn_lstm
上级 b7e8230f
```shell
$ hub install videotag_tsn_lstm==1.0.0
```
<p align="center">
<img src="https://paddlehub.bj.bcebos.com/model/video/video_classifcation/VideoTag_TSN_AttentionLSTM.png" hspace='10'/> <br />
</p>
具体网络结构可参考论文[TSN](https://arxiv.org/abs/1608.00859)[AttentionLSTM](https://arxiv.org/abs/1503.08909)
## 命令行预测示例
```shell
hub run videotag_tsn_lstm --input_path 1.mp4 --use_gpu False
```
示例文件下载:
* [1.mp4](https://paddlehub.bj.bcebos.com/model/video/video_classifcation/1.mp4)
* [2.mp4](https://paddlehub.bj.bcebos.com/model/video/video_classifcation/2.mp4)
## API
```python
def classification(paths,
use_gpu=False,
threshold=0.5,
top_k=10)
```
用于视频分类预测
**参数**
* paths(list\[str\]):mp4文件路径
* use_gpu(bool):是否使用GPU预测,默认为False
* threshold(float):预测结果阈值,只有预测概率大于阈值的类别会被返回,默认为0.5
* top_k(int): 返回预测结果的前k个,默认为10
**返回**
* results(list\[dict\]): result中的每个元素为对应输入的预测结果,预测单个mp4文件时仅有1个元素。每个预测结果为dict,包含mp4文件路径path及其分类概率。例:
```shell
[{'path': '1.mp4', 'prediction': {'训练': 0.9771281480789185, '蹲': 0.9389840960502625, '杠铃': 0.8554490804672241, '健身房': 0.8479971885681152}}, {'path': '2.mp4', 'prediction': {'舞蹈': 0.8504238724708557}}]
```
**代码示例**
```python
import paddlehub as hub
videotag = hub.Module(name="videotag_tsn_lstm")
# execute predict and print the result
results = videotag.classification(paths=["1.mp4","2.mp4"], use_gpu=True)
for result in results:
print(result)
```
## 依赖
paddlepaddle >= 1.6.2
paddlehub >= 1.6.0
## 更新历史
* 1.0.0
初始发布
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ast
import os
import paddle.fluid as fluid
import paddlehub as hub
from paddlehub.module.module import moduleinfo, runnable
from paddlehub.common.logger import logger
from videotag_tsn_lstm.resource.utils.config_utils import *
import videotag_tsn_lstm.resource.models as models
from videotag_tsn_lstm.resource.reader import get_reader
from videotag_tsn_lstm.resource.metrics import get_metrics
from videotag_tsn_lstm.resource.utils.utility import check_cuda
from videotag_tsn_lstm.resource.utils.utility import check_version
@moduleinfo(
name="videotag_tsn_lstm",
version="1.0.0",
summary=
"videotag_tsn_lstm is a video classification model, using TSN for feature extraction and AttentionLSTM for classification",
author="paddlepaddle",
author_email="paddle-dev@baidu.com",
type="video/classification",
)
class VideoTag(hub.Module):
def _initialize(self):
# add arg parser
self.parser = argparse.ArgumentParser(
description="Run the videotag_tsn_lstm module.",
prog='hub run videotag_tsn_lstm',
usage='%(prog)s',
add_help=True)
self.parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=False,
help='default use gpu.')
self.parser.add_argument(
'--input_path',
type=str,
default=None,
help='path of video data, single video')
self._has_load = False
def _extractor(self, args, exe, place):
extractor_scope = fluid.Scope()
with fluid.scope_guard(extractor_scope):
extractor_startup_prog = fluid.Program()
extractor_main_prog = fluid.Program()
with fluid.program_guard(extractor_main_prog,
extractor_startup_prog):
extractor_config = parse_config(args.extractor_config)
extractor_infer_config = merge_configs(extractor_config,
'infer', vars(args))
# build model
extractor_model = models.get_model(
"TSN", extractor_infer_config, mode='infer')
extractor_model.build_input(use_dataloader=False)
extractor_model.build_model()
extractor_feeds = extractor_model.feeds()
extractor_fetch_list = extractor_model.fetches()
exe.run(extractor_startup_prog)
logger.info('load extractor weights from {}'.format(
args.extractor_weights))
extractor_model.load_test_weights(exe, args.extractor_weights,
extractor_main_prog)
# get reader and metrics
extractor_reader = get_reader("TSN", 'infer',
extractor_infer_config)
extractor_feeder = fluid.DataFeeder(
place=place, feed_list=extractor_feeds)
return extractor_reader, extractor_main_prog, extractor_fetch_list, extractor_feeder, extractor_scope
def _predictor(self, args, exe, place):
predictor_scope = fluid.Scope()
with fluid.scope_guard(predictor_scope):
predictor_startup_prog = fluid.default_startup_program()
predictor_main_prog = fluid.default_main_program()
with fluid.program_guard(predictor_main_prog,
predictor_startup_prog):
# parse config
predictor_config = parse_config(args.predictor_config)
predictor_infer_config = merge_configs(predictor_config,
'infer', vars(args))
predictor_model = models.get_model(
"AttentionLSTM", predictor_infer_config, mode='infer')
predictor_model.build_input(use_dataloader=False)
predictor_model.build_model()
predictor_feeds = predictor_model.feeds()
predictor_outputs = predictor_model.outputs()
exe.run(predictor_startup_prog)
logger.info('load lstm weights from {}'.format(
args.predictor_weights))
predictor_model.load_test_weights(exe, args.predictor_weights,
predictor_main_prog)
predictor_feeder = fluid.DataFeeder(
place=place, feed_list=predictor_feeds)
predictor_fetch_list = predictor_model.fetches()
return predictor_main_prog, predictor_fetch_list, predictor_feeder, predictor_scope
@runnable
def run_cmd(self, argsv):
args = self.parser.parse_args(argsv)
results = self.classification(
paths=[args.input_path], use_gpu=args.use_gpu)
return results
def classification(self, paths, use_gpu=False, threshold=0.5, top_k=10):
"""
API of Classification.
Args:
paths (list[str]): the path of mp4s.
use_gpu (bool): whether to use gpu or not.
threshold (float): the result value >= threshold will be returned.
top_k (int): the top k result will be returned.
Returns:
results (list[dict]): every dict includes the mp4 file path and prediction.
"""
args = self.parser.parse_args([])
# config the args in videotag_tsn_lstm
args.use_gpu = use_gpu
args.filelist = paths
args.topk = top_k
args.threshold = threshold
args.extractor_config = os.path.join(self.directory, 'resource',
'configs', 'tsn.yaml')
args.predictor_config = os.path.join(self.directory, 'resource',
'configs', 'attention_lstm.yaml')
args.extractor_weights = os.path.join(self.directory, 'weights', 'tsn')
args.predictor_weights = os.path.join(self.directory, 'weights',
'attention_lstm')
args.label_file = os.path.join(self.directory, 'resource',
'label_3396.txt')
check_cuda(args.use_gpu)
check_version()
if not self._has_load:
self.place = fluid.CUDAPlace(
0) if args.use_gpu else fluid.CPUPlace()
self.exe = fluid.Executor(self.place)
self.extractor_reader, self.extractor_main_prog, self.extractor_fetch_list, self.extractor_feeder, self.extractor_scope = self._extractor(
args, self.exe, self.place)
self.predictor_main_prog, self.predictor_fetch_list, self.predictor_feeder, self.predictor_scope = self._predictor(
args, self.exe, self.place)
self._has_load = True
feature_list = []
file_list = []
for idx, data in enumerate(self.extractor_reader()):
file_id = [item[-1] for item in data]
feed_data = [item[:-1] for item in data]
feature_out = self.exe.run(
program=self.extractor_main_prog,
fetch_list=self.extractor_fetch_list,
feed=self.extractor_feeder.feed(feed_data),
scope=self.extractor_scope)
feature_list.append(feature_out)
file_list.append(file_id)
logger.info(
'========[Stage 1 Sample {} ] Tsn feature extractor finished======'
.format(idx))
# get AttentionLSTM input from Tsn output
num_frames = 300
predictor_feed_list = []
for i in range(len(feature_list)):
feature_out = feature_list[i]
extractor_feature = feature_out[0]
predictor_feed_data = [[
extractor_feature[0].astype(float)[0:num_frames, :]
]]
predictor_feed_list.append((predictor_feed_data, file_list[i]))
metrics_config = parse_config(args.predictor_config)
metrics_config['MODEL']['topk'] = args.topk
metrics_config['MODEL']['threshold'] = args.threshold
predictor_metrics = get_metrics("AttentionLSTM".upper(), 'infer',
metrics_config)
predictor_metrics.reset()
for idx, data in enumerate(predictor_feed_list):
file_id = data[1]
predictor_feed_data = data[0]
final_outs = self.exe.run(
program=self.predictor_main_prog,
fetch_list=self.predictor_fetch_list,
feed=self.predictor_feeder.feed(predictor_feed_data, ),
scope=self.predictor_scope)
logger.info(
'=======[Stage 2 Sample {} ] AttentionLSTM predict finished========'
.format(idx))
final_result_list = [item for item in final_outs] + [file_id]
predictor_metrics.accumulate(final_result_list)
results = predictor_metrics.finalize_and_log_out(
label_file=args.label_file)
return results
if __name__ == '__main__':
test_module = VideoTag()
print(
test_module.run_cmd(
argsv=['--input_path', "1.mp4", '--use_gpu',
str(False)]))
MODEL:
name: "AttentionLSTM"
dataset: "YouTube-8M"
bone_nework: None
drop_rate: 0.5
feature_num: 2
feature_names: ['rgb']
feature_dims: [2048]
embedding_size: 1024
lstm_size: 512
num_classes: 3396
topk: 10
INFER:
batch_size: 1
MODEL:
name: "TSN"
format: "mp4"
num_classes: 400
seg_num: 3
seglen: 1
image_mean: [0.485, 0.456, 0.406]
image_std: [0.229, 0.224, 0.225]
num_layers: 50
topk: 5
INFER:
seg_num: 300
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 1
kinetics_labels: "./data/kinetics_labels.json"
filelist: "./data/tsn.list"
video_path: "./data/mp4/1.mp4"
single_file: True
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import os
import io
import logging
import numpy as np
import json
from videotag_tsn_lstm.resource.metrics.youtube8m import eval_util as youtube8m_metrics
logger = logging.getLogger(__name__)
class Metrics(object):
def __init__(self, name, mode, metrics_args):
"""Not implemented"""
pass
def calculate_and_log_out(self, fetch_list, info=''):
"""Not implemented"""
pass
def accumulate(self, fetch_list, info=''):
"""Not implemented"""
pass
def finalize_and_log_out(self, info='', savedir='./'):
"""Not implemented"""
pass
def reset(self):
"""Not implemented"""
pass
class Youtube8mMetrics(Metrics):
def __init__(self, name, mode, metrics_args):
self.name = name
self.mode = mode
self.num_classes = metrics_args['MODEL']['num_classes']
self.topk = metrics_args['MODEL']['topk']
self.threshold = metrics_args['MODEL']['threshold']
self.calculator = youtube8m_metrics.EvaluationMetrics(
self.num_classes, self.topk)
if self.mode == 'infer':
self.infer_results = []
def calculate_and_log_out(self, fetch_list, info=''):
loss = np.mean(np.array(fetch_list[0]))
pred = np.array(fetch_list[1])
label = np.array(fetch_list[2])
hit_at_one = youtube8m_metrics.calculate_hit_at_one(pred, label)
perr = youtube8m_metrics.calculate_precision_at_equal_recall_rate(
pred, label)
gap = youtube8m_metrics.calculate_gap(pred, label)
logger.info(info + ' , loss = {0}, Hit@1 = {1}, PERR = {2}, GAP = {3}'.format(\
'%.6f' % loss, '%.2f' % hit_at_one, '%.2f' % perr, '%.2f' % gap))
def accumulate(self, fetch_list, info=''):
if self.mode == 'infer':
predictions = np.array(fetch_list[0])
video_id = fetch_list[1]
for i in range(len(predictions)):
topk_inds = predictions[i].argsort()[0 - self.topk:]
topk_inds = topk_inds[::-1]
preds = predictions[i][topk_inds]
self.infer_results.append((video_id[i], topk_inds.tolist(),
preds.tolist()))
else:
loss = np.array(fetch_list[0])
pred = np.array(fetch_list[1])
label = np.array(fetch_list[2])
self.calculator.accumulate(loss, pred, label)
def finalize_and_log_out(self, info='', label_file='./label_3396.txt'):
if self.mode == 'infer':
all_res_list = []
for index, item in enumerate(self.infer_results):
video_id = item[0]
f = io.open(label_file, "r", encoding="utf-8")
fl = f.readlines()
res = {}
res["path"] = video_id
res["prediction"] = {}
for i in range(len(item[1])):
class_id = item[1][i]
class_prob = item[2][i]
if class_prob < self.threshold:
continue
class_name = fl[class_id].split('\n')[0]
res["prediction"][class_name] = class_prob
if not res["prediction"]:
logger.warning(
"%s: No prediction exceeds the threshold = %s." %
(video_id, self.threshold))
all_res_list.append(res)
return all_res_list
else:
epoch_info_dict = self.calculator.get()
logger.info(info + '\tavg_hit_at_one: {0},\tavg_perr: {1},\tavg_loss :{2},\taps: {3},\tgap:{4}'\
.format(epoch_info_dict['avg_hit_at_one'], epoch_info_dict['avg_perr'], \
epoch_info_dict['avg_loss'], epoch_info_dict['aps'], epoch_info_dict['gap']))
def reset(self):
self.calculator.clear()
if self.mode == 'infer':
self.infer_results = []
class MetricsZoo(object):
def __init__(self):
self.metrics_zoo = {}
def regist(self, name, metrics):
assert metrics.__base__ == Metrics, "Unknow model type {}".format(
type(metrics))
self.metrics_zoo[name] = metrics
def get(self, name, mode, cfg):
for k, v in self.metrics_zoo.items():
if k == name:
return v(name, mode, cfg)
raise KeyError(name, self.metrics_zoo.keys())
# singleton metrics_zoo
metrics_zoo = MetricsZoo()
def regist_metrics(name, metrics):
metrics_zoo.regist(name, metrics)
def get_metrics(name, mode, cfg):
return metrics_zoo.get(name, mode, cfg)
# sort by alphabet
regist_metrics("ATTENTIONLSTM", Youtube8mMetrics)
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calculate or keep track of the interpolated average precision.
It provides an interface for calculating interpolated average precision for an
entire list or the top-n ranked items. For the definition of the
(non-)interpolated average precision:
http://trec.nist.gov/pubs/trec15/appendices/CE.MEASURES06.pdf
Example usages:
1) Use it as a static function call to directly calculate average precision for
a short ranked list in the memory.
```
import random
p = np.array([random.random() for _ in xrange(10)])
a = np.array([random.choice([0, 1]) for _ in xrange(10)])
ap = average_precision_calculator.AveragePrecisionCalculator.ap(p, a)
```
2) Use it as an object for long ranked list that cannot be stored in memory or
the case where partial predictions can be observed at a time (Tensorflow
predictions). In this case, we first call the function accumulate many times
to process parts of the ranked list. After processing all the parts, we call
peek_interpolated_ap_at_n.
```
p1 = np.array([random.random() for _ in xrange(5)])
a1 = np.array([random.choice([0, 1]) for _ in xrange(5)])
p2 = np.array([random.random() for _ in xrange(5)])
a2 = np.array([random.choice([0, 1]) for _ in xrange(5)])
# interpolated average precision at 10 using 1000 break points
calculator = average_precision_calculator.AveragePrecisionCalculator(10)
calculator.accumulate(p1, a1)
calculator.accumulate(p2, a2)
ap3 = calculator.peek_ap_at_n()
```
"""
import heapq
import random
import numbers
import numpy
class AveragePrecisionCalculator(object):
"""Calculate the average precision and average precision at n."""
def __init__(self, top_n=None):
"""Construct an AveragePrecisionCalculator to calculate average precision.
This class is used to calculate the average precision for a single label.
Args:
top_n: A positive Integer specifying the average precision at n, or
None to use all provided data points.
Raises:
ValueError: An error occurred when the top_n is not a positive integer.
"""
if not ((isinstance(top_n, int) and top_n >= 0) or top_n is None):
raise ValueError("top_n must be a positive integer or None.")
self._top_n = top_n # average precision at n
self._total_positives = 0 # total number of positives have seen
self._heap = [] # max heap of (prediction, actual)
@property
def heap_size(self):
"""Gets the heap size maintained in the class."""
return len(self._heap)
@property
def num_accumulated_positives(self):
"""Gets the number of positive samples that have been accumulated."""
return self._total_positives
def accumulate(self, predictions, actuals, num_positives=None):
"""Accumulate the predictions and their ground truth labels.
After the function call, we may call peek_ap_at_n to actually calculate
the average precision.
Note predictions and actuals must have the same shape.
Args:
predictions: a list storing the prediction scores.
actuals: a list storing the ground truth labels. Any value
larger than 0 will be treated as positives, otherwise as negatives.
num_positives = If the 'predictions' and 'actuals' inputs aren't complete,
then it's possible some true positives were missed in them. In that case,
you can provide 'num_positives' in order to accurately track recall.
Raises:
ValueError: An error occurred when the format of the input is not the
numpy 1-D array or the shape of predictions and actuals does not match.
"""
if len(predictions) != len(actuals):
raise ValueError(
"the shape of predictions and actuals does not match.")
if not num_positives is None:
if not isinstance(num_positives,
numbers.Number) or num_positives < 0:
raise ValueError(
"'num_positives' was provided but it wan't a nonzero number."
)
if not num_positives is None:
self._total_positives += num_positives
else:
self._total_positives += numpy.size(numpy.where(actuals > 0))
topk = self._top_n
heap = self._heap
for i in range(numpy.size(predictions)):
if topk is None or len(heap) < topk:
heapq.heappush(heap, (predictions[i], actuals[i]))
else:
if predictions[i] > heap[0][0]: # heap[0] is the smallest
heapq.heappop(heap)
heapq.heappush(heap, (predictions[i], actuals[i]))
def clear(self):
"""Clear the accumulated predictions."""
self._heap = []
self._total_positives = 0
def peek_ap_at_n(self):
"""Peek the non-interpolated average precision at n.
Returns:
The non-interpolated average precision at n (default 0).
If n is larger than the length of the ranked list,
the average precision will be returned.
"""
if self.heap_size <= 0:
return 0
predlists = numpy.array(list(zip(*self._heap)))
ap = self.ap_at_n(
predlists[0],
predlists[1],
n=self._top_n,
total_num_positives=self._total_positives)
return ap
@staticmethod
def ap(predictions, actuals):
"""Calculate the non-interpolated average precision.
Args:
predictions: a numpy 1-D array storing the sparse prediction scores.
actuals: a numpy 1-D array storing the ground truth labels. Any value
larger than 0 will be treated as positives, otherwise as negatives.
Returns:
The non-interpolated average precision at n.
If n is larger than the length of the ranked list,
the average precision will be returned.
Raises:
ValueError: An error occurred when the format of the input is not the
numpy 1-D array or the shape of predictions and actuals does not match.
"""
return AveragePrecisionCalculator.ap_at_n(predictions, actuals, n=None)
@staticmethod
def ap_at_n(predictions, actuals, n=20, total_num_positives=None):
"""Calculate the non-interpolated average precision.
Args:
predictions: a numpy 1-D array storing the sparse prediction scores.
actuals: a numpy 1-D array storing the ground truth labels. Any value
larger than 0 will be treated as positives, otherwise as negatives.
n: the top n items to be considered in ap@n.
total_num_positives : (optionally) you can specify the number of total
positive
in the list. If specified, it will be used in calculation.
Returns:
The non-interpolated average precision at n.
If n is larger than the length of the ranked list,
the average precision will be returned.
Raises:
ValueError: An error occurred when
1) the format of the input is not the numpy 1-D array;
2) the shape of predictions and actuals does not match;
3) the input n is not a positive integer.
"""
if len(predictions) != len(actuals):
raise ValueError(
"the shape of predictions and actuals does not match.")
if n is not None:
if not isinstance(n, int) or n <= 0:
raise ValueError("n must be 'None' or a positive integer."
" It was '%s'." % n)
ap = 0.0
predictions = numpy.array(predictions)
actuals = numpy.array(actuals)
# add a shuffler to avoid overestimating the ap
predictions, actuals = AveragePrecisionCalculator._shuffle(
predictions, actuals)
sortidx = sorted(
range(len(predictions)), key=lambda k: predictions[k], reverse=True)
if total_num_positives is None:
numpos = numpy.size(numpy.where(actuals > 0))
else:
numpos = total_num_positives
if numpos == 0:
return 0
if n is not None:
numpos = min(numpos, n)
delta_recall = 1.0 / numpos
poscount = 0.0
# calculate the ap
r = len(sortidx)
if n is not None:
r = min(r, n)
for i in range(r):
if actuals[sortidx[i]] > 0:
poscount += 1
ap += poscount / (i + 1) * delta_recall
return ap
@staticmethod
def _shuffle(predictions, actuals):
random.seed(0)
suffidx = random.sample(range(len(predictions)), len(predictions))
predictions = predictions[suffidx]
actuals = actuals[suffidx]
return predictions, actuals
@staticmethod
def _zero_one_normalize(predictions, epsilon=1e-7):
"""Normalize the predictions to the range between 0.0 and 1.0.
For some predictions like SVM predictions, we need to normalize them before
calculate the interpolated average precision. The normalization will not
change the rank in the original list and thus won't change the average
precision.
Args:
predictions: a numpy 1-D array storing the sparse prediction scores.
epsilon: a small constant to avoid denominator being zero.
Returns:
The normalized prediction.
"""
denominator = numpy.max(predictions) - numpy.min(predictions)
ret = (predictions - numpy.min(predictions)) / numpy.max(
denominator, epsilon)
return ret
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides functions to help with evaluating models."""
import datetime
import numpy
from . import mean_average_precision_calculator as map_calculator
from . import average_precision_calculator as ap_calculator
def flatten(l):
""" Merges a list of lists into a single list. """
return [item for sublist in l for item in sublist]
def calculate_hit_at_one(predictions, actuals):
"""Performs a local (numpy) calculation of the hit at one.
Args:
predictions: Matrix containing the outputs of the model.
Dimensions are 'batch' x 'num_classes'.
actuals: Matrix containing the ground truth labels.
Dimensions are 'batch' x 'num_classes'.
Returns:
float: The average hit at one across the entire batch.
"""
top_prediction = numpy.argmax(predictions, 1)
hits = actuals[numpy.arange(actuals.shape[0]), top_prediction]
return numpy.average(hits)
def calculate_precision_at_equal_recall_rate(predictions, actuals):
"""Performs a local (numpy) calculation of the PERR.
Args:
predictions: Matrix containing the outputs of the model.
Dimensions are 'batch' x 'num_classes'.
actuals: Matrix containing the ground truth labels.
Dimensions are 'batch' x 'num_classes'.
Returns:
float: The average precision at equal recall rate across the entire batch.
"""
aggregated_precision = 0.0
num_videos = actuals.shape[0]
for row in numpy.arange(num_videos):
num_labels = int(numpy.sum(actuals[row]))
top_indices = numpy.argpartition(predictions[row],
-num_labels)[-num_labels:]
item_precision = 0.0
for label_index in top_indices:
if predictions[row][label_index] > 0:
item_precision += actuals[row][label_index]
item_precision /= top_indices.size
aggregated_precision += item_precision
aggregated_precision /= num_videos
return aggregated_precision
def calculate_gap(predictions, actuals, top_k=20):
"""Performs a local (numpy) calculation of the global average precision.
Only the top_k predictions are taken for each of the videos.
Args:
predictions: Matrix containing the outputs of the model.
Dimensions are 'batch' x 'num_classes'.
actuals: Matrix containing the ground truth labels.
Dimensions are 'batch' x 'num_classes'.
top_k: How many predictions to use per video.
Returns:
float: The global average precision.
"""
gap_calculator = ap_calculator.AveragePrecisionCalculator()
sparse_predictions, sparse_labels, num_positives = top_k_by_class(
predictions, actuals, top_k)
gap_calculator.accumulate(
flatten(sparse_predictions), flatten(sparse_labels), sum(num_positives))
return gap_calculator.peek_ap_at_n()
def top_k_by_class(predictions, labels, k=20):
"""Extracts the top k predictions for each video, sorted by class.
Args:
predictions: A numpy matrix containing the outputs of the model.
Dimensions are 'batch' x 'num_classes'.
k: the top k non-zero entries to preserve in each prediction.
Returns:
A tuple (predictions,labels, true_positives). 'predictions' and 'labels'
are lists of lists of floats. 'true_positives' is a list of scalars. The
length of the lists are equal to the number of classes. The entries in the
predictions variable are probability predictions, and
the corresponding entries in the labels variable are the ground truth for
those predictions. The entries in 'true_positives' are the number of true
positives for each class in the ground truth.
Raises:
ValueError: An error occurred when the k is not a positive integer.
"""
if k <= 0:
raise ValueError("k must be a positive integer.")
k = min(k, predictions.shape[1])
num_classes = predictions.shape[1]
prediction_triplets = []
for video_index in range(predictions.shape[0]):
prediction_triplets.extend(
top_k_triplets(predictions[video_index], labels[video_index], k))
out_predictions = [[] for v in range(num_classes)]
out_labels = [[] for v in range(num_classes)]
for triplet in prediction_triplets:
out_predictions[triplet[0]].append(triplet[1])
out_labels[triplet[0]].append(triplet[2])
out_true_positives = [numpy.sum(labels[:, i]) for i in range(num_classes)]
return out_predictions, out_labels, out_true_positives
def top_k_triplets(predictions, labels, k=20):
"""Get the top_k for a 1-d numpy array. Returns a sparse list of tuples in
(prediction, class) format"""
m = len(predictions)
k = min(k, m)
indices = numpy.argpartition(predictions, -k)[-k:]
return [(index, predictions[index], labels[index]) for index in indices]
class EvaluationMetrics(object):
"""A class to store the evaluation metrics."""
def __init__(self, num_class, top_k):
"""Construct an EvaluationMetrics object to store the evaluation metrics.
Args:
num_class: A positive integer specifying the number of classes.
top_k: A positive integer specifying how many predictions are considered per video.
Raises:
ValueError: An error occurred when MeanAveragePrecisionCalculator cannot
not be constructed.
"""
self.sum_hit_at_one = 0.0
self.sum_perr = 0.0
self.sum_loss = 0.0
self.map_calculator = map_calculator.MeanAveragePrecisionCalculator(
num_class)
self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator()
self.top_k = top_k
self.num_examples = 0
#def accumulate(self, predictions, labels, loss):
def accumulate(self, loss, predictions, labels):
"""Accumulate the metrics calculated locally for this mini-batch.
Args:
predictions: A numpy matrix containing the outputs of the model.
Dimensions are 'batch' x 'num_classes'.
labels: A numpy matrix containing the ground truth labels.
Dimensions are 'batch' x 'num_classes'.
loss: A numpy array containing the loss for each sample.
Returns:
dictionary: A dictionary storing the metrics for the mini-batch.
Raises:
ValueError: An error occurred when the shape of predictions and actuals
does not match.
"""
batch_size = labels.shape[0]
mean_hit_at_one = calculate_hit_at_one(predictions, labels)
mean_perr = calculate_precision_at_equal_recall_rate(
predictions, labels)
mean_loss = numpy.mean(loss)
# Take the top 20 predictions.
sparse_predictions, sparse_labels, num_positives = top_k_by_class(
predictions, labels, self.top_k)
self.map_calculator.accumulate(sparse_predictions, sparse_labels,
num_positives)
self.global_ap_calculator.accumulate(
flatten(sparse_predictions), flatten(sparse_labels),
sum(num_positives))
self.num_examples += batch_size
self.sum_hit_at_one += mean_hit_at_one * batch_size
self.sum_perr += mean_perr * batch_size
self.sum_loss += mean_loss * batch_size
return {
"hit_at_one": mean_hit_at_one,
"perr": mean_perr,
"loss": mean_loss
}
def get(self):
"""Calculate the evaluation metrics for the whole epoch.
Raises:
ValueError: If no examples were accumulated.
Returns:
dictionary: a dictionary storing the evaluation metrics for the epoch. The
dictionary has the fields: avg_hit_at_one, avg_perr, avg_loss, and
aps (default nan).
"""
if self.num_examples <= 0:
raise ValueError("total_sample must be positive.")
avg_hit_at_one = self.sum_hit_at_one / self.num_examples
avg_perr = self.sum_perr / self.num_examples
avg_loss = self.sum_loss / self.num_examples
aps = self.map_calculator.peek_map_at_n()
gap = self.global_ap_calculator.peek_ap_at_n()
epoch_info_dict = {}
return {
"avg_hit_at_one": avg_hit_at_one,
"avg_perr": avg_perr,
"avg_loss": avg_loss,
"aps": aps,
"gap": gap
}
def clear(self):
"""Clear the evaluation metrics and reset the EvaluationMetrics object."""
self.sum_hit_at_one = 0.0
self.sum_perr = 0.0
self.sum_loss = 0.0
self.map_calculator.clear()
self.global_ap_calculator.clear()
self.num_examples = 0
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calculate the mean average precision.
It provides an interface for calculating mean average precision
for an entire list or the top-n ranked items.
Example usages:
We first call the function accumulate many times to process parts of the ranked
list. After processing all the parts, we call peek_map_at_n
to calculate the mean average precision.
```
import random
p = np.array([[random.random() for _ in xrange(50)] for _ in xrange(1000)])
a = np.array([[random.choice([0, 1]) for _ in xrange(50)]
for _ in xrange(1000)])
# mean average precision for 50 classes.
calculator = mean_average_precision_calculator.MeanAveragePrecisionCalculator(
num_class=50)
calculator.accumulate(p, a)
aps = calculator.peek_map_at_n()
```
"""
import numpy
from . import average_precision_calculator
class MeanAveragePrecisionCalculator(object):
"""This class is to calculate mean average precision.
"""
def __init__(self, num_class):
"""Construct a calculator to calculate the (macro) average precision.
Args:
num_class: A positive Integer specifying the number of classes.
top_n_array: A list of positive integers specifying the top n for each
class. The top n in each class will be used to calculate its average
precision at n.
The size of the array must be num_class.
Raises:
ValueError: An error occurred when num_class is not a positive integer;
or the top_n_array is not a list of positive integers.
"""
if not isinstance(num_class, int) or num_class <= 1:
raise ValueError("num_class must be a positive integer.")
self._ap_calculators = [] # member of AveragePrecisionCalculator
self._num_class = num_class # total number of classes
for i in range(num_class):
self._ap_calculators.append(
average_precision_calculator.AveragePrecisionCalculator())
def accumulate(self, predictions, actuals, num_positives=None):
"""Accumulate the predictions and their ground truth labels.
Args:
predictions: A list of lists storing the prediction scores. The outer
dimension corresponds to classes.
actuals: A list of lists storing the ground truth labels. The dimensions
should correspond to the predictions input. Any value
larger than 0 will be treated as positives, otherwise as negatives.
num_positives: If provided, it is a list of numbers representing the
number of true positives for each class. If not provided, the number of
true positives will be inferred from the 'actuals' array.
Raises:
ValueError: An error occurred when the shape of predictions and actuals
does not match.
"""
if not num_positives:
num_positives = [None for i in predictions.shape[1]]
calculators = self._ap_calculators
for i in range(len(predictions)):
calculators[i].accumulate(predictions[i], actuals[i],
num_positives[i])
def clear(self):
for calculator in self._ap_calculators:
calculator.clear()
def is_empty(self):
return ([calculator.heap_size for calculator in self._ap_calculators
] == [0 for _ in range(self._num_class)])
def peek_map_at_n(self):
"""Peek the non-interpolated mean average precision at n.
Returns:
An array of non-interpolated average precision at n (default 0) for each
class.
"""
aps = [
self._ap_calculators[i].peek_ap_at_n()
for i in range(self._num_class)
]
return aps
from .model import regist_model, get_model
from .attention_lstm import AttentionLSTM
from .tsn import TSN
# regist models, sort by alphabet
regist_model("AttentionLSTM", AttentionLSTM)
regist_model("TSN", TSN)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
from ..model import ModelBase
from .lstm_attention import LSTMAttentionModel
import logging
logger = logging.getLogger(__name__)
__all__ = ["AttentionLSTM"]
class AttentionLSTM(ModelBase):
def __init__(self, name, cfg, mode='train'):
super(AttentionLSTM, self).__init__(name, cfg, mode)
self.get_config()
def get_config(self):
# get model configs
self.feature_num = self.cfg.MODEL.feature_num
self.feature_names = self.cfg.MODEL.feature_names
self.feature_dims = self.cfg.MODEL.feature_dims
self.num_classes = self.cfg.MODEL.num_classes
self.embedding_size = self.cfg.MODEL.embedding_size
self.lstm_size = self.cfg.MODEL.lstm_size
self.drop_rate = self.cfg.MODEL.drop_rate
# get mode configs
self.batch_size = self.get_config_from_sec(self.mode, 'batch_size', 1)
self.num_gpus = self.get_config_from_sec(self.mode, 'num_gpus', 1)
def build_input(self, use_dataloader):
self.feature_input = []
for name, dim in zip(self.feature_names, self.feature_dims):
self.feature_input.append(
fluid.data(
shape=[None, dim], lod_level=1, dtype='float32', name=name))
# self.label_input = None
if use_dataloader:
assert self.mode != 'infer', \
'dataloader is not recommendated when infer, please set use_dataloader to be false.'
self.dataloader = fluid.io.DataLoader.from_generator(
feed_list=self.feature_input, #+ [self.label_input],
capacity=8,
iterable=True)
def build_model(self):
att_outs = []
for i, (input_dim, feature) in enumerate(
zip(self.feature_dims, self.feature_input)):
att = LSTMAttentionModel(input_dim, self.embedding_size,
self.lstm_size, self.drop_rate)
att_out = att.forward(feature, is_training=(self.mode == 'train'))
att_outs.append(att_out)
if len(att_outs) > 1:
out = fluid.layers.concat(att_outs, axis=1)
else:
out = att_outs[0]
fc1 = fluid.layers.fc(
input=out,
size=8192,
act='relu',
bias_attr=ParamAttr(
regularizer=fluid.regularizer.L2Decay(0.0),
initializer=fluid.initializer.NormalInitializer(scale=0.0)),
name='fc1')
fc2 = fluid.layers.fc(
input=fc1,
size=4096,
act='tanh',
bias_attr=ParamAttr(
regularizer=fluid.regularizer.L2Decay(0.0),
initializer=fluid.initializer.NormalInitializer(scale=0.0)),
name='fc2')
self.logit = fluid.layers.fc(input=fc2, size=self.num_classes, act=None, \
bias_attr=ParamAttr(regularizer=fluid.regularizer.L2Decay(0.0),
initializer=fluid.initializer.NormalInitializer(scale=0.0)),
name = 'output')
self.output = fluid.layers.sigmoid(self.logit)
def optimizer(self):
assert self.mode == 'train', "optimizer only can be get in train mode"
values = [
self.learning_rate * (self.decay_gamma**i)
for i in range(len(self.decay_epochs) + 1)
]
iter_per_epoch = self.num_samples / self.batch_size
boundaries = [e * iter_per_epoch for e in self.decay_epochs]
return fluid.optimizer.RMSProp(
learning_rate=fluid.layers.piecewise_decay(
values=values, boundaries=boundaries),
centered=True,
regularization=fluid.regularizer.L2Decay(self.weight_decay))
def loss(self):
assert self.mode != 'infer', "invalid loss calculationg in infer mode"
cost = fluid.layers.sigmoid_cross_entropy_with_logits(
x=self.logit, label=self.label_input)
cost = fluid.layers.reduce_sum(cost, dim=-1)
sum_cost = fluid.layers.reduce_sum(cost)
self.loss_ = fluid.layers.scale(
sum_cost, scale=self.num_gpus, bias_after_scale=False)
return self.loss_
def outputs(self):
return [self.output, self.logit]
def feeds(self):
return self.feature_input
def fetches(self):
fetch_list = [self.output]
return fetch_list
def weights_info(self):
return (
'AttentionLSTM.pdparams',
'https://paddlemodels.bj.bcebos.com/video_classification/AttentionLSTM.pdparams'
)
def load_pretrain_params(self, exe, pretrain, prog, place):
#def is_parameter(var):
# return isinstance(var, fluid.framework.Parameter)
#params_list = list(filter(is_parameter, prog.list_vars()))
#for param in params_list:
# print(param.name)
#assert False, "stop here"
logger.info(
"Load pretrain weights from {}, exclude fc layer.".format(pretrain))
state_dict = fluid.load_program_state(pretrain)
dict_keys = list(state_dict.keys())
for name in dict_keys:
if "fc_0" in name:
del state_dict[name]
logger.info(
'Delete {} from pretrained parameters. Do not load it'.
format(name))
fluid.set_program_state(prog, state_dict)
# def load_test_weights(self, exe, weights, prog):
# def is_parameter(var):
# return isinstance(var, fluid.framework.Parameter)
# params_list = list(filter(is_parameter, prog.list_vars()))
# state_dict = np.load(weights)
# for p in params_list:
# if p.name in state_dict.keys():
# logger.info('########### load param {} from file'.format(p.name))
# else:
# logger.info('----------- param {} not in file'.format(p.name))
# fluid.set_program_state(prog, state_dict)
# fluid.save(prog, './weights/attention_lstm')
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
class LSTMAttentionModel(object):
"""LSTM Attention Model"""
def __init__(self,
bias_attr,
embedding_size=512,
lstm_size=1024,
drop_rate=0.5):
self.lstm_size = lstm_size
self.embedding_size = embedding_size
self.drop_rate = drop_rate
def forward(self, input, is_training):
input_fc = fluid.layers.fc(
input=input,
size=self.embedding_size,
act='tanh',
bias_attr=ParamAttr(
regularizer=fluid.regularizer.L2Decay(0.0),
initializer=fluid.initializer.NormalInitializer(scale=0.0)),
name='rgb_fc')
#lstm_forward_fc = fluid.layers.fc(
# input=input_fc,
# size=self.lstm_size * 4,
# act=None,
# bias_attr=ParamAttr(
# regularizer=fluid.regularizer.L2Decay(0.0),
# initializer=fluid.initializer.NormalInitializer(scale=0.0)),
# name='rgb_fc_forward')
lstm_forward_fc = fluid.layers.fc(
input=input_fc,
size=self.lstm_size * 4,
act=None,
bias_attr=False,
name='rgb_fc_forward')
lstm_forward, _ = fluid.layers.dynamic_lstm(
input=lstm_forward_fc,
size=self.lstm_size * 4,
is_reverse=False,
name='rgb_lstm_forward')
#lsmt_backward_fc = fluid.layers.fc(
# input=input_fc,
# size=self.lstm_size * 4,
# act=None,
# bias_attr=ParamAttr(
# regularizer=fluid.regularizer.L2Decay(0.0),
# initializer=fluid.initializer.NormalInitializer(scale=0.0)),
# name='rgb_fc_backward')
lsmt_backward_fc = fluid.layers.fc(
input=input_fc,
size=self.lstm_size * 4,
act=None,
bias_attr=False,
name='rgb_fc_backward')
lstm_backward, _ = fluid.layers.dynamic_lstm(
input=lsmt_backward_fc,
size=self.lstm_size * 4,
is_reverse=True,
name='rgb_lstm_backward')
lstm_concat = fluid.layers.concat(
input=[lstm_forward, lstm_backward], axis=1)
lstm_dropout = fluid.layers.dropout(
x=lstm_concat,
dropout_prob=self.drop_rate,
is_test=(not is_training))
#lstm_weight = fluid.layers.fc(
# input=lstm_dropout,
# size=1,
# act='sequence_softmax',
# bias_attr=ParamAttr(
# regularizer=fluid.regularizer.L2Decay(0.0),
# initializer=fluid.initializer.NormalInitializer(scale=0.0)),
# name='rgb_weight')
lstm_weight = fluid.layers.fc(
input=lstm_dropout,
size=1,
act='sequence_softmax',
bias_attr=False,
name='rgb_weight')
scaled = fluid.layers.elementwise_mul(
x=lstm_dropout, y=lstm_weight, axis=0)
lstm_pool = fluid.layers.sequence_pool(input=scaled, pool_type='sum')
return lstm_pool
import json
depth = [3, 4, 23, 3]
num_filters = [64, 128, 256, 512]
layer_index = 1
caffe_param_list = []
name_list = ['conv1']
params_list = []
name = name_list[0]
conv_w = name + '_weights'
caffe_conv_w = 'ConvNdBackward' + str(layer_index) + '_weights'
params_list.append(conv_w)
caffe_param_list.append(caffe_conv_w)
layer_index += 1
bn_name = "bn_" + name
caffe_bn_name = 'BatchNormBackward' + str(layer_index) + '_bn'
params_list.append(bn_name + '_scale')
params_list.append(bn_name + '_offset')
params_list.append(bn_name + '_mean')
params_list.append(bn_name + '_variance')
caffe_param_list.append(caffe_bn_name + '_scale')
caffe_param_list.append(caffe_bn_name + '_offset')
caffe_param_list.append(caffe_bn_name + '_mean')
caffe_param_list.append(caffe_bn_name + '_variance')
filter_input = 64
layer_index += 3
for block in range(len(depth)):
for i in range(depth[block]):
if block == 2:
if i == 0:
name = "res" + str(block + 2) + "a"
else:
name = "res" + str(block + 2) + "b" + str(i)
else:
name = "res" + str(block + 2) + chr(97 + i)
name_list.append(name)
for item in ['a', 'b', 'c']:
name_branch = name + '_branch2' + item
bn_name = 'bn' + name_branch[3:]
params_list.append(name_branch + '_weights')
params_list.append(bn_name + '_scale')
params_list.append(bn_name + '_offset')
params_list.append(bn_name + '_mean')
params_list.append(bn_name + '_variance')
caffe_name_branch = 'ConvNdBackward' + str(layer_index)
caffe_param_list.append(caffe_name_branch + '_weights')
layer_index += 1
caffe_bn_name = 'BatchNormBackward' + str(layer_index) + '_bn'
caffe_param_list.append(caffe_bn_name + '_scale')
caffe_param_list.append(caffe_bn_name + '_offset')
caffe_param_list.append(caffe_bn_name + '_mean')
caffe_param_list.append(caffe_bn_name + '_variance')
layer_index += 2
stride = 2 if i == 0 and block != 0 else 1
filter_num = num_filters[block]
filter_output = filter_num * 4
if (filter_output != filter_input) or (stride != 1):
name_branch = name + '_branch1'
print(
'filter_input {}, filter_output {}, stride {}, branch name {}'.
format(filter_input, filter_output, stride, name_branch))
bn_name = 'bn' + name_branch[3:]
params_list.append(name_branch + '_weights')
params_list.append(bn_name + '_scale')
params_list.append(bn_name + '_offset')
params_list.append(bn_name + '_mean')
params_list.append(bn_name + '_variance')
caffe_name_branch = 'ConvNdBackward' + str(layer_index)
caffe_param_list.append(caffe_name_branch + '_weights')
layer_index += 1
caffe_bn_name = 'BatchNormBackward' + str(layer_index) + '_bn'
caffe_param_list.append(caffe_bn_name + '_scale')
caffe_param_list.append(caffe_bn_name + '_offset')
caffe_param_list.append(caffe_bn_name + '_mean')
caffe_param_list.append(caffe_bn_name + '_variance')
layer_index += 3
else:
layer_index += 2
filter_input = filter_output
map_dict = {}
for i in range(len(params_list)):
print(params_list[i], caffe_param_list[i])
map_dict[params_list[i]] = caffe_param_list[i]
json.dump(map_dict, open('name_map.json', 'w'))
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import wget
import tarfile
__all__ = ['decompress', 'download', 'AttrDict']
def decompress(path):
t = tarfile.open(path)
t.extractall(path=os.path.split(path)[0])
t.close()
os.remove(path)
def download(url, path):
weight_dir = os.path.split(path)[0]
if not os.path.exists(weight_dir):
os.makedirs(weight_dir)
path = path + ".tar.gz"
wget.download(url, path)
decompress(path)
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
from .reader_utils import regist_reader, get_reader
from .kinetics_reader import KineticsReader
# regist reader, sort by alphabet
regist_reader("TSN", KineticsReader)
name: videotag_tsn_lstm
dir: "modules/video/classification/videotag_tsn_lstm"
exclude:
- README.md
resources:
-
url: https://paddlehub.bj.bcebos.com/model/video/video_classifcation/videotag_tsn_lstm.tar.gz
dest: weights
uncompress: True
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册