提交 409785dc 编写于 作者: K Kaiyu Yue 提交者: qingqing01

Add Human Pose Estimation for FluidCV (#1474)

* add human pose estimation
* add comments for args of pose net
* use skip_opt_set for memory_optimize
* fix the typo
* add pretrained link
* fix typo
* clean the code
上级 d8544499
# Simple Baselines for Human Pose Estimation in Fluid
## Introduction
This is a simple demonstration of re-implementation in [PaddlePaddle.Fluid](http://www.paddlepaddle.org/en) for the paper [Simple Baselines for Human Pose Estimation and Tracking](https://arxiv.org/abs/1804.06208) (ECCV'18) from MSRA.
![demo](demo.gif)
> **Video in Demo**: *Bruno Mars - That’s What I Like [Official Video]*.
## Requirements
- Python == 2.7
- PaddlePaddle >= 1.0
- opencv-python >= 3.3
- tqdm >= 4.25
## Environment
The code is developed and tested under 4 Tesla K40 GPUS cards on CentOS with installed CUDA-9.2/8.0 and cuDNN-7.1.
## Known Issues
- The model does not converge with large batch\_size (e.g. = 32) on Tesla P40 / V100 / P100 GPUS cards, because PaddlePaddle uses the batch normalization function of cuDNN. Changing batch\_size into 1 image on each card during training will ease this problem, but not sure the performance. The issue can be tracked at [here](https://github.com/PaddlePaddle/Paddle/issues/14580).
## Results on MPII Val
| Arch | Head | Shoulder | Elbow | Wrist | Hip | Knee | Ankle | Mean | Mean@0.1| Models |
| ---- |:----:|:--------:|:-----:|:-----:|:---:|:----:|:-----:|:----:|:-------:|:------:|
| 383x384\_pose\_resnet\_50 in PyTorch | 96.658 | 95.754 | 89.790 | 84.614 | 88.523 | 84.666 | 79.287 | 89.066 | 38.046 | - |
| 383x384\_pose\_resnet\_50 in Fluid | 96.248 | 95.346 | 89.807 | 84.873 | 88.298 | 83.679 | 78.649 | 88.767 | 37.374 | [`link`](http://paddlemodels.bj.bcebos.com/pose/pose-resnet-50-384x384-mpii.tar.gz) |
### Notes:
- Flip test is used.
- We do not hardly search the best model, just use the last saved model to make validation.
## Getting Start
### Prepare Datasets and Pretrained Models
- Following the [instruction](https://github.com/Microsoft/human-pose-estimation.pytorch#data-preparation) to prepare datasets.
- Download the pretrained ResNet-50 model in PaddlePaddle.Fluid on ImageNet from [Model Zoo](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification#supported-models-and-performances).
```bash
wget http://paddle-imagenet-models.bj.bcebos.com/resnet_50_model.tar
```
Then, put them in the folder `pretrained` under the directory root of this repo, make them look like:
```
${THIS REPO ROOT}
`-- pretrained
`-- resnet_50
|-- 115
`-- data
`-- coco
|-- annotations
|-- images
`-- mpii
|-- annot
|-- images
```
### Install [COCOAPI](https://github.com/cocodataset/cocoapi)
```bash
# COCOAPI=/path/to/clone/cocoapi
git clone https://github.com/cocodataset/cocoapi.git $COCOAPI
cd $COCOAPI/PythonAPI
# if cython is not installed
pip install Cython
# Install into global site-packages
make install
# Alternatively, if you do not have permissions or prefer
# not to install the COCO API into global site-packages
python2 setup.py install --user
```
### Perform Validating
Downloading the checkpoints of Pose-ResNet-50 trained on MPII dataset from [here](http://paddlemodels.bj.bcebos.com/pose/pose-resnet-50-384x384-mpii.tar.gz). Extract it into the folder `checkpoints` under the directory root of this repo. Then run
```bash
python2 val.py --dataset 'mpii' --checkpoint 'checkpoints/pose-resnet-50-384x384-mpii'
```
### Perform Training
```bash
python2 train.py --dataset 'mpii' # or coco
```
**Note**: Configurations for training are aggregated in the `lib/mpii_reader.py` and `lib/coco_reader.py`.
### Perform Test on Images
Put the images into the folder `test` under the directory root of this repo. Then run
```bash
python2 test.py --checkpoint 'checkpoints/pose-resnet-50-384x384-mpii'
```
If there are multiple persons in images, detectors such as [Faster R-CNN](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/faster_rcnn), [SSD](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/object_detection) or others should be used first to crop them out. Because the simple baseline for human pose estimation is a top-down method.
## Reference
- Simple Baselines for Human Pose Estimation and Tracking in PyTorch [`code`](https://github.com/Microsoft/human-pose-estimation.pytorch#data-preparation)
## License
This code is released under the Apache License 2.0.
因为 它太大了无法显示 image diff 。你可以改为 查看blob
# Copyright (c) 2018-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Libs for data reader."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import cv2
import numpy as np
def visualize(cfg, filename, data_numpy, input, joints, target):
"""
:param cfg: global configurations for dataset
:param filename: the name of image file
:param data_numpy: original numpy image data
:param input: input tensor [b, c, h, w]
:param joints: [num_joints, 3]
:param target: target tensor [b, c, h, w]
"""
TMPDIR = cfg.TMPDIR
NUM_JOINTS = cfg.NUM_JOINTS
if os.path.exists(TMPDIR):
shutil.rmtree(TMPDIR)
os.mkdir(TMPDIR)
else:
os.mkdir(TMPDIR)
f = open(os.path.join(TMPDIR, filename), 'w')
f.close()
cv2.imwrite(os.path.join(TMPDIR, 'flip.jpg'), data_numpy)
cv2.imwrite(os.path.join(TMPDIR, 'input.jpg'), input)
for i in range(NUM_JOINTS):
cv2.imwrite(os.path.join(TMPDIR, 'target_{}.jpg'.format(i)), cv2.applyColorMap(
np.uint8(np.expand_dims(target[i], 2)*255.), cv2.COLORMAP_JET))
cv2.circle(input, (int(joints[i, 0]), int(joints[i, 1])), 5, [170, 255, 0], -1)
cv2.imwrite(os.path.join(TMPDIR, 'input_kps.jpg'), input)
def generate_target(cfg, joints, joints_vis):
"""
:param joints: [num_joints, 3]
:param joints_vis: [num_joints, 3]
:return: target, target_weight(1: visible, 0: invisible)
"""
NUM_JOINTS = cfg.NUM_JOINTS
TARGET_TYPE = cfg.TARGET_TYPE
HEATMAP_SIZE = cfg.HEATMAP_SIZE
IMAGE_SIZE = cfg.IMAGE_SIZE
SIGMA = cfg.SIGMA
target_weight = np.ones((NUM_JOINTS, 1), dtype=np.float32)
target_weight[:, 0] = joints_vis[:, 0]
assert TARGET_TYPE == 'gaussian', \
'Only support gaussian map now!'
if TARGET_TYPE == 'gaussian':
target = np.zeros((NUM_JOINTS,
HEATMAP_SIZE[1],
HEATMAP_SIZE[0]),
dtype=np.float32)
tmp_size = SIGMA * 3
for joint_id in range(NUM_JOINTS):
feat_stride = np.array(IMAGE_SIZE) / np.array(HEATMAP_SIZE)
mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
# Check that any part of the gaussian is in-bounds
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
if ul[0] >= HEATMAP_SIZE[0] or ul[1] >= HEATMAP_SIZE[1] \
or br[0] < 0 or br[1] < 0:
# If not, just return the image as is
target_weight[joint_id] = 0
continue
# Generate gaussian
size = 2 * tmp_size + 1
x = np.arange(0, size, 1, np.float32)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * SIGMA ** 2))
# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], HEATMAP_SIZE[0]) - ul[0]
g_y = max(0, -ul[1]), min(br[1], HEATMAP_SIZE[1]) - ul[1]
# Image range
img_x = max(0, ul[0]), min(br[0], HEATMAP_SIZE[0])
img_y = max(0, ul[1]), min(br[1], HEATMAP_SIZE[1])
v = target_weight[joint_id]
if v > 0.5:
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
return target, target_weight
# Copyright (c) 2018-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Data reader for COCO dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import functools
import numpy as np
import cv2
import random
from utils.transforms import fliplr_joints
from utils.transforms import get_affine_transform
from utils.transforms import affine_transform
from lib.base_reader import visualize, generate_target
from pycocotools.coco import COCO
# NOTE
# -- COCO Datatset --
# "keypoints":
# {
# 0: "nose",
# 1: "left_eye",
# 2: "right_eye",
# 3: "left_ear",
# 4: "right_ear",
# 5: "left_shoulder",
# 6: "right_shoulder",
# 7: "left_elbow",
# 8: "right_elbow",
# 9: "left_wrist",
# 10: "right_wrist",
# 11: "left_hip",
# 12: "right_hip",
# 13: "left_knee",
# 14: "right_knee",
# 15: "left_ankle",
# 16: "right_ankle"
# },
#
# "skeleton":
# [
# [16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13], [6,7],[6,8],
# [7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]
# ]
class Config:
"""Configurations for COCO dataset.
"""
DEBUG = False
TMPDIR = 'tmp_fold_for_debug'
# For reader
BUF_SIZE = 102400
THREAD = 1 if DEBUG else 8 # have to be larger than 0
# Fixed infos of dataset
DATAROOT = 'data/coco'
IMAGEDIR = 'images'
NUM_JOINTS = 17
FLIP_PAIRS = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
PARENT_IDS = None
# CFGS
SCALE_FACTOR = 0.3
ROT_FACTOR = 40
FLIP = True
TARGET_TYPE = 'gaussian'
SIGMA = 3
IMAGE_SIZE = [288, 384]
HEATMAP_SIZE = [72, 96]
ASPECT_RATIO = IMAGE_SIZE[0] * 1.0 / IMAGE_SIZE[1]
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
PIXEL_STD = 200
cfg = Config()
def _box2cs(box):
x, y, w, h = box[:4]
return _xywh2cs(x, y, w, h)
def _xywh2cs(x, y, w, h):
center = np.zeros((2), dtype=np.float32)
center[0] = x + w * 0.5
center[1] = y + h * 0.5
if w > cfg.ASPECT_RATIO * h:
h = w * 1.0 / cfg.ASPECT_RATIO
elif w < cfg.ASPECT_RATIO * h:
w = h * cfg.ASPECT_RATIO
scale = np.array(
[w * 1.0 / cfg.PIXEL_STD, h * 1.0 / cfg.PIXEL_STD],
dtype=np.float32)
if center[0] != -1:
scale = scale * 1.25
return center, scale
def _select_data(db):
db_selected = []
for rec in db:
num_vis = 0
joints_x = 0.0
joints_y = 0.0
for joint, joint_vis in zip(
rec['joints_3d'], rec['joints_3d_vis']):
if joint_vis[0] <= 0:
continue
num_vis += 1
joints_x += joint[0]
joints_y += joint[1]
if num_vis == 0:
continue
joints_x, joints_y = joints_x / num_vis, joints_y / num_vis
area = rec['scale'][0] * rec['scale'][1] * (cfg.PIXEL_STD**2)
joints_center = np.array([joints_x, joints_y])
bbox_center = np.array(rec['center'])
diff_norm2 = np.linalg.norm((joints_center-bbox_center), 2)
ks = np.exp(-1.0*(diff_norm2**2) / ((0.2)**2*2.0*area))
metric = (0.2 / 16) * num_vis + 0.45 - 0.2 / 16
if ks > metric:
db_selected.append(rec)
print('=> num db: {}'.format(len(db)))
print('=> num selected db: {}'.format(len(db_selected)))
return db_selected
def _load_coco_keypoint_annotation(image_set_index, coco, _coco_ind_to_class_ind, image_set):
"""Ground truth bbox and keypoints.
"""
print('generating coco gt_db...')
gt_db = []
for index in image_set_index:
im_ann = coco.loadImgs(index)[0]
width = im_ann['width']
height = im_ann['height']
annIds = coco.getAnnIds(imgIds=index, iscrowd=False)
objs = coco.loadAnns(annIds)
# Sanitize bboxes
valid_objs = []
for obj in objs:
x, y, w, h = obj['bbox']
x1 = np.max((0, x))
y1 = np.max((0, y))
x2 = np.min((width - 1, x1 + np.max((0, w - 1))))
y2 = np.min((height - 1, y1 + np.max((0, h - 1))))
if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
obj['clean_bbox'] = [x1, y1, x2-x1, y2-y1]
valid_objs.append(obj)
objs = valid_objs
rec = []
for obj in objs:
cls = _coco_ind_to_class_ind[obj['category_id']]
if cls != 1:
continue
# Ignore objs without keypoints annotation
if max(obj['keypoints']) == 0:
continue
joints_3d = np.zeros((cfg.NUM_JOINTS, 3), dtype=np.float)
joints_3d_vis = np.zeros((cfg.NUM_JOINTS, 3), dtype=np.float)
for ipt in range(cfg.NUM_JOINTS):
joints_3d[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
joints_3d[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
joints_3d[ipt, 2] = 0
t_vis = obj['keypoints'][ipt * 3 + 2]
if t_vis > 1:
t_vis = 1
joints_3d_vis[ipt, 0] = t_vis
joints_3d_vis[ipt, 1] = t_vis
joints_3d_vis[ipt, 2] = 0
center, scale = _box2cs(obj['clean_bbox'][:4])
rec.append({
'image': os.path.join(cfg.DATAROOT, cfg.IMAGEDIR, image_set+'2017', '%012d.jpg' % index),
'center': center,
'scale': scale,
'joints_3d': joints_3d,
'joints_3d_vis': joints_3d_vis,
'filename': '%012d.jpg' % index,
'imgnum': 0,
})
gt_db.extend(rec)
return gt_db
def data_augmentation(sample, is_train):
image_file = sample['image']
filename = sample['filename'] if 'filename' in sample else ''
joints = sample['joints_3d']
joints_vis = sample['joints_3d_vis']
c = sample['center']
s = sample['scale']
# score = sample['score'] if 'score' in sample else 1
# imgnum = sample['imgnum'] if 'imgnum' in sample else ''
r = 0
data_numpy = cv2.imread(
image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
if is_train:
sf = cfg.SCALE_FACTOR
rf = cfg.ROT_FACTOR
s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
r = np.clip(np.random.randn()*rf, -rf*2, rf*2) \
if random.random() <= 0.6 else 0
if cfg.FLIP and random.random() <= 0.5:
data_numpy = data_numpy[:, ::-1, :]
joints, joints_vis = fliplr_joints(
joints, joints_vis, data_numpy.shape[1], cfg.FLIP_PAIRS)
c[0] = data_numpy.shape[1] - c[0] - 1
trans = get_affine_transform(c, s, r, cfg.IMAGE_SIZE)
input = cv2.warpAffine(
data_numpy,
trans,
(int(cfg.IMAGE_SIZE[0]), int(cfg.IMAGE_SIZE[1])),
flags=cv2.INTER_LINEAR)
for i in range(cfg.NUM_JOINTS):
if joints_vis[i, 0] > 0.0:
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
# Numpy target
target, target_weight = generate_target(cfg, joints, joints_vis)
if cfg.DEBUG:
visualize(cfg, filename, data_numpy, input.copy(), joints, target)
# Normalization
input = input.astype('float32').transpose((2, 0, 1)) / 255
input -= np.array(cfg.MEAN).reshape((3, 1, 1))
input /= np.array(cfg.STD).reshape((3, 1, 1))
if is_train:
return input, target, target_weight
else:
return input, target, target_weight, c, s
# Create a reader
def _reader_creator(root, image_set, shuffle=False, is_train=False, use_gt_bbox=False):
def reader():
if image_set in ['train', 'val']:
file_name = os.path.join(root, 'annotations', 'person_keypoints_'+image_set+'2017.json')
elif image_set in ['test', 'test-dev']:
file_name = os.path.join(root, 'annotations', 'image_info_'+image_set+'2017.json')
else:
raise ValueError("The dataset '{}' is not supported".format(image_set))
# Load annotations
coco = COCO(file_name)
# Deal with class names
cats = [cat['name']
for cat in coco.loadCats(coco.getCatIds())]
classes = ['__background__'] + cats
print('=> classes: {}'.format(classes))
num_classes = len(classes)
_class_to_ind = dict(zip(classes, range(num_classes)))
_class_to_coco_ind = dict(zip(cats, coco.getCatIds()))
_coco_ind_to_class_ind = dict([(_class_to_coco_ind[cls],
_class_to_ind[cls])
for cls in classes[1:]])
# Load image file names
image_set_index = coco.getImgIds()
num_images = len(image_set_index)
print('=> num_images: {}'.format(num_images))
if is_train or use_gt_bbox:
gt_db = _load_coco_keypoint_annotation(
image_set_index, coco, _coco_ind_to_class_ind, image_set)
gt_db = _select_data(gt_db)
if shuffle:
random.shuffle(gt_db)
for db in gt_db:
yield db
mapper = functools.partial(data_augmentation, is_train=is_train)
return reader, mapper
def train():
reader, mapper = _reader_creator(cfg.DATAROOT, 'train', shuffle=True, is_train=True)
def pop():
for i, x in enumerate(reader()):
yield mapper(x)
return pop
def valid():
reader, mapper = _reader_creator(cfg.DATAROOT, 'val', shuffle=False, is_train=False)
def pop():
for i, x in enumerate(reader()):
yield mapper(x)
return pop
# Copyright (c) 2018-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Data reader for MPII."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import functools
import json
import numpy as np
import cv2
from utils.transforms import fliplr_joints
from utils.transforms import get_affine_transform
from utils.transforms import affine_transform
from lib.base_reader import visualize, generate_target
class Config:
"""Configurations for MPII dataset.
"""
DEBUG = False
TMPDIR = 'tmp_fold_for_debug'
# For reader
BUF_SIZE = 102400
THREAD = 1 if DEBUG else 8 # have to be larger than 0
# Fixed infos of dataset
DATAROOT = 'data/mpii'
IMAGEDIR = 'images'
NUM_JOINTS = 16
FLIP_PAIRS = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]]
PARENT_IDS = [1, 2, 6, 6, 3, 4, 6, 6, 7, 8, 11, 12, 7, 7, 13, 14]
# CFGS
SCALE_FACTOR = 0.3
ROT_FACTOR = 40
FLIP = True
TARGET_TYPE = 'gaussian'
SIGMA = 3
IMAGE_SIZE = [384, 384]
HEATMAP_SIZE = [96, 96]
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
cfg = Config()
def data_augmentation(sample, is_train):
image_file = sample['image']
filename = sample['filename'] if 'filename' in sample else ''
joints = sample['joints_3d']
joints_vis = sample['joints_3d_vis']
c = sample['center']
s = sample['scale']
score = sample['score'] if 'score' in sample else 1
# imgnum = sample['imgnum'] if 'imgnum' in sample else ''
r = 0
data_numpy = cv2.imread(
image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
if is_train:
sf = cfg.SCALE_FACTOR
rf = cfg.ROT_FACTOR
s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
r = np.clip(np.random.randn()*rf, -rf*2, rf*2) \
if random.random() <= 0.6 else 0
if cfg.FLIP and random.random() <= 0.5:
data_numpy = data_numpy[:, ::-1, :]
joints, joints_vis = fliplr_joints(
joints, joints_vis, data_numpy.shape[1], cfg.FLIP_PAIRS)
c[0] = data_numpy.shape[1] - c[0] - 1
trans = get_affine_transform(c, s, r, cfg.IMAGE_SIZE)
input = cv2.warpAffine(
data_numpy,
trans,
(int(cfg.IMAGE_SIZE[0]), int(cfg.IMAGE_SIZE[1])),
flags=cv2.INTER_LINEAR)
for i in range(cfg.NUM_JOINTS):
if joints_vis[i, 0] > 0.0:
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
# Numpy target
target, target_weight = generate_target(cfg, joints, joints_vis)
if cfg.DEBUG:
visualize(cfg, filename, data_numpy, input.copy(), joints, target)
# Normalization
input = input.astype('float32').transpose((2, 0, 1)) / 255
input -= np.array(cfg.MEAN).reshape((3, 1, 1))
input /= np.array(cfg.STD).reshape((3, 1, 1))
if is_train:
return input, target, target_weight
else:
return input, target, target_weight, c, s, score
def test_data_augmentation(sample):
image_file = sample['image']
filename = sample['filename'] if 'filename' in sample else ''
file_id = int(filename.split('.')[0].split('_')[1])
input = cv2.imread(
image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
input = cv2.resize(input, (int(cfg.IMAGE_SIZE[0]), int(cfg.IMAGE_SIZE[1])))
# Normalization
input = input.astype('float32').transpose((2, 0, 1)) / 255
input -= np.array(cfg.MEAN).reshape((3, 1, 1))
input /= np.array(cfg.STD).reshape((3, 1, 1))
return input, file_id
# Create a reader
def _reader_creator(root, image_set, shuffle=False, is_train=False):
def reader():
if image_set != 'test':
file_name = os.path.join(root, 'annot', image_set+'.json')
with open(file_name) as anno_file:
anno = json.load(anno_file)
print('=> load {} samples of {} dataset'.format(len(anno), image_set))
if shuffle:
random.shuffle(anno)
for a in anno:
image_name = a['image']
c = np.array(a['center'], dtype=np.float)
s = np.array([a['scale'], a['scale']], dtype=np.float)
# Adjust center/scale slightly to avoid cropping limbs
if c[0] != -1:
c[1] = c[1] + 15 * s[1]
s = s * 1.25
# MPII uses matlab format, index is based 1,
# we should first convert to 0-based index
c = c - 1
joints_3d = np.zeros((cfg.NUM_JOINTS, 3), dtype=np.float)
joints_3d_vis = np.zeros((cfg.NUM_JOINTS, 3), dtype=np.float)
joints = np.array(a['joints'])
joints[:, 0:2] = joints[:, 0:2] - 1
joints_vis = np.array(a['joints_vis'])
assert len(joints) == cfg.NUM_JOINTS, \
'joint num diff: {} vs {}'.format(len(joints), cfg.NUM_JOINTS)
joints_3d[:, 0:2] = joints[:, 0:2]
joints_3d_vis[:, 0] = joints_vis[:]
joints_3d_vis[:, 1] = joints_vis[:]
yield dict(
image = os.path.join(cfg.DATAROOT, cfg.IMAGEDIR, image_name),
center = c,
scale = s,
joints_3d = joints_3d,
joints_3d_vis = joints_3d_vis,
filename = image_name,
test_mode = False,
imagenum = 0)
else:
fold = 'test'
for img_name in os.listdir(fold):
yield dict(image = os.path.join(fold, img_name),
filename = img_name)
if not image_set == 'test':
mapper = functools.partial(data_augmentation, is_train=is_train)
else:
mapper = functools.partial(test_data_augmentation)
return reader, mapper
def train():
reader, mapper = _reader_creator(cfg.DATAROOT, 'train', shuffle=True, is_train=True)
def pop():
for i, x in enumerate(reader()):
yield mapper(x)
return pop
def valid():
reader, mapper = _reader_creator(cfg.DATAROOT, 'valid', shuffle=False, is_train=False)
def pop():
for i, x in enumerate(reader()):
yield mapper(x)
return pop
def test():
reader, mapper = _reader_creator(cfg.DATAROOT, 'test')
def pop():
for i, x in enumerate(reader()):
yield mapper(x)
return pop
# Copyright (c) 2018-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Functions for building network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
__all__ = ["ResNet", "ResNet50", "ResNet101", "ResNet152"]
# Global parameters
BN_MOMENTUM = 0.1
class ResNet():
def __init__(self, layers=50, kps_num=16, test_mode=False):
"""
:param layers: int, the layers number which is used here
:param kps_num: int, the number of keypoints in accord with the dataset
:param test_mode: bool, if True, only return output heatmaps, no loss
:return: loss, output heatmaps
"""
self.k = kps_num
self.layers = layers
self.test_mode = test_mode
def net(self, input, target=None, target_weight=None):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1)
conv = fluid.layers.conv2d_transpose(
input=conv, num_filters=256,
filter_size=4,
padding=1,
stride=2,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Normal(0., 0.001)),
act=None,
bias_attr=False)
conv = fluid.layers.batch_norm(input=conv, act='relu', momentum=BN_MOMENTUM)
conv = fluid.layers.conv2d_transpose(
input=conv, num_filters=256,
filter_size=4,
padding=1,
stride=2,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Normal(0., 0.001)),
act=None,
bias_attr=False)
conv = fluid.layers.batch_norm(input=conv, act='relu', momentum=BN_MOMENTUM)
conv = fluid.layers.conv2d_transpose(
input=conv, num_filters=256,
filter_size=4,
padding=1,
stride=2,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Normal(0., 0.001)),
act=None,
bias_attr=False)
conv = fluid.layers.batch_norm(input=conv, act='relu', momentum=BN_MOMENTUM)
out = fluid.layers.conv2d(
input=conv,
num_filters=self.k,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Normal(0., 0.001)))
if self.test_mode:
return out
else:
loss = self.calc_loss(out, target, target_weight)
return loss, out
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Normal(0., 0.001)),
act=None,
bias_attr=False)
return fluid.layers.batch_norm(input=conv, act=act, momentum=BN_MOMENTUM)
def shortcut(self, input, ch_out, stride):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
return self.conv_bn_layer(input, ch_out, 1, stride)
else:
return input
def calc_loss(self, heatmap, target, target_weight):
_, c, h, w = heatmap.shape
x = fluid.layers.reshape(heatmap, (-1, self.k, h*w))
y = fluid.layers.reshape(target, (-1, self.k, h*w))
w = fluid.layers.reshape(target_weight, (-1, self.k))
x = fluid.layers.split(x, num_or_sections=self.k, dim=1)
y = fluid.layers.split(y, num_or_sections=self.k, dim=1)
w = fluid.layers.split(w, num_or_sections=self.k, dim=1)
_list = []
for idx in range(self.k):
_tmp = fluid.layers.scale(x=x[idx] - y[idx], scale=1.)
_tmp = _tmp * _tmp
_tmp = fluid.layers.reduce_mean(_tmp, dim=2)
_list.append(_tmp * w[idx])
_loss = fluid.layers.concat(_list, axis=0)
_loss = fluid.layers.reduce_mean(_loss)
return 0.5 * _loss
def bottleneck_block(self, input, num_filters, stride):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu')
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None)
short = self.shortcut(input, num_filters * 4, stride)
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def ResNet50():
model = ResNet(layers=50)
return model
def ResNet101():
model = ResNet(layers=101)
return model
def ResNet152():
model = ResNet(layers=152)
return model
# Copyright (c) 2018-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Functions for inference."""
import os
import argparse
import functools
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from tqdm import tqdm
from lib import pose_resnet
from utils.transforms import flip_back
from utils.utility import *
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('dataset', str, 'mpii', "Dataset")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('num_epochs', int, 140, "Number of epochs.")
add_arg('total_images', int, 144406, "Training image number.")
add_arg('kp_dim', int, 16, "Class number.")
add_arg('model_save_dir', str, "output", "Model save directory")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
add_arg('lr', float, 0.001, "Set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('flip_test', bool, True, "Flip test")
add_arg('shift_heatmap', bool, True, "Shift heatmap")
add_arg('post_process', bool, False, "post process")
# yapf: enable
FLIP_PAIRS = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]]
def test(args):
if args.dataset == 'coco':
import lib.coco_reader as reader
IMAGE_SIZE = [288, 384]
# HEATMAP_SIZE = [72, 96]
FLIP_PAIRS = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
args.kp_dim = 17
args.total_images = 144406 # 149813
elif args.dataset == 'mpii':
import lib.mpii_reader as reader
IMAGE_SIZE = [384, 384]
# HEATMAP_SIZE = [96, 96]
FLIP_PAIRS = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]]
args.kp_dim = 16
args.total_images = 2958 # validation
else:
raise ValueError('The dataset {} is not supported yet.'.format(args.dataset))
print_arguments(args)
# Image and target
image = layers.data(name='image', shape=[3, IMAGE_SIZE[1], IMAGE_SIZE[0]], dtype='float32')
file_id = layers.data(name='file_id', shape=[1,], dtype='int')
# Build model
model = pose_resnet.ResNet(layers=50, kps_num=args.kp_dim, test_mode=True)
# Output
output = model.net(input=image, target=None, target_weight=None)
# Parameters from model and arguments
params = {}
params["total_images"] = args.total_images
params["lr"] = args.lr
params["num_epochs"] = args.num_epochs
params["learning_strategy"] = {}
params["learning_strategy"]["batch_size"] = args.batch_size
params["learning_strategy"]["name"] = args.lr_strategy
if args.with_mem_opt:
fluid.memory_optimize(fluid.default_main_program(),
skip_opt_set=[output.name])
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
args.pretrained_model = './pretrained/resnet_50/115'
if args.pretrained_model:
def if_exist(var):
exist_flag = os.path.exists(os.path.join(args.pretrained_model, var.name))
return exist_flag
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
if args.checkpoint is not None:
fluid.io.load_persistables(exe, args.checkpoint)
# Dataloader
test_reader = paddle.batch(reader.test(), batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, file_id])
test_exe = fluid.ParallelExecutor(
use_cuda=True if args.use_gpu else False,
main_program=fluid.default_main_program().clone(for_test=False),
loss_name=None)
fetch_list = [image.name, output.name]
for batch_id, data in tqdm(enumerate(test_reader())):
num_images = len(data)
file_ids = []
for i in range(num_images):
file_ids.append(data[i][1])
input_image, out_heatmaps = test_exe.run(
fetch_list=fetch_list,
feed=feeder.feed(data))
if args.flip_test:
# Flip all the images in a same batch
data_fliped = []
for i in range(num_images):
data_fliped.append((
data[i][0][:, :, ::-1],
data[i][1]))
# Inference again
_, output_flipped = test_exe.run(
fetch_list=fetch_list,
feed=feeder.feed(data_fliped))
# Flip back
output_flipped = flip_back(output_flipped, FLIP_PAIRS)
# Feature is not aligned, shift flipped heatmap for higher accuracy
if args.shift_heatmap:
output_flipped[:, :, :, 1:] = \
output_flipped.copy()[:, :, :, 0:-1]
# Aggregate
out_heatmaps = (out_heatmaps + output_flipped) * 0.5
save_predict_results(input_image, out_heatmaps, file_ids, fold_name='results')
if __name__ == '__main__':
args = parser.parse_args()
test(args)
# Copyright (c) 2018-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Functions for training."""
import os
import numpy as np
import cv2
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import argparse
import functools
from lib import pose_resnet
from utils.utility import *
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('dataset', str, 'mpii', "Dataset")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('num_epochs', int, 140, "Number of epochs.")
add_arg('total_images', int, 144406, "Training image number.")
add_arg('kp_dim', int, 16, "Class number.")
add_arg('model_save_dir', str, "output", "Model save directory")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
add_arg('lr', float, 0.001, "Set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
# yapf: enable
def optimizer_setting(args, params):
lr_drop_ratio = 0.1
ls = params["learning_strategy"]
if ls["name"] == "piecewise_decay":
total_images = params["total_images"]
batch_size = ls["batch_size"]
step = int(total_images / batch_size + 1)
ls['epochs'] = [91, 121]
print('=> LR will be dropped at the epoch of {}'.format(ls['epochs']))
bd = [step * e for e in ls["epochs"]]
base_lr = params["lr"]
lr = []
lr = [base_lr * (lr_drop_ratio**i) for i in range(len(bd) + 1)]
# AdamOptimizer
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr))
else:
lr = params["lr"]
optimizer = fluid.optimizer.Momentum(
learning_rate=lr,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(0.0005))
return optimizer
def train(args):
if args.dataset == 'coco':
import lib.coco_reader as reader
IMAGE_SIZE = [288, 384]
HEATMAP_SIZE = [72, 96]
args.kp_dim = 17
args.total_images = 144406 # 149813
elif args.dataset == 'mpii':
import lib.mpii_reader as reader
IMAGE_SIZE = [384, 384]
HEATMAP_SIZE = [96, 96]
args.kp_dim = 16
args.total_images = 22246
else:
raise ValueError('The dataset {} is not supported yet.'.format(args.dataset))
print_arguments(args)
# Image and target
image = layers.data(name='image', shape=[3, IMAGE_SIZE[1], IMAGE_SIZE[0]], dtype='float32')
target = layers.data(name='target', shape=[args.kp_dim, HEATMAP_SIZE[1], HEATMAP_SIZE[0]], dtype='float32')
target_weight = layers.data(name='target_weight', shape=[args.kp_dim, 1], dtype='float32')
# Build model
model = pose_resnet.ResNet(layers=50, kps_num=args.kp_dim)
# Output
loss, output = model.net(input=image, target=target, target_weight=target_weight)
# Parameters from model and arguments
params = {}
params["total_images"] = args.total_images
params["lr"] = args.lr
params["num_epochs"] = args.num_epochs
params["learning_strategy"] = {}
params["learning_strategy"]["batch_size"] = args.batch_size
params["learning_strategy"]["name"] = args.lr_strategy
# Initialize optimizer
optimizer = optimizer_setting(args, params)
optimizer.minimize(loss)
if args.with_mem_opt:
fluid.memory_optimize(fluid.default_main_program(),
skip_opt_set=[loss.name, output.name, target.name])
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
args.pretrained_model = './pretrained/resnet_50/115'
if args.pretrained_model:
def if_exist(var):
exist_flag = os.path.exists(os.path.join(args.pretrained_model, var.name))
return exist_flag
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
if args.checkpoint is not None:
fluid.io.load_persistables(exe, args.checkpoint)
# Dataloader
train_reader = paddle.batch(reader.train(), batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, target, target_weight])
train_exe = fluid.ParallelExecutor(
use_cuda=True if args.use_gpu else False, loss_name=loss.name)
fetch_list = [image.name, loss.name, output.name]
for pass_id in range(params["num_epochs"]):
for batch_id, data in enumerate(train_reader()):
current_lr = np.array(paddle.fluid.global_scope().find_var('learning_rate').get_tensor())
input_image, loss, out_heatmaps = train_exe.run(
fetch_list, feed=feeder.feed(data))
loss = np.mean(np.array(loss))
print('Epoch [{:4d}/{:3d}] LR: {:.10f} '
'Loss = {:.5f}'.format(
batch_id, pass_id, current_lr[0], loss))
if batch_id % 10 == 0:
save_batch_heatmaps(input_image, out_heatmaps, file_name='visualization@train.jpg', normalize=True)
model_path = os.path.join(args.model_save_dir + '/' + 'simplebase-{}'.format(args.dataset),
str(pass_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path)
if __name__ == '__main__':
args = parser.parse_args()
train(args)
# Copyright (c) 2018-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
#
# Based on
# ------------------------------------------------------------------------------
# https://github.com/Microsoft/human-pose-estimation.pytorch
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# ------------------------------------------------------------------------------
"""Transforms functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import cv2
def flip_back(output_flipped, matched_parts):
"""
:param ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
"""
assert output_flipped.ndim == 4,\
'output_flipped should be [batch_size, num_joints, height, width]'
output_flipped = output_flipped[:, :, :, ::-1]
for pair in matched_parts:
tmp = output_flipped[:, pair[0], :, :].copy()
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
output_flipped[:, pair[1], :, :] = tmp
return output_flipped
def fliplr_joints(joints, joints_vis, width, matched_parts):
"""Flip coords.
"""
# Flip horizontal
joints[:, 0] = width - joints[:, 0] - 1
# Change left-right parts
for pair in matched_parts:
joints[pair[0], :], joints[pair[1], :] = \
joints[pair[1], :], joints[pair[0], :].copy()
joints_vis[pair[0], :], joints_vis[pair[1], :] = \
joints_vis[pair[1], :], joints_vis[pair[0], :].copy()
return joints*joints_vis, joints_vis
def transform_preds(coords, center, scale, output_size):
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
return target_coords
def get_affine_transform(center,
scale,
rot,
output_size,
shift=np.array([0, 0], dtype=np.float32),
inv=0):
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
print(scale)
scale = np.array([scale, scale])
scale_tmp = scale * 200.0
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, dst_w * -0.5], np.float32)
src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]
def get_3rd_point(a, b):
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
def get_dir(src_point, rot_rad):
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
src_result = [0, 0]
src_result[0] = src_point[0] * cs - src_point[1] * sn
src_result[1] = src_point[0] * sn + src_point[1] * cs
return src_result
def crop(img, center, scale, output_size, rot=0):
trans = get_affine_transform(center, scale, rot, output_size)
dst_img = cv2.warpAffine(img,
trans,
(int(output_size[0]), int(output_size[1])),
flags=cv2.INTER_LINEAR)
return dst_img
# Copyright (c) 2018-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import distutils.util
import numpy as np
import cv2
from pathlib import Path
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(vars(args).iteritems()):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def get_max_preds(batch_heatmaps):
"""Get predictions from score maps.
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
"""
assert isinstance(batch_heatmaps, np.ndarray), \
'batch_heatmaps should be numpy.ndarray'
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
batch_size = batch_heatmaps.shape[0]
num_joints = batch_heatmaps.shape[1]
width = batch_heatmaps.shape[3]
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
idx = np.argmax(heatmaps_reshaped, 2)
maxvals = np.amax(heatmaps_reshaped, 2)
maxvals = maxvals.reshape((batch_size, num_joints, 1))
idx = idx.reshape((batch_size, num_joints, 1))
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
preds[:, :, 0] = (preds[:, :, 0]) % width
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
pred_mask = pred_mask.astype(np.float32)
preds *= pred_mask
return preds, maxvals
def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]
def get_3rd_point(a, b):
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
def get_dir(src_point, rot_rad):
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
src_result = [0, 0]
src_result[0] = src_point[0] * cs - src_point[1] * sn
src_result[1] = src_point[0] * sn + src_point[1] * cs
return src_result
def crop(img, center, scale, output_size, rot=0):
trans = get_affine_transform(center, scale, rot, output_size)
dst_img = cv2.warpAffine(img,
trans,
(int(output_size[0]), int(output_size[1])),
flags=cv2.INTER_LINEAR)
return dst_img
def get_affine_transform(center,
scale,
rot,
output_size,
shift=np.array([0, 0], dtype=np.float32),
inv=0):
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
print(scale)
scale = np.array([scale, scale])
scale_tmp = scale * 200.0
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, dst_w * -0.5], np.float32)
src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def transform_preds(coords, center, scale, output_size):
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
return target_coords
def get_final_preds(args, batch_heatmaps, center, scale):
coords, maxvals = get_max_preds(batch_heatmaps)
heatmap_height = batch_heatmaps.shape[2]
heatmap_width = batch_heatmaps.shape[3]
# Post-processing
if args.post_process:
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
hm = batch_heatmaps[n][p]
px = int(math.floor(coords[n][p][0] + 0.5))
py = int(math.floor(coords[n][p][1] + 0.5))
if 1 < px < heatmap_width-1 and 1 < py < heatmap_height-1:
diff = np.array([hm[py][px+1] - hm[py][px-1],
hm[py+1][px]-hm[py-1][px]])
coords[n][p] += np.sign(diff) * .25
preds = coords.copy()
# Transform back
for i in range(coords.shape[0]):
preds[i] = transform_preds(coords[i], center[i], scale[i],
[heatmap_width, heatmap_height])
return preds, maxvals
def calc_dists(preds, target, normalize):
preds = preds.astype(np.float32)
target = target.astype(np.float32)
dists = np.zeros((preds.shape[1], preds.shape[0]))
for n in range(preds.shape[0]):
for c in range(preds.shape[1]):
if target[n, c, 0] > 1 and target[n, c, 1] > 1:
normed_preds = preds[n, c, :] / normalize[n]
normed_targets = target[n, c, :] / normalize[n]
dists[c, n] = np.linalg.norm(normed_preds - normed_targets)
else:
dists[c, n] = -1
return dists
def dist_acc(dists, thr=0.5):
"""Return percentage below threshold while ignoring values with a -1.
"""
dist_cal = np.not_equal(dists, -1)
num_dist_cal = dist_cal.sum()
if num_dist_cal > 0:
return np.less(dists[dist_cal], thr).sum() * 1.0 / num_dist_cal
else:
return -1
def accuracy(output, target, hm_type='gaussian', thr=0.5):
"""
Calculate accuracy according to PCK,
but uses ground truth heatmap rather than x,y locations
First value to be returned is average accuracy across 'idxs',
followed by individual accuracies
"""
idx = list(range(output.shape[1]))
norm = 1.0
if hm_type == 'gaussian':
pred, _ = get_max_preds(output)
target, _ = get_max_preds(target)
h = output.shape[2]
w = output.shape[3]
norm = np.ones((pred.shape[0], 2)) * np.array([h, w]) / 10
dists = calc_dists(pred, target, norm)
acc = np.zeros((len(idx) + 1))
avg_acc = 0
cnt = 0
for i in range(len(idx)):
acc[i + 1] = dist_acc(dists[idx[i]])
if acc[i + 1] >= 0:
avg_acc = avg_acc + acc[i + 1]
cnt += 1
avg_acc = avg_acc / cnt if cnt != 0 else 0
if cnt != 0:
acc[0] = avg_acc
return acc, avg_acc, cnt, pred
def save_batch_heatmaps(batch_image, batch_heatmaps, file_name, normalize=True):
"""
:param batch_image: [batch_size, channel, height, width]
:param batch_heatmaps: ['batch_size, num_joints, height, width]
:param file_name: saved file name
"""
if normalize:
min = np.array(batch_image.min(), dtype=np.float)
max = np.array(batch_image.max(), dtype=np.float)
batch_image = np.add(batch_image, -min)
batch_image = np.divide(batch_image, max - min + 1e-5)
batch_size, num_joints, \
heatmap_height, heatmap_width = batch_heatmaps.shape
grid_image = np.zeros((batch_size*heatmap_height,
(num_joints+1)*heatmap_width,
3),
dtype=np.uint8)
preds, maxvals = get_max_preds(batch_heatmaps)
for i in range(batch_size):
image = batch_image[i] * 255
image = image.clip(0, 255).astype(np.uint8)
image = image.transpose(1, 2, 0)
heatmaps = batch_heatmaps[i] * 255
heatmaps = heatmaps.clip(0, 255).astype(np.uint8)
resized_image = cv2.resize(image,
(int(heatmap_width), int(heatmap_height)))
height_begin = heatmap_height * i
height_end = heatmap_height * (i + 1)
for j in range(num_joints):
cv2.circle(resized_image,
(int(preds[i][j][0]), int(preds[i][j][1])),
1, [0, 0, 255], 1)
heatmap = heatmaps[j, :, :]
colored_heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
masked_image = colored_heatmap*0.7 + resized_image*0.3
cv2.circle(masked_image,
(int(preds[i][j][0]), int(preds[i][j][1])),
1, [0, 0, 255], 1)
width_begin = heatmap_width * (j+1)
width_end = heatmap_width * (j+2)
grid_image[height_begin:height_end, width_begin:width_end, :] = \
masked_image
grid_image[height_begin:height_end, 0:heatmap_width, :] = resized_image
cv2.imwrite(file_name, grid_image)
def save_predict_results(batch_image, batch_heatmaps, file_ids, fold_name, normalize=True):
"""
:param batch_image: [batch_size, channel, height, width]
:param batch_heatmaps: ['batch_size, num_joints, height, width]
:param fold_name: saved files in this folder
"""
save_dir = Path('./{}'.format(fold_name))
try:
save_dir.mkdir()
except OSError:
pass
if normalize:
min = np.array(batch_image.min(), dtype=np.float)
max = np.array(batch_image.max(), dtype=np.float)
batch_image = np.add(batch_image, -min)
batch_image = np.divide(batch_image, max - min + 1e-5)
batch_size, num_joints, \
heatmap_height, heatmap_width = batch_heatmaps.shape
# (32, 16, 2), (32, 16, 1))
preds, maxvals = get_max_preds(batch_heatmaps)
# Blue
icolor = (255, 137, 0)
ocolor = (138, 255, 0)
for i in range(batch_size):
image = batch_image[i] * 255
image = image.clip(0, 255).astype(np.uint8)
image = image.transpose(1, 2, 0)
image = cv2.resize(image, (384, 384))
file_id = file_ids[i]
imgname = save_dir.joinpath('rendered_{}.png'.format(str(file_id).zfill(7)))
for j in range(num_joints):
x, y = preds[i][j]
cv2.circle(image, (int(x * 4), int(y * 4)), 3, icolor, -1, 16)
cv2.circle(image, (int(x * 4), int(y * 4)), 6, ocolor, 1, 16)
cv2.imwrite(str(imgname), image)
# Clean format output
def print_name_value(name_value, full_arch_name):
names = name_value.keys()
values = name_value.values()
num_values = len(name_value)
results = []
for value in values:
results.append('| {:.3f}'.format(value))
print(
'| Arch ' +
' '.join(['| {}'.format(name) for name in names]) +
' |'
)
print('|---' * (num_values+1) + '|')
print('| ' + 'SIMPLEBASE RESNET50 ' + ' '.join(results) + ' |')
class AverageMeter(object):
"""Computes and stores the average and current value.
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count if self.count != 0 else 0
# Copyright (c) 2018-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Functions for validation."""
import os
import argparse
import functools
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from collections import OrderedDict
from scipy.io import loadmat, savemat
from lib import pose_resnet
from utils.transforms import flip_back
from utils.utility import *
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('dataset', str, 'mpii', "Dataset")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('num_epochs', int, 140, "Number of epochs.")
add_arg('total_images', int, 144406, "Training image number.")
add_arg('kp_dim', int, 16, "Class number.")
add_arg('model_save_dir', str, "output", "Model save directory")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
add_arg('lr', float, 0.001, "Set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('flip_test', bool, True, "Flip test")
add_arg('shift_heatmap', bool, True, "Shift heatmap")
add_arg('post_process', bool, True, "Post process")
# yapf: enable
def valid(args):
if args.dataset == 'coco':
import lib.coco_reader as reader
IMAGE_SIZE = [288, 384]
HEATMAP_SIZE = [72, 96]
FLIP_PAIRS = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
args.kp_dim = 17
args.total_images = 144406 # 149813
elif args.dataset == 'mpii':
import lib.mpii_reader as reader
IMAGE_SIZE = [384, 384]
HEATMAP_SIZE = [96, 96]
FLIP_PAIRS = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]]
args.kp_dim = 16
args.total_images = 2958 # validation
else:
raise ValueError('The dataset {} is not supported yet.'.format(args.dataset))
print_arguments(args)
# Image and target
image = layers.data(name='image', shape=[3, IMAGE_SIZE[1], IMAGE_SIZE[0]], dtype='float32')
target = layers.data(name='target', shape=[args.kp_dim, HEATMAP_SIZE[1], HEATMAP_SIZE[0]], dtype='float32')
target_weight = layers.data(name='target_weight', shape=[args.kp_dim, 1], dtype='float32')
center = layers.data(name='center', shape=[2,], dtype='float32')
scale = layers.data(name='scale', shape=[2,], dtype='float32')
score = layers.data(name='score', shape=[1,], dtype='float32')
# Build model
model = pose_resnet.ResNet(layers=50, kps_num=args.kp_dim)
# Output
loss, output = model.net(input=image, target=target, target_weight=target_weight)
# Parameters from model and arguments
params = {}
params["total_images"] = args.total_images
params["lr"] = args.lr
params["num_epochs"] = args.num_epochs
params["learning_strategy"] = {}
params["learning_strategy"]["batch_size"] = args.batch_size
params["learning_strategy"]["name"] = args.lr_strategy
if args.with_mem_opt:
fluid.memory_optimize(fluid.default_main_program(),
skip_opt_set=[loss.name, output.name, target.name])
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
args.pretrained_model = './pretrained/resnet_50/115'
if args.pretrained_model:
def if_exist(var):
exist_flag = os.path.exists(os.path.join(args.pretrained_model, var.name))
return exist_flag
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
if args.checkpoint is not None:
fluid.io.load_persistables(exe, args.checkpoint)
# Dataloader
valid_reader = paddle.batch(reader.valid(), batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, target, target_weight, center, scale, score])
valid_exe = fluid.ParallelExecutor(
use_cuda=True if args.use_gpu else False,
main_program=fluid.default_main_program().clone(for_test=False),
loss_name=loss.name)
fetch_list = [image.name, loss.name, output.name, target.name]
# For validation
acc = AverageMeter()
idx = 0
num_samples = args.total_images
all_preds = np.zeros((num_samples, args.kp_dim, 3),
dtype=np.float32)
all_boxes = np.zeros((num_samples, 6))
for batch_id, data in enumerate(valid_reader()):
num_images = len(data)
centers = []
scales = []
scores = []
for i in range(num_images):
centers.append(data[i][3])
scales.append(data[i][4])
scores.append(data[i][5])
input_image, loss, out_heatmaps, target_heatmaps = valid_exe.run(
fetch_list=fetch_list,
feed=feeder.feed(data))
if args.flip_test:
# Flip all the images in a same batch
data_fliped = []
for i in range(num_images):
# Input, target, target_weight, c, s, score
data_fliped.append((
# np.flip(input_image, 3)[i],
data[i][0][:, :, ::-1],
data[i][1],
data[i][2],
data[i][3],
data[i][4],
data[i][5]))
# Inference again
_, _, output_flipped, _ = valid_exe.run(
fetch_list=fetch_list,
feed=feeder.feed(data_fliped))
# Flip back
output_flipped = flip_back(output_flipped, FLIP_PAIRS)
# Feature is not aligned, shift flipped heatmap for higher accuracy
if args.shift_heatmap:
output_flipped[:, :, :, 1:] = \
output_flipped.copy()[:, :, :, 0:-1]
# Aggregate
# out_heatmaps.shape: size[b, args.kp_dim, 96, 96]
out_heatmaps = (out_heatmaps + output_flipped) * 0.5
loss = np.mean(np.array(loss))
# Accuracy
_, avg_acc, cnt, pred = accuracy(out_heatmaps, target_heatmaps)
acc.update(avg_acc, cnt)
# Current center, scale, score
centers = np.array(centers)
scales = np.array(scales)
scores = np.array(scores)
preds, maxvals = get_final_preds(
args, out_heatmaps, centers, scales)
all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
all_preds[idx:idx + num_images, :, 2:3] = maxvals
# Double check this all_boxes parts
all_boxes[idx:idx + num_images, 0:2] = centers[:, 0:2]
all_boxes[idx:idx + num_images, 2:4] = scales[:, 0:2]
all_boxes[idx:idx + num_images, 4] = np.prod(scales*200, 1)
all_boxes[idx:idx + num_images, 5] = scores
# image_path.extend(meta['image'])
idx += num_images
print('Epoch [{:4d}] '
'Loss = {:.5f} '
'Acc = {:.5f}'.format(batch_id, loss, acc.avg))
if batch_id % 10 == 0:
save_batch_heatmaps(input_image, out_heatmaps, file_name='visualization@val.jpg', normalize=True)
# Evaluate
args.DATAROOT = 'data/mpii'
args.TEST_SET = 'valid'
output_dir = ''
filenames = []
imgnums = []
image_path = []
name_values, perf_indicator = mpii_evaluate(
args, all_preds, output_dir, all_boxes, image_path,
filenames, imgnums)
print_name_value(name_values, perf_indicator)
def mpii_evaluate(cfg, preds, output_dir, *args, **kwargs):
# Convert 0-based index to 1-based index
preds = preds[:, :, 0:2] + 1.0
if output_dir:
pred_file = os.path.join(output_dir, 'pred.mat')
savemat(pred_file, mdict={'preds': preds})
if 'test' in cfg.TEST_SET:
return {'Null': 0.0}, 0.0
SC_BIAS = 0.6
threshold = 0.5
gt_file = os.path.join(cfg.DATAROOT,
'annot',
'gt_{}.mat'.format(cfg.TEST_SET))
gt_dict = loadmat(gt_file)
dataset_joints = gt_dict['dataset_joints']
jnt_missing = gt_dict['jnt_missing']
pos_gt_src = gt_dict['pos_gt_src']
headboxes_src = gt_dict['headboxes_src']
pos_pred_src = np.transpose(preds, [1, 2, 0])
head = np.where(dataset_joints == 'head')[1][0]
lsho = np.where(dataset_joints == 'lsho')[1][0]
lelb = np.where(dataset_joints == 'lelb')[1][0]
lwri = np.where(dataset_joints == 'lwri')[1][0]
lhip = np.where(dataset_joints == 'lhip')[1][0]
lkne = np.where(dataset_joints == 'lkne')[1][0]
lank = np.where(dataset_joints == 'lank')[1][0]
rsho = np.where(dataset_joints == 'rsho')[1][0]
relb = np.where(dataset_joints == 'relb')[1][0]
rwri = np.where(dataset_joints == 'rwri')[1][0]
rkne = np.where(dataset_joints == 'rkne')[1][0]
rank = np.where(dataset_joints == 'rank')[1][0]
rhip = np.where(dataset_joints == 'rhip')[1][0]
jnt_visible = 1 - jnt_missing
uv_error = pos_pred_src - pos_gt_src
uv_err = np.linalg.norm(uv_error, axis=1)
headsizes = headboxes_src[1, :, :] - headboxes_src[0, :, :]
headsizes = np.linalg.norm(headsizes, axis=0)
headsizes *= SC_BIAS
scale = np.multiply(headsizes, np.ones((len(uv_err), 1)))
scaled_uv_err = np.divide(uv_err, scale)
scaled_uv_err = np.multiply(scaled_uv_err, jnt_visible)
jnt_count = np.sum(jnt_visible, axis=1)
less_than_threshold = np.multiply((scaled_uv_err <= threshold),
jnt_visible)
PCKh = np.divide(100.*np.sum(less_than_threshold, axis=1), jnt_count)
# Save
rng = np.arange(0, 0.5+0.01, 0.01)
pckAll = np.zeros((len(rng), cfg.kp_dim))
for r in range(len(rng)):
threshold = rng[r]
less_than_threshold = np.multiply(scaled_uv_err <= threshold,
jnt_visible)
pckAll[r, :] = np.divide(100.*np.sum(less_than_threshold, axis=1),
jnt_count)
PCKh = np.ma.array(PCKh, mask=False)
PCKh.mask[6:8] = True
jnt_count = np.ma.array(jnt_count, mask=False)
jnt_count.mask[6:8] = True
jnt_ratio = jnt_count / np.sum(jnt_count).astype(np.float64)
name_value = [
('Head', PCKh[head]),
('Shoulder', 0.5 * (PCKh[lsho] + PCKh[rsho])),
('Elbow', 0.5 * (PCKh[lelb] + PCKh[relb])),
('Wrist', 0.5 * (PCKh[lwri] + PCKh[rwri])),
('Hip', 0.5 * (PCKh[lhip] + PCKh[rhip])),
('Knee', 0.5 * (PCKh[lkne] + PCKh[rkne])),
('Ankle', 0.5 * (PCKh[lank] + PCKh[rank])),
('Mean', np.sum(PCKh * jnt_ratio)),
('Mean@0.1', np.sum(pckAll[11, :] * jnt_ratio))
]
name_value = OrderedDict(name_value)
return name_value, name_value['Mean']
# TODO: coco_evaluate()
if __name__ == '__main__':
args = parser.parse_args()
valid(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册