提交 8ecb5b10 编写于 作者: Z zlyi1225 提交者: qingqing01

updates for human pose estimation (#1912)

* fix bugs and remove redundant code
* minor changes and add README in Chinese
上级 63b904cb
......@@ -10,23 +10,28 @@ This is a simple demonstration of re-implementation in [PaddlePaddle.Fluid](http
## Requirements
- Python == 2.7
- PaddlePaddle >= 1.0
- PaddlePaddle >= 1.1.0
- opencv-python >= 3.3
- tqdm >= 4.25
## Environment
The code is developed and tested under 4 Tesla K40 GPUS cards on CentOS with installed CUDA-9.2/8.0 and cuDNN-7.1.
## Known Issues
- The model does not converge with large batch\_size (e.g. = 32) on Tesla P40 / V100 / P100 GPUS cards, because PaddlePaddle uses the batch normalization function of cuDNN. Changing batch\_size into 1 image on each card during training will ease this problem, but not sure the performance. The issue can be tracked at [here](https://github.com/PaddlePaddle/Paddle/issues/14580).
The code is developed and tested under 4 Tesla K40/P40 GPUS cards on CentOS with installed CUDA-9.2/8.0 and cuDNN-7.1.
## Results on MPII Val
| Arch | Head | Shoulder | Elbow | Wrist | Hip | Knee | Ankle | Mean | Mean@0.1| Models |
| ---- |:----:|:--------:|:-----:|:-----:|:---:|:----:|:-----:|:----:|:-------:|:------:|
| 383x384\_pose\_resnet\_50 in PyTorch | 96.658 | 95.754 | 89.790 | 84.614 | 88.523 | 84.666 | 79.287 | 89.066 | 38.046 | - |
| 383x384\_pose\_resnet\_50 in Fluid | 96.248 | 95.346 | 89.807 | 84.873 | 88.298 | 83.679 | 78.649 | 88.767 | 37.374 | [`link`](http://paddlemodels.bj.bcebos.com/pose/pose-resnet-50-384x384-mpii.tar.gz) |
| 256x256\_pose\_resnet\_50 in PyTorch | 96.351 | 95.329 | 88.989 | 83.176 | 88.420 | 83.960 | 79.594 | 88.532 | 33.911 | - |
| 256x256\_pose\_resnet\_50 in Fluid | 96.385 | 95.363 | 89.211 | 84.084 | 88.454 | 84.182 | 79.546 | 88.748 | 33.750 | [`link`](https://paddlemodels.bj.bcebos.com/pose/pose-resnet50-mpii-256x256.tar.gz) |
| 384x384\_pose\_resnet\_50 in PyTorch | 96.658 | 95.754 | 89.790 | 84.614 | 88.523 | 84.666 | 79.287 | 89.066 | 38.046 | - |
| 384x384\_pose\_resnet\_50 in Fluid | 96.862 | 95.635 | 90.046 | 85.557 | 88.818 | 84.948 | 78.484 | 89.235 | 38.093 | [`link`](https://paddlemodels.bj.bcebos.com/pose/pose-resnet50-mpii-384x384.tar.gz) |
## Results on COCO val2017 with detector having human AP of 56.4 on COCO val2017 dataset
| Arch | AP | Ap .5 | AP .75 | AP (M) | AP (L) | AR | AR .5 | AR .75 | AR (M) | AR (L) | Models |
| ---- |:--:|:-----:|:------:|:------:|:------:|:--:|:-----:|:------:|:------:|:------:|:------:|
| 256x192\_pose\_resnet\_50 in PyTorch | 0.704 | 0.886 | 0.783 | 0.671 | 0.772 | 0.763 | 0.929 | 0.834 | 0.721 | 0.824 | - |
| 256x192\_pose\_resnet\_50 in Fluid | 0.712 | 0.897 | 0.786 | 0.683 | 0.756 | 0.741 | 0.906 | 0.806 | 0.709 | 0.790 | [`link`](https://paddlemodels.bj.bcebos.com/pose/pose-resnet50-coco-256x192.tar.gz) |
| 384x288\_pose\_resnet\_50 in PyTorch | 0.722 | 0.893 | 0.789 | 0.681 | 0.797 | 0.776 | 0.932 | 0.838 | 0.728 | 0.846 | - |
| 384x288\_pose\_resnet\_50 in Fluid | 0.727 | 0.897 | 0.796 | 0.690 | 0.783 | 0.754 | 0.907 | 0.813 | 0.714 | 0.814 | [`link`](https://paddlemodels.bj.bcebos.com/pose/pose-resnet50-coco-384x288.tar.gz |
### Notes:
......@@ -77,16 +82,16 @@ python2 setup.py install --user
### Perform Validating
Downloading the checkpoints of Pose-ResNet-50 trained on MPII dataset from [here](http://paddlemodels.bj.bcebos.com/pose/pose-resnet-50-384x384-mpii.tar.gz). Extract it into the folder `checkpoints` under the directory root of this repo. Then run
Downloading the checkpoints of Pose-ResNet-50 trained on MPII dataset from [here](https://paddlemodels.bj.bcebos.com/pose/pose-resnet50-mpii-384x384.tar.gz). Extract it into the folder `checkpoints` under the directory root of this repo. Then run
```bash
python2 val.py --dataset 'mpii' --checkpoint 'checkpoints/pose-resnet-50-384x384-mpii'
python val.py --dataset 'mpii' --checkpoint 'checkpoints/pose-resnet50-mpii-384x384'
```
### Perform Training
```bash
python2 train.py --dataset 'mpii' # or coco
python train.py --dataset 'mpii' # or coco
```
**Note**: Configurations for training are aggregated in the `lib/mpii_reader.py` and `lib/coco_reader.py`.
......@@ -96,10 +101,10 @@ python2 train.py --dataset 'mpii' # or coco
Put the images into the folder `test` under the directory root of this repo. Then run
```bash
python2 test.py --checkpoint 'checkpoints/pose-resnet-50-384x384-mpii'
python test.py --checkpoint 'checkpoints/pose-resnet-50-384x384-mpii'
```
If there are multiple persons in images, detectors such as [Faster R-CNN](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/faster_rcnn), [SSD](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/object_detection) or others should be used first to crop them out. Because the simple baseline for human pose estimation is a top-down method.
If there are multiple persons in images, detectors such as [Faster R-CNN](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/rcnn), [SSD](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/object_detection) or others should be used first to crop them out. Because the simple baseline for human pose estimation is a top-down method.
## Reference
......
# 关键点检测(Simple Baselines for Human Pose Estimation)
## 介绍
本目录包含了对论文[Simple Baselines for Human Pose Estimation and Tracking](https://arxiv.org/abs/1804.06208) (ECCV'18)的复现.
![demo](demo.gif)
> **演示视频**: *Bruno Mars - That’s What I Like [官方视频]*.
## 环境依赖
本目录下的代码均在4卡Tesla K40/P40 GPU,CentOS系统,CUDA-9.2/8.0,cuDNN-7.1环境下测试运行无误
- Python == 2.7
- PaddlePaddle >= 1.1.0
- opencv-python >= 3.3
## MPII Val结果
| Arch | Head | Shoulder | Elbow | Wrist | Hip | Knee | Ankle | Mean | Mean@0.1| Models |
| ---- |:----:|:--------:|:-----:|:-----:|:---:|:----:|:-----:|:----:|:-------:|:------:|
| 256x256\_pose\_resnet\_50 in PyTorch | 96.351 | 95.329 | 88.989 | 83.176 | 88.420 | 83.960 | 79.594 | 88.532 | 33.911 | - |
| 256x256\_pose\_resnet\_50 in Fluid | 96.385 | 95.363 | 89.211 | 84.084 | 88.454 | 84.182 | 79.546 | 88.748 | 33.750 | [`link`](https://paddlemodels.bj.bcebos.com/pose/pose-resnet50-mpii-256x256.tar.gz) |
| 384x384\_pose\_resnet\_50 in PyTorch | 96.658 | 95.754 | 89.790 | 84.614 | 88.523 | 84.666 | 79.287 | 89.066 | 38.046 | - |
| 384x384\_pose\_resnet\_50 in Fluid | 96.862 | 95.635 | 90.046 | 85.557 | 88.818 | 84.948 | 78.484 | 89.235 | 38.093 | [`link`](https://paddlemodels.bj.bcebos.com/pose/pose-resnet50-mpii-384x384.tar.gz) |
## COCO val2017结果(使用的检测器在COCO val2017数据集上AP为56.4)
| Arch | AP | Ap .5 | AP .75 | AP (M) | AP (L) | AR | AR .5 | AR .75 | AR (M) | AR (L) | Models |
| ---- |:--:|:-----:|:------:|:------:|:------:|:--:|:-----:|:------:|:------:|:------:|:------:|
| 256x192\_pose\_resnet\_50 in PyTorch | 0.704 | 0.886 | 0.783 | 0.671 | 0.772 | 0.763 | 0.929 | 0.834 | 0.721 | 0.824 | - |
| 256x192\_pose\_resnet\_50 in Fluid | 0.712 | 0.897 | 0.786 | 0.683 | 0.756 | 0.741 | 0.906 | 0.806 | 0.709 | 0.790 | [`link`](https://paddlemodels.bj.bcebos.com/pose/pose-resnet50-coco-256x192.tar.gz) |
| 384x288\_pose\_resnet\_50 in PyTorch | 0.722 | 0.893 | 0.789 | 0.681 | 0.797 | 0.776 | 0.932 | 0.838 | 0.728 | 0.846 | - |
| 384x288\_pose\_resnet\_50 in Fluid | 0.727 | 0.897 | 0.796 | 0.690 | 0.783 | 0.754 | 0.907 | 0.813 | 0.714 | 0.814 | [`link`](https://paddlemodels.bj.bcebos.com/pose/pose-resnet50-coco-384x288.tar.gz |
### 说明
- 使用Flip test
- 对当前模型结果并没有进行调参选择,使用下面相关实验配置训练后,取最后一个epoch后的模型作为最终模型,即可得到上述实验结果
## 开始
### 数据准备和预训练模型
- 安照[提示](https://github.com/Microsoft/human-pose-estimation.pytorch#data-preparation)进行数据准备
- 下载预训练好的ResNet-50
```bash
wget http://paddle-imagenet-models.bj.bcebos.com/resnet_50_model.tar
```
下载完成后,将模型解压、放入到根目录下的'pretrained'文件夹中,默认文件路径树为:
```
${根目录}
`-- pretrained
`-- resnet_50
|-- 115
`-- data
`-- coco
|-- annotations
|-- images
`-- mpii
|-- annot
|-- images
```
### 安装 [COCOAPI](https://github.com/cocodataset/cocoapi)
```bash
# COCOAPI=/path/to/clone/cocoapi
git clone https://github.com/cocodataset/cocoapi.git $COCOAPI
cd $COCOAPI/PythonAPI
# if cython is not installed
pip install Cython
# Install into global site-packages
make install
# Alternatively, if you do not have permissions or prefer
# not to install the COCO API into global site-packages
python2 setup.py install --user
```
### 模型验证(COCO或MPII)
下载COCO/MPII预训练模型(见上表最后一列所附链接),保存到根目录下的'checkpoints'文件夹中,运行:
```bash
python val.py --dataset 'mpii' --checkpoint 'checkpoints/pose-resnet50-mpii-384x384'
```
### 模型训练
```bash
python train.py --dataset 'mpii' # or coco
```
**说明** 详细参数配置已保存到`lib/mpii_reader.py``lib/coco_reader.py`文件中,通过设置dataset来选择使用具体的参数配置
### 模型测试(任意图片,使用上述COCO或MPII预训练好的模型)
将测试图片放入根目录下的'test'文件夹中,执行
```bash
python test.py --checkpoint 'checkpoints/pose-resnet-50-384x384-mpii'
```
## 引用
- Simple Baselines for Human Pose Estimation and Tracking in PyTorch [`code`](https://github.com/Microsoft/human-pose-estimation.pytorch#data-preparation)
......@@ -216,7 +216,7 @@ def data_augmentation(sample, is_train):
joints_vis = sample['joints_3d_vis']
c = sample['center']
s = sample['scale']
# score = sample['score'] if 'score' in sample else 1
score = sample['score'] if 'score' in sample else 1
# imgnum = sample['imgnum'] if 'imgnum' in sample else ''
r = 0
......@@ -261,7 +261,7 @@ def data_augmentation(sample, is_train):
if is_train:
return input, target, target_weight
else:
return input, target, target_weight, c, s
return input, target, target_weight, c, s, score, image_file
# Create a reader
def _reader_creator(root, image_set, shuffle=False, is_train=False, use_gt_bbox=False):
......@@ -316,7 +316,14 @@ def train():
return pop
def valid():
reader, mapper = _reader_creator(cfg.DATAROOT, 'val', shuffle=False, is_train=False)
reader, mapper = _reader_creator(cfg.DATAROOT, 'val', shuffle=False, is_train=False, use_gt_bbox=True)
def pop():
for i, x in enumerate(reader()):
yield mapper(x)
return pop
def test():
reader, mapper = _reader_creator(cfg.DATAROOT, 'test', shuffle=False, is_train=False, use_gt_bbox=True)
def pop():
for i, x in enumerate(reader()):
yield mapper(x)
......
......@@ -119,7 +119,7 @@ def test_data_augmentation(sample):
image_file = sample['image']
filename = sample['filename'] if 'filename' in sample else ''
file_id = int(filename.split('.')[0].split('_')[1])
file_id = int(filename.split('.')[0])
input = cv2.imread(
image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
......@@ -174,19 +174,19 @@ def _reader_creator(root, image_set, shuffle=False, is_train=False):
joints_3d_vis[:, 1] = joints_vis[:]
yield dict(
image = os.path.join(cfg.DATAROOT, cfg.IMAGEDIR, image_name),
center = c,
scale = s,
joints_3d = joints_3d,
joints_3d_vis = joints_3d_vis,
filename = image_name,
test_mode = False,
imagenum = 0)
image=os.path.join(cfg.DATAROOT, cfg.IMAGEDIR, image_name),
center=c,
scale=s,
joints_3d=joints_3d,
joints_3d_vis=joints_3d_vis,
filename=image_name,
test_mode=False,
imagenum=0)
else:
fold = 'test'
for img_name in os.listdir(fold):
yield dict(image = os.path.join(fold, img_name),
filename = img_name)
yield dict(image=os.path.join(fold, img_name),
filename=img_name)
if not image_set == 'test':
mapper = functools.partial(data_augmentation, is_train=is_train)
......
......@@ -23,7 +23,7 @@ import paddle.fluid as fluid
__all__ = ["ResNet", "ResNet50", "ResNet101", "ResNet152"]
# Global parameters
BN_MOMENTUM = 0.1
BN_MOMENTUM = 0.9
class ResNet():
def __init__(self, layers=50, kps_num=16, test_mode=False):
......
......@@ -22,7 +22,6 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from tqdm import tqdm
from lib import pose_resnet
from utils.transforms import flip_back
from utils.utility import *
......@@ -34,37 +33,24 @@ add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('dataset', str, 'mpii', "Dataset")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('num_epochs', int, 140, "Number of epochs.")
add_arg('total_images', int, 144406, "Training image number.")
add_arg('kp_dim', int, 16, "Class number.")
add_arg('model_save_dir', str, "output", "Model save directory")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
add_arg('lr', float, 0.001, "Set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('flip_test', bool, True, "Flip test")
add_arg('shift_heatmap', bool, True, "Shift heatmap")
add_arg('post_process', bool, False, "post process")
# yapf: enable
FLIP_PAIRS = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]]
def test(args):
import lib.mpii_reader as reader
if args.dataset == 'coco':
import lib.coco_reader as reader
IMAGE_SIZE = [288, 384]
# HEATMAP_SIZE = [72, 96]
FLIP_PAIRS = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
args.kp_dim = 17
args.total_images = 144406 # 149813
elif args.dataset == 'mpii':
import lib.mpii_reader as reader
IMAGE_SIZE = [384, 384]
# HEATMAP_SIZE = [96, 96]
FLIP_PAIRS = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]]
args.kp_dim = 16
args.total_images = 2958 # validation
else:
raise ValueError('The dataset {} is not supported yet.'.format(args.dataset))
......@@ -80,15 +66,6 @@ def test(args):
# Output
output = model.net(input=image, target=None, target_weight=None)
# Parameters from model and arguments
params = {}
params["total_images"] = args.total_images
params["lr"] = args.lr
params["num_epochs"] = args.num_epochs
params["learning_strategy"] = {}
params["learning_strategy"]["batch_size"] = args.batch_size
params["learning_strategy"]["name"] = args.lr_strategy
if args.with_mem_opt:
fluid.memory_optimize(fluid.default_main_program(),
skip_opt_set=[output.name])
......@@ -97,13 +74,6 @@ def test(args):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
args.pretrained_model = './pretrained/resnet_50/115'
if args.pretrained_model:
def if_exist(var):
exist_flag = os.path.exists(os.path.join(args.pretrained_model, var.name))
return exist_flag
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
if args.checkpoint is not None:
fluid.io.load_persistables(exe, args.checkpoint)
......@@ -113,12 +83,12 @@ def test(args):
test_exe = fluid.ParallelExecutor(
use_cuda=True if args.use_gpu else False,
main_program=fluid.default_main_program().clone(for_test=False),
main_program=fluid.default_main_program().clone(for_test=True),
loss_name=None)
fetch_list = [image.name, output.name]
for batch_id, data in tqdm(enumerate(test_reader())):
for batch_id, data in enumerate(test_reader()):
num_images = len(data)
file_ids = []
......
......@@ -30,18 +30,18 @@ from utils.utility import *
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('dataset', str, 'mpii', "Dataset")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('num_epochs', int, 140, "Number of epochs.")
add_arg('total_images', int, 144406, "Training image number.")
add_arg('kp_dim', int, 16, "Class number.")
add_arg('model_save_dir', str, "output", "Model save directory")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
add_arg('lr', float, 0.001, "Set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('batch_size', int, 128, "Minibatch size totally.")
add_arg('dataset', str, 'mpii', "Dataset, valid value: mpii, coco")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('num_epochs', int, 140, "Number of epochs.")
add_arg('total_images', int, 144406, "Training image number.")
add_arg('kp_dim', int, 16, "Class number.")
add_arg('model_save_dir', str, "output", "Model save directory")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, "pretrained/resnet_50/115", "Whether to use pretrained model.")
add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
add_arg('lr', float, 0.001, "Set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
# yapf: enable
def optimizer_setting(args, params):
......@@ -54,7 +54,7 @@ def optimizer_setting(args, params):
batch_size = ls["batch_size"]
step = int(total_images / batch_size + 1)
ls['epochs'] = [91, 121]
ls['epochs'] = [90, 120]
print('=> LR will be dropped at the epoch of {}'.format(ls['epochs']))
bd = [step * e for e in ls["epochs"]]
......@@ -85,7 +85,7 @@ def train(args):
elif args.dataset == 'mpii':
import lib.mpii_reader as reader
IMAGE_SIZE = [384, 384]
HEATMAP_SIZE = [96, 96]
HEATMAP_SIZE = [96, 96]
args.kp_dim = 16
args.total_images = 22246
else:
......@@ -125,7 +125,7 @@ def train(args):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
args.pretrained_model = './pretrained/resnet_50/115'
if args.pretrained_model:
def if_exist(var):
exist_flag = os.path.exists(os.path.join(args.pretrained_model, var.name))
......
# Copyright (c) 2019-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Interface for evaluation."""
class BaseEvaluator(object):
def __init__(self, root, kp_dim):
"""
:param root: the root dir of dataset
:param kp_dim: the dimension of keypoints
"""
self.root = root
self.kp_dim = kp_dim
def evaluate(self, *args, **kwargs):
"""
Need Implementation for specific task / dataset
"""
raise NotImplementedError
# Copyright (c) 2019-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Interface for COCO evaluation."""
import os
import json
import numpy as np
from collections import defaultdict
from collections import OrderedDict
import pickle
from utils.base_evaluator import BaseEvaluator
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from nms.nms import oks_nms
class COCOEvaluator(BaseEvaluator):
def __init__(self, root, kp_dim=17):
"""
:param root: the root dir of dataset
:param kp_dim: the dimension of keypoints
"""
super(COCOEvaluator, self).__init__(root, kp_dim)
self.kp_dim = kp_dim
self.in_vis_thre = 0.2
self.oks_thre = 0.9
self.coco = COCO(os.path.join(root, 'annotations', 'person_keypoints_val2017.json'))
cats = [cat['name']
for cat in self.coco.loadCats(self.coco.getCatIds())]
self.classes = ['__background__'] + cats
self.num_classes = len(self.classes)
self._class_to_ind = dict(zip(self.classes, range(self.num_classes)))
self._class_to_coco_ind = dict(zip(cats, self.coco.getCatIds()))
self._coco_ind_to_class_ind = dict([(self._class_to_coco_ind[cls],
self._class_to_ind[cls])
for cls in self.classes[1:]])
def evaluate(self, preds, output_dir, all_boxes, img_path, *args, **kwargs):
"""
:param preds: the predictions to be evaluated
:param output_dir: target directory to save evaluation results
:param all_boxes: ground truth
:param img_path: paths of the original image
:return:
"""
res_folder = os.path.join(output_dir, 'results')
if not os.path.exists(res_folder):
os.makedirs(res_folder)
res_file = os.path.join(res_folder, 'keypoints_coco_results.json')
# person x (keypoints)
_kpts = []
for idx, kpt in enumerate(preds):
_kpts.append({
'keypoints': kpt,
'center': all_boxes[idx][0:2],
'scale': all_boxes[idx][2:4],
'area': all_boxes[idx][4],
'score': all_boxes[idx][5],
'image': int(img_path[idx][-16:-4])
})
# image x person x (keypoints)
kpts = defaultdict(list)
for kpt in _kpts:
kpts[kpt['image']].append(kpt)
# rescoring and oks nms
kp_dim = self.kp_dim
in_vis_thre = self.in_vis_thre
oks_thre = self.oks_thre
oks_nmsed_kpts = []
for img in kpts.keys():
img_kpts = kpts[img]
for n_p in img_kpts:
box_score = n_p['score']
kpt_score = 0
valid_num = 0
for n_jt in range(0, kp_dim):
t_s = n_p['keypoints'][n_jt][2]
if t_s > in_vis_thre:
kpt_score = kpt_score + t_s
valid_num = valid_num + 1
if valid_num != 0:
kpt_score = kpt_score / valid_num
# rescoring
n_p['score'] = kpt_score * box_score
keep = oks_nms([img_kpts[i] for i in range(len(img_kpts))], oks_thre)
if len(keep) == 0:
oks_nmsed_kpts.append(img_kpts)
else:
oks_nmsed_kpts.append([img_kpts[_keep] for _keep in keep])
self._write_coco_keypoint_results(oks_nmsed_kpts, res_file)
info_str = self._do_python_keypoint_eval(res_file, res_folder)
name_value = OrderedDict(info_str)
return name_value, name_value['AP']
def _write_coco_keypoint_results(self, keypoints, res_file):
data_pack = [{'cat_id': self._class_to_coco_ind[cls],
'cls_ind': cls_ind,
'cls': cls,
'ann_type': 'keypoints',
'keypoints': keypoints
}
for cls_ind, cls in enumerate(self.classes) if not cls == '__background__']
results = self._coco_keypoint_results_one_category_kernel(data_pack[0])
with open(res_file, 'w') as f:
json.dump(results, f, sort_keys=True, indent=4)
try:
json.load(open(res_file))
except Exception:
content = []
with open(res_file, 'r') as f:
for line in f:
content.append(line)
content[-1] = ']'
with open(res_file, 'w') as f:
for c in content:
f.write(c)
def _coco_keypoint_results_one_category_kernel(self, data_pack):
cat_id = data_pack['cat_id']
keypoints = data_pack['keypoints']
cat_results = []
for img_kpts in keypoints:
if len(img_kpts) == 0:
continue
_key_points = np.array([img_kpts[k]['keypoints']
for k in range(len(img_kpts))])
key_points = np.zeros(
(_key_points.shape[0], self.kp_dim * 3), dtype=np.float)
for ipt in range(self.kp_dim):
key_points[:, ipt * 3 + 0] = _key_points[:, ipt, 0]
key_points[:, ipt * 3 + 1] = _key_points[:, ipt, 1]
key_points[:, ipt * 3 + 2] = _key_points[:, ipt, 2] # keypoints score.
result = [{'image_id': img_kpts[k]['image'],
'category_id': cat_id,
'keypoints': list(key_points[k]),
'score': img_kpts[k]['score'],
'center': list(img_kpts[k]['center']),
'scale': list(img_kpts[k]['scale'])
} for k in range(len(img_kpts))]
cat_results.extend(result)
return cat_results
def _do_python_keypoint_eval(self, res_file, res_folder):
coco_dt = self.coco.loadRes(res_file)
coco_eval = COCOeval(self.coco, coco_dt, 'keypoints')
coco_eval.params.useSegm = None
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)']
info_str = []
for ind, name in enumerate(stats_names):
info_str.append((name, coco_eval.stats[ind]))
eval_file = os.path.join(res_folder, 'keypoints_val_results.pkl')
with open(eval_file, 'wb') as f:
pickle.dump(coco_eval, f, pickle.HIGHEST_PROTOCOL)
print('=> coco eval results saved to %s' % eval_file)
return info_str
# Copyright (c) 2019-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Interface for building evaluator."""
from utils.coco_evaluator import COCOEvaluator
from utils.mpii_evaluator import MPIIEvaluator
evaluator_map = {
'coco': COCOEvaluator,
'mpii': MPIIEvaluator
}
def create_evaluator(dataset):
"""
:param dataset: specific dataset to be evaluated
:return:
"""
return evaluator_map[dataset]
# Copyright (c) 2019-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Interface for MPII evaluation."""
import os
import numpy as np
from collections import OrderedDict
from scipy.io import loadmat, savemat
from utils.base_evaluator import BaseEvaluator
class MPIIEvaluator(BaseEvaluator):
def __init__(self, root, kp_dim=16):
"""
:param root: the root dir of dataset
:param kp_dim: the dimension of keypoints
"""
super(MPIIEvaluator, self).__init__(root, kp_dim)
self.root = root
self.kp_dim = kp_dim
self.sc_bias = 0.6
self.threshold = 0.5
def evaluate(self, preds, output_dir, *args, **kwargs):
"""
:param preds: the predictions to be evaluated
:param output_dir: target directory to save evaluation results
:return:
"""
# Convert 0-based index to 1-based index
preds = preds[:, :, 0:2] + 1.0
if output_dir:
pred_file = os.path.join(output_dir, 'pred.mat')
savemat(pred_file, mdict={'preds': preds})
gt_file = os.path.join(self.root, 'annot', 'gt_valid.mat')
gt_dict = loadmat(gt_file)
dataset_joints = gt_dict['dataset_joints']
jnt_missing = gt_dict['jnt_missing']
pos_gt_src = gt_dict['pos_gt_src']
headboxes_src = gt_dict['headboxes_src']
pos_pred_src = np.transpose(preds, [1, 2, 0])
head = np.where(dataset_joints == 'head')[1][0]
lsho = np.where(dataset_joints == 'lsho')[1][0]
lelb = np.where(dataset_joints == 'lelb')[1][0]
lwri = np.where(dataset_joints == 'lwri')[1][0]
lhip = np.where(dataset_joints == 'lhip')[1][0]
lkne = np.where(dataset_joints == 'lkne')[1][0]
lank = np.where(dataset_joints == 'lank')[1][0]
rsho = np.where(dataset_joints == 'rsho')[1][0]
relb = np.where(dataset_joints == 'relb')[1][0]
rwri = np.where(dataset_joints == 'rwri')[1][0]
rkne = np.where(dataset_joints == 'rkne')[1][0]
rank = np.where(dataset_joints == 'rank')[1][0]
rhip = np.where(dataset_joints == 'rhip')[1][0]
jnt_visible = 1 - jnt_missing
uv_error = pos_pred_src - pos_gt_src
uv_err = np.linalg.norm(uv_error, axis=1)
headsizes = headboxes_src[1, :, :] - headboxes_src[0, :, :]
headsizes = np.linalg.norm(headsizes, axis=0)
headsizes *= self.sc_bias
scale = np.multiply(headsizes, np.ones((len(uv_err), 1)))
scaled_uv_err = np.divide(uv_err, scale)
scaled_uv_err = np.multiply(scaled_uv_err, jnt_visible)
jnt_count = np.sum(jnt_visible, axis=1)
less_than_threshold = np.multiply((scaled_uv_err <= self.threshold), jnt_visible)
PCKh = np.divide(100. * np.sum(less_than_threshold, axis=1), jnt_count)
# Save
rng = np.arange(0, 0.5 + 0.01, 0.01)
pckAll = np.zeros((len(rng), self.kp_dim))
for r in range(len(rng)):
thresh = rng[r]
less_than_threshold = np.multiply(scaled_uv_err <= thresh, jnt_visible)
pckAll[r, :] = np.divide(100. * np.sum(less_than_threshold, axis=1), jnt_count)
PCKh = np.ma.array(PCKh, mask=False)
PCKh.mask[6:8] = True
jnt_count = np.ma.array(jnt_count, mask=False)
jnt_count.mask[6:8] = True
jnt_ratio = jnt_count / np.sum(jnt_count).astype(np.float64)
name_value = [
('Head', PCKh[head]),
('Shoulder', 0.5 * (PCKh[lsho] + PCKh[rsho])),
('Elbow', 0.5 * (PCKh[lelb] + PCKh[relb])),
('Wrist', 0.5 * (PCKh[lwri] + PCKh[rwri])),
('Hip', 0.5 * (PCKh[lhip] + PCKh[rhip])),
('Knee', 0.5 * (PCKh[lkne] + PCKh[rkne])),
('Ankle', 0.5 * (PCKh[lank] + PCKh[rank])),
('Mean', np.sum(PCKh * jnt_ratio)),
('Mean@0.1', np.sum(pckAll[11, :] * jnt_ratio))
]
name_value = OrderedDict(name_value)
return name_value, name_value['Mean']
......@@ -18,36 +18,35 @@
import os
import argparse
import functools
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from collections import OrderedDict
from scipy.io import loadmat, savemat
from lib import pose_resnet
from utils.transforms import flip_back
from utils.utility import *
from utils.evaluator_builder import create_evaluator
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('dataset', str, 'mpii', "Dataset")
add_arg('batch_size', int, 128, "Minibatch size.")
add_arg('dataset', str, 'coco', "Dataset")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('num_epochs', int, 140, "Number of epochs.")
add_arg('total_images', int, 144406, "Training image number.")
add_arg('kp_dim', int, 16, "Class number.")
add_arg('model_save_dir', str, "output", "Model save directory")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
add_arg('lr', float, 0.001, "Set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('flip_test', bool, True, "Flip test")
add_arg('shift_heatmap', bool, True, "Shift heatmap")
add_arg('post_process', bool, True, "Post process")
add_arg('post_process', bool, True, "Post process")
add_arg('data_root', str, "data/coco", "Root directory of dataset")
# yapf: enable
def valid(args):
......@@ -57,14 +56,14 @@ def valid(args):
HEATMAP_SIZE = [72, 96]
FLIP_PAIRS = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
args.kp_dim = 17
args.total_images = 144406 # 149813
args.total_images = 6108
elif args.dataset == 'mpii':
import lib.mpii_reader as reader
IMAGE_SIZE = [384, 384]
HEATMAP_SIZE = [96, 96]
FLIP_PAIRS = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]]
args.kp_dim = 16
args.total_images = 2958 # validation
args.total_images = 2958
else:
raise ValueError('The dataset {} is not supported yet.'.format(args.dataset))
......@@ -101,10 +100,11 @@ def valid(args):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
args.pretrained_model = './pretrained/resnet_50/115'
if args.pretrained_model:
def if_exist(var):
exist_flag = os.path.exists(os.path.join(args.pretrained_model, var.name))
if exist_flag:
print("Copy pretrianed weights from: %s" % var.name)
return exist_flag
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
......@@ -117,7 +117,7 @@ def valid(args):
valid_exe = fluid.ParallelExecutor(
use_cuda=True if args.use_gpu else False,
main_program=fluid.default_main_program().clone(for_test=False),
main_program=fluid.default_main_program().clone(for_test=True),
loss_name=loss.name)
fetch_list = [image.name, loss.name, output.name, target.name]
......@@ -127,11 +127,18 @@ def valid(args):
idx = 0
num_samples = args.total_images
all_preds = np.zeros((num_samples, args.kp_dim, 3),
dtype=np.float32)
all_preds = np.zeros((num_samples, args.kp_dim, 3), dtype=np.float32)
all_boxes = np.zeros((num_samples, 6))
for batch_id, data in enumerate(valid_reader()):
image_path = []
for batch_id, meta in enumerate(valid_reader()):
num_images = len(meta)
data = meta
if args.dataset == 'coco':
for i in range(num_images):
image_path.append(meta[i][-1])
data[i] = data[i][:-1]
num_images = len(data)
centers = []
......@@ -198,7 +205,6 @@ def valid(args):
all_boxes[idx:idx + num_images, 2:4] = scales[:, 0:2]
all_boxes[idx:idx + num_images, 4] = np.prod(scales*200, 1)
all_boxes[idx:idx + num_images, 5] = scores
# image_path.extend(meta['image'])
idx += num_images
......@@ -210,106 +216,11 @@ def valid(args):
save_batch_heatmaps(input_image, out_heatmaps, file_name='visualization@val.jpg', normalize=True)
# Evaluate
args.DATAROOT = 'data/mpii'
args.TEST_SET = 'valid'
output_dir = ''
filenames = []
imgnums = []
image_path = []
name_values, perf_indicator = mpii_evaluate(
args, all_preds, output_dir, all_boxes, image_path,
filenames, imgnums)
output_dir = './'
evaluator = create_evaluator(args.dataset)(args.data_root, args.kp_dim)
name_values, perf_indicator = evaluator.evaluate(all_preds, output_dir, all_boxes, image_path)
print_name_value(name_values, perf_indicator)
def mpii_evaluate(cfg, preds, output_dir, *args, **kwargs):
# Convert 0-based index to 1-based index
preds = preds[:, :, 0:2] + 1.0
if output_dir:
pred_file = os.path.join(output_dir, 'pred.mat')
savemat(pred_file, mdict={'preds': preds})
if 'test' in cfg.TEST_SET:
return {'Null': 0.0}, 0.0
SC_BIAS = 0.6
threshold = 0.5
gt_file = os.path.join(cfg.DATAROOT,
'annot',
'gt_{}.mat'.format(cfg.TEST_SET))
gt_dict = loadmat(gt_file)
dataset_joints = gt_dict['dataset_joints']
jnt_missing = gt_dict['jnt_missing']
pos_gt_src = gt_dict['pos_gt_src']
headboxes_src = gt_dict['headboxes_src']
pos_pred_src = np.transpose(preds, [1, 2, 0])
head = np.where(dataset_joints == 'head')[1][0]
lsho = np.where(dataset_joints == 'lsho')[1][0]
lelb = np.where(dataset_joints == 'lelb')[1][0]
lwri = np.where(dataset_joints == 'lwri')[1][0]
lhip = np.where(dataset_joints == 'lhip')[1][0]
lkne = np.where(dataset_joints == 'lkne')[1][0]
lank = np.where(dataset_joints == 'lank')[1][0]
rsho = np.where(dataset_joints == 'rsho')[1][0]
relb = np.where(dataset_joints == 'relb')[1][0]
rwri = np.where(dataset_joints == 'rwri')[1][0]
rkne = np.where(dataset_joints == 'rkne')[1][0]
rank = np.where(dataset_joints == 'rank')[1][0]
rhip = np.where(dataset_joints == 'rhip')[1][0]
jnt_visible = 1 - jnt_missing
uv_error = pos_pred_src - pos_gt_src
uv_err = np.linalg.norm(uv_error, axis=1)
headsizes = headboxes_src[1, :, :] - headboxes_src[0, :, :]
headsizes = np.linalg.norm(headsizes, axis=0)
headsizes *= SC_BIAS
scale = np.multiply(headsizes, np.ones((len(uv_err), 1)))
scaled_uv_err = np.divide(uv_err, scale)
scaled_uv_err = np.multiply(scaled_uv_err, jnt_visible)
jnt_count = np.sum(jnt_visible, axis=1)
less_than_threshold = np.multiply((scaled_uv_err <= threshold),
jnt_visible)
PCKh = np.divide(100.*np.sum(less_than_threshold, axis=1), jnt_count)
# Save
rng = np.arange(0, 0.5+0.01, 0.01)
pckAll = np.zeros((len(rng), cfg.kp_dim))
for r in range(len(rng)):
threshold = rng[r]
less_than_threshold = np.multiply(scaled_uv_err <= threshold,
jnt_visible)
pckAll[r, :] = np.divide(100.*np.sum(less_than_threshold, axis=1),
jnt_count)
PCKh = np.ma.array(PCKh, mask=False)
PCKh.mask[6:8] = True
jnt_count = np.ma.array(jnt_count, mask=False)
jnt_count.mask[6:8] = True
jnt_ratio = jnt_count / np.sum(jnt_count).astype(np.float64)
name_value = [
('Head', PCKh[head]),
('Shoulder', 0.5 * (PCKh[lsho] + PCKh[rsho])),
('Elbow', 0.5 * (PCKh[lelb] + PCKh[relb])),
('Wrist', 0.5 * (PCKh[lwri] + PCKh[rwri])),
('Hip', 0.5 * (PCKh[lhip] + PCKh[rhip])),
('Knee', 0.5 * (PCKh[lkne] + PCKh[rkne])),
('Ankle', 0.5 * (PCKh[lank] + PCKh[rank])),
('Mean', np.sum(PCKh * jnt_ratio)),
('Mean@0.1', np.sum(pckAll[11, :] * jnt_ratio))
]
name_value = OrderedDict(name_value)
return name_value, name_value['Mean']
# TODO: coco_evaluate()
if __name__ == '__main__':
args = parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册