未验证 提交 2293e337 编写于 作者: X xbsu 提交者: GitHub

Update VOT code: add SiamRPN and SiamMask (#4734)

* Add SiamRPN and SiamMask
* delete useless code
* fix bug
上级 09f7796c
......@@ -48,15 +48,20 @@ pytracking 包含跟踪代码
主流的训练数据集有:
- [VID](http://bvisionweb1.cs.unc.edu/ilsvrc2015/ILSVRC2015_VID.tar.gz)
- [DET](http://image-net.org/challenges/LSVRC/2015/)
- [Microsoft COCO 2014](http://cocodataset.org/#download)
- [Microsoft COCO 2017](http://cocodataset.org/#download)
- [Youtube-VOS](https://youtube-vos.org/)
- [LaSOT](https://drive.google.com/file/d/1O2DLxPP8M4Pn4-XCttCJUW3A29tDIeNa/view)
- [GOT-10K](http://got-10k.aitestunion.com/downloads_dataset/full_data)
下载并解压后的数据集的组织方式为:
```
/Datasets/
└─ ILSVRC2015_VID/
└─ train2014/
└─ ILSVRC2015/
└─ ILSVRC2015_DET/
└─ COCO/
└─ YoutubeVOS/
└─ GOT-10K/
└─ LaSOTBenchmark/
......@@ -71,16 +76,16 @@ Datasets是数据集保存的路径。
tracking的工作环境:
- Linux
- python3
- PaddlePaddle1.7
- PaddlePaddle1.8
> 注意:如果遇到cmath无法import的问题,建议切换Python版本,建议使用python3.6.8, python3.7.0 。另外,
> tracking暂不支持在window上运行,如果开发者有需求在window上运行tracking,请在issue中提出需求。
### 安装依赖
1. 安装paddle,需要安装1.7版本的Paddle,如低于这个版本,请升级到Paddle 1.7.
1. 安装paddle,需要安装1.8版本的Paddle,如低于这个版本,请升级到Paddle 1.8.
```bash
pip install paddlepaddle-gpu==1.7.0
pip install paddlepaddle-gpu==1.8.0
```
2. 安装第三方库,建议使用anaconda
......@@ -114,10 +119,12 @@ pip install python-prctl
└─ atom_resnet18.pdparams
└─ atom_resnet50.pdparams
└─ backbone
└─ AlexNet.pdparams
└─ ResNet18.pdparams
└─ ResNet50.pdparams
└─ ResNet50_dilated.pdparams
```
其中/pretrained_models/backbone/文件夹包含,ResNet18、ResNet50在Imagenet上的预训练模型。
其中/pretrained_models/backbone/文件夹包含,AlexNet、ResNet18、ResNet50在Imagenet上的预训练模型。
### 设置训练参数
......@@ -154,7 +161,7 @@ python -c "from ltr.admin.environment import create_default_local_file; create_d
```bash
self.workspace_dir = './checkpoints'
self.lasot_dir = '/Datasets/LaSOTBenchmark/'
self.coco_dir = '/Datasets/train2014/'
self.coco_dir = '/Datasets/COCO/'
self.got10k_dir = '/Datasets/GOT-10k/train'
self.imagenet_dir = '/Datasets/ILSVRC2015/'
```
......@@ -164,6 +171,16 @@ cd ltr/data_specs/
wget https://paddlemodels.cdn.bcebos.com/paddle_track/vot/got10k_lasot_split.tar
tar xvf got10k_lasot_split.tar
```
训练SiamRPN、SiamMask时,需要配置 workspace_dir,以及imagenet、coco、imagenetdet、youtubevos、lasot、got10k的数据集路径,如下:
```bash
self.workspace_dir = './checkpoints'
self.imagenet_dir = '/Datasets/ILSVRC2015/'
self.coco_dir = '/Datasets/COCO/'
self.imagenetdet_dir = '/Datasets/ILSVRC2015_DET/'
self.youtubevos_dir = '/Datasets/YoutubeVOS/'
self.lasot_dir = '/Datasets/LaSOTBenchmark/'
self.got10k_dir = '/Datasets/GOT-10k/train'
```
### 启动训练
......@@ -180,6 +197,15 @@ python run_training.py bbreg atom_res50_vid_lasot_coco
# 训练 SiamFC
python run_training.py siamfc siamfc_alexnet_vid
# 训练 SiamRPN AlexNet
python run_training.py siamrpn siamrpn_alexnet
# 训练 SiamMask-Base ResNet50
python run_training.py siammask siammask_res50_base
# 训练 SiamMask-Refine ResNet50,需要配置settings.base_model为最优的SiamMask-Base模型
python run_training.py siammask siammask_res50_sharp
```
......@@ -242,6 +268,19 @@ python eval_benchmark.py -d VOT2018 -tr bbreg.atom_res18_vid_lasot_coco -te atom
python eval_benchmark.py -d VOT2018 -tr siamfc.siamfc_alexnet_vid -te siamfc.default -e 'range(1, 50, 1)'
```
测试SiamRPN
```
python eval_benchmark.py -d OTB100 -tr siamrpn.siamrpn_alexnet -te siamrpn.default_otb -e 'range(1, 40, 1)'
```
测试SiamMask
```bash
# 在VOT2018上测试SiamMask-Base
python eval_benchmark.py -d VOT2018 -tr siammask.siammask_res50_base -te siammask.base_default -e 'range(1, 20, 1)'
# 在VOT2018上测试SiamMask-Sharp
python eval_benchmark.py -d VOT2018 -tr siammask.siammask_res50_sharp -te siammask.sharp_default_vot -e 'range(1, 20, 1)'
```
## 跟踪结果可视化
......@@ -265,7 +304,9 @@ jupyter notebook --ip 0.0.0.0 --port 8888
| 数据集 | 模型 | Backbone | 论文结果 | 训练结果 | 模型|
| :-------: | :-------: | :---: | :---: | :---------: |:---------: |
|VOT2018| ATOM | Res18 | EAO: 0.401 | 0.399 | [model](https://paddlemodels.cdn.bcebos.com/paddle_track/vot/ATOM.tar) |
|VOT2018| SiamMask | Res50 | EAO: 0.380 | 0.379 | [model](https://paddlemodels.cdn.bcebos.com/paddle_track/vot/SiamMask.tar) |
|VOT2018| SiamFC | AlexNet | EAO: 0.188 | 0.211 | [model](https://paddlemodels.cdn.bcebos.com/paddle_track/vot/SiamFC.tar) |
|OTB100| SiamRPN | AlexNet | Succ: 0.637, Prcn: 0.851 | Succ: 0.644, Prcn: 0.848 | [model](https://paddlemodels.cdn.bcebos.com/paddle_track/vot/SiamRPN.tar) |
## 引用与参考
......@@ -280,6 +321,26 @@ SiamFC **[[Paper]](https://arxiv.org/pdf/1811.07628.pdf) [[Code]](https://www.ro
organization={Springer}
}
SiamRPN **[[Paper]](http://openaccess.thecvf.com/content_cvpr_2018/papers/Li_High_Performance_Visual_CVPR_2018_paper.pdf) [[Code]](https://github.com/STVIR/pysot)**
@inproceedings{li2018high,
title={High performance visual tracking with siamese region proposal network},
author={Li, Bo and Yan, Junjie and Wu, Wei and Zhu, Zheng and Hu, Xiaolin},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={8971--8980},
year={2018}
}
SiamMask **[[Paper]](https://arxiv.org/pdf/1812.05050.pdf) [[Code]](https://github.com/foolwood/SiamMask)**
@inproceedings{wang2019fast,
title={Fast online object tracking and segmentation: A unifying approach},
author={Wang, Qiang and Zhang, Li and Bertinetto, Luca and Hu, Weiming and Torr, Philip HS},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={1328--1338},
year={2019}
}
ATOM **[[Paper]](https://arxiv.org/pdf/1811.07628.pdf) [[Raw results]](https://drive.google.com/drive/folders/1MdJtsgr34iJesAgL7Y_VelP8RvQm_IG_) [[Models]](https://drive.google.com/open?id=1EsNSQr25qfXHYLqjZaVZElbGdUg-nyzd) [[Training Code]](https://github.com/visionml/pytracking/blob/master/ltr/README.md#ATOM) [[Tracker Code]](https://github.com/visionml/pytracking/blob/master/pytracking/README.md#ATOM)**
@inproceedings{danelljan2019atom,
......
from .base_actor import BaseActor
from .bbreg import AtomActor
from .siamfc import SiamFCActor
from .siam import SiamActor
from . import BaseActor
import paddle.fluid as fluid
import numpy as np
class SiamActor(BaseActor):
""" Actor for training the SiamRPN/SiamMask"""
def __call__(self, data):
# Run network to obtain predictiion
pred = self.net(data['train_images'], data['test_images'])
# Compute loss
label_cls = fluid.layers.cast(x=data['label_cls'], dtype=np.int64)
cls_loss = self.objective['cls'](pred['cls'], label_cls)
loc_loss = self.objective['loc'](pred['loc'], data['label_loc'], data['label_loc_weight'])
loss = {}
loss['cls'] = cls_loss
loss['loc'] = loc_loss
# Return training stats
stats = {}
stats['Loss/cls'] = cls_loss.numpy()
stats['Loss/loc'] = loc_loss.numpy()
# Compute mask loss if necessary
if 'mask' in pred:
mask_loss, iou_m, iou_5, iou_7 = self.objective['mask'](
pred['mask'],
data['label_mask'],
data['label_mask_weight'])
loss['mask'] = mask_loss
stats['Loss/mask'] = mask_loss.numpy()
stats['Accuracy/mask_iou_mean'] = iou_m.numpy()
stats['Accuracy/mask_at_5'] = iou_5.numpy()
stats['Accuracy/mask_at_7'] = iou_7.numpy()
# Use scale loss if exists
scale_loss = getattr(self.net, "scale_loss", None)
if callable(scale_loss):
total_loss = scale_loss(loss)
else:
total_loss = 0
for k, v in loss.items():
total_loss += v
stats['Loss/total'] = total_loss.numpy()
return total_loss, stats
......@@ -16,7 +16,8 @@ def create_default_local_file():
'trackingnet_dir': empty_str,
'coco_dir': empty_str,
'imagenet_dir': empty_str,
'imagenetdet_dir': empty_str
'imagenetdet_dir': empty_str,
'youtubevos_dir': empty_str
})
comment = {
......
......@@ -9,3 +9,4 @@ class EnvironmentSettings:
self.coco_dir = ''
self.imagenet_dir = ''
self.imagenetdet_dir = ''
self.youtubevos_dir = ''
import math
import numpy as np
from collections import namedtuple
Corner = namedtuple('Corner', 'x1 y1 x2 y2')
# alias
BBox = Corner
Center = namedtuple('Center', 'x y w h')
def topleft2corner(topleft):
""" convert (x, y, w, h) to (x1, y1, x2, y2)
Args:
center: np.array (4 * N)
Return:
np.array (4 * N)
"""
x, y, w, h = topleft[0], topleft[1], topleft[2], topleft[3]
x1 = x
y1 = y
x2 = x + w
y2 = y + h
return x1, y1, x2, y2
def corner2center(corner):
""" convert (x1, y1, x2, y2) to (cx, cy, w, h)
Args:
conrner: Corner or np.array (4*N)
Return:
Center or np.array (4 * N)
"""
if isinstance(corner, Corner):
x1, y1, x2, y2 = corner
return Center((x1 + x2) * 0.5, (y1 + y2) * 0.5, (x2 - x1), (y2 - y1))
else:
x1, y1, x2, y2 = corner[0], corner[1], corner[2], corner[3]
x = (x1 + x2) * 0.5
y = (y1 + y2) * 0.5
w = x2 - x1
h = y2 - y1
return x, y, w, h
def center2corner(center):
""" convert (cx, cy, w, h) to (x1, y1, x2, y2)
Args:
center: Center or np.array (4 * N)
Return:
center or np.array (4 * N)
"""
if isinstance(center, Center):
x, y, w, h = center
return Corner(x - w * 0.5, y - h * 0.5, x + w * 0.5, y + h * 0.5)
else:
x, y, w, h = center[0], center[1], center[2], center[3]
x1 = x - w * 0.5
y1 = y - h * 0.5
x2 = x + w * 0.5
y2 = y + h * 0.5
return x1, y1, x2, y2
def IoU(rect1, rect2):
""" caculate interection over union
Args:
rect1: (x1, y1, x2, y2)
rect2: (x1, y1, x2, y2)
Returns:
iou
"""
# overlap
x1, y1, x2, y2 = rect1[0], rect1[1], rect1[2], rect1[3]
tx1, ty1, tx2, ty2 = rect2[0], rect2[1], rect2[2], rect2[3]
xx1 = np.maximum(tx1, x1)
yy1 = np.maximum(ty1, y1)
xx2 = np.minimum(tx2, x2)
yy2 = np.minimum(ty2, y2)
ww = np.maximum(0, xx2 - xx1)
hh = np.maximum(0, yy2 - yy1)
area = (x2 - x1) * (y2 - y1)
target_a = (tx2 - tx1) * (ty2 - ty1)
inter = ww * hh
iou = inter / (area + target_a - inter)
return iou
class Anchors:
"""
This class generate anchors.
"""
def __init__(self, stride, ratios, scales, image_center=0, size=0):
self.stride = stride
self.ratios = ratios
self.scales = scales
self.image_center = 0
self.size = 0
self.anchor_num = len(self.scales) * len(self.ratios)
self.anchors = None
self.generate_anchors()
def generate_anchors(self):
"""
generate anchors based on predefined configuration
"""
self.anchors = np.zeros((self.anchor_num, 4), dtype=np.float32)
size = self.stride * self.stride
count = 0
for r in self.ratios:
ws = int(math.sqrt(size * 1. / r))
hs = int(ws * r)
for s in self.scales:
w = ws * s
h = hs * s
self.anchors[count][:] = [-w * 0.5, -h * 0.5, w * 0.5, h * 0.5][:]
count += 1
def generate_all_anchors(self, im_c, size):
"""
im_c: image center
size: image size
"""
if self.image_center == im_c and self.size == size:
return False
self.image_center = im_c
self.size = size
a0x = im_c - size // 2 * self.stride
ori = np.array([a0x] * 4, dtype=np.float32)
zero_anchors = self.anchors + ori
x1 = zero_anchors[:, 0]
y1 = zero_anchors[:, 1]
x2 = zero_anchors[:, 2]
y2 = zero_anchors[:, 3]
x1, y1, x2, y2 = map(lambda x: x.reshape(self.anchor_num, 1, 1),
[x1, y1, x2, y2])
cx, cy, w, h = corner2center([x1, y1, x2, y2])
disp_x = np.arange(0, size).reshape(1, 1, -1) * self.stride
disp_y = np.arange(0, size).reshape(1, -1, 1) * self.stride
cx = cx + disp_x
cy = cy + disp_y
# broadcast
zero = np.zeros((self.anchor_num, size, size), dtype=np.float32)
cx, cy, w, h = map(lambda x: x + zero, [cx, cy, w, h])
x1, y1, x2, y2 = center2corner([cx, cy, w, h])
self.all_anchors = (np.stack([x1, y1, x2, y2]).astype(np.float32),
np.stack([cx, cy, w, h]).astype(np.float32))
return True
class AnchorTarget:
def __init__(self,
search_size,
output_size,
stride,
ratios,
scales,
num_pos,
num_neg,
num_total,
thr_high,
thr_low):
self.search_size = search_size
self.output_size = output_size
self.anchor_stride = stride
self.anchor_ratios = ratios
self.anchor_scales = scales
self.num_pos = num_pos
self.num_neg = num_neg
self.num_total = num_total
self.thr_high = thr_high
self.thr_low = thr_low
self.anchors = Anchors(stride,
ratios,
scales)
self.anchors.generate_all_anchors(im_c=search_size // 2,
size=output_size)
def __call__(self, target, size, neg=False):
anchor_num = len(self.anchor_ratios) * len(self.anchor_scales)
# -1 ignore 0 negative 1 positive
cls = -1 * np.ones((anchor_num, size, size), dtype=np.int64)
delta = np.zeros((4, anchor_num, size, size), dtype=np.float32)
delta_weight = np.zeros((anchor_num, size, size), dtype=np.float32)
def select(position, keep_num=16):
num = position[0].shape[0]
if num <= keep_num:
return position, num
slt = np.arange(num)
np.random.shuffle(slt)
slt = slt[:keep_num]
return tuple(p[slt] for p in position), keep_num
tcx, tcy, tw, th = corner2center(target)
if neg:
# l = size // 2 - 3
# r = size // 2 + 3 + 1
# cls[:, l:r, l:r] = 0
cx = size // 2
cy = size // 2
cx += int(np.ceil((tcx - self.search_size // 2) /
self.anchor_stride + 0.5))
cy += int(np.ceil((tcy - self.search_size // 2) /
self.anchor_stride + 0.5))
l = max(0, cx - 3)
r = min(size, cx + 4)
u = max(0, cy - 3)
d = min(size, cy + 4)
cls[:, u:d, l:r] = 0
neg, neg_num = select(np.where(cls == 0), self.num_neg)
cls[:] = -1
cls[neg] = 0
overlap = np.zeros((anchor_num, size, size), dtype=np.float32)
return cls, delta, delta_weight, overlap
anchor_box = self.anchors.all_anchors[0]
anchor_center = self.anchors.all_anchors[1]
x1, y1, x2, y2 = anchor_box[0], anchor_box[1], \
anchor_box[2], anchor_box[3]
cx, cy, w, h = anchor_center[0], anchor_center[1], \
anchor_center[2], anchor_center[3]
delta[0] = (tcx - cx) / w
delta[1] = (tcy - cy) / h
delta[2] = np.log(tw / w)
delta[3] = np.log(th / h)
overlap = IoU([x1, y1, x2, y2], target)
pos = np.where(overlap > self.thr_high)
neg = np.where(overlap < self.thr_low)
pos, pos_num = select(pos, self.num_pos)
neg, neg_num = select(neg, self.num_total - self.num_pos)
cls[pos] = 1
delta_weight[pos] = 1. / (pos_num + 1e-6)
cls[neg] = 0
return cls, delta, delta_weight, overlap
......@@ -12,7 +12,7 @@ def default_image_loader(path):
im = jpeg4py_loader(path)
if im is None:
default_image_loader.use_jpeg4py = False
print('Using opencv_loader instead.')
print('Jpeg4py is not available. Using OpenCV instead.')
else:
default_image_loader.use_jpeg4py = True
return im
......@@ -29,9 +29,9 @@ def jpeg4py_loader(path):
try:
return jpeg4py.JPEG(path).decode()
except Exception as e:
print('ERROR: Could not read image "{}"'.format(path))
print('ERROR: Jpeg4py could not read image "{}". Using OpenCV instead.'.format(path))
print(e)
return None
return opencv_loader(path)
def opencv_loader(path):
......@@ -41,7 +41,7 @@ def opencv_loader(path):
# convert to rgb and return
return cv.cvtColor(im, cv.COLOR_BGR2RGB)
except Exception as e:
print('ERROR: Could not read image "{}"'.format(path))
print('ERROR: OpenCV could not read image "{}"'.format(path))
print(e)
return None
......@@ -55,7 +55,7 @@ def lmdb_loader(path, lmdb_path=None):
img_buffer = np.frombuffer(img_buffer, np.uint8)
return cv.imdecode(img_buffer, cv.IMREAD_COLOR)
except Exception as e:
print('ERROR: Could not read image "{}"'.format(path))
print('ERROR: Lmdb could not read image "{}"'.format(path))
print(e)
return None
......
import os
import signal
import sys
import dataflow as df
import numpy as np
# handle terminate reader process, do not print stack frame
def _reader_quit(signum, frame):
print("Reader process exit.")
sys.exit()
def _term_group(sig_num, frame):
print('pid {} terminated, terminate group '
'{}...'.format(os.getpid(), os.getpgrp()))
os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)
signal.signal(signal.SIGTERM, _reader_quit)
signal.signal(signal.SIGINT, _term_group)
class LTRLoader(df.DataFlow):
"""
Data loader. Combines a dataset and a sampler, and provides
......
......@@ -2,6 +2,7 @@ import numpy as np
from ltr.data import transforms
import ltr.data.processing_utils as prutils
from ltr.data.anchor import AnchorTarget
from pytracking.libs import TensorDict
......@@ -113,6 +114,148 @@ class SiamFCProcessing(BaseProcessing):
return data
class SiamProcessing(BaseProcessing):
def __init__(self,
search_area_factor,
output_sz,
center_jitter_factor,
scale_jitter_factor,
label_params,
mode='pair',
scale_type='context',
border_type='meanpad',
*args,
**kwargs):
self._init_transform(*args, **kwargs)
self.search_area_factor = search_area_factor
self.output_sz = output_sz
self.center_jitter_factor = center_jitter_factor
self.scale_jitter_factor = scale_jitter_factor
self.mode = mode
self.scale_type = scale_type
self.border_type = border_type
self.label_params = label_params
self.anchor_target = AnchorTarget(
label_params['search_size'],
label_params['output_size'],
label_params['anchor_stride'],
label_params['anchor_ratios'],
label_params['anchor_scales'],
label_params['num_pos'],
label_params['num_neg'],
label_params['num_total'],
label_params['thr_high'],
label_params['thr_low'])
def _init_transform(self,
transform=transforms.ToArray(),
train_transform=None,
test_transform=None,
train_mask_transform=None,
test_mask_transform=None,
joint_transform=None):
self.transform = {'train': transform if train_transform is None else train_transform,
'test': transform if test_transform is None else test_transform,
'joint': joint_transform}
super().__init__(
transform=transform,
train_transform=train_transform,
test_transform=test_transform,
joint_transform=joint_transform)
self.transform['train_mask'] = self.transform['train'] if train_mask_transform is None \
else train_mask_transform
self.transform['test_mask'] = self.transform['test'] if test_mask_transform is None \
else test_mask_transform
def _get_jittered_box(self, box, mode, rng):
jittered_size = box[2:4] * (1 + (2 * rng.rand(2) - 1) * self.scale_jitter_factor[mode])
max_offset = (np.sqrt(jittered_size.prod()) * self.center_jitter_factor[mode])
jittered_center = box[0:2] + 0.5 * box[2:4] + max_offset * (rng.rand(2) - 0.5)
return np.concatenate((jittered_center - 0.5 * jittered_size, jittered_size), axis=0)
def _get_label(self, target_bb, neg):
return self.anchor_target(target_bb, self.label_params['output_size'], neg)
def __call__(self, data: TensorDict, rng=None):
neg = data['neg']
# Apply joint transforms
if self.transform['joint'] is not None:
num_train_images = len(data['train_images'])
all_images = data['train_images'] + data['test_images']
all_images_trans = self.transform['joint'](*all_images)
data['train_images'] = all_images_trans[:num_train_images]
data['test_images'] = all_images_trans[num_train_images:]
for s in ['train', 'test']:
assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
"In pair mode, num train/test frames must be 1"
# Add a uniform noise to the center pos
jittered_anno = [self._get_jittered_box(a, s, rng) for a in data[s + '_anno']]
# Crop image region centered at jittered_anno box
try:
crops, boxes = prutils.jittered_center_crop(
data[s + '_images'],
jittered_anno,
data[s + '_anno'],
self.search_area_factor[s],
self.output_sz[s],
scale_type=self.scale_type,
border_type=self.border_type)
mask_crops, _ = prutils.jittered_center_crop(
data[s + '_masks'],
jittered_anno,
data[s + '_anno'],
self.search_area_factor[s],
self.output_sz[s],
scale_type=self.scale_type,
border_type='zeropad')
except Exception as e:
print('{}, anno: {}'.format(data['dataset'], data[s + '_anno']))
raise e
# Apply transforms
data[s + '_images'] = [self.transform[s](x) for x in crops]
data[s + '_anno'] = boxes
data[s + '_masks'] = [self.transform[s + '_mask'](x) for x in mask_crops]
# Prepare output
if self.mode == 'sequence':
data = data.apply(prutils.stack_tensors)
else:
data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
# Get labels
if self.label_params is not None:
assert data['test_anno'].shape[0] == 1
gt_box = data['test_anno'][0]
gt_box[2:] += gt_box[:2]
cls, delta, delta_weight, overlap = self._get_label(gt_box, neg)
mask = data['test_masks'][0]
if np.sum(mask) > 0:
mask_weight = cls.max(axis=0, keepdims=True)
else:
mask_weight = np.zeros([1, cls.shape[1], cls.shape[2]], dtype=np.float32)
mask = (mask > 0.5) * 2. - 1.
data['label_cls'] = cls
data['label_loc'] = delta
data['label_loc_weight'] = delta_weight
data['label_mask'] = mask
data['label_mask_weight'] = mask_weight
data.pop('train_anno')
data.pop('test_anno')
data.pop('train_masks')
data.pop('test_masks')
return data
class ATOMProcessing(BaseProcessing):
""" The processing class used for training ATOM. The images are processed in the following way.
First, the target bounding box is jittered by adding some noise. Next, a square region (called search region )
......
import random
import numpy as np
import dataflow as df
from pytracking.libs import TensorDict
......@@ -178,3 +179,267 @@ class ATOMSampler(df.RNGDataFlow):
# Send for processing
yield self.processing(data, rng=self.rng)
class MaskSampler(df.RNGDataFlow):
""" Class responsible for sampling frames from training sequences to form batches. Each training sample is a
tuple consisting of i) a train frame, used to obtain the modulation vector, and ii) a set of test frames on which
the IoU prediction loss is calculated.
The sampling is done in the following ways. First a dataset is selected at random. Next, a sequence is selected
from that dataset. A 'train frame' is then sampled randomly from the sequence. Next, depending on the
frame_sample_mode, the required number of test frames are sampled randomly, either from the range
[train_frame_id - max_gap, train_frame_id + max_gap] in the 'default' mode, or from [train_frame_id, train_frame_id + max_gap]
in the 'causal' mode. Only the frames in which the target is visible are sampled, and if enough visible frames are
not found, the 'max_gap' is incremented.
The sampled frames are then passed through the input 'processing' function for the necessary processing-
"""
def __init__(self,
datasets,
p_datasets,
samples_per_epoch,
max_gap,
num_test_frames=1,
processing=no_processing,
frame_sample_mode='default',
neg=0):
"""
args:
datasets - List of datasets to be used for training
p_datasets - List containing the probabilities by which each dataset will be sampled
samples_per_epoch - Number of training samples per epoch
max_gap - Maximum gap, in frame numbers, between the train (reference) frame and the test frames.
num_test_frames - Number of test frames used for calculating the rpn/mask prediction loss.
processing - An instance of Processing class which performs the necessary processing of the data.
frame_sample_mode - Either 'default' or 'causal'. If 'causal', then the test frames are sampled in a causal
manner.
neg - Probability of sampling a negative sample pair.
"""
self.datasets = datasets
# If p not provided, sample uniformly from all videos
if p_datasets is None:
p_datasets = [1 for d in self.datasets]
# Normalize
p_total = sum(p_datasets)
self.p_datasets = [x / p_total for x in p_datasets]
self.samples_per_epoch = samples_per_epoch
self.max_gap = max_gap
self.num_test_frames = num_test_frames
self.num_train_frames = 1 # Only a single train frame allowed
self.processing = processing
self.frame_sample_mode = frame_sample_mode
self.neg = neg
def __len__(self):
return self.samples_per_epoch
def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None):
""" Samples num_ids frames between min_id and max_id for which target is visible
args:
visible - 1d Tensor indicating whether target is visible for each frame
num_ids - number of frames to be samples
min_id - Minimum allowed frame number
max_id - Maximum allowed frame number
returns:
list - List of sampled frame numbers. None if not sufficient visible frames could be found.
"""
if min_id is None or min_id < 0:
min_id = 0
if max_id is None or max_id > len(visible):
max_id = len(visible)
valid_ids = [i for i in range(min_id, max_id) if visible[i]]
# No visible ids
if len(valid_ids) == 0:
return None
inds = self.rng.choice(range(len(valid_ids)), size=num_ids, replace=True)
ids = [valid_ids[ii] for ii in inds]
# return random.choices(valid_ids, k=num_ids)
return ids
def has_mask(self, dataset):
return dataset.get_name() in ['coco', 'youtubevos']
def _get_positive_pair(self, dataset):
is_video_dataset = dataset.is_video_sequence()
min_visible_frames = 2 * (self.num_test_frames + self.num_train_frames)
enough_visible_frames = False
# Sample a sequence with enough visible frames and get anno for the same
while not enough_visible_frames:
seq_id = self.rng.randint(0, dataset.get_num_sequences() - 1)
anno, visible = dataset.get_sequence_info(seq_id)
num_visible = np.sum(visible.astype('int64'))
enough_visible_frames = not is_video_dataset or (
num_visible > min_visible_frames and len(visible) >= 20)
if is_video_dataset:
train_frame_ids = None
test_frame_ids = None
gap_increase = 0
if self.frame_sample_mode == 'default':
# Sample frame numbers
while test_frame_ids is None:
train_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_train_frames)
test_frame_ids = self._sample_visible_ids(
visible,
min_id=train_frame_ids[0] - self.max_gap - gap_increase,
max_id=train_frame_ids[0] + self.max_gap + gap_increase,
num_ids=self.num_test_frames)
gap_increase += 5 # Increase gap until a frame is found
elif self.frame_sample_mode == 'causal':
# Sample frame numbers in a causal manner, i.e. test_frame_ids > train_frame_ids
while test_frame_ids is None:
base_frame_id = self._sample_visible_ids(
visible,
num_ids=1,
min_id=self.num_train_frames - 1,
max_id=len(visible) - self.num_test_frames)
prev_frame_ids = self._sample_visible_ids(
visible, num_ids=self.num_train_frames - 1,
min_id=base_frame_id[0] - self.max_gap - gap_increase,
max_id=base_frame_id[0])
if prev_frame_ids is None:
gap_increase += 5
continue
train_frame_ids = base_frame_id + prev_frame_ids
test_frame_ids = self._sample_visible_ids(
visible, min_id=train_frame_ids[0] + 1,
max_id=train_frame_ids[0] + self.max_gap + gap_increase,
num_ids=self.num_test_frames)
gap_increase += 5 # Increase gap until a frame is found
else:
raise ValueError('Unknown frame_sample_mode.')
else:
train_frame_ids = [1] * self.num_train_frames
test_frame_ids = [1] * self.num_test_frames
return seq_id, train_frame_ids, test_frame_ids, anno
def _get_random_pair(self, train_dataset, test_dataset):
is_video_dataset = train_dataset.is_video_sequence()
min_visible_frames = self.num_train_frames
enough_visible_frames = False
# Sample a sequence with enough visible frames and get anno for the same
while not enough_visible_frames:
train_seq_id = self.rng.randint(0, train_dataset.get_num_sequences() - 1)
train_anno, visible = train_dataset.get_sequence_info(train_seq_id)
num_visible = np.sum(visible.astype('int64'))
enough_visible_frames = not is_video_dataset or (
num_visible > min_visible_frames and len(visible) >= 20)
if is_video_dataset:
# Sample frame numbers
train_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_train_frames)
else:
train_frame_ids = [1] * self.num_train_frames
is_video_dataset = test_dataset.is_video_sequence()
min_visible_frames = self.num_test_frames
enough_visible_frames = False
# Sample a sequence with enough visible frames and get anno for the same
while not enough_visible_frames:
test_seq_id = self.rng.randint(0, test_dataset.get_num_sequences() - 1)
test_anno, visible = test_dataset.get_sequence_info(test_seq_id)
num_visible = np.sum(visible.astype('int64'))
enough_visible_frames = not is_video_dataset or (
num_visible > min_visible_frames and len(visible) >= 20)
if is_video_dataset:
# Sample frame numbers
test_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_test_frames)
else:
test_frame_ids = [1] * self.num_test_frames
return train_seq_id, test_seq_id, train_frame_ids, test_frame_ids, train_anno, test_anno
def __iter__(self):
"""
args:
index (int): Index (Ignored since we sample randomly)
returns:
TensorDict - dict containing all the data blocks
"""
neg = self.neg and self.neg > random.random()
# Select a dataset
if neg:
dataset_idx = self.rng.choice(
range(len(self.datasets)),
p=self.p_datasets,
replace=False)
train_dataset = self.datasets[dataset_idx]
dataset_idx = self.rng.choice(
range(len(self.datasets)),
p=self.p_datasets,
replace=False)
test_dataset = self.datasets[dataset_idx]
train_seq_id, test_seq_id, train_frame_ids, test_frame_ids, train_anno, test_anno = \
self._get_random_pair(train_dataset, test_dataset)
# Get frames
train_frames, train_anno, _ = train_dataset.get_frames(
train_seq_id,
train_frame_ids,
train_anno)
train_masks = [np.zeros([frame.shape[0], frame.shape[1], 1], dtype=np.float32)
for frame in train_frames]
test_frames, test_anno, _ = test_dataset.get_frames(
test_seq_id,
test_frame_ids,
test_anno)
test_masks = [np.zeros([frame.shape[0], frame.shape[1], 1], dtype=np.float32)
for frame in test_frames]
else:
dataset_idx = self.rng.choice(
range(len(self.datasets)),
p=self.p_datasets,
replace=False)
dataset = self.datasets[dataset_idx]
seq_id, train_frame_ids, test_frame_ids, anno = self._get_positive_pair(dataset)
# Get frames
if self.has_mask(dataset):
train_frames, train_anno, train_masks, _ = dataset.get_frames_mask(
seq_id, train_frame_ids, anno)
test_frames, test_anno, test_masks, _ = dataset.get_frames_mask(
seq_id, test_frame_ids, anno)
else:
train_frames, train_anno, _ = dataset.get_frames(
seq_id, train_frame_ids, anno)
train_masks = [np.zeros([frame.shape[0], frame.shape[1], 1], dtype=np.float32)
for frame in train_frames]
test_frames, test_anno, _ = dataset.get_frames(seq_id, test_frame_ids, anno)
test_masks = [np.zeros([frame.shape[0], frame.shape[1], 1], dtype=np.float32)
for frame in test_frames]
# Prepare data
data = TensorDict({
'train_images': train_frames,
'train_anno': train_anno,
'train_masks': train_masks,
'test_images': test_frames,
'test_anno': test_anno,
'test_masks': test_masks,
'neg': neg
})
# Send for processing
yield self.processing(data, rng=self.rng)
......@@ -80,6 +80,19 @@ class Normalize(object):
return (tensor - self.mean) / self.std
class Transpose(Transform):
""" Transpose image."""
def __call__(self, img):
if len(img.shape) == 3:
img = img.transpose((2, 0, 1))
elif len(img.shape) == 2:
img = np.expand_dims(img, axis=0)
else:
raise NotImplementedError
return img.astype('float32')
class ToArray(Transform):
""" Transpose image and jitter brightness"""
......@@ -146,3 +159,53 @@ class RandomHorizontalFlip(Transform):
return layers.reverse(img, 2)
return np.fliplr(img).copy()
return img
class Blur(Transform):
""" Blur the image by applying a random kernel."""
def __init__(self, probability=0.5):
self.probability = probability
def roll(self):
return random.random() < self.probability
def transform(self, img, do_blur):
def rand_kernel():
sizes = np.arange(5, 46, 2)
size = np.random.choice(sizes)
kernel = np.zeros((size, size))
c = int(size/2)
wx = np.random.random()
kernel[:, c] += 1. / size * wx
kernel[c, :] += 1. / size * (1-wx)
return kernel
if do_blur:
kernel = rand_kernel()
img = cv.filter2D(img, -1, kernel)
return img
class Color(Transform):
""" Blur the image by applying a random kernel."""
def __init__(self, probability=1):
self.probability = probability
self.rgbVar = np.array(
[
[-0.55919361, 0.98062831, - 0.41940627],
[1.72091413, 0.19879334, - 1.82968581],
[4.64467907, 4.73710203, 4.88324118]
],
dtype=np.float32)
def roll(self):
return random.random() < self.probability
def transform(self, img, do_color_aug):
if do_color_aug:
offset = np.dot(self.rgbVar, np.random.randn(3, 1))
offset = offset.reshape(3)
img = img - offset
return img
......@@ -2,7 +2,8 @@ from .lasot import Lasot
from .got10k import Got10k
from .tracking_net import TrackingNet
from .imagenetvid import ImagenetVID
from .imagenetdet import ImagenetDET
from .coco_seq import MSCOCOSeq
from .vot import VOT
from .youtube_vos import VOS
from .youtube_vos import YoutubeVOS
from .youtube_bb import YoutubeBB
......@@ -85,12 +85,19 @@ class MSCOCOSeq(BaseDataset):
anno = self.coco_set.anns[self.sequence_list[seq_id]]['bbox']
return np.reshape(np.array(anno), (1, 4))
def _get_frames(self, seq_id):
def _get_frames(self, seq_id, mask=False):
path = self.coco_set.loadImgs(
[self.coco_set.anns[self.sequence_list[seq_id]]['image_id']])[0][
'file_name']
img = self.image_loader(os.path.join(self.img_pth, path))
return img
if mask:
ann = self.coco_set.anns[self.sequence_list[seq_id]]
im_mask = (self.coco_set.annToMask(ann).astype(np.float32) > 0.5).astype(np.float32)
im_mask = np.expand_dims(im_mask, axis=2)
return img, im_mask
else:
return img
def get_meta_info(self, seq_id):
try:
......@@ -128,3 +135,21 @@ class MSCOCOSeq(BaseDataset):
object_meta = self.get_meta_info(seq_id)
return frame_list, anno_frames, object_meta
def get_frames_mask(self, seq_id=None, frame_ids=None, anno=None):
# COCO is an image dataset. Thus we replicate the image denoted by seq_id len(frame_ids) times, and return a
# list containing these replicated images.
frame, mask = self._get_frames(seq_id, mask=True)
frame_list = [frame.copy() for _ in frame_ids]
mask_list = [mask.copy() for _ in frame_ids]
if anno is None:
anno = self._get_anno(seq_id)
anno_frames = [anno.copy()[0, :] for _ in frame_ids]
object_meta = self.get_meta_info(seq_id)
return frame_list, anno_frames, mask_list, object_meta
import os
import numpy as np
from .base_dataset import BaseDataset
from ltr.data.image_loader import default_image_loader
import xml.etree.ElementTree as ET
import glob
import json
from collections import OrderedDict
import nltk
from nltk.corpus import wordnet
from ltr.admin.environment import env_settings
class ImagenetDET(BaseDataset):
""" Imagenet DET dataset.
Publication:
ImageNet Large Scale Visual Recognition Challenge
Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy,
Aditya Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei
IJCV, 2015
https://arxiv.org/pdf/1409.0575.pdf
Download the dataset from http://image-net.org/
"""
def __init__(self, root=None, filter=None, image_loader=default_image_loader):
"""
args:
root - path to the imagenet det dataset.
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
is used by default.
"""
root = env_settings().imagenetdet_dir if root is None else root
super().__init__(root, image_loader)
self.filter = filter
self.set_list = ['ILSVRC2013_train', 'ILSVRC2014_train_0000',
'ILSVRC2014_train_0001', 'ILSVRC2014_train_0002',
'ILSVRC2014_train_0003', 'ILSVRC2014_train_0004',
'ILSVRC2014_train_0005', 'ILSVRC2014_train_0006']
cache_file = os.path.join(root, 'cache.json')
if os.path.isfile(cache_file):
# If available, load the pre-processed cache file containing meta-info for each sequence
with open(cache_file, 'r') as f:
sequence_list_dict = json.load(f)
self.sequence_list = sequence_list_dict
else:
# Else process the imagenet annotations and generate the cache file
self.sequence_list = self._process_anno(root)
with open(cache_file, 'w') as f:
json.dump(self.sequence_list, f)
def is_video_sequence(self):
return False
def get_name(self):
return 'imagenetdet'
def get_num_sequences(self):
return len(self.sequence_list)
def get_sequence_info(self, seq_id):
anno = self._get_anno(seq_id)
target_visible = (anno[:, 2] > 0) & (anno[:, 3] > 0)
if self.filter:
target_large = (anno[:, 2] * anno[:, 3] > 30 * 30)
ratio = anno[:, 2] / anno[:, 3]
target_reasonable_ratio = (10 > ratio) & (ratio > 0.1)
target_visible = target_visible & target_reasonable_ratio & target_large
return anno, target_visible
def _get_anno(self, seq_id):
anno = self.sequence_list[seq_id]['anno']
return np.reshape(np.array(anno), (1, 4))
def _get_frames(self, seq_id):
set_name = self.set_list[self.sequence_list[seq_id]['set_id']]
folder = self.sequence_list[seq_id]['folder']
if folder == set_name:
folder = ''
filename = self.sequence_list[seq_id]['filename']
frame_path = os.path.join(self.root, 'Data', 'DET', 'train', set_name, folder,
'{:s}.JPEG'.format(filename))
return self.image_loader(frame_path)
def get_frames(self, seq_id, frame_ids, anno=None):
# ImageNet DET is an image dataset. Thus we replicate the image denoted by seq_id len(frame_ids) times, and return a
# list containing these replicated images.
frame = self._get_frames(seq_id)
frame_list = [frame.copy() for _ in frame_ids]
if anno is None:
anno = self._get_anno(seq_id)
anno_frames = [anno.copy()[0, :] for _ in frame_ids]
object_meta = OrderedDict({'object_class': self.sequence_list[seq_id]['class_name'],
'motion_class': None,
'major_class': None,
'root_class': None,
'motion_adverb': None})
return frame_list, anno_frames, object_meta
def _process_anno(self, root):
# Builds individual tracklets
base_det_anno_path = os.path.join(root, 'Annotations', 'DET', 'train')
all_sequences = []
for set_id, set in enumerate(self.set_list):
if set_id == 0:
xmls = sorted(glob.glob(os.path.join(base_det_anno_path, set, '*', '*.xml')))
else:
xmls = sorted(glob.glob(os.path.join(base_det_anno_path, set, '*.xml')))
for xml in xmls:
xmltree = ET.parse(xml)
folder = xmltree.find('folder').text
filename = xmltree.find('filename').text
image_size = [int(xmltree.find('size/width').text), int(xmltree.find('size/height').text)]
objects = xmltree.findall('object')
# Find all objects
for id, object_iter in enumerate(objects):
bndbox = object_iter.find('bndbox')
x1 = int(bndbox.find('xmin').text)
y1 = int(bndbox.find('ymin').text)
x2 = int(bndbox.find('xmax').text)
y2 = int(bndbox.find('ymax').text)
object_anno = [x1, y1, x2 - x1, y2 - y1]
class_name = None
if x2 <= x1 or y2 <= y1:
continue
new_sequence = {'set_id': set_id, 'folder': folder, 'filename': filename,
'class_name': class_name, 'anno': object_anno, 'image_size': image_size}
all_sequences.append(new_sequence)
return all_sequences
......@@ -2,151 +2,244 @@ import os
from .base_dataset import BaseDataset
from ltr.data.image_loader import default_image_loader
import numpy as np
import cv2 as cv
import json
import cv2
from collections import OrderedDict
from ltr.admin.environment import env_settings
def get_axis_aligned_bbox(region):
region = np.array(region)
if len(region.shape) == 3:
# region (1,4,2)
region = np.array([
region[0][0][0], region[0][0][1], region[0][1][0], region[0][1][1],
region[0][2][0], region[0][2][1], region[0][3][0], region[0][3][1]
])
cx = np.mean(region[0::2])
cy = np.mean(region[1::2])
x1 = min(region[0::2])
x2 = max(region[0::2])
y1 = min(region[1::2])
y2 = max(region[1::2])
A1 = np.linalg.norm(region[0:2] - region[2:4]) * np.linalg.norm(region[
2:4] - region[4:6])
A2 = (x2 - x1) * (y2 - y1)
s = np.sqrt(A1 / A2)
if s is np.nan:
x11, y11, w, h = 0, 0, 0, 0
else:
w = s * (x2 - x1) + 1
h = s * (y2 - y1) + 1
def get_target_to_image_ratio(seq):
anno = np.array(seq['anno'])
img_sz = np.array(seq['image_size'])
return np.sqrt(anno[0, 2:4].prod() / (img_sz.prod()))
class Instance(object):
instID = 0
pixelCount = 0
def __init__(self, imgNp, instID):
if (instID ==0 ):
return
self.instID = int(instID)
self.pixelCount = int(self.getInstancePixels(imgNp, instID))
def getInstancePixels(self, imgNp, instLabel):
return (imgNp == instLabel).sum()
x11 = cx - w // 2
y11 = cy - h // 2
return x11, y11, w, h
def toDict(self):
buildDict = {}
buildDict["instID"] = self.instID
buildDict["pixelCount"] = self.pixelCount
return buildDict
def __str__(self):
return "("+str(self.instID)+")"
class VOS(BaseDataset):
def __init__(self, root=None, image_loader=default_image_loader):
# root = env_settings().vot_dir if root is None else root
assert root is not None
def xyxy_to_xywh(xyxy):
"""Convert [x1 y1 x2 y2] box format to [x1 y1 w h] format."""
if isinstance(xyxy, (list, tuple)):
# Single box given as a list of coordinates
assert len(xyxy) == 4
x1, y1 = xyxy[0], xyxy[1]
w = xyxy[2] - x1 + 1
h = xyxy[3] - y1 + 1
return (x1, y1, w, h)
elif isinstance(xyxy, np.ndarray):
# Multiple boxes given as a 2D ndarray
return np.hstack((xyxy[:, 0:2], xyxy[:, 2:4] - xyxy[:, 0:2] + 1))
else:
raise TypeError('Argument xyxy must be a list, tuple, or numpy array.')
def polys_to_boxes(polys):
"""Convert a list of polygons into an array of tight bounding boxes."""
boxes_from_polys = np.zeros((len(polys), 4), dtype=np.float32)
for i in range(len(polys)):
poly = polys[i]
x0 = min(min(p[::2]) for p in poly)
x1 = max(max(p[::2]) for p in poly)
y0 = min(min(p[1::2]) for p in poly)
y1 = max(max(p[1::2]) for p in poly)
boxes_from_polys[i, :] = [x0, y0, x1, y1]
return boxes_from_polys
class YoutubeVOS(BaseDataset):
""" Youtube-VOS dataset.
Publication:
https://arxiv.org/pdf/
Download the dataset from https://youtube-vos.org/dataset/download
"""
def __init__(self, root=None, filter=None, image_loader=default_image_loader, min_length=1, max_target_area=1):
"""
args:
root - path to the youtube-vos dataset.
image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
is used by default.
min_length - Minimum allowed sequence length.
max_target_area - max allowed ratio between target area and image area. Can be used to filter out targets
which cover complete image.
"""
root = env_settings().youtubevos_dir if root is None else root
super().__init__(root, image_loader)
with open(os.path.join(self.root, 'meta.json')) as f:
self.meta = json.load(f)['videos']
self.sequence_list = self._get_sequence_list()
self.ann = self._get_annotations()
def _get_sequence_list(self):
seq_list = []
videos = self.meta.keys()
for v in videos:
objs = self.meta[v]['objects'].keys()
for o in objs:
if "rotate_box" in self.meta[v]['objects'][o]:
seq_list.append((v, o))
assert len(seq_list) > 0
return seq_list
def _get_annotations(self):
ann = {}
for seq in self.sequence_list:
ann[seq] = {'bbox': [], 'rbb': []}
polygons = self.meta[seq[0]]['objects'][seq[1]]['rotate_box']
for vs in polygons:
if len(vs) == 4:
polys = [
vs[0], vs[1] + vs[3] - 1, vs[0], vs[1],
vs[0] + vs[2] - 1, vs[1], vs[0] + vs[2] - 1,
vs[1] + vs[3] - 1
]
else:
polys = vs
if not np.all(polys == 0):
box = get_axis_aligned_bbox(polys)
rbb = cv.minAreaRect(
np.int0(np.array(polys).reshape((-1, 2))))
else:
box = np.array([0, 0, 0, 0])
rbb = ((0, 0), (0, 0), 0)
if box[2] * box[3] > 500 * 500:
print(box)
# assume small rotation angle, switch height, width
if rbb[2] < -45:
angle = rbb[2] + 90
height = rbb[1][0]
width = rbb[1][1]
else:
angle = rbb[2]
height = rbb[1][1]
width = rbb[1][0]
rbb = [rbb[0][0], rbb[0][1], width, height, angle]
ann[seq]['bbox'].append(box)
ann[seq]['rbb'].append(rbb)
return ann
def is_video_sequence(self):
return True
cache_file = os.path.join(root, 'cache.json')
if os.path.isfile(cache_file):
# If available, load the pre-processed cache file containing meta-info for each sequence
with open(cache_file, 'r') as f:
sequence_list_dict = json.load(f)
self.sequence_list = sequence_list_dict
else:
# Else process the youtube-vos annotations and generate the cache file
print('processing the youtube-vos annotations...')
self.sequence_list = self._process_anno(root)
with open(cache_file, 'w') as f:
json.dump(self.sequence_list, f)
print('cache file generated!')
# Filter the sequences based on min_length and max_target_area in the first frame
self.sequence_list = [x for x in self.sequence_list if len(x['anno']) >= min_length and
get_target_to_image_ratio(x) < max_target_area]
self.filter = filter
def get_name(self):
return 'vot'
return 'youtubevos'
def get_num_sequences(self):
return len(self.sequence_list)
def get_sequence_info(self, seq_id):
anno = self._get_anno(seq_id)
anno = np.array(self.sequence_list[seq_id]['anno'])
target_visible = (anno[:, 2] > 0) & (anno[:, 3] > 0)
target_large = (anno[:, 2] * anno[:, 3] > 30 * 30)
target_resonable = (anno[:, 2] * anno[:, 3] < 500 * 500)
return anno, target_visible & target_large & target_resonable
def _get_anno(self, seq_id):
anno = self.ann[self.sequence_list[seq_id]]['bbox']
return np.reshape(np.array(anno), (-1, 4))
def get_meta_info(self, seq_id):
object_meta = OrderedDict({
'object_class': None,
'motion_class': None,
'major_class': None,
'root_class': None,
'motion_adverb': None
})
return object_meta
def _get_frame_path(self, seq_id, frame_id):
v, o = self.sequence_list[seq_id]
frame_name = self.meta[v]['objects'][o]['frames'][frame_id]
return os.path.join(self.root, 'JPEGImages', v,
'{}.jpg'.format(frame_name)) # frames start from 1
def _get_frame(self, seq_id, frame_id):
return self.image_loader(self._get_frame_path(seq_id, frame_id))
def get_frames(self, seq_id=None, frame_ids=None, anno=None):
frame_list = [self._get_frame(seq_id, f_id) for f_id in frame_ids]
if self.filter is not None:
target_large = (anno[:, 2] * anno[:, 3] > 30 * 30)
ratio = anno[:, 2] / anno[:, 3]
target_reasonable_ratio = (10 > ratio) & (ratio > 0.1)
target_visible = target_visible & target_reasonable_ratio & target_large
return anno, target_visible
def _get_frame(self, sequence, frame_id):
vid_name = sequence['video']
frame_number = sequence['frames'][frame_id]
frame_path = os.path.join(self.root, 'train', 'JPEGImages', vid_name,
'{:05d}.jpg'.format(frame_number))
return self.image_loader(frame_path)
def _get_mask(self, sequence, frame_id):
vid_name = sequence['video']
frame_number = sequence['frames'][frame_id]
id = sequence['id']
mask_path = os.path.join(self.root, 'train', 'Annotations', vid_name,
'{:05d}.png'.format(frame_number))
mask = cv2.imread(mask_path, 0)
mask = (mask == id).astype(np.float32)
mask = np.expand_dims(mask, axis=2)
return mask
def get_frames(self, seq_id, frame_ids, anno=None):
sequence = self.sequence_list[seq_id]
frame_list = [self._get_frame(sequence, f) for f in frame_ids]
if anno is None:
anno = self._get_anno(seq_id)
anno = sequence['anno']
# Return as list of tensors
anno_frames = [anno[f_id, :] for f_id in frame_ids]
object_meta = self.get_meta_info(seq_id)
# added the class info to the meta info
object_meta = OrderedDict({'object_class': sequence['class_name'],
'motion_class': None,
'major_class': None,
'root_class': None,
'motion_adverb': None})
return frame_list, anno_frames, object_meta
def get_frames_mask(self, seq_id, frame_ids, anno=None):
sequence = self.sequence_list[seq_id]
frame_list = [self._get_frame(sequence, f) for f in frame_ids]
mask_list = [self._get_mask(sequence, f) for f in frame_ids]
if anno is None:
anno = sequence['anno']
# Return as list of tensors
anno_frames = [anno[f_id, :] for f_id in frame_ids]
# added the class info to the meta info
object_meta = OrderedDict({'object_class': sequence['class_name'],
'motion_class': None,
'major_class': None,
'root_class': None,
'motion_adverb': None})
return frame_list, anno_frames, mask_list, object_meta
def _process_anno(self, root):
# Builds individual tracklets
base_anno_path = os.path.join(root, 'train', 'Annotations')
num_obj = 0
num_ann = 0
all_sequences = []
meta = json.load(open(os.path.join(base_anno_path, '../meta.json')))
for vid_id, video in enumerate(meta['videos']):
v = meta['videos'][video]
frames = []
objects = dict()
for obj in v['objects']:
o = v['objects'][obj]
frames.extend(o['frames'])
frames = sorted(set(frames))
for frame in frames:
file_name = os.path.join(video, frame)
img = cv2.imread(os.path.join(base_anno_path, file_name+'.png'), 0)
h, w = img.shape[:2]
image_size = [w, h]
for instanceId in np.unique(img):
if instanceId == 0:
continue
instanceObj = Instance(img, instanceId)
instanceObj_dict = instanceObj.toDict()
mask = (img == instanceId).astype(np.uint8)
if cv2.__version__[0] == '3':
_, contour, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
else:
contour, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
polygons = [c.reshape(-1).tolist() for c in contour]
instanceObj_dict['contours'] = [p for p in polygons if len(p) > 4]
if len(instanceObj_dict['contours']) and instanceObj_dict['pixelCount'] > 1000:
len_p = [len(p) for p in instanceObj_dict['contours']]
if min(len_p) <= 4:
print('Warning: invalid contours.')
continue # skip non-instance categories
bbox = xyxy_to_xywh(
polys_to_boxes([instanceObj_dict['contours']])).tolist()[0]
if instanceId not in objects:
objects[instanceId] = \
{'anno': [], 'frames': [], 'image_size': image_size}
objects[instanceId]['anno'].append(bbox)
objects[instanceId]['frames'].append(int(frame))
for obj in objects:
new_sequence = {'video': video, 'id': int(obj), 'class_name': None,
'frames': objects[obj]['frames'], 'anno': objects[obj]['anno'],
'image_size': image_size}
all_sequences.append(new_sequence)
print('Youtube-VOS: ', len(all_sequences))
return all_sequences
import os
from paddle import fluid
from paddle.fluid.dygraph import nn
from ltr.admin.environment import env_settings
CURRENT_DIR = os.path.dirname(__file__)
class alexnet(fluid.dygraph.Layer):
def __init__(self, name, is_test, output_layers):
super(alexnet, self).__init__()
self.is_test = is_test
self.layer_init()
self.output_layers = output_layers
def layer_init(self):
# for conv1
self.conv1 = nn.Conv2D(
num_channels=3,
num_filters=96,
filter_size=11,
stride=2,
padding=0,
groups=1,
param_attr=self.weight_init(),
bias_attr=self.bias_init())
self.bn1 = nn.BatchNorm(
num_channels=96,
is_test=self.is_test,
param_attr=self.norm_weight_init(),
bias_attr=self.bias_init(),
use_global_stats=self.is_test)
self.pool1 = nn.Pool2D(
pool_size=3, pool_type="max", pool_stride=2, pool_padding=0)
# for conv2
self.conv2 = nn.Conv2D(
num_channels=96,
num_filters=256,
filter_size=5,
stride=1,
padding=0,
groups=1,
param_attr=self.weight_init(),
bias_attr=self.bias_init())
self.bn2 = nn.BatchNorm(
num_channels=256,
is_test=self.is_test,
param_attr=self.norm_weight_init(),
bias_attr=self.bias_init(),
use_global_stats=self.is_test)
self.pool2 = nn.Pool2D(
pool_size=3, pool_type="max", pool_stride=2, pool_padding=0)
# for conv3
self.conv3 = nn.Conv2D(
num_channels=256,
num_filters=384,
filter_size=3,
stride=1,
padding=0,
groups=1,
param_attr=self.weight_init(),
bias_attr=self.bias_init())
self.bn3 = nn.BatchNorm(
num_channels=384,
is_test=self.is_test,
param_attr=self.norm_weight_init(),
bias_attr=self.bias_init(),
use_global_stats=self.is_test)
# for conv4
self.conv4 = nn.Conv2D(
num_channels=384,
num_filters=384,
filter_size=3,
stride=1,
padding=0,
groups=1,
param_attr=self.weight_init(),
bias_attr=self.bias_init())
self.bn4 = nn.BatchNorm(
num_channels=384,
is_test=self.is_test,
param_attr=self.norm_weight_init(),
bias_attr=self.bias_init(),
use_global_stats=self.is_test)
# for conv5
self.conv5 = nn.Conv2D(
num_channels=384,
num_filters=256,
filter_size=3,
stride=1,
padding=0,
groups=1,
param_attr=self.weight_init(),
bias_attr=self.bias_init())
self.bn5 = nn.BatchNorm(
num_channels=256,
is_test=self.is_test,
param_attr=self.norm_weight_init(),
bias_attr=self.bias_init(),
use_global_stats=self.is_test)
def _add_output_and_check(self, name, x, outputs):
if name in self.output_layers:
outputs.append(x)
return len(self.output_layers) == len(outputs)
@fluid.dygraph.no_grad
def forward(self, inputs):
outputs = []
out1 = self.conv1(inputs)
out1 = self.bn1(out1)
out1 = fluid.layers.relu(out1)
if self._add_output_and_check('conv1', out1, outputs):
outputs[-1].stop_gradient = True if self.is_test else False
return outputs[0] if len(outputs) == 1 else outputs
out1 = self.pool1(out1)
out2 = self.conv2(out1)
out2 = self.bn2(out2)
out2 = fluid.layers.relu(out2)
if self._add_output_and_check('conv2', out2, outputs):
outputs[-1].stop_gradient = True if self.is_test else False
return outputs[0] if len(outputs) == 1 else outputs
out2 = self.pool2(out2)
out3 = self.conv3(out2)
out3 = self.bn3(out3)
out3 = fluid.layers.relu(out3)
if self._add_output_and_check('conv3', out3, outputs):
outputs[-1].stop_gradient = True if self.is_test else False
return outputs[0] if len(outputs) == 1 else outputs
out4 = self.conv4(out3)
out4 = self.bn4(out4)
out4 = fluid.layers.relu(out4)
if self._add_output_and_check('conv4', out4, outputs):
outputs[-1].stop_gradient = True if self.is_test else False
return outputs[0] if len(outputs) == 1 else outputs
out5 = self.conv5(out4)
out5 = self.bn5(out5)
if self._add_output_and_check('conv5', out5, outputs):
outputs[-1].stop_gradient = True if self.is_test else False
return outputs[0] if len(outputs) == 1 else outputs
outputs[-1].stop_gradient = True if self.is_test else False
return outputs[0] if len(outputs) == 1 else outputs
def norm_weight_init(self):
init = fluid.initializer.ConstantInitializer(1.0)
param = fluid.ParamAttr(initializer=init)
return param
def weight_init(self):
init = fluid.initializer.MSRAInitializer(uniform=False)
param = fluid.ParamAttr(initializer=init)
return param
def bias_init(self):
init = fluid.initializer.ConstantInitializer(value=0.)
param = fluid.ParamAttr(initializer=init)
return param
def AlexNet(name, is_test, output_layers, pretrained=False):
net = alexnet(name, is_test=is_test, output_layers=output_layers)
if pretrained:
params_path = os.path.join(env_settings().backbone_dir, 'AlexNet')
print("=> loading backbone model from '{}'".format(params_path))
params, _ = fluid.load_dygraph(params_path)
net.load_dict(params)
print("Done")
return net
import os
import paddle.fluid as fluid
import paddle.fluid.dygraph.nn as nn
from ltr.admin.environment import env_settings
CURRENT_DIR = os.path.dirname(__file__)
def weight_init():
init = fluid.initializer.MSRAInitializer(uniform=False)
param = fluid.ParamAttr(initializer=init)
return param
def norm_weight_init(constant=1.0):
init = fluid.initializer.ConstantInitializer(constant)
param = fluid.ParamAttr(initializer=init)
return param
def norm_bias_init():
init = fluid.initializer.ConstantInitializer(value=0.)
param = fluid.ParamAttr(initializer=init)
return param
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
in_channels,
out_channels,
filter_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bn_init_constant=1.0):
super(ConvBNLayer, self).__init__()
self.conv = nn.Conv2D(
num_channels=in_channels,
filter_size=filter_size,
num_filters=out_channels,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
param_attr=weight_init(),
bias_attr=False)
self.bn = nn.BatchNorm(
out_channels,
param_attr=norm_weight_init(bn_init_constant),
bias_attr=norm_bias_init(),
act=None,
momentum=0.1,
use_global_stats=True)
def forward(self, inputs):
res = self.conv(inputs)
self.conv_res = res
res = self.bn(res)
return res
class BasicBlock(fluid.dygraph.Layer):
expansion = 1
def __init__(self,
in_channels,
out_channels,
stride=1,
is_downsample=None):
super(BasicBlock, self).__init__()
self.expansion = 1
self.conv_bn1 = ConvBNLayer(
num_channels=in_channels,
out_channels=out_channels,
filter_size=3,
stride=stride,
groups=1)
self.conv_bn2 = ConvBNLayer(
out_channels=out_channels,
filter_size=3,
stride=1,
groups=1)
self.is_downsample = is_downsample
if self.is_downsample:
self.downsample = ConvBNLayer(
num_channels=in_channels,
out_channels=out_channels,
filter_size=1,
stride=stride)
self.stride = stride
def forward(self, inputs):
identity = inputs
res = self.conv_bn1(inputs)
res = fluid.layers.relu(res)
res = self.conv_bn2(res)
if self.is_downsample:
identity = self.downsample(identity)
res += identity
res = fluid.layers.relu(res)
return res
class Bottleneck(fluid.dygraph.Layer):
expansion = 4
def __init__(self,
in_channels,
out_channels,
stride=1,
downsample=None,
base_width=64,
dilation=1,
groups=1):
super(Bottleneck, self).__init__()
width = int(out_channels*(base_width / 64.))*groups
self.conv_bn1 = ConvBNLayer(
in_channels=in_channels,
filter_size=1,
out_channels=width,
groups=1)
padding = 2 - stride
if downsample is not None and dilation > 1:
dilation = dilation // 2
padding = dilation
assert stride == 1 or dilation == 1, \
"stride and dilation must have one equals to zero at least"
if dilation > 1:
padding = dilation
self.conv_bn2 = ConvBNLayer(
in_channels=width,
filter_size=3,
out_channels=width,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
self.conv_bn3 = ConvBNLayer(
in_channels=width,
filter_size=1,
out_channels=out_channels*self.expansion,
bn_init_constant=0.)
self.downsample = downsample
self.stride = stride
def forward(self, inputs):
identify = inputs
out = self.conv_bn1(inputs)
out = fluid.layers.relu(out)
out = self.conv_bn2(out)
out = fluid.layers.relu(out)
out = self.conv_bn3(out)
if self.downsample is not None:
identify = self.downsample(inputs)
out += identify
out = fluid.layers.relu(out)
return out
class ResNet(fluid.dygraph.Layer):
def __init__(self, name, Block, layers, output_layers, is_test=False):
"""
:param name: str, namescope
:param layers: int, the layer of defined network
:param output_layers: list of int, the layers for output
"""
super(ResNet, self).__init__(name_scope=name)
support_layers = [50]
assert layers in support_layers, \
"support layer can only be one of [50, ]"
self.layers = layers
self.feat_layers = ['block{}'.format(i) for i in output_layers]
output_depth = max(output_layers) + 1
self.is_test = is_test
if layers == 18:
depths = [2, 2, 2, 2]
elif layers == 50 or layers == 34:
depths = [3, 4, 6, 3]
elif layers == 101:
depths = [3, 4, 23, 3]
elif layers == 152:
depths = [3, 8, 36, 3]
strides = [1, 2, 1, 1]
num_filters = [64, 128, 256, 512]
dilations = [1, 1, 2, 4]
self.in_channels = 64
self.dilation = 1
self.conv_bn_init = ConvBNLayer(
in_channels=3,
out_channels=self.in_channels,
filter_size=7,
stride=2)
self.maxpool = nn.Pool2D(
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type="max")
block_collect = []
downsample = None
for i in range(min(len(depths), output_depth)):
# collect layers in each block
_block = []
stride = strides[i]
out_channel = num_filters[i]
dilation = dilations[i]
if stride != 1 or self.in_channels != self.in_channels*Block.expansion:
if stride == 1 and dilation == 1:
downsample = ConvBNLayer(
in_channels=self.in_channels,
out_channels=out_channel*Block.expansion,
filter_size=1,
stride=stride)
else:
if dilation > 1:
dd = dilation // 2
padding = dd
else:
dd = 1
padding = 0
downsample = ConvBNLayer(
in_channels=self.in_channels,
out_channels=out_channel*Block.expansion,
filter_size=3,
stride=stride,
padding=padding,
dilation=dd)
bottleneck_block = self.add_sublayer(
"block{}_0".format(i),
Block(
in_channels=self.in_channels,
out_channels=out_channel,
stride=stride,
dilation=dilation,
downsample=downsample))
_block.append(bottleneck_block)
self.in_channels = num_filters[i]*Block.expansion
for j in range(1, depths[i]):
bottleneck_block = self.add_sublayer(
"block{}_{}".format(i, j),
Block(self.in_channels, out_channel, dilation=dilation))
_block.append(bottleneck_block)
# collect blocks
block_collect.append(_block)
self.block_collect = block_collect
@fluid.dygraph.no_grad
def forward(self, inputs):
out = []
res = self.conv_bn_init(inputs)
res = fluid.layers.relu(res)
out.append(res)
res = self.maxpool(res)
for i in range(len(self.block_collect)):
for layer in self.block_collect[i]:
res = layer(res)
name = 'block{}'.format(i)
if name in self.feat_layers:
out.append(res)
if (len(out) - 1) == len(self.feat_layers):
out[-1].stop_gradient = True if self.is_test else False
if len(out) == 1:
return out[0]
else:
return out
out[-1].stop_gradient = True if self.is_test else False
return out
def resnet50(name, pretrained=False, **kwargs):
net = ResNet(name, Block=Bottleneck, layers=50, **kwargs)
if pretrained:
params_path = os.path.join(env_settings().backbone_dir, 'ResNet50_dilated')
print("=> loading backbone model from '{}'".format(params_path))
params, _ = fluid.load_dygraph(params_path)
net.load_dict(params)
print("Done")
return net
import paddle.fluid as fluid
import numpy as np
def get_cls_loss(pred, label, select):
if select.shape[0] == 0:
return fluid.layers.reduce_sum(pred) * 0
pred = fluid.layers.gather(pred, select)
label = fluid.layers.gather(label, select)
label = fluid.layers.reshape(label, [-1, 1])
loss = fluid.layers.softmax_with_cross_entropy(
logits = pred,
label = label)
return fluid.layers.mean(loss)
def select_softmax_with_cross_entropy_loss(pred, label):
b, c, h, w = pred.shape
pred = fluid.layers.reshape(pred, [b, 2, -1, h, w])
pred = fluid.layers.transpose(pred, [0, 2, 3, 4, 1])
pred = fluid.layers.reshape(pred, [-1, 2])
label = fluid.layers.reshape(label, [-1])
pos = fluid.layers.where(label == 1)
neg = fluid.layers.where(label == 0)
loss_pos = get_cls_loss(pred, label, pos)
loss_neg = get_cls_loss(pred, label, neg)
return loss_pos * 0.5 + loss_neg * 0.5
def weight_l1_loss(pred_loc, label_loc, loss_weight):
b, c, h, w = pred_loc.shape
pred_loc = fluid.layers.reshape(pred_loc, [b, 4, -1, h, w])
loss = fluid.layers.abs(pred_loc - label_loc)
loss = fluid.layers.reduce_sum(loss, dim=1)
loss = loss * loss_weight
return fluid.layers.reduce_sum(loss) / b
def soft_margin_loss(pred, label):
#loss = fluid.layers.elementwise_mul(pred, label)
loss = fluid.layers.exp(-1 * pred * label)
loss = fluid.layers.log(1 + loss)
return fluid.layers.reduce_mean(loss)
def iou_measure(pred, label):
pred = fluid.layers.cast(pred >= 0, 'float32')
pred = fluid.layers.cast(pred == 1, 'float32')
label = fluid.layers.cast(label == 1, 'float32')
mask_sum = pred + label
intxn = fluid.layers.reduce_sum(
fluid.layers.cast(mask_sum == 2, 'float32'), dim=1)
union = fluid.layers.reduce_sum(
fluid.layers.cast(mask_sum > 0, 'float32'), dim=1)
iou = intxn / union
iou_m = fluid.layers.reduce_mean(iou)
iou_5 = fluid.layers.cast(iou > 0.5, 'float32')
iou_5 = fluid.layers.reduce_sum(iou_5) / iou.shape[0]
iou_7 = fluid.layers.cast(iou > 0.7, 'float32')
iou_7 = fluid.layers.reduce_sum(iou_7) / iou.shape[0]
return iou_m, iou_5, iou_7
def select_mask_logistic_loss(pred_mask, label_mask, loss_weight, out_size=63, gt_size=127):
loss_weight = fluid.layers.reshape(loss_weight, [-1])
pos = loss_weight == 1
if np.sum(pos.numpy()) == 0:
return fluid.layers.reduce_sum(pred_mask) * 0, fluid.layers.reduce_sum(pred_mask) * 0, fluid.layers.reduce_sum(pred_mask) * 0, fluid.layers.reduce_sum(pred_mask) * 0
pos = fluid.layers.where(pos)
if len(pred_mask.shape) == 4:
pred_mask = fluid.layers.transpose(pred_mask, [0, 2, 3, 1])
pred_mask = fluid.layers.reshape(pred_mask, [-1, 1, out_size, out_size])
pred_mask = fluid.layers.gather(pred_mask, pos)
pred_mask = fluid.layers.resize_bilinear(pred_mask, out_shape=[gt_size, gt_size]);
pred_mask = fluid.layers.reshape(pred_mask, [-1, gt_size * gt_size])
label_mask_uf = fluid.layers.unfold(label_mask, [gt_size, gt_size], 8, 32)
else:
pred_mask = fluid.layers.gather(pred_mask, pos)
label_mask_uf = fluid.layers.unfold(label_mask, [gt_size, gt_size], 8, 0)
label_mask_uf = fluid.layers.transpose(label_mask_uf, [0, 2, 1])
label_mask_uf = fluid.layers.reshape(label_mask_uf, [-1, gt_size * gt_size])
label_mask_uf = fluid.layers.gather(label_mask_uf, pos)
loss = soft_margin_loss(pred_mask, label_mask_uf)
if np.isnan(loss.numpy()):
return fluid.layers.reduce_sum(pred_mask) * 0, fluid.layers.reduce_sum(pred_mask) * 0, fluid.layers.reduce_sum(pred_mask) * 0, fluid.layers.reduce_sum(pred_mask) * 0
iou_m, iou_5, iou_7 = iou_measure(pred_mask, label_mask_uf)
return loss, iou_m, iou_5, iou_7
if __name__ == "__main__":
import numpy as np
pred_mask = np.random.randn(4, 63*63, 25, 25)
weight_mask = np.random.randn(4, 1, 25, 25) > 0.9
label_mask = np.random.randint(-1, 1, (4, 1, 255, 255))
pred_loc = np.random.randn(3, 32, 17, 17)
weight_loc = np.random.randn(3, 8, 17, 17)
label_loc = np.random.randn(3, 4, 8, 17, 17)
pred_cls = np.random.randn(3, 16, 17, 17)
label_cls = np.random.randint(0, 2, (3, 8, 17, 17))
with fluid.dygraph.guard():
pred_mask = fluid.dygraph.to_variable(pred_mask)
weight_mask = fluid.dygraph.to_variable(weight_mask.astype('float32'))
label_mask = fluid.dygraph.to_variable(label_mask.astype('float32'))
loss = select_mask_logistic_loss(pred_mask, label_mask, weight_mask)
print("loss_mask = ", loss)
pred_loc = fluid.dygraph.to_variable(pred_loc)
weight_loc = fluid.dygraph.to_variable(weight_loc)
label_loc = fluid.dygraph.to_variable(label_loc)
loss = weight_l1_loss(pred_loc, label_loc, weight_loc)
print("loss_loc = ", loss)
pred_cls = fluid.dygraph.to_variable(pred_cls)
label_cls = fluid.dygraph.to_variable(label_cls)
loss = select_softmax_with_cross_entropy_loss(pred_cls, label_cls)
print("loss_cls = ", loss)
import paddle.fluid as fluid
import paddle.fluid.dygraph.nn as nn
import os.path as osp
import sys
from ltr.models.siam.xcorr import xcorr, xcorr_depthwise
CURRENT_DIR = osp.dirname(__file__)
sys.path.append(osp.join(CURRENT_DIR, '..', '..', '..'))
def weight_init():
init = fluid.initializer.MSRAInitializer(uniform=False)
param = fluid.ParamAttr(initializer=init)
return param
def bias_init():
init = fluid.initializer.ConstantInitializer(value=0.)
param = fluid.ParamAttr(initializer=init)
return param
def norm_weight_init():
init = fluid.initializer.Uniform(low=0., high=1.)
param = fluid.ParamAttr(initializer=init)
return param
def norm_bias_init():
init = fluid.initializer.ConstantInitializer(value=0.)
param = fluid.ParamAttr(initializer=init)
return param
class RPN(fluid.dygraph.Layer):
def __init__(self):
super(RPN, self).__init__()
def forward(self, z_f, x_f):
raise NotImplementedError
class DepthwiseXCorr(fluid.dygraph.Layer):
def __init__(self,
in_channels,
hidden,
out_channels,
filter_size=3,
is_test=False):
super(DepthwiseXCorr, self).__init__()
self.kernel_conv1 = nn.Conv2D(
num_channels=in_channels,
num_filters=hidden,
filter_size=filter_size,
stride=1,
padding=0,
groups=1,
param_attr=weight_init(),
bias_attr=False)
self.kernel_bn1 = nn.BatchNorm(
num_channels=hidden,
act='relu',
param_attr=norm_weight_init(),
bias_attr=norm_bias_init(),
momentum=0.9,
use_global_stats=is_test)
self.search_conv1 = nn.Conv2D(
num_channels=in_channels,
num_filters=hidden,
filter_size=filter_size,
stride=1,
padding=0,
groups=1,
param_attr=weight_init(),
bias_attr=False)
self.search_bn1 = nn.BatchNorm(
num_channels=hidden,
act='relu',
param_attr=norm_weight_init(),
bias_attr=norm_bias_init(),
momentum=0.9,
use_global_stats=is_test)
self.head_conv1 = nn.Conv2D(
num_channels=hidden,
num_filters=hidden,
filter_size=1,
stride=1,
padding=0,
groups=1,
param_attr=weight_init(),
bias_attr=False)
self.head_bn1 = nn.BatchNorm(
num_channels=hidden,
act='relu',
param_attr=norm_weight_init(),
bias_attr=norm_bias_init(),
momentum=0.9,
use_global_stats=is_test)
self.head_conv2 = nn.Conv2D(
num_channels=hidden,
num_filters=out_channels,
filter_size=1,
stride=1,
padding=0,
groups=1,
param_attr=weight_init())
def forward(self, kernel, search):
kernel = self.kernel_conv1(kernel)
kernel = self.kernel_bn1(kernel)
search = self.search_conv1(search)
search = self.search_bn1(search)
feature = xcorr_depthwise(search, kernel)
out = self.head_conv1(feature)
out = self.head_bn1(out)
out = self.head_conv2(out)
return out
class DepthwiseRPN(RPN):
def __init__(self, anchor_num=5, in_channels=256, out_channels=256, is_test=False):
super(DepthwiseRPN, self).__init__()
self.cls = DepthwiseXCorr(in_channels, out_channels, 2 * anchor_num, is_test=is_test)
self.loc = DepthwiseXCorr(in_channels, out_channels, 4 * anchor_num, is_test=is_test)
def forward(self, z_f, x_f):
cls = self.cls(z_f, x_f)
loc = self.loc(z_f, x_f)
return cls, loc
class MaskCorr(DepthwiseXCorr):
def __init__(self,
in_channels,
hidden,
out_channels,
filter_size=3,
hidden_filter_size=5,
is_test=False):
super(MaskCorr, self).__init__(
in_channels,
hidden,
out_channels,
filter_size,
is_test)
def forward(self, kernel, search):
kernel = self.kernel_conv1(kernel)
kernel = self.kernel_bn1(kernel)
search = self.search_conv1(search)
search = self.search_bn1(search)
feature = xcorr_depthwise(search, kernel)
out = self.head_conv1(feature)
out = self.head_bn1(out)
out = self.head_conv2(out)
return out, feature
class RefineModule(fluid.dygraph.Layer):
def __init__(self,
in_channels,
hidden1,
hidden2,
out_channels,
out_shape,
filter_size=3,
padding=1):
super(RefineModule, self).__init__()
self.v_conv0 = nn.Conv2D(
num_channels=in_channels,
num_filters=hidden1,
filter_size=filter_size,
stride=1,
padding=padding,
groups=1,
param_attr=weight_init())
self.v_conv1 = nn.Conv2D(
num_channels=hidden1,
num_filters=hidden2,
filter_size=filter_size,
stride=1,
padding=padding,
groups=1,
param_attr=weight_init())
self.h_conv0 = nn.Conv2D(
num_channels=hidden2,
num_filters=hidden2,
filter_size=filter_size,
stride=1,
padding=padding,
groups=1,
param_attr=weight_init())
self.h_conv1 = nn.Conv2D(
num_channels=hidden2,
num_filters=hidden2,
filter_size=filter_size,
stride=1,
padding=padding,
groups=1,
param_attr=weight_init())
self.out_shape = out_shape
self.post = nn.Conv2D(
num_channels=hidden2,
num_filters=out_channels,
filter_size=filter_size,
stride=1,
padding=padding,
groups=1,
param_attr=weight_init())
def forward(self, xh, xv):
yh = self.h_conv0(xh)
yh = fluid.layers.relu(yh)
yh = self.h_conv1(yh)
yh = fluid.layers.relu(yh)
yv = self.v_conv0(xv)
yv = fluid.layers.relu(yv)
yv = self.v_conv1(yv)
yv = fluid.layers.relu(yv)
out = yh + yv
out = fluid.layers.resize_nearest(out, out_shape=self.out_shape, align_corners=False)
out = self.post(out)
return out
class Refine(fluid.dygraph.Layer):
def __init__(self):
super(Refine, self).__init__()
self.U4 = RefineModule(
in_channels=64,
hidden1=16,
hidden2=4,
out_channels=1,
filter_size=3,
padding=1,
out_shape=[127, 127])
self.U3 = RefineModule(
in_channels=256,
hidden1=64,
hidden2=16,
out_channels=4,
filter_size=3,
padding=1,
out_shape=[61, 61])
self.U2 = RefineModule(
in_channels=512,
hidden1=128,
hidden2=32,
out_channels=16,
filter_size=3,
padding=1,
out_shape=[31, 31])
self.deconv = nn.Conv2DTranspose(
num_channels=256,
num_filters=32,
filter_size=15,
padding=0,
stride=15)
def forward(self, xf, corr_feature, pos=None, test=False):
if test:
p0 = fluid.layers.pad2d(xf[0], [16, 16, 16, 16])
p0 = p0[:, :, 4*pos[0]:4*pos[0]+61, 4*pos[1]:4*pos[1]+61]
p1 = fluid.layers.pad2d(xf[1], [8, 8, 8, 8])
p1 = p1[:, :, 2*pos[0]:2*pos[0]+31, 2*pos[1]:2*pos[1]+31]
p2 = fluid.layers.pad2d(xf[2], [4, 4, 4, 4])
p2 = p2[:, :, pos[0]:pos[0]+15, pos[1]:pos[1]+15]
p3 = corr_feature[:, :, pos[0], pos[1]]
p3 = fluid.layers.reshape(p3, [-1, 256, 1, 1])
else:
p0 = fluid.layers.unfold(xf[0], [61, 61], 4, 0)
p0 = fluid.layers.transpose(p0, [0, 2, 1])
p0 = fluid.layers.reshape(p0, [-1, 64, 61, 61])
p1 = fluid.layers.unfold(xf[1], [31, 31], 2, 0)
p1 = fluid.layers.transpose(p1, [0, 2, 1])
p1 = fluid.layers.reshape(p1, [-1, 256, 31, 31])
p2 = fluid.layers.unfold(xf[2], [15, 15], 1, 0)
p2 = fluid.layers.transpose(p2, [0, 2, 1])
p2 = fluid.layers.reshape(p2, [-1, 512, 15, 15])
p3 = fluid.layers.transpose(corr_feature, [0, 2, 3, 1])
p3 = fluid.layers.reshape(p3, [-1, 256, 1, 1])
out = self.deconv(p3)
out = self.U2(out, p2)
out = self.U3(out, p1)
out = self.U4(out, p0)
out = fluid.layers.reshape(out, [-1, 127*127])
return out
import paddle.fluid as fluid
import paddle.fluid.dygraph.nn as nn
import os.path as osp
import sys
CURRENT_DIR = osp.dirname(__file__)
sys.path.append(osp.join(CURRENT_DIR, '..', '..', '..'))
def weight_init():
init = fluid.initializer.MSRAInitializer(uniform=False)
param = fluid.ParamAttr(initializer=init)
return param
def bias_init():
init = fluid.initializer.ConstantInitializer(value=0.)
param = fluid.ParamAttr(initializer=init)
return param
def norm_weight_init():
init = fluid.initializer.Uniform(low=0., high=1.)
param = fluid.ParamAttr(initializer=init)
return param
def norm_bias_init():
init = fluid.initializer.ConstantInitializer(value=0.)
param = fluid.ParamAttr(initializer=init)
return param
class AdjustLayer(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters, is_test=False):
super(AdjustLayer, self).__init__()
self.conv = nn.Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
param_attr=weight_init(),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=num_filters,
param_attr=norm_weight_init(),
bias_attr=norm_bias_init(),
momentum=0.9,
act=None,
use_global_stats=is_test)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if x.shape[3] < 20:
l = 4
r = -4
x = x[:, :, l:r, l:r]
return x
class AdjustAllLayer(fluid.dygraph.Layer):
def __init__(self, in_channels, out_channels, is_test=False):
super(AdjustAllLayer, self).__init__('')
self.num = len(out_channels)
self.sub_layer_list = []
if self.num == 1:
self.downsample = AdjustLayer(in_channels[0], out_channels[0], is_test)
else:
for i in range(self.num):
Build_Adjust_Layer = self.add_sublayer(
'downsample'+str(i+2),
AdjustLayer(in_channels[i], out_channels[i], is_test))
self.sub_layer_list.append(Build_Adjust_Layer)
def forward(self, features):
if self.num == 1:
return self.downsample(features)
else:
out = []
for i in range(self.num):
build_adjust_layer_i = sub_layer_list[i]
out.append(build_adjust_layer_i(features[i]))
return out
import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
import os.path as osp
import sys
CURRENT_DIR = osp.dirname(__file__)
sys.path.append(osp.join(CURRENT_DIR, '..', '..', '..'))
from ltr.models.backbone.resnet_dilated import resnet50
from ltr.models.backbone.alexnet import AlexNet
from ltr.models.siam.head import DepthwiseRPN, MaskCorr, Refine
from ltr.models.siam.neck import AdjustAllLayer
class Siamnet(dygraph.layers.Layer):
def __init__(self,
feature_extractor,
rpn_head,
neck=None,
mask_head=None,
refine_head=None,
scale_loss=None):
super(Siamnet, self).__init__()
self.feature_extractor = feature_extractor
self.rpn_head = rpn_head
self.neck = neck
self.mask_head = mask_head
self.refine_head = refine_head
self.scale_loss = scale_loss
def forward(self, template, search):
# get feature
if len(template.shape) == 5:
template = fluid.layers.reshape(template, [-1, *list(template.shape)[-3:]])
search = fluid.layers.reshape(search, [-1, *list(search.shape)[-3:]])
zf = self.feature_extractor(template)
xf = self.feature_extractor(search)
if not self.mask_head is None:
zf = zf[-1]
xf_refine = xf[:-1]
xf = xf[-1]
if isinstance(zf, list):
zf = zf[-1]
if isinstance(xf, list):
xf = xf[-1]
if not self.neck is None:
zf = self.neck(zf)
xf = self.neck(xf)
cls, loc = self.rpn_head(zf, xf)
if not self.mask_head is None:
if not self.refine_head is None:
_, mask_corr_feature = self.mask_head(zf, xf)
mask = self.refine_head(xf_refine, mask_corr_feature)
else:
mask, mask_corr_feature = self.mask_head(zf, xf)
return {'cls': cls,
'loc': loc,
'mask': mask}
else:
return {'cls': cls,
'loc': loc}
def extract_backbone_features(self, im):
return self.feature_extractor(im)
def template(self, template):
zf = self.feature_extractor(template)
if not self.mask_head is None:
zf = zf[-1]
if isinstance(zf, list):
zf = zf[-1]
if not self.neck is None:
zf = self.neck(zf)
self.zf = zf
def track(self, search):
xf = self.feature_extractor(search)
if not self.mask_head is None:
self.xf = xf[:-1]
xf = xf[-1]
if isinstance(xf, list):
xf = xf[-1]
if not self.neck is None:
xf = self.neck(xf)
cls, loc = self.rpn_head(self.zf, xf)
if not self.mask_head is None:
mask, self.mask_corr_feature = self.mask_head(self.zf, xf)
return {'cls': cls,
'loc': loc,
'mask': mask}
else:
return {'cls': cls,
'loc': loc}
def mask_refine(self, pos):
return self.refine_head(self.xf, self.mask_corr_feature, pos, test=True)
def SiamRPN_AlexNet(backbone_pretrained=True,
backbone_is_test=True,
is_test=False,
scale_loss=None):
backbone = AlexNet(
'AlexNet',
is_test=backbone_is_test,
output_layers=['conv5'],
pretrained=backbone_pretrained)
rpn_head = DepthwiseRPN(anchor_num=5, in_channels=256, out_channels=256, is_test=is_test)
model = Siamnet(
feature_extractor=backbone,
rpn_head=rpn_head,
scale_loss=scale_loss)
return model
def SiamRPN_ResNet50(backbone_pretrained=True,
backbone_is_test=True,
is_test=False,
scale_loss=None):
backbone = resnet50(
'ResNet50',
pretrained=backbone_pretrained,
output_layers=[2],
is_test=backbone_is_test)
neck = AdjustAllLayer(in_channels=[1024], out_channels=[256], is_test=is_test)
rpn_head = DepthwiseRPN(anchor_num=5, in_channels=256, out_channels=256, is_test=is_test)
model = Siamnet(
feature_extractor=backbone,
neck=neck,
rpn_head=rpn_head,
scale_loss=scale_loss)
return model
def SiamMask_ResNet50_base(backbone_pretrained=True,
backbone_is_test=True,
is_test=False,
scale_loss=None):
backbone = resnet50(
'ResNet50',
pretrained=backbone_pretrained,
output_layers=[0,1,2],
is_test=backbone_is_test)
neck = AdjustAllLayer(in_channels=[1024], out_channels=[256], is_test=is_test)
rpn_head = DepthwiseRPN(anchor_num=5, in_channels=256, out_channels=256, is_test=is_test)
mask_head = MaskCorr(in_channels=256, hidden=256, out_channels=3969, is_test=is_test)
model = Siamnet(
feature_extractor=backbone,
neck=neck,
rpn_head=rpn_head,
mask_head=mask_head,
scale_loss=scale_loss)
return model
def SiamMask_ResNet50_sharp(backbone_pretrained=False,
backbone_is_test=True,
is_test=False,
scale_loss=None):
backbone = resnet50(
'ResNet50',
pretrained=backbone_pretrained,
output_layers=[0,1,2],
is_test=backbone_is_test)
neck = AdjustAllLayer(in_channels=[1024], out_channels=[256], is_test=True)
rpn_head = DepthwiseRPN(anchor_num=5, in_channels=256, out_channels=256, is_test=True)
mask_head = MaskCorr(in_channels=256, hidden=256, out_channels=3969, is_test=is_test)
refine_head = Refine()
model = Siamnet(
feature_extractor=backbone,
neck=neck,
rpn_head=rpn_head,
mask_head=mask_head,
refine_head=refine_head,
scale_loss=scale_loss)
return model
if __name__ == '__main__':
import numpy as np
search = np.random.uniform(-1, 1, [1, 3, 255, 255]).astype(np.float32)
template = np.random.uniform(-1, 1, [1, 3, 127, 127]).astype(np.float32)
with fluid.dygraph.guard():
search = fluid.dygraph.to_variable(search)
template = fluid.dygraph.to_variable(template)
model = SiamMask(False)
res = model(template, search)
params = model.state_dict()
for v in params:
print(v)
import paddle.fluid as fluid
import paddle.fluid.dygraph.nn as nn
from pytracking.libs.Fconv2d import FConv2D
def xcorr(x, kernel):
"""group conv2d to calculate cross correlation
"""
batch = kernel.shape[0]
px = fluid.layers.reshape(x, [1, -1, x.shape[2], x.shape[3]])
pk = fluid.layers.reshape(kernel, [-1, x.shape[1], kernel.shape[2], kernel.shape[3]])
scores_map = FConv2D(px, pk, stride=1, padding=0, dilation=1, groups=batch)
scores_map = fluid.layers.reshape(
scores_map, [batch, -1, scores_map.shape[2], scores_map.shape[3]])
return scores_map
def xcorr_depthwise(x, kernel):
"""depthwise cross correlation
"""
batch = kernel.shape[0]
channel = kernel.shape[1]
px = fluid.layers.reshape(x, [1, -1, x.shape[2], x.shape[3]])
pk = fluid.layers.reshape(kernel, [-1, 1, kernel.shape[2], kernel.shape[3]])
scores_map = FConv2D(px, pk, stride=1, padding=0, dilation=1, groups=batch*channel)
scores_map = fluid.layers.reshape(
scores_map,[batch, -1, scores_map.shape[2], scores_map.shape[3]])
return scores_map
......@@ -2,7 +2,6 @@ from paddle import fluid
from paddle.fluid import dygraph
from paddle.fluid.dygraph import nn
from pytracking.libs.Fconv2d import Conv2D
from pytracking.libs.Fconv2d import FConv2D
......
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
import ltr.actors as actors
import ltr.data.transforms as dltransforms
from ltr.data import processing, sampler, loader
from ltr.dataset import ImagenetVID, ImagenetDET, MSCOCOSeq, YoutubeVOS, Lasot, Got10k
from ltr.models.siam.siam import SiamMask_ResNet50_base
from ltr.models.loss import select_softmax_with_cross_entropy_loss, weight_l1_loss, select_mask_logistic_loss
from ltr.trainers import LTRTrainer
from ltr.trainers.learning_rate_scheduler import LinearLrWarmup
import numpy as np
import cv2 as cv
from PIL import Image, ImageEnhance
def run(settings):
# Most common settings are assigned in the settings struct
settings.description = 'SiamMask_base with ResNet-50 backbone.'
settings.print_interval = 100 # How often to print loss and other info
settings.batch_size = 64 # Batch size
settings.samples_per_epoch = 600000 # Number of training pairs per epoch
settings.num_workers = 4 # Number of workers for image loading
settings.search_area_factor = {'train': 1.0, 'test': 2.0}
settings.output_sz = {'train': 127, 'test': 255}
settings.scale_type = 'context'
settings.border_type = 'meanpad'
# Settings for the image sample and label generation
settings.center_jitter_factor = {'train': 0.125, 'test': 2.0}
settings.scale_jitter_factor = {'train': 0.05, 'test': 0.18}
settings.label_params = {
'search_size': 255,
'output_size': 25,
'anchor_stride': 8,
'anchor_ratios': [0.33, 0.5, 1, 2, 3],
'anchor_scales': [8],
'num_pos': 16,
'num_neg': 16,
'num_total': 64,
'thr_high': 0.6,
'thr_low': 0.3
}
settings.loss_weights = {'cls': 1., 'loc': 1.2, 'mask':36}
settings.neg = 0.2
# Train datasets
vos_train = YoutubeVOS()
vid_train = ImagenetVID()
coco_train = MSCOCOSeq()
det_train = ImagenetDET()
lasot_train = Lasot(split='train')
got10k_train = Got10k(split='train')
# Validation datasets
vid_val = ImagenetVID()
# The joint augmentation transform, that is applied to the pairs jointly
transform_joint = dltransforms.ToGrayscale(probability=0.25)
# The augmentation transform applied to the training set (individually to each image in the pair)
transform_exemplar = dltransforms.Transpose()
transform_instance = dltransforms.Compose(
[
dltransforms.Color(probability=1.0),
dltransforms.Blur(probability=0.18),
dltransforms.Transpose()
])
transform_instance_mask = dltransforms.Transpose()
# Data processing to do on the training pairs
data_processing_train = processing.SiamProcessing(
search_area_factor=settings.search_area_factor,
output_sz=settings.output_sz,
center_jitter_factor=settings.center_jitter_factor,
scale_jitter_factor=settings.scale_jitter_factor,
scale_type=settings.scale_type,
border_type=settings.border_type,
mode='sequence',
label_params=settings.label_params,
train_transform=transform_exemplar,
test_transform=transform_instance,
test_mask_transform=transform_instance_mask,
joint_transform=transform_joint)
# Data processing to do on the validation pairs
data_processing_val = processing.SiamProcessing(
search_area_factor=settings.search_area_factor,
output_sz=settings.output_sz,
center_jitter_factor=settings.center_jitter_factor,
scale_jitter_factor=settings.scale_jitter_factor,
scale_type=settings.scale_type,
border_type=settings.border_type,
mode='sequence',
label_params=settings.label_params,
transform=transform_exemplar,
joint_transform=transform_joint)
nums_per_epoch = settings.samples_per_epoch // settings.batch_size
# The sampler for training
dataset_train = sampler.MaskSampler(
[vid_train, coco_train, det_train, vos_train, lasot_train, got10k_train],
[2, 1, 1, 2, 1, 1],
samples_per_epoch=nums_per_epoch * settings.batch_size,
max_gap=100,
processing=data_processing_train,
neg=settings.neg)
# The loader for training
train_loader = loader.LTRLoader(
'train',
dataset_train,
training=True,
batch_size=settings.batch_size,
num_workers=settings.num_workers,
stack_dim=0)
# The sampler for validation
dataset_val = sampler.MaskSampler(
[vid_val],
[1, ],
samples_per_epoch=100 * settings.batch_size,
max_gap=100,
processing=data_processing_val)
# The loader for validation
val_loader = loader.LTRLoader(
'val',
dataset_val,
training=False,
batch_size=settings.batch_size,
num_workers=settings.num_workers,
stack_dim=0)
# creat network, set objective, creat optimizer, learning rate scheduler, trainer
with dygraph.guard():
# Create network
def scale_loss(loss):
total_loss = 0
for k in settings.loss_weights:
total_loss += loss[k] * settings.loss_weights[k]
return total_loss
net = SiamMask_ResNet50_base(scale_loss=scale_loss)
# Define objective
objective = {
'cls': select_softmax_with_cross_entropy_loss,
'loc': weight_l1_loss,
'mask': select_mask_logistic_loss
}
# Create actor, which wraps network and objective
actor = actors.SiamActor(net=net, objective=objective)
# Set to training mode
actor.train()
# Define optimizer and learning rate
decayed_lr = fluid.layers.exponential_decay(
learning_rate=0.005,
decay_steps=nums_per_epoch,
decay_rate=0.9642,
staircase=True)
lr_scheduler = LinearLrWarmup(
learning_rate=decayed_lr,
warmup_steps=5*nums_per_epoch,
start_lr=0.001,
end_lr=0.005)
optimizer = fluid.optimizer.Adam(
parameter_list=net.rpn_head.parameters()
+ net.neck.parameters()
+ net.mask_head.parameters(),
learning_rate=lr_scheduler)
trainer = LTRTrainer(actor, [train_loader, val_loader], optimizer, settings, lr_scheduler)
trainer.train(20, load_latest=False, fail_safe=False)
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
import ltr.actors as actors
import ltr.data.transforms as dltransforms
from ltr.data import processing, sampler, loader
from ltr.dataset import ImagenetVID, ImagenetDET, MSCOCOSeq, YoutubeVOS
from ltr.models.siam.siam import SiamMask_ResNet50_sharp
from ltr.models.loss import select_softmax_with_cross_entropy_loss, weight_l1_loss, select_mask_logistic_loss
from ltr.trainers import LTRTrainer
from ltr.trainers.learning_rate_scheduler import LinearLrWarmup
import numpy as np
import cv2 as cv
from PIL import Image, ImageEnhance
def run(settings):
# Most common settings are assigned in the settings struct
settings.base_model = ''
settings.description = 'SiamMask_sharp with ResNet-50 backbone.'
settings.print_interval = 100 # How often to print loss and other info
settings.batch_size = 64 # Batch size
settings.samples_per_epoch = 600000 # Number of training pairs per epoch
settings.num_workers = 8 # Number of workers for image loading
settings.search_area_factor = {'train': 1.0, 'test': 143./127.}
settings.output_sz = {'train': 127, 'test': 143}
settings.scale_type = 'context'
settings.border_type = 'meanpad'
# Settings for the image sample and label generation
settings.center_jitter_factor = {'train': 0.2, 'test': 0.4}
settings.scale_jitter_factor = {'train': 0.05, 'test': 0.18}
settings.label_params = {
'search_size': 143,
'output_size': 3,
'anchor_stride': 8,
'anchor_ratios': [0.33, 0.5, 1, 2, 3],
'anchor_scales': [8],
'num_pos': 16,
'num_neg': 16,
'num_total': 64,
'thr_high': 0.6,
'thr_low': 0.3
}
settings.loss_weights = {'cls': 0., 'loc': 0., 'mask':1}
settings.neg = 0
# Train datasets
vos_train = YoutubeVOS()
coco_train = MSCOCOSeq()
# Validation datasets
vos_val = vos_train
# The joint augmentation transform, that is applied to the pairs jointly
transform_joint = dltransforms.ToGrayscale(probability=0.25)
# The augmentation transform applied to the training set (individually to each image in the pair)
transform_exemplar = dltransforms.Transpose()
transform_instance = dltransforms.Compose(
[
dltransforms.Color(probability=1.0),
dltransforms.Blur(probability=0.18),
dltransforms.Transpose()
])
transform_instance_mask = dltransforms.Transpose()
# Data processing to do on the training pairs
data_processing_train = processing.SiamProcessing(
search_area_factor=settings.search_area_factor,
output_sz=settings.output_sz,
center_jitter_factor=settings.center_jitter_factor,
scale_jitter_factor=settings.scale_jitter_factor,
scale_type=settings.scale_type,
border_type=settings.border_type,
mode='sequence',
label_params=settings.label_params,
train_transform=transform_exemplar,
test_transform=transform_instance,
test_mask_transform=transform_instance_mask,
joint_transform=transform_joint)
# Data processing to do on the validation pairs
data_processing_val = processing.SiamProcessing(
search_area_factor=settings.search_area_factor,
output_sz=settings.output_sz,
center_jitter_factor=settings.center_jitter_factor,
scale_jitter_factor=settings.scale_jitter_factor,
scale_type=settings.scale_type,
border_type=settings.border_type,
mode='sequence',
label_params=settings.label_params,
transform=transform_exemplar,
joint_transform=transform_joint)
nums_per_epoch = settings.samples_per_epoch // settings.batch_size
# The sampler for training
dataset_train = sampler.MaskSampler(
[coco_train, vos_train],
[1 ,1],
samples_per_epoch=nums_per_epoch * settings.batch_size,
max_gap=100,
processing=data_processing_train,
neg=settings.neg)
# The loader for training
train_loader = loader.LTRLoader(
'train',
dataset_train,
training=True,
batch_size=settings.batch_size,
num_workers=settings.num_workers,
stack_dim=0)
# The sampler for validation
dataset_val = sampler.MaskSampler(
[vos_val],
[1, ],
samples_per_epoch=100 * settings.batch_size,
max_gap=100,
processing=data_processing_val)
# The loader for validation
val_loader = loader.LTRLoader(
'val',
dataset_val,
training=False,
batch_size=settings.batch_size,
num_workers=settings.num_workers,
stack_dim=0)
# creat network, set objective, creat optimizer, learning rate scheduler, trainer
with dygraph.guard():
# Create network
def scale_loss(loss):
total_loss = 0
for k in settings.loss_weights:
total_loss += loss[k] * settings.loss_weights[k]
return total_loss
net = SiamMask_ResNet50_sharp(scale_loss=scale_loss)
# Load parameters from the best_base_model
if settings.base_model == '':
raise Exception(
'The base_model path is not setup. Check settings.base_model in "ltr/train_settings/siammask/siammask_res50_sharp.py".'
)
para_dict, _ = fluid.load_dygraph(settings.base_model)
model_dict = net.state_dict()
for key in model_dict.keys():
if key in para_dict.keys():
model_dict[key] = para_dict[key]
net.set_dict(model_dict)
# Define objective
objective = {
'cls': select_softmax_with_cross_entropy_loss,
'loc': weight_l1_loss,
'mask': select_mask_logistic_loss
}
# Create actor, which wraps network and objective
actor = actors.SiamActor(net=net, objective=objective)
# Set to training mode
actor.train()
# Define optimizer and learning rate
decayed_lr = fluid.layers.exponential_decay(
learning_rate=0.0005,
decay_steps=nums_per_epoch,
decay_rate=0.9,
staircase=True)
lr_scheduler = LinearLrWarmup(
learning_rate=decayed_lr,
warmup_steps=5*nums_per_epoch,
start_lr=0.0001,
end_lr=0.0005)
optimizer = fluid.optimizer.Adam(
parameter_list=net.mask_head.parameters()
+ net.refine_head.parameters(),
learning_rate=lr_scheduler)
trainer = LTRTrainer(actor, [train_loader, val_loader], optimizer, settings, lr_scheduler)
trainer.train(20, load_latest=False, fail_safe=False)
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
import ltr.actors as actors
import ltr.data.transforms as dltransforms
from ltr.data import processing, sampler, loader
from ltr.dataset import ImagenetVID, ImagenetDET, MSCOCOSeq, YoutubeVOS, Lasot, Got10k
from ltr.models.siam.siam import SiamRPN_AlexNet
from ltr.models.loss import select_softmax_with_cross_entropy_loss, weight_l1_loss
from ltr.trainers import LTRTrainer
from ltr.trainers.learning_rate_scheduler import LinearLrWarmup
import numpy as np
import cv2 as cv
from PIL import Image, ImageEnhance
def run(settings):
# Most common settings are assigned in the settings struct
settings.description = 'SiamRPN with AlexNet backbone.'
settings.print_interval = 100 # How often to print loss and other info
settings.batch_size = 512 # Batch size
settings.samples_per_epoch = 600000 # Number of training pairs per epoch
settings.num_workers = 8 # Number of workers for image loading
settings.search_area_factor = {'train': 1.0, 'test': 2.0}
settings.output_sz = {'train': 127, 'test': 255}
settings.scale_type = 'context'
settings.border_type = 'meanpad'
# Settings for the image sample and label generation
settings.center_jitter_factor = {'train': 0.125, 'test': 2.0}
settings.scale_jitter_factor = {'train': 0.05, 'test': 0.18}
settings.label_params = {
'search_size': 255,
'output_size': 17,
'anchor_stride': 8,
'anchor_ratios': [0.33, 0.5, 1, 2, 3],
'anchor_scales': [8],
'num_pos': 16,
'num_neg': 16,
'num_total': 64,
'thr_high': 0.6,
'thr_low': 0.3
}
settings.loss_weights = {'cls': 1., 'loc': 1.2}
settings.neg = 0.2
# Train datasets
vos_train = YoutubeVOS()
vid_train = ImagenetVID()
coco_train = MSCOCOSeq()
det_train = ImagenetDET()
#lasot_train = Lasot(split='train')
#got10k_train = Got10k(split='train')
# Validation datasets
vid_val = ImagenetVID()
# The joint augmentation transform, that is applied to the pairs jointly
transform_joint = dltransforms.ToGrayscale(probability=0.25)
# The augmentation transform applied to the training set (individually to each image in the pair)
transform_exemplar = dltransforms.Transpose()
transform_instance = dltransforms.Compose(
[
dltransforms.Color(probability=1.0),
dltransforms.Blur(probability=0.18),
dltransforms.Transpose()
])
transform_instance_mask = dltransforms.Transpose()
# Data processing to do on the training pairs
data_processing_train = processing.SiamProcessing(
search_area_factor=settings.search_area_factor,
output_sz=settings.output_sz,
center_jitter_factor=settings.center_jitter_factor,
scale_jitter_factor=settings.scale_jitter_factor,
scale_type=settings.scale_type,
border_type=settings.border_type,
mode='sequence',
label_params=settings.label_params,
train_transform=transform_exemplar,
test_transform=transform_instance,
test_mask_transform=transform_instance_mask,
joint_transform=transform_joint)
# Data processing to do on the validation pairs
data_processing_val = processing.SiamProcessing(
search_area_factor=settings.search_area_factor,
output_sz=settings.output_sz,
center_jitter_factor=settings.center_jitter_factor,
scale_jitter_factor=settings.scale_jitter_factor,
scale_type=settings.scale_type,
border_type=settings.border_type,
mode='sequence',
label_params=settings.label_params,
transform=transform_exemplar,
joint_transform=transform_joint)
nums_per_epoch = settings.samples_per_epoch // settings.batch_size
# The sampler for training
dataset_train = sampler.MaskSampler(
[vid_train, coco_train, det_train, vos_train],
[2, 1, 1, 2],
samples_per_epoch=nums_per_epoch * settings.batch_size,
max_gap=100,
processing=data_processing_train,
neg=settings.neg)
# The loader for training
train_loader = loader.LTRLoader(
'train',
dataset_train,
training=True,
batch_size=settings.batch_size,
num_workers=settings.num_workers,
stack_dim=0)
# The sampler for validation
dataset_val = sampler.MaskSampler(
[vid_val],
[1, ],
samples_per_epoch=100 * settings.batch_size,
max_gap=100,
processing=data_processing_val)
# The loader for validation
val_loader = loader.LTRLoader(
'val',
dataset_val,
training=False,
batch_size=settings.batch_size,
num_workers=settings.num_workers,
stack_dim=0)
# creat network, set objective, creat optimizer, learning rate scheduler, trainer
with dygraph.guard():
# Create network
def scale_loss(loss):
total_loss = 0
for k in settings.loss_weights:
total_loss += loss[k] * settings.loss_weights[k]
return total_loss
net = SiamRPN_AlexNet(scale_loss=scale_loss)
# Define objective
objective = {
'cls': select_softmax_with_cross_entropy_loss,
'loc': weight_l1_loss,
}
# Create actor, which wraps network and objective
actor = actors.SiamActor(net=net, objective=objective)
# Define optimizer and learning rate
decayed_lr = fluid.layers.exponential_decay(
learning_rate=0.01,
decay_steps=nums_per_epoch,
decay_rate=0.9407,
staircase=True)
lr_scheduler = LinearLrWarmup(
learning_rate=decayed_lr,
warmup_steps=5*nums_per_epoch,
start_lr=0.005,
end_lr=0.01)
optimizer = fluid.optimizer.Adam(
parameter_list=net.rpn_head.parameters(),
learning_rate=lr_scheduler)
trainer = LTRTrainer(actor, [train_loader, val_loader], optimizer, settings, lr_scheduler)
trainer.train(50, load_latest=False, fail_safe=False)
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
import ltr.actors as actors
import ltr.data.transforms as dltransforms
from ltr.data import processing, sampler, loader
from ltr.dataset import ImagenetVID, ImagenetDET, MSCOCOSeq, YoutubeVOS
from ltr.models.siam.siam import SiamRPN_ResNet50
from ltr.models.loss import select_softmax_with_cross_entropy_loss, weight_l1_loss
from ltr.trainers import LTRTrainer
import numpy as np
import cv2 as cv
from PIL import Image, ImageEnhance
def run(settings):
# Most common settings are assigned in the settings struct
settings.description = 'SiamRPN with ResNet-50 backbone.'
settings.print_interval = 100 # How often to print loss and other info
settings.batch_size = 32 # Batch size
settings.num_workers = 4 # Number of workers for image loading
settings.search_area_factor = {'train': 1.0, 'test': 255./127.}
settings.output_sz = {'train': 127, 'test': 255}
settings.scale_type = 'context'
settings.border_type = 'meanpad'
# Settings for the image sample and label generation
settings.center_jitter_factor = {'train': 0.1, 'test': 1.5}
settings.scale_jitter_factor = {'train': 0.05, 'test': 0.18}
settings.label_params = {
'search_size': 255,
'output_size': 25,
'anchor_stride': 8,
'anchor_ratios': [0.33, 0.5, 1, 2, 3],
'anchor_scales': [8],
'num_pos': 16,
'num_neg': 16,
'num_total': 64,
'thr_high': 0.6,
'thr_low': 0.3
}
settings.loss_weights = {'cls': 1., 'loc': 1.2}
settings.neg = 0.2
# Train datasets
vos_train = YoutubeVOS()
vid_train = ImagenetVID()
coco_train = MSCOCOSeq()
det_train = ImagenetDET()
# Validation datasets
#vid_val = ImagenetVID()
vid_val = coco_train
# The joint augmentation transform, that is applied to the pairs jointly
transform_joint = dltransforms.ToGrayscale(probability=0.25)
# The augmentation transform applied to the training set (individually to each image in the pair)
transform_exemplar = dltransforms.Transpose()
transform_instance = dltransforms.Transpose()
# Data processing to do on the training pairs
data_processing_train = processing.SiamProcessing(
search_area_factor=settings.search_area_factor,
output_sz=settings.output_sz,
center_jitter_factor=settings.center_jitter_factor,
scale_jitter_factor=settings.scale_jitter_factor,
scale_type=settings.scale_type,
border_type=settings.border_type,
mode='sequence',
label_params=settings.label_params,
train_transform=transform_exemplar,
test_transform=transform_instance,
joint_transform=transform_joint)
# Data processing to do on the validation pairs
data_processing_val = processing.SiamProcessing(
search_area_factor=settings.search_area_factor,
output_sz=settings.output_sz,
center_jitter_factor=settings.center_jitter_factor,
scale_jitter_factor=settings.scale_jitter_factor,
scale_type=settings.scale_type,
border_type=settings.border_type,
mode='sequence',
label_params=settings.label_params,
transform=transform_exemplar,
joint_transform=transform_joint)
# The sampler for training
dataset_train = sampler.MaskSampler(
[vid_train, coco_train, det_train, vos_train],
[2, 1 ,1, 2],
samples_per_epoch=5000 * settings.batch_size,
max_gap=100,
processing=data_processing_train,
neg=settings.neg)
# The loader for training
train_loader = loader.LTRLoader(
'train',
dataset_train,
training=True,
batch_size=settings.batch_size,
num_workers=settings.num_workers,
stack_dim=0)
# The sampler for validation
dataset_val = sampler.MaskSampler(
[vid_val],
[1, ],
samples_per_epoch=100 * settings.batch_size,
max_gap=100,
processing=data_processing_val)
# The loader for validation
val_loader = loader.LTRLoader(
'val',
dataset_val,
training=False,
batch_size=settings.batch_size,
num_workers=settings.num_workers,
stack_dim=0)
# creat network, set objective, creat optimizer, learning rate scheduler, trainer
with dygraph.guard():
# Create network
def scale_loss(loss):
total_loss = 0
for k in settings.loss_weights:
total_loss += loss[k] * settings.loss_weights[k]
return total_loss
net = SiamRPN_ResNet50(scale_loss=scale_loss)
# Define objective
objective = {
'cls': select_softmax_with_cross_entropy_loss,
'loc': weight_l1_loss,
}
# Create actor, which wraps network and objective
actor = actors.SiamActor(net=net, objective=objective)
# Set to training mode
actor.train()
# Define optimizer and learning rate
lr_scheduler = fluid.layers.exponential_decay(
learning_rate=0.005,
decay_steps=5000,
decay_rate=0.9659,
staircase=True)
optimizer = fluid.optimizer.Adam(
parameter_list=net.rpn_head.parameters() + net.neck.parameters(),
learning_rate=lr_scheduler)
trainer = LTRTrainer(actor, [train_loader, val_loader], optimizer, settings, lr_scheduler)
trainer.train(50, load_latest=False, fail_safe=False)
......@@ -123,7 +123,7 @@ class BaseTrainer:
self.settings.project_path,
net_type)))
if checkpoint_list:
checkpoint_path = checkpoint_list[-1].split('.')[0]
checkpoint_path = os.path.splitext(checkpoint_list[-1])[0]
else:
print('No matching checkpoint file found')
return
......@@ -144,13 +144,13 @@ class BaseTrainer:
self.optimizer.set_dict(opt_params)
# paddle load state
state_path = '{}/{}/custom_state.pickle'.format(
self._checkpoint_dir, self.settings.project_path)
current_state = pickle.load(
open(os.path.join(state_path, 'custom_state.pickle'), 'rb'))
open(os.path.join(checkpoint_path, '_custom_state.pickle'), 'rb'))
print("\nload checkpoint done !! Current states are as follows:")
for key, value in enumerate(current_state):
for key, value in current_state.items():
print(key, value)
self.epoch = current_state['epoch']
self.stats = current_state['stats']
return True
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import math
from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay
class LinearLrWarmup(LearningRateDecay):
"""
This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling.
For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
When global_step < warmup_steps, learning rate is updated as:
.. code-block:: text
linear_step = end_lr - start_lr
lr = start_lr + linear_step * (global_step / warmup_steps)
where start_lr is the initial learning rate, and end_lr is the final learning rate;
When global_step >= warmup_steps, learning rate is updated as:
.. code-block:: text
lr = learning_rate
where lr is the learning_rate after warm-up.
Args:
learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32.
warmup_steps (int): Steps for warm up.
start_lr (float): Initial learning rate of warm up.
end_lr (float): Final learning rate of warm up.
begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
step(int, optional): The step size used to calculate the new global_step in the description above.
The default value is 1.
dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
'float32', 'float64'. The default value is 'float32'.
Returns:
Variable: Warm-up learning rate with the same data type as learning_rate.
Examples:
.. code-block:: python
import paddle.fluid as fluid
learning_rate = 0.1
warmup_steps = 50
start_lr = 1. / 3.
end_lr = 0.1
with fluid.dygraph.guard():
lr_decay = fluid.dygraph.LinearLrWarmup( learning_rate, warmup_steps, start_lr, end_lr)
"""
def __init__(self,
learning_rate,
warmup_steps,
start_lr,
end_lr,
begin=1,
step=1,
dtype='float32'):
super(LinearLrWarmup, self).__init__(begin, step, dtype)
type_check = isinstance(learning_rate, float) or isinstance(
learning_rate, int) or isinstance(learning_rate, LearningRateDecay)
if not type_check:
raise TypeError(
"the type of learning_rate should be [int, float or LearningRateDecay], the current type is {}".
format(learning_rate))
self.learning_rate = learning_rate
self.warmup_steps = warmup_steps
assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format(
end_lr, start_lr)
self.lr_ratio_before_warmup = (
float(end_lr) - float(start_lr)) / float(warmup_steps)
self.start_lr = start_lr
def step(self):
base_lr = self.learning_rate
if isinstance(self.learning_rate, LearningRateDecay):
base_lr = base_lr()
if self.step_num < self.warmup_steps:
return self.start_lr + self.lr_ratio_before_warmup * self.step_num
else:
return base_lr
......@@ -162,6 +162,8 @@ class LTRTrainer(BaseTrainer):
print_str += '%s: %.5f , ' % (name, val.avg)
print_str += '%s: %.5f , ' % ("time", batch_size / batch_fps *
self.settings.print_interval)
if loader.training:
print_str += '%s: %f , ' % ("lr", self.optimizer.current_step_lr())
print(print_str[:-5])
def _stats_new_epoch(self):
......
......@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import paddle.fluid
import argparse
import importlib
import os
......@@ -172,8 +173,12 @@ def run_one_sequence(video, params, tracker=None):
if isinstance(res, int):
outputs.append('{}'.format(res))
else:
outputs.append('{},{},{},{}'.format(res[0], res[1], res[
2], res[3]))
if len(res) is 8:
outputs.append('{},{},{},{},{},{},{},{}'.format(
res[0], res[1], res[2], res[3], res[4], res[5], res[6], res[7]))
else:
outputs.append('{},{},{},{}'.format(
res[0], res[1], res[2], res[3]))
f.write('\n'.join(outputs))
else:
os.makedirs(save_dir, exist_ok=True)
......
import numpy as np
import math
from paddle.fluid import layers
import cv2 as cv
from pytracking.features.preprocessing import numpy_to_paddle, paddle_to_numpy
from pytracking.libs.Fconv2d import FConv2D
from pytracking.libs.paddle_utils import PTensor, _padding, n2p
class Transform:
"""Base data augmentation transform class."""
def __init__(self, output_sz=None, shift=None):
self.output_sz = output_sz
self.shift = (0, 0) if shift is None else shift
def __call__(self, image):
raise NotImplementedError
def crop_to_output(self, image, shift=None):
if isinstance(image, PTensor):
imsz = image.shape[2:]
else:
imsz = image.shape[:2]
if self.output_sz is None:
pad_h = 0
pad_w = 0
else:
pad_h = (self.output_sz[0] - imsz[0]) / 2
pad_w = (self.output_sz[1] - imsz[1]) / 2
if shift is None:
shift = self.shift
pad_left = math.floor(pad_w) + shift[1]
pad_right = math.ceil(pad_w) - shift[1]
pad_top = math.floor(pad_h) + shift[0]
pad_bottom = math.ceil(pad_h) - shift[0]
if isinstance(image, PTensor):
return _padding(
image, (pad_left, pad_right, pad_top, pad_bottom),
mode='replicate')
else:
return _padding(
image, (0, 0, pad_left, pad_right, pad_top, pad_bottom),
mode='replicate')
class Identity(Transform):
"""Identity transformation."""
def __call__(self, image):
return self.crop_to_output(image)
class FlipHorizontal(Transform):
"""Flip along horizontal axis."""
def __call__(self, image):
if isinstance(image, PTensor):
return self.crop_to_output(layers.reverse(image, 3))
else:
return self.crop_to_output(np.fliplr(image))
class FlipVertical(Transform):
"""Flip along vertical axis."""
def __call__(self, image: PTensor):
if isinstance(image, PTensor):
return self.crop_to_output(layers.reverse(image, 2))
else:
return self.crop_to_output(np.flipud(image))
class Translation(Transform):
"""Translate."""
def __init__(self, translation, output_sz=None, shift=None):
super().__init__(output_sz, shift)
self.shift = (self.shift[0] + translation[0],
self.shift[1] + translation[1])
def __call__(self, image):
return self.crop_to_output(image)
class Scale(Transform):
"""Scale."""
def __init__(self, scale_factor, output_sz=None, shift=None):
super().__init__(output_sz, shift)
self.scale_factor = scale_factor
def __call__(self, image):
# Calculate new size. Ensure that it is even so that crop/pad becomes easier
h_orig, w_orig = image.shape[2:]
if h_orig != w_orig:
raise NotImplementedError
h_new = round(h_orig / self.scale_factor)
h_new += (h_new - h_orig) % 2
w_new = round(w_orig / self.scale_factor)
w_new += (w_new - w_orig) % 2
if isinstance(image, PTensor):
image_resized = layers.resize_bilinear(
image, [h_new, w_new], align_corners=False)
else:
image_resized = cv.resize(
image, (w_new, h_new), interpolation=cv.INTER_LINEAR)
return self.crop_to_output(image_resized)
class Affine(Transform):
"""Affine transformation."""
def __init__(self, transform_matrix, output_sz=None, shift=None):
super().__init__(output_sz, shift)
self.transform_matrix = transform_matrix
def __call__(self, image, crop=True):
if isinstance(image, PTensor):
return self.crop_to_output(
numpy_to_paddle(self(
paddle_to_numpy(image), crop=False)))
else:
warp = cv.warpAffine(
image,
self.transform_matrix,
image.shape[1::-1],
borderMode=cv.BORDER_REPLICATE)
if crop:
return self.crop_to_output(warp)
else:
return warp
class Rotate(Transform):
"""Rotate with given angle."""
def __init__(self, angle, output_sz=None, shift=None):
super().__init__(output_sz, shift)
self.angle = math.pi * angle / 180
def __call__(self, image, crop=True):
if isinstance(image, PTensor):
return self.crop_to_output(
numpy_to_paddle(self(
paddle_to_numpy(image), crop=False)))
else:
c = (np.expand_dims(np.array(image.shape[:2]), 1) - 1) / 2
R = np.array([[math.cos(self.angle), math.sin(self.angle)],
[-math.sin(self.angle), math.cos(self.angle)]])
H = np.concatenate([R, c - R @c], 1)
warp = cv.warpAffine(
image, H, image.shape[1::-1], borderMode=cv.BORDER_REPLICATE)
if crop:
return self.crop_to_output(warp)
else:
return warp
class Blur(Transform):
"""Blur with given sigma (can be axis dependent)."""
def __init__(self, sigma, output_sz=None, shift=None):
super().__init__(output_sz, shift)
if isinstance(sigma, (float, int)):
sigma = (sigma, sigma)
self.sigma = sigma
self.filter_size = [math.ceil(2 * s) for s in self.sigma]
x_coord = [
np.arange(
-sz, sz + 1, 1, dtype='float32') for sz in self.filter_size
]
self.filter_np = [
np.exp(0 - (x * x) / (2 * s**2))
for x, s in zip(x_coord, self.sigma)
]
self.filter_np[0] = np.reshape(
self.filter_np[0], [1, 1, -1, 1]) / np.sum(self.filter_np[0])
self.filter_np[1] = np.reshape(
self.filter_np[1], [1, 1, 1, -1]) / np.sum(self.filter_np[1])
def __call__(self, image):
if isinstance(image, PTensor):
sz = image.shape[2:]
filter = [n2p(f) for f in self.filter_np]
im1 = FConv2D(
layers.reshape(image, [-1, 1, sz[0], sz[1]]),
filter[0],
padding=(self.filter_size[0], 0))
return self.crop_to_output(
layers.reshape(
FConv2D(
im1, filter[1], padding=(0, self.filter_size[1])),
[1, -1, sz[0], sz[1]]))
else:
return paddle_to_numpy(self(numpy_to_paddle(image)))
import numpy as np
import math
from paddle.fluid import layers
import cv2 as cv
from pytracking.features.preprocessing import numpy_to_paddle, paddle_to_numpy
from pytracking.libs.Fconv2d import FConv2D
from pytracking.libs.paddle_utils import PTensor, _padding, n2p
class Transform:
"""Base data augmentation transform class."""
def __init__(self, output_sz=None, shift=None):
self.output_sz = output_sz
self.shift = (0, 0) if shift is None else shift
def __call__(self, image):
raise NotImplementedError
def crop_to_output(self, image, shift=None):
if isinstance(image, PTensor):
imsz = image.shape[2:]
else:
imsz = image.shape[:2]
if self.output_sz is None:
pad_h = 0
pad_w = 0
else:
pad_h = (self.output_sz[0] - imsz[0]) / 2
pad_w = (self.output_sz[1] - imsz[1]) / 2
if shift is None:
shift = self.shift
pad_left = math.floor(pad_w) + shift[1]
pad_right = math.ceil(pad_w) - shift[1]
pad_top = math.floor(pad_h) + shift[0]
pad_bottom = math.ceil(pad_h) - shift[0]
if isinstance(image, PTensor):
return _padding(
image, (pad_left, pad_right, pad_top, pad_bottom),
mode='replicate')
else:
return _padding(
image, (0, 0, pad_left, pad_right, pad_top, pad_bottom),
mode='replicate')
class Identity(Transform):
"""Identity transformation."""
def __call__(self, image):
return self.crop_to_output(image)
class FlipHorizontal(Transform):
"""Flip along horizontal axis."""
def __call__(self, image):
if isinstance(image, PTensor):
return self.crop_to_output(layers.reverse(image, 3))
else:
return self.crop_to_output(np.fliplr(image))
class FlipVertical(Transform):
"""Flip along vertical axis."""
def __call__(self, image: PTensor):
if isinstance(image, PTensor):
return self.crop_to_output(layers.reverse(image, 2))
else:
return self.crop_to_output(np.flipud(image))
class Translation(Transform):
"""Translate."""
def __init__(self, translation, output_sz=None, shift=None):
super().__init__(output_sz, shift)
self.shift = (self.shift[0] + translation[0],
self.shift[1] + translation[1])
def __call__(self, image):
return self.crop_to_output(image)
class Scale(Transform):
"""Scale."""
def __init__(self, scale_factor, output_sz=None, shift=None):
super().__init__(output_sz, shift)
self.scale_factor = scale_factor
def __call__(self, image):
# Calculate new size. Ensure that it is even so that crop/pad becomes easier
h_orig, w_orig = image.shape[2:]
if h_orig != w_orig:
raise NotImplementedError
h_new = round(h_orig / self.scale_factor)
h_new += (h_new - h_orig) % 2
w_new = round(w_orig / self.scale_factor)
w_new += (w_new - w_orig) % 2
if isinstance(image, PTensor):
image_resized = layers.resize_bilinear(
image, [h_new, w_new], align_corners=False)
else:
image_resized = cv.resize(
image, (w_new, h_new), interpolation=cv.INTER_LINEAR)
return self.crop_to_output(image_resized)
class Affine(Transform):
"""Affine transformation."""
def __init__(self, transform_matrix, output_sz=None, shift=None):
super().__init__(output_sz, shift)
self.transform_matrix = transform_matrix
def __call__(self, image, crop=True):
if isinstance(image, PTensor):
return self.crop_to_output(
numpy_to_paddle(self(
paddle_to_numpy(image), crop=False)))
else:
warp = cv.warpAffine(
image,
self.transform_matrix,
image.shape[1::-1],
borderMode=cv.BORDER_REPLICATE)
if crop:
return self.crop_to_output(warp)
else:
return warp
class Rotate(Transform):
"""Rotate with given angle."""
def __init__(self, angle, output_sz=None, shift=None):
super().__init__(output_sz, shift)
self.angle = math.pi * angle / 180
def __call__(self, image, crop=True):
if isinstance(image, PTensor):
return self.crop_to_output(
numpy_to_paddle(self(
paddle_to_numpy(image), crop=False)))
else:
c = (np.expand_dims(np.array(image.shape[:2]), 1) - 1) / 2
R = np.array([[math.cos(self.angle), math.sin(self.angle)],
[-math.sin(self.angle), math.cos(self.angle)]])
H = np.concatenate([R, c - R @c], 1)
warp = cv.warpAffine(
image, H, image.shape[1::-1], borderMode=cv.BORDER_REPLICATE)
if crop:
return self.crop_to_output(warp)
else:
return warp
class Blur(Transform):
"""Blur with given sigma (can be axis dependent)."""
def __init__(self, sigma, output_sz=None, shift=None):
super().__init__(output_sz, shift)
if isinstance(sigma, (float, int)):
sigma = (sigma, sigma)
self.sigma = sigma
self.filter_size = [math.ceil(2 * s) for s in self.sigma]
x_coord = [
np.arange(
-sz, sz + 1, 1, dtype='float32') for sz in self.filter_size
]
self.filter_np = [
np.exp(0 - (x * x) / (2 * s**2))
for x, s in zip(x_coord, self.sigma)
]
self.filter_np[0] = np.reshape(
self.filter_np[0], [1, 1, -1, 1]) / np.sum(self.filter_np[0])
self.filter_np[1] = np.reshape(
self.filter_np[1], [1, 1, 1, -1]) / np.sum(self.filter_np[1])
def __call__(self, image):
if isinstance(image, PTensor):
sz = image.shape[2:]
filter = [n2p(f) for f in self.filter_np]
im1 = FConv2D(
layers.reshape(image, [-1, 1, sz[0], sz[1]]),
filter[0],
padding=(self.filter_size[0], 0))
return self.crop_to_output(
layers.reshape(
FConv2D(
im1, filter[1], padding=(0, self.filter_size[1])),
[1, -1, sz[0], sz[1]]))
else:
return paddle_to_numpy(self(numpy_to_paddle(image)))
......@@ -5,6 +5,7 @@ from paddle import fluid
from ltr.models.bbreg.atom import atom_resnet50, atom_resnet18
from ltr.models.siamese.siam import siamfc_alexnet
from ltr.models.siam.siam import SiamRPN_AlexNet, SiamMask_ResNet50_sharp, SiamMask_ResNet50_base
from pytracking.admin.environment import env_settings
from pytracking.features.featurebase import MultiFeatureBase
from pytracking.libs import TensorList
......@@ -347,3 +348,147 @@ class SFCAlexnet(MultiFeatureBase):
output_features[layer].numpy() for layer in self.output_layers
])
return output
class SRPNAlexNet(MultiFeatureBase):
"""Alexnet feature.
args:
output_layers: List of layers to output.
net_path: Relative or absolute net path (default should be fine).
use_gpu: Use GPU or CPU.
"""
def __init__(self,
net_path='estimator',
use_gpu=True,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.use_gpu = use_gpu
self.net_path = net_path
def initialize(self):
with fluid.dygraph.guard():
if os.path.isabs(self.net_path):
net_path_full = self.net_path
else:
net_path_full = os.path.join(env_settings().network_path, self.net_path)
self.net = SiamRPN_AlexNet(backbone_pretrained=False, is_test=True)
state_dict, _ = fluid.load_dygraph(net_path_full)
self.net.load_dict(state_dict)
self.net.eval()
def free_memory(self):
if hasattr(self, 'net'):
del self.net
def extract(self, im: np.ndarray, debug_save_name=None):
with fluid.dygraph.guard():
if debug_save_name is not None:
np.savez(debug_save_name, im)
im = n2p(im)
output_features = self.net.extract_backbone_features(im)
# Store the raw backbone features which are input to estimator
output = TensorList([layer.numpy() for layer in output_features])
return output
class SMaskResNet50_base(MultiFeatureBase):
"""Resnet50-dilated feature.
args:
output_layers: List of layers to output.
net_path: Relative or absolute net path (default should be fine).
use_gpu: Use GPU or CPU.
"""
def __init__(self,
net_path='estimator',
use_gpu=True,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.use_gpu = use_gpu
self.net_path = net_path
def initialize(self):
with fluid.dygraph.guard():
if os.path.isabs(self.net_path):
net_path_full = self.net_path
else:
net_path_full = os.path.join(env_settings().network_path, self.net_path)
self.net = SiamMask_ResNet50_base(backbone_pretrained=False, is_test=True)
state_dict, _ = fluid.load_dygraph(net_path_full)
self.net.load_dict(state_dict)
self.net.eval()
def free_memory(self):
if hasattr(self, 'net'):
del self.net
def extract(self, im: np.ndarray, debug_save_name=None):
with fluid.dygraph.guard():
if debug_save_name is not None:
np.savez(debug_save_name, im)
im = n2p(im)
output_features = self.net.extract_backbone_features(im)
# Store the raw backbone features which are input to estimator
output = TensorList([layer.numpy() for layer in output_features])
return output
class SMaskResNet50_sharp(MultiFeatureBase):
"""Resnet50-dilated feature.
args:
output_layers: List of layers to output.
net_path: Relative or absolute net path (default should be fine).
use_gpu: Use GPU or CPU.
"""
def __init__(self,
net_path='estimator',
use_gpu=True,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.use_gpu = use_gpu
self.net_path = net_path
def initialize(self):
with fluid.dygraph.guard():
if os.path.isabs(self.net_path):
net_path_full = self.net_path
else:
net_path_full = os.path.join(env_settings().network_path, self.net_path)
self.net = SiamMask_ResNet50_sharp(backbone_pretrained=False, is_test=True)
state_dict, _ = fluid.load_dygraph(net_path_full)
self.net.load_dict(state_dict)
self.net.eval()
def free_memory(self):
if hasattr(self, 'net'):
del self.net
def extract(self, im: np.ndarray, debug_save_name=None):
with fluid.dygraph.guard():
if debug_save_name is not None:
np.savez(debug_save_name, im)
im = n2p(im)
output_features = self.net.extract_backbone_features(im)
# Store the raw backbone features which are input to estimator
output = TensorList([layer.numpy() for layer in output_features])
return output
from __future__ import print_function
import numpy as np
from paddle.fluid.framework import Variable, in_dygraph_mode
from paddle.fluid import core, dygraph_utils
from paddle.fluid.layers import nn, utils
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
def _is_list_or_tuple(input):
return isinstance(input, (list, tuple))
def _zero_padding_in_batch_and_channel(padding, channel_last):
if channel_last:
return list(padding[0]) == [0, 0] and list(padding[-1]) == [0, 0]
else:
return list(padding[0]) == [0, 0] and list(padding[1]) == [0, 0]
def _exclude_padding_in_batch_and_channel(padding, channel_last):
padding_ = padding[1:-1] if channel_last else padding[2:]
padding_ = [elem for pad_a_dim in padding_ for elem in pad_a_dim]
return padding_
def _update_padding_nd(padding, channel_last, num_dims):
if isinstance(padding, str):
padding = padding.upper()
if padding not in ["SAME", "VALID"]:
raise ValueError(
"Unknown padding: '{}'. It can only be 'SAME' or 'VALID'.".
format(padding))
if padding == "VALID":
padding_algorithm = "VALID"
padding = [0] * num_dims
else:
padding_algorithm = "SAME"
padding = [0] * num_dims
elif _is_list_or_tuple(padding):
# for padding like
# [(pad_before, pad_after), (pad_before, pad_after), ...]
# padding for batch_dim and channel_dim included
if len(padding) == 2 + num_dims and _is_list_or_tuple(padding[0]):
if not _zero_padding_in_batch_and_channel(padding, channel_last):
raise ValueError(
"Non-zero padding({}) in the batch or channel dimensions "
"is not supported.".format(padding))
padding_algorithm = "EXPLICIT"
padding = _exclude_padding_in_batch_and_channel(padding,
channel_last)
if utils._is_symmetric_padding(padding, num_dims):
padding = padding[0::2]
# for padding like [pad_before, pad_after, pad_before, pad_after, ...]
elif len(padding) == 2 * num_dims and isinstance(padding[0], int):
padding_algorithm = "EXPLICIT"
padding = utils.convert_to_list(padding, 2 * num_dims, 'padding')
if utils._is_symmetric_padding(padding, num_dims):
padding = padding[0::2]
# for padding like [pad_d1, pad_d2, ...]
elif len(padding) == num_dims and isinstance(padding[0], int):
padding_algorithm = "EXPLICIT"
padding = utils.convert_to_list(padding, num_dims, 'padding')
else:
raise ValueError("In valid padding: {}".format(padding))
# for integer padding
else:
padding_algorithm = "EXPLICIT"
padding = utils.convert_to_list(padding, num_dims, 'padding')
return padding, padding_algorithm
def FConv2D(input,
weight,
bias=None,
padding=0,
stride=1,
dilation=1,
groups=1,
use_cudnn=True,
act=None,
data_format="NCHW",
name=None):
# entry checks
if not isinstance(use_cudnn, bool):
raise ValueError("Attr(use_cudnn) should be True or False. "
"Received Attr(use_cudnn): {}.".format(use_cudnn))
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. "
"Received Attr(data_format): {}.".format(data_format))
channel_last = (data_format == "NHWC")
channel_dim = -1 if channel_last else 1
num_channels = input.shape[channel_dim]
num_filters = weight.shape[0]
if num_channels < 0:
raise ValueError("The channel dimmention of the input({}) "
"should be defined. Received: {}.".format(
input.shape, num_channels))
if num_channels % groups != 0:
raise ValueError(
"the channel of input must be divisible by groups,"
"received: the channel of input is {}, the shape of input is {}"
", the groups is {}".format(num_channels, input.shape, groups))
if num_filters % groups != 0:
raise ValueError(
"the number of filters must be divisible by groups,"
"received: the number of filters is {}, the shape of weight is {}"
", the groups is {}".format(num_filters, weight.shape, groups))
# update attrs
padding, padding_algorithm = _update_padding_nd(padding, channel_last, 2)
stride = utils.convert_to_list(stride, 2, 'stride')
dilation = utils.convert_to_list(dilation, 2, 'dilation')
l_type = "conv2d"
if (num_channels == groups and num_filters % num_channels == 0 and
not use_cudnn):
l_type = 'depthwise_conv2d'
inputs = {'Input': [input], 'Filter': [weight]}
attrs = {
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False,
'fuse_relu_before_depthwise_conv': False,
"padding_algorithm": padding_algorithm,
"data_format": data_format
}
if in_dygraph_mode():
attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation,
'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', False,
'fuse_relu_before_depthwise_conv', False, "padding_algorithm",
padding_algorithm, "data_format", data_format)
pre_bias = getattr(core.ops, l_type)(input, weight, *attrs)
if bias is not None:
pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
pre_act = pre_bias
out = dygraph_utils._append_activation_in_dygraph(
pre_act, act, use_cudnn=use_cudnn)
else:
inputs = {'Input': [input], 'Filter': [weight]}
attrs = {
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False,
'fuse_relu_before_depthwise_conv': False,
"padding_algorithm": padding_algorithm,
"data_format": data_format
}
check_variable_and_dtype(input, 'input',
['float16', 'float32', 'float64'], 'conv2d')
helper = LayerHelper(l_type, **locals())
dtype = helper.input_dtype()
pre_bias = helper.create_variable_for_type_inference(dtype)
outputs = {"Output": [pre_bias]}
helper.append_op(
type=l_type, inputs=inputs, outputs=outputs, attrs=attrs)
if bias is not None:
pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
pre_act = pre_bias
out = helper.append_activation(pre_act)
return out
def test_conv2d_with_filter():
import paddle.fluid.dygraph as dygraph
import numpy as np
exemplar = np.random.random((8, 4, 6, 6)).astype(np.float32)
instance = np.random.random((8, 4, 22, 22)).astype(np.float32)
with dygraph.guard():
exem = dygraph.to_variable(exemplar)
inst = dygraph.to_variable(instance)
res = FConv2D(inst, exem, groups=1)
from __future__ import print_function
import numpy as np
from paddle.fluid.framework import Variable, in_dygraph_mode
from paddle.fluid import core, dygraph_utils
from paddle.fluid.layers import nn, utils
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
def _is_list_or_tuple(input):
return isinstance(input, (list, tuple))
def _zero_padding_in_batch_and_channel(padding, channel_last):
if channel_last:
return list(padding[0]) == [0, 0] and list(padding[-1]) == [0, 0]
else:
return list(padding[0]) == [0, 0] and list(padding[1]) == [0, 0]
def _exclude_padding_in_batch_and_channel(padding, channel_last):
padding_ = padding[1:-1] if channel_last else padding[2:]
padding_ = [elem for pad_a_dim in padding_ for elem in pad_a_dim]
return padding_
def _update_padding_nd(padding, channel_last, num_dims):
if isinstance(padding, str):
padding = padding.upper()
if padding not in ["SAME", "VALID"]:
raise ValueError(
"Unknown padding: '{}'. It can only be 'SAME' or 'VALID'.".
format(padding))
if padding == "VALID":
padding_algorithm = "VALID"
padding = [0] * num_dims
else:
padding_algorithm = "SAME"
padding = [0] * num_dims
elif _is_list_or_tuple(padding):
# for padding like
# [(pad_before, pad_after), (pad_before, pad_after), ...]
# padding for batch_dim and channel_dim included
if len(padding) == 2 + num_dims and _is_list_or_tuple(padding[0]):
if not _zero_padding_in_batch_and_channel(padding, channel_last):
raise ValueError(
"Non-zero padding({}) in the batch or channel dimensions "
"is not supported.".format(padding))
padding_algorithm = "EXPLICIT"
padding = _exclude_padding_in_batch_and_channel(padding,
channel_last)
if utils._is_symmetric_padding(padding, num_dims):
padding = padding[0::2]
# for padding like [pad_before, pad_after, pad_before, pad_after, ...]
elif len(padding) == 2 * num_dims and isinstance(padding[0], int):
padding_algorithm = "EXPLICIT"
padding = utils.convert_to_list(padding, 2 * num_dims, 'padding')
if utils._is_symmetric_padding(padding, num_dims):
padding = padding[0::2]
# for padding like [pad_d1, pad_d2, ...]
elif len(padding) == num_dims and isinstance(padding[0], int):
padding_algorithm = "EXPLICIT"
padding = utils.convert_to_list(padding, num_dims, 'padding')
else:
raise ValueError("In valid padding: {}".format(padding))
# for integer padding
else:
padding_algorithm = "EXPLICIT"
padding = utils.convert_to_list(padding, num_dims, 'padding')
return padding, padding_algorithm
def FConv2D(input,
weight,
bias=None,
padding=0,
stride=1,
dilation=1,
groups=1,
use_cudnn=True,
act=None,
data_format="NCHW",
name=None):
# entry checks
if not isinstance(use_cudnn, bool):
raise ValueError("Attr(use_cudnn) should be True or False. "
"Received Attr(use_cudnn): {}.".format(use_cudnn))
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. "
"Received Attr(data_format): {}.".format(data_format))
channel_last = (data_format == "NHWC")
channel_dim = -1 if channel_last else 1
num_channels = input.shape[channel_dim]
num_filters = weight.shape[0]
if num_channels < 0:
raise ValueError("The channel dimmention of the input({}) "
"should be defined. Received: {}.".format(
input.shape, num_channels))
if num_channels % groups != 0:
raise ValueError(
"the channel of input must be divisible by groups,"
"received: the channel of input is {}, the shape of input is {}"
", the groups is {}".format(num_channels, input.shape, groups))
if num_filters % groups != 0:
raise ValueError(
"the number of filters must be divisible by groups,"
"received: the number of filters is {}, the shape of weight is {}"
", the groups is {}".format(num_filters, weight.shape, groups))
# update attrs
padding, padding_algorithm = _update_padding_nd(padding, channel_last, 2)
stride = utils.convert_to_list(stride, 2, 'stride')
dilation = utils.convert_to_list(dilation, 2, 'dilation')
l_type = "conv2d"
if (num_channels == groups and num_filters % num_channels == 0 and
not use_cudnn):
l_type = 'depthwise_conv2d'
inputs = {'Input': [input], 'Filter': [weight]}
attrs = {
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False,
'fuse_relu_before_depthwise_conv': False,
"padding_algorithm": padding_algorithm,
"data_format": data_format
}
if in_dygraph_mode():
attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation,
'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', False,
'fuse_relu_before_depthwise_conv', False, "padding_algorithm",
padding_algorithm, "data_format", data_format)
pre_bias = getattr(core.ops, l_type)(input, weight, *attrs)
if bias is not None:
pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
pre_act = pre_bias
out = dygraph_utils._append_activation_in_dygraph(
pre_act, act, use_cudnn=use_cudnn)
else:
inputs = {'Input': [input], 'Filter': [weight]}
attrs = {
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False,
'fuse_relu_before_depthwise_conv': False,
"padding_algorithm": padding_algorithm,
"data_format": data_format
}
check_variable_and_dtype(input, 'input',
['float16', 'float32', 'float64'], 'conv2d')
helper = LayerHelper(l_type, **locals())
dtype = helper.input_dtype()
pre_bias = helper.create_variable_for_type_inference(dtype)
outputs = {"Output": [pre_bias]}
helper.append_op(
type=l_type, inputs=inputs, outputs=outputs, attrs=attrs)
if bias is not None:
pre_act = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
else:
pre_act = pre_bias
out = helper.append_activation(pre_act)
return out
def test_conv2d_with_filter():
import paddle.fluid.dygraph as dygraph
import numpy as np
exemplar = np.random.random((8, 4, 6, 6)).astype(np.float32)
instance = np.random.random((8, 4, 22, 22)).astype(np.float32)
with dygraph.guard():
exem = dygraph.to_variable(exemplar)
inst = dygraph.to_variable(instance)
res = FConv2D(inst, exem, groups=1)
print(res.shape)
\ No newline at end of file
from paddle import fluid
from paddle.fluid import layers
from pytracking.libs.Fconv2d import FConv2D
from pytracking.libs.tensorlist import tensor_operation, TensorList
from paddle.fluid.framework import Variable as PTensor
@tensor_operation
def conv2d(input: PTensor,
weight: PTensor,
bias: PTensor=None,
stride=1,
padding=0,
dilation=1,
groups=1,
mode=None):
"""Standard conv2d. Returns the input if weight=None."""
if weight is None:
return input
ind = None
if mode is not None:
if padding != 0:
raise ValueError('Cannot input both padding and mode.')
if mode == 'same':
padding = (weight.shape[2] // 2, weight.shape[3] // 2)
if weight.shape[2] % 2 == 0 or weight.shape[3] % 2 == 0:
ind = (slice(-1)
if weight.shape[2] % 2 == 0 else slice(None), slice(-1)
if weight.shape[3] % 2 == 0 else slice(None))
elif mode == 'valid':
padding = (0, 0)
elif mode == 'full':
padding = (weight.shape[2] - 1, weight.shape[3] - 1)
else:
raise ValueError('Unknown mode for padding.')
assert bias is None
out = FConv2D(
input,
weight,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
if ind is None:
return out
return out[:, :, ind[0], ind[1]]
@tensor_operation
def conv1x1(input: PTensor, weight: PTensor):
"""Do a convolution with a 1x1 kernel weights. Implemented with matmul, which can be faster than using conv."""
if weight is None:
return input
return FConv2D(input, weight)
from paddle import fluid
from paddle.fluid import layers
from pytracking.libs.Fconv2d import FConv2D
from pytracking.libs.tensorlist import tensor_operation, TensorList
from paddle.fluid.framework import Variable as PTensor
@tensor_operation
def conv2d(input: PTensor,
weight: PTensor,
bias: PTensor=None,
stride=1,
padding=0,
dilation=1,
groups=1,
mode=None):
"""Standard conv2d. Returns the input if weight=None."""
if weight is None:
return input
ind = None
if mode is not None:
if padding != 0:
raise ValueError('Cannot input both padding and mode.')
if mode == 'same':
padding = (weight.shape[2] // 2, weight.shape[3] // 2)
if weight.shape[2] % 2 == 0 or weight.shape[3] % 2 == 0:
ind = (slice(-1)
if weight.shape[2] % 2 == 0 else slice(None), slice(-1)
if weight.shape[3] % 2 == 0 else slice(None))
elif mode == 'valid':
padding = (0, 0)
elif mode == 'full':
padding = (weight.shape[2] - 1, weight.shape[3] - 1)
else:
raise ValueError('Unknown mode for padding.')
assert bias is None
out = FConv2D(
input,
weight,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
if ind is None:
return out
return out[:, :, ind[0], ind[1]]
@tensor_operation
def conv1x1(input: PTensor, weight: PTensor):
"""Do a convolution with a 1x1 kernel weights. Implemented with matmul, which can be faster than using conv."""
if weight is None:
return input
return FConv2D(input, weight)
import numpy as np
from pytracking.features import deep
from pytracking.features.extractor import MultiResolutionExtractor
from pytracking.utils import TrackerParams, FeatureParams
def parameters():
params = TrackerParams()
# These are usually set from outside
params.debug = 0 # Debug level
params.visualization = False # Do visualization
# Use GPU or not (IoUNet requires this to be True)
params.use_gpu = True
# Feature specific parameters
deep_params = TrackerParams()
# Patch sampling parameters
params.exemplar_size = 127
params.instance_size = 255
params.base_size = 8
params.context_amount = 0.5
# Anchor parameters
params.anchor_stride = 8
params.anchor_ratios = [0.33, 0.5, 1, 2, 3]
params.anchor_scales = [8]
# Tracking parameters
params.penalty_k = 0.1
params.window_influence = 0.41
params.lr = 0.32
params.mask_threshold = 0.15
# Setup the feature extractor
deep_fparams = FeatureParams(feature_params=[deep_params])
deep_feat = deep.SMaskResNet50_base(fparams=deep_fparams)
params.features = MultiResolutionExtractor([deep_feat])
return params
import numpy as np
from pytracking.features import deep
from pytracking.features.extractor import MultiResolutionExtractor
from pytracking.utils import TrackerParams, FeatureParams
def parameters():
params = TrackerParams()
# These are usually set from outside
params.debug = 0 # Debug level
params.visualization = False # Do visualization
# Use GPU or not (IoUNet requires this to be True)
params.use_gpu = True
# Feature specific parameters
deep_params = TrackerParams()
# Patch sampling parameters
params.exemplar_size = 127
params.instance_size = 255
params.base_size = 8
params.context_amount = 0.5
params.mask_output_size = 127
# Anchor parameters
params.anchor_stride = 8
params.anchor_ratios = [0.33, 0.5, 1, 2, 3]
params.anchor_scales = [8]
# Tracking parameters
params.penalty_k = 0.04
params.window_influence = 0.42
params.lr = 0.25
params.mask_threshold = 0.30
# output rect result
params.polygon = False
# Setup the feature extractor
deep_fparams = FeatureParams(feature_params=[deep_params])
deep_feat = deep.SMaskResNet50_sharp(fparams=deep_fparams)
params.features = MultiResolutionExtractor([deep_feat])
return params
import numpy as np
from pytracking.features import deep
from pytracking.features.extractor import MultiResolutionExtractor
from pytracking.utils import TrackerParams, FeatureParams
def parameters():
params = TrackerParams()
# These are usually set from outside
params.debug = 0 # Debug level
params.visualization = False # Do visualization
# Use GPU or not (IoUNet requires this to be True)
params.use_gpu = True
# Feature specific parameters
deep_params = TrackerParams()
# Patch sampling parameters
params.exemplar_size = 127
params.instance_size = 255
params.base_size = 8
params.context_amount = 0.5
params.mask_output_size = 127
# Anchor parameters
params.anchor_stride = 8
params.anchor_ratios = [0.33, 0.5, 1, 2, 3]
params.anchor_scales = [8]
# Tracking parameters
params.penalty_k = 0.20
params.window_influence = 0.41
params.lr = 0.30
params.mask_threshold = 0.30
# output polygon result
params.polygon = True
# Setup the feature extractor
deep_fparams = FeatureParams(feature_params=[deep_params])
deep_feat = deep.SMaskResNet50_sharp(fparams=deep_fparams)
params.features = MultiResolutionExtractor([deep_feat])
return params
import numpy as np
from pytracking.features import deep
from pytracking.features.extractor import MultiResolutionExtractor
from pytracking.utils import TrackerParams, FeatureParams
def parameters():
params = TrackerParams()
# These are usually set from outside
params.debug = 0 # Debug level
params.visualization = False # Do visualization
# Use GPU or not (IoUNet requires this to be True)
params.use_gpu = True
# Feature specific parameters
deep_params = TrackerParams()
# Patch sampling parameters
params.exemplar_size = 127
params.instance_size = 287
params.base_size = 0
params.context_amount = 0.5
# Anchor parameters
params.anchor_stride = 8
params.anchor_ratios = [0.33, 0.5, 1, 2, 3]
params.anchor_scales = [8]
# Tracking parameters
params.penalty_k = 0.18
params.window_influence = 0.41
params.lr = 0.05
# Setup the feature extractor
deep_fparams = FeatureParams(feature_params=[deep_params])
deep_feat = deep.SRPNAlexNet(fparams=deep_fparams)
params.features = MultiResolutionExtractor([deep_feat])
return params
from .siammask import SiamMask
def get_tracker_class():
return SiamMask
from .siamrpn import SiamRPN
def get_tracker_class():
return SiamRPN
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册