提交 9da4cc8b 编写于 作者: D dengkaipeng

add tsm model

上级 6d9e77b9
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import argparse
import numpy as np
from paddle import fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from model import Model, CrossEntropy, Input, set_device
from metrics import Accuracy
from tsm import *
NUM_CLASSES = 10
def make_optimizer(num_samples, parameter_list=None):
step = int(num_samples / FLAGS.batch_size)
boundaries = [e * step for e in [40, 60]]
values = [FLAGS.lr * (0.1 ** i) for i in range(len(boundaries) + 1)]
learning_rate = fluid.layers.piecewise_decay(
boundaries=boundaries,
values=values)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(1e-4),
momentum=0.9,
parameter_list=parameter_list)
return optimizer
def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None
train_transform = Compose([GroupScale(),
GroupMultiScaleCrop(),
GroupRandomCrop(),
GroupRandomFlip(),
NormalizeImage()])
train_dataset = KineticsDataset(
filelist=os.path.join(FLAGS.data, 'train_10.list'),
pickle_dir=os.path.join(FLAGS.data, 'train_10'),
transform=train_transform)
val_transform = Compose([GroupScale(),
GroupCenterCrop(),
NormalizeImage()])
val_dataset = KineticsDataset(
filelist=os.path.join(FLAGS.data, 'val_10.list'),
pickle_dir=os.path.join(FLAGS.data, 'val_10'),
mode='val',
transform=val_transform)
pretrained = FLAGS.eval_only and FLAGS.weights is None
model = tsm_resnet50(num_classes=NUM_CLASSES, pretrained=pretrained)
optim = make_optimizer(len(train_dataset), model.parameters())
inputs = [Input([None, 8, 3, 224, 224], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model.prepare(
optim,
CrossEntropy(),
Accuracy(topk=(1, 5)),
inputs=inputs,
labels=labels,
device=FLAGS.device)
if FLAGS.eval_only:
if FLGAS.weights:
model.load(FLAGS.weights)
model.evaluate(
val_dataset,
batch_size=FLAGS.batch_size,
num_workers=FLAGS.num_workers)
return
if FLAGS.resume is not None:
model.load(FLAGS.resume)
model.fit(train_dataset,
val_dataset,
epochs=FLAGS.epoch,
batch_size=FLAGS.batch_size,
save_dir='tsm_checkpoint',
num_workers=4,
drop_last=True,
shuffle=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser("CNN training on TSM")
parser.add_argument('data', metavar='DIR', help='path to kineteics dataset')
parser.add_argument(
"--device", type=str, default='gpu', help="device to use, gpu or cpu")
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"--eval_only", action='store_true', help="run evaluation only")
parser.add_argument(
"-e", "--epoch", default=70, type=int, help="number of epoch")
parser.add_argument(
"-j", "--num_workers", default=4, type=int, help="read worker number")
parser.add_argument(
'--lr',
'--learning-rate',
default=1e-2,
type=float,
metavar='LR',
help='initial learning rate')
parser.add_argument(
"-b", "--batch_size", default=16, type=int, help="batch size")
parser.add_argument(
"-n", "--num_devices", default=1, type=int, help="number of devices")
parser.add_argument(
"-r",
"--resume",
default=None,
type=str,
help="checkpoint path to resume")
parser.add_argument(
"-w",
"--weights",
default=None,
type=str,
help="weights path for evaluation")
FLAGS = parser.parse_args()
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import kinetics_dataset
from .kinetics_dataset import *
from . import modeling
from .modeling import *
from . import transforms
from .transforms import *
__all__ = kinetics_dataset.__all__ \
+ modeling.__all__ \
+ transforms.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
import sys
import random
import numpy as np
from PIL import Image, ImageEnhance
try:
import cPickle as pickle
from cStringIO import StringIO
except ImportError:
import pickle
from io import BytesIO
from paddle.fluid.io import Dataset
import logging
logger = logging.getLogger(__name__)
__all__ = ['KineticsDataset']
class KineticsDataset(Dataset):
"""
Kinetics dataset
Args:
filelist (str): path to file list, default None.
num_classes (int): class number
"""
def __init__(self,
filelist,
pickle_dir,
mode='train',
seg_num=8,
seg_len=1,
transform=None):
assert os.path.isfile(filelist), \
"filelist {} not a file".format(filelist)
with open(filelist) as f:
self.pickle_paths = [l.strip() for l in f]
assert os.path.isdir(pickle_dir), \
"pickle_dir {} not a directory".format(pickle_dir)
self.pickle_dir = pickle_dir
assert mode in ['train', 'val'], \
"mode can only be 'train' or 'val'"
self.mode = mode
self.seg_num = seg_num
self.seg_len = seg_len
self.transform = transform
def __len__(self):
return len(self.pickle_paths)
def __getitem__(self, idx):
pickle_path = os.path.join(self.pickle_dir, self.pickle_paths[idx])
try:
if six.PY2:
data = pickle.load(open(pickle_path, 'rb'))
else:
data = pickle.load(open(pickle_path, 'rb'), encoding='bytes')
vid, label, frames = data
if len(frames) < 1:
logger.error("{} contains no frame".format(pickle_path))
sys.exit(-1)
except Exception as e:
logger.error("Load {} failed: {}".format(pickle_path, e))
sys.exit(-1)
label_list = [0, 2, 3, 4, 6, 7, 9, 12, 14, 15]
label = label_list.index(label)
imgs = self._video_loader(frames)
if self.transform:
imgs, label = self.transform(imgs, label)
return imgs, np.array([label])
def _video_loader(self, frames):
videolen = len(frames)
average_dur = int(videolen / self.seg_num)
imgs = []
for i in range(self.seg_num):
idx = 0
if self.mode == 'train':
if average_dur >= self.seg_len:
idx = random.randint(0, average_dur - self.seg_len)
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
else:
if average_dur >= self.seg_len:
idx = (average_dur - self.seg_len) // 2
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
for jj in range(idx, idx + self.seg_len):
imgbuf = frames[int(jj % videolen)]
img = self._imageloader(imgbuf)
imgs.append(img)
return imgs
def _imageloader(self, buf):
if isinstance(buf, str):
img = Image.open(StringIO(buf))
else:
img = Image.open(BytesIO(buf))
return img.convert('RGB')
if __name__ == "__main__":
kd = KineticsDataset('/paddle/ssd3/kineteics_mini/val_10.list', '/paddle/ssd3/kineteics_mini/val_10')
print("KineticsDataset length", len(kd))
for d in kd:
print(len(d[0]), d[0][0].size, d[1])
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import math
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from model import Model
from download import get_weights_path
__all__ = ["TSM_ResNet", "tsm_resnet50"]
# {num_layers: (url, md5)}
pretrain_infos = {
50: ('https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams',
'5755dc538e422589f417f7b38d7cc3c7')
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=None,
act=None,
param_attr=fluid.param_attr.ParamAttr(),
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=fluid.param_attr.ParamAttr(),
bias_attr=fluid.param_attr.ParamAttr())
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
seg_num=8):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride)
self.shortcut = shortcut
self.seg_num = seg_num
self._num_channels_out = int(num_filters * 4)
def forward(self, inputs):
shifts = fluid.layers.temporal_shift(inputs, self.seg_num, 1.0 / 8)
y = self.conv0(shifts)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv2, act="relu")
return y
class TSM_ResNet(Model):
def __init__(self, num_layers=50, seg_num=8, num_classes=400):
super(TSM_ResNet, self).__init__()
self.layers = num_layers
self.seg_num = seg_num
self.class_dim = num_classes
if self.layers == 50:
depth = [3, 4, 6, 3]
else:
raise NotImplementedError
num_filters = [64, 128, 256, 512]
self.conv = ConvBNLayer(
num_channels=3, num_filters=64, filter_size=7, stride=2, act='relu')
self.pool2d_max = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
self.bottleneck_block_list = []
num_channels = 64
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
num_channels=num_channels,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
seg_num=self.seg_num))
num_channels = int(bottleneck_block._num_channels_out)
self.bottleneck_block_list.append(bottleneck_block)
shortcut = True
self.pool2d_avg = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(2048 * 1.0)
self.out = Linear(
2048,
self.class_dim,
act="softmax",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)),
bias_attr=fluid.param_attr.ParamAttr(
learning_rate=2.0, regularizer=fluid.regularizer.L2Decay(0.)))
def forward(self, inputs):
y = fluid.layers.reshape(
inputs, [-1, inputs.shape[2], inputs.shape[3], inputs.shape[4]])
y = self.conv(y)
y = self.pool2d_max(y)
for bottleneck_block in self.bottleneck_block_list:
y = bottleneck_block(y)
y = self.pool2d_avg(y)
y = fluid.layers.dropout(y, dropout_prob=0.5)
y = fluid.layers.reshape(y, [-1, self.seg_num, y.shape[1]])
y = fluid.layers.reduce_mean(y, dim=1)
y = fluid.layers.reshape(y, shape=[-1, 2048])
y = self.out(y)
return y
def _tsm_resnet(num_layers, seg_num=8, num_classes=400, pretrained=True):
model = TSM_ResNet(num_layers, seg_num, num_classes)
if pretrained:
assert num_layers in pretrain_infos.keys(), \
"TSM_ResNet{} do not have pretrained weights now, " \
"pretrained should be set as False"
weight_path = get_weights_path(*(pretrain_infos[num_layers]))
assert weight_path.endswith('.pdparams'), \
"suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def tsm_resnet50(seg_num=8, num_classes=400, pretrained=True):
return _tsm_resnet(50, seg_num, num_classes, pretrained)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import traceback
import numpy as np
from PIL import Image
import logging
logger = logging.getLogger(__name__)
__all__ = ['GroupScale', 'GroupMultiScaleCrop', 'GroupRandomCrop',
'GroupRandomFlip', 'GroupCenterCrop', 'NormalizeImage',
'Compose']
class Compose(object):
def __init__(self, transforms=[]):
self.transforms = transforms
def __call__(self, *data):
for f in self.transforms:
try:
data = f(*data)
except Exception as e:
stack_info = traceback.format_exc()
logger.info("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
class GroupScale(object):
"""
Group scale image
Args:
target_size (int): image resize target size
"""
def __init__(self, target_size=224):
self.target_size = target_size
def __call__(self, imgs, label):
resized_imgs = []
for i in range(len(imgs)):
img = imgs[i]
w, h = img.size
if (w <= h and w == self.target_size) or \
(h <= w and h == self.target_size):
resized_imgs.append(img)
continue
if w < h:
ow = self.target_size
oh = int(self.target_size * 4.0 / 3.0)
resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
else:
oh = self.target_size
ow = int(self.target_size * 4.0 / 3.0)
resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
return resized_imgs, label
class GroupMultiScaleCrop(object):
"""
FIXME: add comments
"""
def __init__(self,
short_size=256,
scales=None,
max_distort=1,
fix_crop=True,
more_fix_crop=True):
self.short_size = short_size
self.scales = scales if scales is not None \
else [1, .875, .75, .66]
self.max_distort = max_distort
self.fix_crop = fix_crop
self.more_fix_crop = more_fix_crop
def __call__(self, imgs, label):
input_size = [self.short_size, self.short_size]
im_size = imgs[0].size
# get random crop offset
def _sample_crop_size(im_size):
image_w, image_h = im_size[0], im_size[1]
base_size = min(image_w, image_h)
crop_sizes = [int(base_size * x) for x in self.scales]
crop_h = [
input_size[1] if abs(x - input_size[1]) < 3 else x
for x in crop_sizes
]
crop_w = [
input_size[0] if abs(x - input_size[0]) < 3 else x
for x in crop_sizes
]
pairs = []
for i, h in enumerate(crop_h):
for j, w in enumerate(crop_w):
if abs(i - j) <= self.max_distort:
pairs.append((w, h))
crop_pair = random.choice(pairs)
if not self.fix_crop:
w_offset = np.random.randint(0, image_w - crop_pair[0])
h_offset = np.random.randint(0, image_h - crop_pair[1])
else:
w_step = (image_w - crop_pair[0]) / 4
h_step = (image_h - crop_pair[1]) / 4
ret = list()
ret.append((0, 0)) # upper left
if w_step != 0:
ret.append((4 * w_step, 0)) # upper right
if h_step != 0:
ret.append((0, 4 * h_step)) # lower left
if h_step != 0 and w_step != 0:
ret.append((4 * w_step, 4 * h_step)) # lower right
if h_step != 0 or w_step != 0:
ret.append((2 * w_step, 2 * h_step)) # center
if self.more_fix_crop:
ret.append((0, 2 * h_step)) # center left
ret.append((4 * w_step, 2 * h_step)) # center right
ret.append((2 * w_step, 4 * h_step)) # lower center
ret.append((2 * w_step, 0 * h_step)) # upper center
ret.append((1 * w_step, 1 * h_step)) # upper left quarter
ret.append((3 * w_step, 1 * h_step)) # upper right quarter
ret.append((1 * w_step, 3 * h_step)) # lower left quarter
ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
w_offset, h_offset = random.choice(ret)
return crop_pair[0], crop_pair[1], w_offset, h_offset
crop_w, crop_h, offset_w, offset_h = _sample_crop_size(im_size)
crop_imgs = [
img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h))
for img in imgs
]
ret_imgs = [
img.resize((input_size[0], input_size[1]), Image.BILINEAR)
for img in crop_imgs
]
return ret_imgs, label
class GroupRandomCrop(object):
def __init__(self, target_size=224):
self.target_size = target_size
def __call__(self, imgs, label):
w, h = imgs[0].size
th, tw = self.target_size, self.target_size
assert (w >= self.target_size) and (h >= self.target_size), \
"image width({}) and height({}) should be larger than " \
"crop size".format(w, h, self.target_size)
out_images = []
x1 = np.random.randint(0, w - tw)
y1 = np.random.randint(0, h - th)
for img in imgs:
if w == tw and h == th:
out_images.append(img)
else:
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return out_images, label
class GroupRandomFlip(object):
def __call__(self, imgs, label):
v = np.random.random()
if v < 0.5:
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in imgs]
return ret, label
else:
return imgs, label
class GroupCenterCrop(object):
def __init__(self, target_size=224):
self.target_size = target_size
def __call__(self, imgs, label):
crop_imgs = []
for img in imgs:
w, h = img.size
th, tw = self.target_size, self.target_size
assert (w >= self.target_size) and (h >= self.target_size), \
"image width({}) and height({}) should be larger " \
"than crop size".format(w, h, self.target_size)
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
crop_imgs.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return crop_imgs, label
class NormalizeImage(object):
def __init__(self,
target_size=224,
img_mean=[0.485, 0.456, 0.406],
img_std=[0.229, 0.224, 0.225],
seg_num=8,
seg_len=1):
self.target_size = target_size
self.img_mean = np.array(img_mean).reshape((3, 1, 1)).astype('float32')
self.img_std = np.array(img_std).reshape((3, 1, 1)).astype('float32')
self.seg_num = seg_num
self.seg_len = seg_len
def __call__(self, imgs, label):
np_imgs = (np.array(imgs[0]).astype('float32').transpose(
(2, 0, 1))).reshape(1, 3, self.target_size,
self.target_size) / 255
for i in range(len(imgs) - 1):
img = (np.array(imgs[i + 1]).astype('float32').transpose(
(2, 0, 1))).reshape(1, 3, self.target_size,
self.target_size) / 255
np_imgs = np.concatenate((np_imgs, img))
np_imgs -= self.img_mean
np_imgs /= self.img_std
np_imgs = np.reshape(np_imgs, (self.seg_num, self.seg_len * 3,
self.target_size, self.target_size))
return np_imgs, label
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册