提交 0d66e9f2 编写于 作者: L linjintao

Merge branch 'tjq/sth_scripts' into 'master'

Add sth scripts.

See merge request open-mmlab/mmaction-lite!289
......@@ -3,7 +3,9 @@ import glob
import os.path as osp
import random
from tools.data.parse_file_list import parse_directory, parse_ucf101_splits
from tools.data.parse_file_list import (parse_directory, parse_sthv1_splits,
parse_sthv2_splits,
parse_ucf101_splits)
def parse_args():
......@@ -153,9 +155,13 @@ def main():
if args.dataset == 'ucf101':
splits = parse_ucf101_splits(args.level)
elif args.dataset == 'sthv1':
splits = parse_sthv1_splits(args.level)
elif args.dataset == 'sthv2':
splits = parse_sthv2_splits(args.level)
else:
raise ValueError(
f"Supported datasets are 'ucf101', but got {args.dataset}")
raise ValueError(f"Supported datasets are 'ucf101, sthv1, sthv2',"
f'but got {args.dataset}')
assert len(splits) == args.num_split
out_path = args.out_root_path + args.dataset
......
import fnmatch
import glob
import json
import os
......@@ -110,3 +111,92 @@ def parse_ucf101_splits(level):
splits.append((train_list, test_list))
return splits
def parse_sthv1_splits(level):
"""Parse Something-Something dataset V1 into "train", "val" splits.
Args:
level: directory level of data.
Returns:
list: "train", "val", "test" splits of Something-Something dataset V1.
"""
# Read the annotations
# yapf: disable
class_index_file = 'data/sthv1/annotations/something-something-v1-labels.csv' # noqa
# yapf: enable
train_file = 'data/sthv1/annotations/something-something-v1-train.csv'
val_file = 'data/sthv1/annotations/something-something-v1-validation.csv'
test_file = 'data/sthv1/annotations/something-something-v1-test.csv'
with open(class_index_file, 'r') as fin:
class_index = [x.strip() for x in fin]
class_mapping = {class_index[idx]: idx for idx in range(len(class_index))}
def line_to_map(line, test_mode=False):
items = line.strip().split(';')
vid = items[0]
vid = '/'.join(vid.split('/')[-level:])
if test_mode:
return vid
else:
label = class_mapping[items[1]]
return vid, label
with open(train_file, 'r') as fin:
train_list = [line_to_map(x) for x in fin]
with open(val_file, 'r') as fin:
val_list = [line_to_map(x) for x in fin]
with open(test_file, 'r') as fin:
test_list = [line_to_map(x, test_mode=True) for x in fin]
return ((train_list, val_list, test_list), )
def parse_sthv2_splits(level):
"""Parse Something-Something dataset V2 into "train", "val" splits.
Args:
level: directory level of data.
Returns:
list: "train", "val", "test" splits of Something-Something dataset V2.
"""
# Read the annotations
# yapf: disable
class_index_file = 'data/sthv2/annotations/something-something-v2-labels.json' # noqa
# yapf: enable
train_file = 'data/sthv2/annotations/something-something-v2-train.json'
val_file = 'data/sthv2/annotations/something-something-v2-validation.json'
test_file = 'data/sthv2/annotations/something-something-v2-test.json'
with open(class_index_file, 'r') as fin:
class_mapping = json.loads(fin.read())
def line_to_map(item, test_mode=False):
vid = item['id']
vid = '/'.join(vid.split('/')[-level:])
if test_mode:
return vid
else:
template = item['template'].replace('[', '')
template = template.replace(']', '')
label = class_mapping[template]
return vid, label
with open(train_file, 'r') as fin:
items = json.loads(fin.read())
train_list = [line_to_map(item) for item in items]
with open(val_file, 'r') as fin:
items = json.loads(fin.read())
val_list = [line_to_map(item) for item in items]
with open(test_file, 'r') as fin:
items = json.loads(fin.read())
test_list = [line_to_map(item, test_mode=True) for item in items]
return ((train_list, val_list, test_list), )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册