From 3c6263b3c5a53b2c212213cf2c5dac3c8f26ac6f Mon Sep 17 00:00:00 2001 From: zlyi1225 Date: Wed, 10 Apr 2019 15:49:54 +0800 Subject: [PATCH] fix bugs and add python3 supports (#2021) --- PaddleCV/human_pose_estimation/README.md | 12 ++-- PaddleCV/human_pose_estimation/README_cn.md | 14 ++-- PaddleCV/human_pose_estimation/test.py | 11 ++- PaddleCV/human_pose_estimation/train.py | 9 ++- .../utils/coco_evaluator.py | 2 +- .../human_pose_estimation/utils/nms_utils.py | 71 +++++++++++++++++++ PaddleCV/human_pose_estimation/val.py | 10 ++- 7 files changed, 112 insertions(+), 17 deletions(-) create mode 100644 PaddleCV/human_pose_estimation/utils/nms_utils.py diff --git a/PaddleCV/human_pose_estimation/README.md b/PaddleCV/human_pose_estimation/README.md index d629c6b7..c563ce31 100644 --- a/PaddleCV/human_pose_estimation/README.md +++ b/PaddleCV/human_pose_estimation/README.md @@ -9,13 +9,13 @@ This is a simple demonstration of re-implementation in [PaddlePaddle.Fluid](http ## Requirements - - Python == 2.7 - - PaddlePaddle >= 1.1.0 + - Python == 2.7 or 3.6 + - PaddlePaddle >= 1.1.0 (<= 1.3.0) - opencv-python >= 3.3 ## Environment -The code is developed and tested under 4 Tesla K40/P40 GPUS cards on CentOS with installed CUDA-9.2/8.0 and cuDNN-7.1. +The code is developed and tested under 4 Tesla K40/P40 GPUS cards on CentOS with installed CUDA-9.0/8.0 and cuDNN-7.0. ## Results on MPII Val | Arch | Head | Shoulder | Elbow | Wrist | Hip | Knee | Ankle | Mean | Mean@0.1| Models | @@ -85,19 +85,21 @@ python2 setup.py install --user Downloading the checkpoints of Pose-ResNet-50 trained on MPII dataset from [here](https://paddlemodels.bj.bcebos.com/pose/pose-resnet50-mpii-384x384.tar.gz). Extract it into the folder `checkpoints` under the directory root of this repo. Then run ```bash -python val.py --dataset 'mpii' --checkpoint 'checkpoints/pose-resnet50-mpii-384x384' +python val.py --dataset 'mpii' --checkpoint 'checkpoints/pose-resnet50-mpii-384x384' --data_root 'data/mpii' ``` ### Perform Training ```bash -python train.py --dataset 'mpii' # or coco +python train.py --dataset 'mpii' --data_root 'data/mpii' ``` **Note**: Configurations for training are aggregated in the `lib/mpii_reader.py` and `lib/coco_reader.py`. ### Perform Test on Images +We also support to apply pre-trained models on customized images. + Put the images into the folder `test` under the directory root of this repo. Then run ```bash diff --git a/PaddleCV/human_pose_estimation/README_cn.md b/PaddleCV/human_pose_estimation/README_cn.md index 08c77201..86f71aa5 100644 --- a/PaddleCV/human_pose_estimation/README_cn.md +++ b/PaddleCV/human_pose_estimation/README_cn.md @@ -9,10 +9,10 @@ ## 环境依赖 -本目录下的代码均在4卡Tesla K40/P40 GPU,CentOS系统,CUDA-9.2/8.0,cuDNN-7.1环境下测试运行无误 +本目录下的代码均在4卡Tesla K40/P40 GPU,CentOS系统,CUDA-9.0/8.0,cuDNN-7.0环境下测试运行无误 - - Python == 2.7 - - PaddlePaddle >= 1.1.0 + - Python == 2.7 / 3.6 + - PaddlePaddle >= 1.1.0 (<= 1.3.0) - opencv-python >= 3.3 ## MPII Val结果 @@ -83,19 +83,21 @@ python2 setup.py install --user 下载COCO/MPII预训练模型(见上表最后一列所附链接),保存到根目录下的'checkpoints'文件夹中,运行: ```bash -python val.py --dataset 'mpii' --checkpoint 'checkpoints/pose-resnet50-mpii-384x384' +python val.py --dataset 'mpii' --checkpoint 'checkpoints/pose-resnet50-mpii-384x384' --data_root 'data/mpii' ``` ### 模型训练 ```bash -python train.py --dataset 'mpii' # or coco +python train.py --dataset 'mpii' --data_root 'data/mpii' ``` **说明** 详细参数配置已保存到`lib/mpii_reader.py` 和 `lib/coco_reader.py`文件中,通过设置dataset来选择使用具体的参数配置 ### 模型测试(任意图片,使用上述COCO或MPII预训练好的模型) +同时,我们支持使用预训练好的关键点检测模型预测任意图片 + 将测试图片放入根目录下的'test'文件夹中,执行 ```bash @@ -104,4 +106,4 @@ python test.py --checkpoint 'checkpoints/pose-resnet-50-384x384-mpii' ## 引用 -- Simple Baselines for Human Pose Estimation and Tracking in PyTorch [`code`](https://github.com/Microsoft/human-pose-estimation.pytorch#data-preparation) +- Simple Baselines for Human Pose Estimation and Tracking in PyTorch [`code`](https://github.com/Microsoft/human-pose-estimation.pytorch#data-preparation) \ No newline at end of file diff --git a/PaddleCV/human_pose_estimation/test.py b/PaddleCV/human_pose_estimation/test.py index aebbe517..8ede66d2 100644 --- a/PaddleCV/human_pose_estimation/test.py +++ b/PaddleCV/human_pose_estimation/test.py @@ -15,7 +15,7 @@ """Functions for inference.""" -import os +import sys import argparse import functools import paddle @@ -34,13 +34,18 @@ add_arg('batch_size', int, 32, "Minibatch size.") add_arg('dataset', str, 'mpii', "Dataset") add_arg('use_gpu', bool, True, "Whether to use GPU or not.") add_arg('kp_dim', int, 16, "Class number.") -add_arg('model_save_dir', str, "output", "Model save directory") add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.") add_arg('checkpoint', str, None, "Whether to resume checkpoint.") add_arg('flip_test', bool, True, "Flip test") add_arg('shift_heatmap', bool, True, "Shift heatmap") # yapf: enable + +def print_immediately(s): + print(s) + sys.stdout.flush() + + def test(args): import lib.mpii_reader as reader if args.dataset == 'coco': @@ -89,6 +94,7 @@ def test(args): fetch_list = [image.name, output.name] for batch_id, data in enumerate(test_reader()): + print_immediately("Processing batch #%d" % batch_id) num_images = len(data) file_ids = [] @@ -124,6 +130,7 @@ def test(args): out_heatmaps = (out_heatmaps + output_flipped) * 0.5 save_predict_results(input_image, out_heatmaps, file_ids, fold_name='results') + if __name__ == '__main__': args = parser.parse_args() test(args) diff --git a/PaddleCV/human_pose_estimation/train.py b/PaddleCV/human_pose_estimation/train.py index 27d0aa8e..c6d75732 100644 --- a/PaddleCV/human_pose_estimation/train.py +++ b/PaddleCV/human_pose_estimation/train.py @@ -16,6 +16,7 @@ """Functions for training.""" import os +import sys import numpy as np import cv2 import paddle @@ -75,6 +76,12 @@ def optimizer_setting(args, params): return optimizer + +def print_immediately(s): + print(s) + sys.stdout.flush() + + def train(args): if args.dataset == 'coco': import lib.coco_reader as reader @@ -152,7 +159,7 @@ def train(args): loss = np.mean(np.array(loss)) - print('Epoch [{:4d}/{:3d}] LR: {:.10f} ' + print_immediately('Epoch [{:4d}/{:3d}] LR: {:.10f} ' 'Loss = {:.5f}'.format( batch_id, pass_id, current_lr[0], loss)) diff --git a/PaddleCV/human_pose_estimation/utils/coco_evaluator.py b/PaddleCV/human_pose_estimation/utils/coco_evaluator.py index 56a2d2d1..9afb2e22 100644 --- a/PaddleCV/human_pose_estimation/utils/coco_evaluator.py +++ b/PaddleCV/human_pose_estimation/utils/coco_evaluator.py @@ -24,9 +24,9 @@ from collections import OrderedDict import pickle from utils.base_evaluator import BaseEvaluator +from utils.nms_utils import oks_nms from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval -from nms.nms import oks_nms class COCOEvaluator(BaseEvaluator): diff --git a/PaddleCV/human_pose_estimation/utils/nms_utils.py b/PaddleCV/human_pose_estimation/utils/nms_utils.py new file mode 100644 index 00000000..ea72ddba --- /dev/null +++ b/PaddleCV/human_pose_estimation/utils/nms_utils.py @@ -0,0 +1,71 @@ +# Copyright (c) 2019-present, Baidu, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None): + if not isinstance(sigmas, np.ndarray): + sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0 + vars = (sigmas * 2) ** 2 + xg = g[0::3] + yg = g[1::3] + vg = g[2::3] + ious = np.zeros((d.shape[0])) + for n_d in range(0, d.shape[0]): + xd = d[n_d, 0::3] + yd = d[n_d, 1::3] + vd = d[n_d, 2::3] + dx = xd - xg + dy = yd - yg + e = (dx ** 2 + dy ** 2) / vars / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2 + if in_vis_thre is not None: + ind = list(vg > in_vis_thre) and list(vd > in_vis_thre) + e = e[ind] + ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0 + return ious + + +def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None): + """ + greedily select boxes with high confidence and overlap with current maximum <= thresh + rule out overlap >= thresh, overlap = oks + :param kpts_db + :param thresh: retain overlap < thresh + :return: indexes to keep + """ + if len(kpts_db) == 0: + return [] + + scores = np.array([kpts_db[i]['score'] for i in range(len(kpts_db))]) + kpts = np.array([kpts_db[i]['keypoints'].flatten() for i in range(len(kpts_db))]) + areas = np.array([kpts_db[i]['area'] for i in range(len(kpts_db))]) + + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + + oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], sigmas, in_vis_thre) + + inds = np.where(oks_ovr <= thresh)[0] + order = order[inds + 1] + + return keep \ No newline at end of file diff --git a/PaddleCV/human_pose_estimation/val.py b/PaddleCV/human_pose_estimation/val.py index a867a674..4224ec1e 100644 --- a/PaddleCV/human_pose_estimation/val.py +++ b/PaddleCV/human_pose_estimation/val.py @@ -16,6 +16,7 @@ """Functions for validation.""" import os +import sys import argparse import functools import paddle @@ -37,7 +38,6 @@ add_arg('use_gpu', bool, True, "Whether to use GPU or n add_arg('num_epochs', int, 140, "Number of epochs.") add_arg('total_images', int, 144406, "Training image number.") add_arg('kp_dim', int, 16, "Class number.") -add_arg('model_save_dir', str, "output", "Model save directory") add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.") add_arg('pretrained_model', str, None, "Whether to use pretrained model.") add_arg('checkpoint', str, None, "Whether to resume checkpoint.") @@ -49,6 +49,12 @@ add_arg('post_process', bool, True, "Post process") add_arg('data_root', str, "data/coco", "Root directory of dataset") # yapf: enable + +def print_immediately(s): + print(s) + sys.stdout.flush() + + def valid(args): if args.dataset == 'coco': import lib.coco_reader as reader @@ -208,7 +214,7 @@ def valid(args): idx += num_images - print('Epoch [{:4d}] ' + print_immediately('Epoch [{:4d}] ' 'Loss = {:.5f} ' 'Acc = {:.5f}'.format(batch_id, loss, acc.avg)) -- GitLab