kinetics_dataset.py 5.1 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import six
import sys
import random
import numpy as np
from PIL import Image, ImageEnhance

try:
    import cPickle as pickle
    from cStringIO import StringIO
except ImportError:
    import pickle
    from io import BytesIO

D
dengkaipeng 已提交
29
from paddle.io import Dataset
D
dengkaipeng 已提交
30 31 32 33 34 35

import logging
logger = logging.getLogger(__name__)

__all__ = ['KineticsDataset']

D
dengkaipeng 已提交
36 37
KINETICS_CLASS_NUM = 400

D
dengkaipeng 已提交
38 39 40 41 42 43

class KineticsDataset(Dataset):
    """
    Kinetics dataset

    Args:
D
dengkaipeng 已提交
44 45 46 47 48 49 50 51 52 53 54 55
        file_list (str): path to file list
        pickle_dir (str): path to pickle file directory
        label_list (str): path to label_list file, if set None, the
            default class number 400 of kinetics dataset will be
            used. Default None
        mode (str): 'train' or 'val' mode, segmentation methods will
            be different in these 2 modes. Default 'train'
        seg_num (int): segment number to sample from each video.
            Default 8
        seg_len (int): frame number of each segment. Default 1
        transform (callable): transforms to perform on video samples,
            None for no transforms. Default None.
D
dengkaipeng 已提交
56 57 58
    """

    def __init__(self,
D
dengkaipeng 已提交
59 60 61
                 file_list=None,
                 pickle_dir=None,
                 pickle_file=None,
D
dengkaipeng 已提交
62
                 label_list=None,
D
dengkaipeng 已提交
63 64 65 66
                 mode='train',
                 seg_num=8,
                 seg_len=1,
                 transform=None):
D
dengkaipeng 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
        assert str.lower(mode) in ['train', 'val', 'test'], \
                "mode can only be 'train' 'val' or 'test'"
        self.mode = str.lower(mode)

        if self.mode in ['train', 'val']:
            assert os.path.isfile(file_list), \
                    "file_list {} not a file".format(file_list)
            with open(file_list) as f:
                self.pickle_paths = [l.strip() for l in f]

            assert os.path.isdir(pickle_dir), \
                    "pickle_dir {} not a directory".format(pickle_dir)
            self.pickle_dir = pickle_dir
        else:
            assert os.path.isfile(pickle_file), \
                    "pickle_file {} not a file".format(pickle_file)
            self.pickle_dir = ''
            self.pickle_paths = [pickle_file]
D
dengkaipeng 已提交
85

D
dengkaipeng 已提交
86 87 88 89 90 91 92
        self.label_list = label_list
        if self.label_list is not None:
            assert os.path.isfile(self.label_list), \
                "label_list {} not a file".format(self.label_list)
            with open(self.label_list) as f:
                self.label_list = [int(l.strip()) for l in f]

D
dengkaipeng 已提交
93 94 95 96 97 98 99 100 101 102
        self.seg_num = seg_num
        self.seg_len = seg_len
        self.transform = transform

    def __len__(self):
        return len(self.pickle_paths)

    def __getitem__(self, idx):
        pickle_path = os.path.join(self.pickle_dir, self.pickle_paths[idx])

D
dengkaipeng 已提交
103 104 105 106 107 108
        if six.PY2:
            data = pickle.load(open(pickle_path, 'rb'))
        else:
            data = pickle.load(open(pickle_path, 'rb'), encoding='bytes')

        vid, label, frames = data
D
dengkaipeng 已提交
109

D
dengkaipeng 已提交
110 111
        if self.label_list is not None:
            label = self.label_list.index(label)
D
dengkaipeng 已提交
112 113 114 115
        imgs = self._video_loader(frames)

        if self.transform:
            imgs, label = self.transform(imgs, label)
D
dengkaipeng 已提交
116
        return imgs, np.array([label]).astype('int64')
D
dengkaipeng 已提交
117

D
dengkaipeng 已提交
118 119 120 121 122
    @property
    def num_classes(self):
        return KINETICS_CLASS_NUM if self.label_list is None \
                else len(self.label_list)

D
dengkaipeng 已提交
123
    def _video_loader(self, frames):
D
dengkaipeng 已提交
124 125
        videolen = len(frames)
        average_dur = int(videolen / self.seg_num)
126

D
dengkaipeng 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
        imgs = []
        for i in range(self.seg_num):
            idx = 0
            if self.mode == 'train':
                if average_dur >= self.seg_len:
                    idx = random.randint(0, average_dur - self.seg_len)
                    idx += i * average_dur
                elif average_dur >= 1:
                    idx += i * average_dur
                else:
                    idx = i
            else:
                if average_dur >= self.seg_len:
                    idx = (average_dur - self.seg_len) // 2
                    idx += i * average_dur
                elif average_dur >= 1:
                    idx += i * average_dur
                else:
                    idx = i
146

D
dengkaipeng 已提交
147 148 149 150
            for jj in range(idx, idx + self.seg_len):
                imgbuf = frames[int(jj % videolen)]
                img = self._imageloader(imgbuf)
                imgs.append(img)
151

D
dengkaipeng 已提交
152
        return imgs
D
dengkaipeng 已提交
153 154

    def _imageloader(self, buf):
D
dengkaipeng 已提交
155 156 157 158
        if isinstance(buf, str):
            img = Image.open(StringIO(buf))
        else:
            img = Image.open(BytesIO(buf))
D
dengkaipeng 已提交
159

160
        return img.convert('RGB')