提交 bef71145 编写于 作者: Z zlyi1225 提交者: qingqing01

fix bugs and add python3 supports (#2019)

上级 811fe5c5
...@@ -9,13 +9,13 @@ This is a simple demonstration of re-implementation in [PaddlePaddle.Fluid](http ...@@ -9,13 +9,13 @@ This is a simple demonstration of re-implementation in [PaddlePaddle.Fluid](http
## Requirements ## Requirements
- Python == 2.7 - Python == 2.7 or 3.6
- PaddlePaddle >= 1.1.0 - PaddlePaddle >= 1.1.0 (<= 1.3.0)
- opencv-python >= 3.3 - opencv-python >= 3.3
## Environment ## Environment
The code is developed and tested under 4 Tesla K40/P40 GPUS cards on CentOS with installed CUDA-9.2/8.0 and cuDNN-7.1. The code is developed and tested under 4 Tesla K40/P40 GPUS cards on CentOS with installed CUDA-9.0/8.0 and cuDNN-7.0.
## Results on MPII Val ## Results on MPII Val
| Arch | Head | Shoulder | Elbow | Wrist | Hip | Knee | Ankle | Mean | Mean@0.1| Models | | Arch | Head | Shoulder | Elbow | Wrist | Hip | Knee | Ankle | Mean | Mean@0.1| Models |
...@@ -85,19 +85,21 @@ python2 setup.py install --user ...@@ -85,19 +85,21 @@ python2 setup.py install --user
Downloading the checkpoints of Pose-ResNet-50 trained on MPII dataset from [here](https://paddlemodels.bj.bcebos.com/pose/pose-resnet50-mpii-384x384.tar.gz). Extract it into the folder `checkpoints` under the directory root of this repo. Then run Downloading the checkpoints of Pose-ResNet-50 trained on MPII dataset from [here](https://paddlemodels.bj.bcebos.com/pose/pose-resnet50-mpii-384x384.tar.gz). Extract it into the folder `checkpoints` under the directory root of this repo. Then run
```bash ```bash
python val.py --dataset 'mpii' --checkpoint 'checkpoints/pose-resnet50-mpii-384x384' python val.py --dataset 'mpii' --checkpoint 'checkpoints/pose-resnet50-mpii-384x384' --data_root 'data/mpii'
``` ```
### Perform Training ### Perform Training
```bash ```bash
python train.py --dataset 'mpii' # or coco python train.py --dataset 'mpii' --data_root 'data/mpii'
``` ```
**Note**: Configurations for training are aggregated in the `lib/mpii_reader.py` and `lib/coco_reader.py`. **Note**: Configurations for training are aggregated in the `lib/mpii_reader.py` and `lib/coco_reader.py`.
### Perform Test on Images ### Perform Test on Images
We also support to apply pre-trained models on customized images.
Put the images into the folder `test` under the directory root of this repo. Then run Put the images into the folder `test` under the directory root of this repo. Then run
```bash ```bash
......
...@@ -9,10 +9,10 @@ ...@@ -9,10 +9,10 @@
## 环境依赖 ## 环境依赖
本目录下的代码均在4卡Tesla K40/P40 GPU,CentOS系统,CUDA-9.2/8.0,cuDNN-7.1环境下测试运行无误 本目录下的代码均在4卡Tesla K40/P40 GPU,CentOS系统,CUDA-9.0/8.0,cuDNN-7.0环境下测试运行无误
- Python == 2.7 - Python == 2.7 / 3.6
- PaddlePaddle >= 1.1.0 - PaddlePaddle >= 1.1.0 (<= 1.3.0)
- opencv-python >= 3.3 - opencv-python >= 3.3
## MPII Val结果 ## MPII Val结果
...@@ -83,19 +83,21 @@ python2 setup.py install --user ...@@ -83,19 +83,21 @@ python2 setup.py install --user
下载COCO/MPII预训练模型(见上表最后一列所附链接),保存到根目录下的'checkpoints'文件夹中,运行: 下载COCO/MPII预训练模型(见上表最后一列所附链接),保存到根目录下的'checkpoints'文件夹中,运行:
```bash ```bash
python val.py --dataset 'mpii' --checkpoint 'checkpoints/pose-resnet50-mpii-384x384' python val.py --dataset 'mpii' --checkpoint 'checkpoints/pose-resnet50-mpii-384x384' --data_root 'data/mpii'
``` ```
### 模型训练 ### 模型训练
```bash ```bash
python train.py --dataset 'mpii' # or coco python train.py --dataset 'mpii' --data_root 'data/mpii'
``` ```
**说明** 详细参数配置已保存到`lib/mpii_reader.py``lib/coco_reader.py`文件中,通过设置dataset来选择使用具体的参数配置 **说明** 详细参数配置已保存到`lib/mpii_reader.py``lib/coco_reader.py`文件中,通过设置dataset来选择使用具体的参数配置
### 模型测试(任意图片,使用上述COCO或MPII预训练好的模型) ### 模型测试(任意图片,使用上述COCO或MPII预训练好的模型)
同时,我们支持使用预训练好的关键点检测模型预测任意图片
将测试图片放入根目录下的'test'文件夹中,执行 将测试图片放入根目录下的'test'文件夹中,执行
```bash ```bash
...@@ -104,4 +106,4 @@ python test.py --checkpoint 'checkpoints/pose-resnet-50-384x384-mpii' ...@@ -104,4 +106,4 @@ python test.py --checkpoint 'checkpoints/pose-resnet-50-384x384-mpii'
## 引用 ## 引用
- Simple Baselines for Human Pose Estimation and Tracking in PyTorch [`code`](https://github.com/Microsoft/human-pose-estimation.pytorch#data-preparation) - Simple Baselines for Human Pose Estimation and Tracking in PyTorch [`code`](https://github.com/Microsoft/human-pose-estimation.pytorch#data-preparation)
\ No newline at end of file
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Functions for inference.""" """Functions for inference."""
import os import sys
import argparse import argparse
import functools import functools
import paddle import paddle
...@@ -34,13 +34,18 @@ add_arg('batch_size', int, 32, "Minibatch size.") ...@@ -34,13 +34,18 @@ add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('dataset', str, 'mpii', "Dataset") add_arg('dataset', str, 'mpii', "Dataset")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.") add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('kp_dim', int, 16, "Class number.") add_arg('kp_dim', int, 16, "Class number.")
add_arg('model_save_dir', str, "output", "Model save directory")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.") add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('checkpoint', str, None, "Whether to resume checkpoint.") add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
add_arg('flip_test', bool, True, "Flip test") add_arg('flip_test', bool, True, "Flip test")
add_arg('shift_heatmap', bool, True, "Shift heatmap") add_arg('shift_heatmap', bool, True, "Shift heatmap")
# yapf: enable # yapf: enable
def print_immediately(s):
print(s)
sys.stdout.flush()
def test(args): def test(args):
import lib.mpii_reader as reader import lib.mpii_reader as reader
if args.dataset == 'coco': if args.dataset == 'coco':
...@@ -89,6 +94,7 @@ def test(args): ...@@ -89,6 +94,7 @@ def test(args):
fetch_list = [image.name, output.name] fetch_list = [image.name, output.name]
for batch_id, data in enumerate(test_reader()): for batch_id, data in enumerate(test_reader()):
print_immediately("Processing batch #%d" % batch_id)
num_images = len(data) num_images = len(data)
file_ids = [] file_ids = []
...@@ -124,6 +130,7 @@ def test(args): ...@@ -124,6 +130,7 @@ def test(args):
out_heatmaps = (out_heatmaps + output_flipped) * 0.5 out_heatmaps = (out_heatmaps + output_flipped) * 0.5
save_predict_results(input_image, out_heatmaps, file_ids, fold_name='results') save_predict_results(input_image, out_heatmaps, file_ids, fold_name='results')
if __name__ == '__main__': if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
test(args) test(args)
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Functions for training.""" """Functions for training."""
import os import os
import sys
import numpy as np import numpy as np
import cv2 import cv2
import paddle import paddle
...@@ -75,6 +76,12 @@ def optimizer_setting(args, params): ...@@ -75,6 +76,12 @@ def optimizer_setting(args, params):
return optimizer return optimizer
def print_immediately(s):
print(s)
sys.stdout.flush()
def train(args): def train(args):
if args.dataset == 'coco': if args.dataset == 'coco':
import lib.coco_reader as reader import lib.coco_reader as reader
...@@ -152,7 +159,7 @@ def train(args): ...@@ -152,7 +159,7 @@ def train(args):
loss = np.mean(np.array(loss)) loss = np.mean(np.array(loss))
print('Epoch [{:4d}/{:3d}] LR: {:.10f} ' print_immediately('Epoch [{:4d}/{:3d}] LR: {:.10f} '
'Loss = {:.5f}'.format( 'Loss = {:.5f}'.format(
batch_id, pass_id, current_lr[0], loss)) batch_id, pass_id, current_lr[0], loss))
......
...@@ -24,9 +24,9 @@ from collections import OrderedDict ...@@ -24,9 +24,9 @@ from collections import OrderedDict
import pickle import pickle
from utils.base_evaluator import BaseEvaluator from utils.base_evaluator import BaseEvaluator
from utils.nms_utils import oks_nms
from pycocotools.coco import COCO from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
from nms.nms import oks_nms
class COCOEvaluator(BaseEvaluator): class COCOEvaluator(BaseEvaluator):
......
# Copyright (c) 2019-present, Baidu, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None):
if not isinstance(sigmas, np.ndarray):
sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
vars = (sigmas * 2) ** 2
xg = g[0::3]
yg = g[1::3]
vg = g[2::3]
ious = np.zeros((d.shape[0]))
for n_d in range(0, d.shape[0]):
xd = d[n_d, 0::3]
yd = d[n_d, 1::3]
vd = d[n_d, 2::3]
dx = xd - xg
dy = yd - yg
e = (dx ** 2 + dy ** 2) / vars / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2
if in_vis_thre is not None:
ind = list(vg > in_vis_thre) and list(vd > in_vis_thre)
e = e[ind]
ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0
return ious
def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None):
"""
greedily select boxes with high confidence and overlap with current maximum <= thresh
rule out overlap >= thresh, overlap = oks
:param kpts_db
:param thresh: retain overlap < thresh
:return: indexes to keep
"""
if len(kpts_db) == 0:
return []
scores = np.array([kpts_db[i]['score'] for i in range(len(kpts_db))])
kpts = np.array([kpts_db[i]['keypoints'].flatten() for i in range(len(kpts_db))])
areas = np.array([kpts_db[i]['area'] for i in range(len(kpts_db))])
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], sigmas, in_vis_thre)
inds = np.where(oks_ovr <= thresh)[0]
order = order[inds + 1]
return keep
\ No newline at end of file
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Functions for validation.""" """Functions for validation."""
import os import os
import sys
import argparse import argparse
import functools import functools
import paddle import paddle
...@@ -37,7 +38,6 @@ add_arg('use_gpu', bool, True, "Whether to use GPU or n ...@@ -37,7 +38,6 @@ add_arg('use_gpu', bool, True, "Whether to use GPU or n
add_arg('num_epochs', int, 140, "Number of epochs.") add_arg('num_epochs', int, 140, "Number of epochs.")
add_arg('total_images', int, 144406, "Training image number.") add_arg('total_images', int, 144406, "Training image number.")
add_arg('kp_dim', int, 16, "Class number.") add_arg('kp_dim', int, 16, "Class number.")
add_arg('model_save_dir', str, "output", "Model save directory")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.") add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.") add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('checkpoint', str, None, "Whether to resume checkpoint.") add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
...@@ -49,6 +49,12 @@ add_arg('post_process', bool, True, "Post process") ...@@ -49,6 +49,12 @@ add_arg('post_process', bool, True, "Post process")
add_arg('data_root', str, "data/coco", "Root directory of dataset") add_arg('data_root', str, "data/coco", "Root directory of dataset")
# yapf: enable # yapf: enable
def print_immediately(s):
print(s)
sys.stdout.flush()
def valid(args): def valid(args):
if args.dataset == 'coco': if args.dataset == 'coco':
import lib.coco_reader as reader import lib.coco_reader as reader
...@@ -208,7 +214,7 @@ def valid(args): ...@@ -208,7 +214,7 @@ def valid(args):
idx += num_images idx += num_images
print('Epoch [{:4d}] ' print_immediately('Epoch [{:4d}] '
'Loss = {:.5f} ' 'Loss = {:.5f} '
'Acc = {:.5f}'.format(batch_id, loss, acc.avg)) 'Acc = {:.5f}'.format(batch_id, loss, acc.avg))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册