未验证 提交 ffe11953 编写于 作者: H huangjun12 提交者: GitHub

update DataParallel in slowfast (#5298)

* update DP in slowfast and refine nextvlad doc

* refine doc
上级 643fd690
......@@ -12,7 +12,7 @@
## 算法介绍
NeXtVLAD模型是第二届Youtube-8M视频理解竞赛中效果最好的单模型,在参数量小于80M的情况下,能得到高于0.87的GAP指标。该模型提供了一种将级别的视频特征转化并压缩成特征向量,以适用于大尺寸视频文件的分类的方法。其基本出发点是在NetVLAD模型的基础上,将高维度的特征先进行分组,通过引入attention机制聚合提取时间维度的信息,这样既可以获得较高的准确率,又可以使用更少的参数量。详细内容请参考[NeXtVLAD: An Efficient Neural Network to Aggregate Frame-level Features for Large-scale Video Classification](https://arxiv.org/abs/1811.05014)
NeXtVLAD模型是第二届Youtube-8M视频理解竞赛中效果最好的单模型,在参数量小于80M的情况下,能得到高于0.87的GAP指标。该模型提供了一种将级别的视频特征转化并压缩成特征向量,以适用于大尺寸视频文件的分类的方法。其基本出发点是在NetVLAD模型的基础上,将高维度的特征先进行分组,通过引入attention机制聚合提取时间维度的信息,这样既可以获得较高的准确率,又可以使用更少的参数量。详细内容请参考[NeXtVLAD: An Efficient Neural Network to Aggregate Frame-level Features for Large-scale Video Classification](https://arxiv.org/abs/1811.05014)
这里实现了论文中的单模型结构,使用2nd-Youtube-8M的train数据集作为训练集,在val数据集上做测试。
......
......@@ -68,7 +68,7 @@ SlowFast Overview
## 数据准备
SlowFast模型的训练数据采用Kinetics400数据集,数据下载及准备请参考[数据说明](../PaddleCV/video/data/dataset/README.md)
SlowFast模型的训练数据采用Kinetics400数据集,数据下载及准备请参考[数据说明](../PaddleCV/video/data/dataset/README.md)
## 模型训练
......
......@@ -101,7 +101,8 @@ def test_slowfast(args):
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
slowfast = fluid.dygraph.parallel.DataParallel(slowfast, strategy)
slowfast = fluid.dygraph.parallel.DataParallel(
slowfast, strategy, find_unused_parameters=False)
#create reader
test_data = KineticsDataset(mode="test", cfg=test_config)
......
......@@ -110,7 +110,8 @@ def infer_slowfast(args):
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
slowfast = fluid.dygraph.parallel.DataParallel(slowfast, strategy)
slowfast = fluid.dygraph.parallel.DataParallel(
slowfast, strategy, find_unused_parameters=False)
#create reader
infer_data = KineticsDataset(mode="infer", cfg=infer_config)
......
......@@ -277,8 +277,8 @@ def train(args):
video_model = SlowFast(cfg=train_config, num_classes=400)
if args.use_data_parallel:
video_model = fluid.dygraph.parallel.DataParallel(video_model,
strategy)
video_model = fluid.dygraph.parallel.DataParallel(
video_model, strategy, find_unused_parameters=False)
bs_denominator = 1
if args.use_gpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册