未验证 提交 d7ff713e 编写于 作者: Z zhiboniu 提交者: GitHub

add pose demo (#3042)

上级 34968d61
......@@ -6,17 +6,21 @@
- PaddleDetection KeyPoint部分紧跟业内最新最优算法方案,包含Top-Down、BottomUp两套方案,以满足用户的不同需求。
<div align="center">
<img src="./football_keypoint.gif" width='800'/>
</div>
#### Model Zoo
| 模型 | 输入尺寸 | 通道数 | AP(coco val) | 模型下载 | 配置文件 |
| :---------------- | -------- | ------ | :----------: | :----------------------------------------------------------: | ------------------------------------------------------------ |
| HigherHRNet | 512 | 32 | 67.1 | [higherhrnet_hrnet_w32_512.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml) |
| HigherHRNet | 640 | 32 | 68.3 | [higherhrnet_hrnet_w32_640.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_640.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_640.yml) |
| HigherHRNet+SWAHR | 512 | 32 | 68.9 | [higherhrnet_hrnet_w32_512_swahr.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512_swahr.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml) |
| HRNet | 256x192 | 32 | 76.9 | [hrnet_w32_256x192.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_256x192.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/hrnet/hrnet_w32_256x192.yml) |
| HRNet | 384x288 | 32 | 77.8 | [hrnet_w32_384x288.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_384x288.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/keypoint/hrnet/hrnet_w32_384x288.yml) |
| 模型 | 输入尺寸 | 通道数 | AP(coco val) | 模型下载 | 配置文件 |
| :---------------- | -------- | ------ | :----------: | :----------------------------------------------------------: | ----------------------------------------------------------- |
| HigherHRNet | 512 | 32 | 67.1 | [higherhrnet_hrnet_w32_512.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512.yml) |
| HigherHRNet | 640 | 32 | 68.3 | [higherhrnet_hrnet_w32_640.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_640.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_640.yml) |
| HigherHRNet+SWAHR | 512 | 32 | 68.9 | [higherhrnet_hrnet_w32_512_swahr.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512_swahr.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml) |
| HRNet | 256x192 | 32 | 76.9 | [hrnet_w32_256x192.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_256x192.pdparams) | [config](./hrnet/hrnet_w32_256x192.yml) |
| HRNet | 384x288 | 32 | 77.8 | [hrnet_w32_384x288.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/hrnet_w32_384x288.pdparams) | [config](./hrnet/hrnet_w32_384x288.yml) |
......@@ -54,6 +58,8 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/keypoint/higherhrnet/hig
**模型预测:**
​ 注意:top-down模型只支持单人截图预测,如需使用多人图,请使用[联合部署推理]方式。或者使用bottom-up模型。
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml -o weights=./output/higherhrnet_hrnet_w32_512/model_final.pdparams --infer_dir=../images/ --draw_threshold=0.5 --save_txt=True
```
......@@ -65,10 +71,10 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/infer.py -c configs/keypoint/higherhrnet/hi
python tools/export_model.py -c configs/keypoint/higherhrnet/higherhrnet_hrnet_w32_512.yml -o weights=output/higherhrnet_hrnet_w32_512/model_final.pdparams
#部署推理
#keypoint top-down/bottom-up 单独推理,图片
python deploy/python/keypoint_infer.py --model_dir=output_inference/higherhrnet_hrnet_w32_512/ --image_file=../images/xxx.jpeg --use_gpu=True --threshold=0.5
python deploy/python/keypoint_infer.py --model_dir=output_inference/hrnet_w32_384x288/ --image_file=../images/xxx.jpeg --use_gpu=True --threshold=0.5
#keypoint top-down/bottom-up 单独推理,该模式下top-down模型只支持单人截图预测。
python deploy/python/keypoint_infer.py --model_dir=output_inference/higherhrnet_hrnet_w32_512/ --image_file=./demo/000000014439_640x640.jpg --use_gpu=True --threshold=0.5
python deploy/python/keypoint_infer.py --model_dir=output_inference/hrnet_w32_384x288/ --image_file=./demo/hrnet_demo.jpg --use_gpu=True --threshold=0.5
#keypoint top-down + detector 与检测联合部署推理
#keypoint top-down模型 + detector 检测联合部署推理(联合推理只支持top-down方式)
python deploy/python/keypoint_det_unite_infer.py --det_model_dir=output_inference/ppyolo_r50vd_dcn_2x_coco/ --keypoint_model_dir=output_inference/hrnet_w32_384x288/ --video_file=../video/xxx.mp4
```
因为 它太大了无法显示 image diff 。你可以改为 查看blob
......@@ -26,19 +26,22 @@ from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint
from keypoint_visualize import draw_pose
def expand_crop(images, rect, expand_ratio=0.5):
def expand_crop(images, rect, expand_ratio=0.3):
imgh, imgw, c = images.shape
label, _, xmin, ymin, xmax, ymax = [int(x) for x in rect.tolist()]
label, conf, xmin, ymin, xmax, ymax = [int(x) for x in rect.tolist()]
if label != 0:
return None, None
return None, None, None
org_rect = [xmin, ymin, xmax, ymax]
h_half = (ymax - ymin) * (1 + expand_ratio) / 2.
w_half = (xmax - xmin) * (1 + expand_ratio) / 2.
if h_half > w_half * 4 / 3:
w_half = h_half * 0.75
center = [(ymin + ymax) / 2., (xmin + xmax) / 2.]
ymin = max(0, int(center[0] - h_half))
ymax = min(imgh - 1, int(center[0] + h_half))
xmin = max(0, int(center[1] - w_half))
xmax = min(imgw - 1, int(center[1] + w_half))
return images[ymin:ymax, xmin:xmax, :], [xmin, ymin, xmax, ymax]
return images[ymin:ymax, xmin:xmax, :], [xmin, ymin, xmax, ymax], org_rect
def get_person_from_rect(images, results):
......@@ -46,12 +49,14 @@ def get_person_from_rect(images, results):
mask = det_results[:, 1] > FLAGS.det_threshold
valid_rects = det_results[mask]
image_buff = []
org_rects = []
for rect in valid_rects:
rect_image, new_rect = expand_crop(images, rect)
rect_image, new_rect, org_rect = expand_crop(images, rect)
if rect_image is None:
continue
image_buff.append([rect_image, new_rect])
return image_buff
org_rects.append(org_rect)
return image_buff, org_rects
def affine_backto_orgimages(keypoint_result, batch_records):
......@@ -65,10 +70,10 @@ def topdown_unite_predict(detector, topdown_keypoint_detector, image_list):
for i, img_file in enumerate(image_list):
image, _ = decode_image(img_file, {})
results = detector.predict(image, FLAGS.det_threshold)
batchs_images = get_person_from_rect(image, results)
batchs_images, det_rects = get_person_from_rect(image, results)
keypoint_vector = []
score_vector = []
rect_vecotr = []
rect_vecotr = det_rects
for batch_images, batch_records in batchs_images:
keypoint_result = topdown_keypoint_detector.predict(
batch_images, FLAGS.keypoint_threshold)
......@@ -76,14 +81,18 @@ def topdown_unite_predict(detector, topdown_keypoint_detector, image_list):
batch_records)
keypoint_vector.append(orgkeypoints)
score_vector.append(scores)
rect_vecotr.append(batch_records)
keypoint_res = {}
keypoint_res['keypoint'] = [
np.vstack(keypoint_vector), np.vstack(score_vector)
]
keypoint_res['bbox'] = rect_vecotr
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
draw_pose(
img_file, keypoint_res, visual_thread=FLAGS.keypoint_threshold)
img_file,
keypoint_res,
visual_thread=FLAGS.keypoint_threshold,
save_dir=FLAGS.output_dir)
def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id):
......@@ -92,8 +101,8 @@ def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id):
video_name = 'output.mp4'
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.basename(
os.path.split(FLAGS.video_file + '.mp4')[-1])
video_name = os.path.splitext(os.path.basename(FLAGS.video_file))[
0] + '.mp4'
fps = 30
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
......@@ -114,10 +123,9 @@ def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id):
frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = detector.predict(frame2, FLAGS.det_threshold)
batchs_images = get_person_from_rect(frame, results)
batchs_images, rect_vecotr = get_person_from_rect(frame2, results)
keypoint_vector = []
score_vector = []
rect_vecotr = []
for batch_images, batch_records in batchs_images:
keypoint_result = topdown_keypoint_detector.predict(
batch_images, FLAGS.keypoint_threshold)
......@@ -125,7 +133,6 @@ def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id):
batch_records)
keypoint_vector.append(orgkeypoints)
score_vector.append(scores)
rect_vecotr.append(batch_records)
keypoint_res = {}
keypoint_res['keypoint'] = [
np.vstack(keypoint_vector), np.vstack(score_vector)
......
......@@ -332,7 +332,13 @@ def predict_image(detector, image_list):
print('Test iter {}, file name:{}'.format(i, img_file))
else:
results = detector.predict(img_file, FLAGS.threshold)
draw_pose(img_file, results, visual_thread=FLAGS.threshold)
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
draw_pose(
img_file,
results,
visual_thread=FLAGS.threshold,
save_dir=FLAGS.output_dir)
def predict_video(detector, camera_id):
......
......@@ -28,6 +28,7 @@ def draw_pose(imgfile,
results,
visual_thread=0.6,
save_name='pose.jpg',
save_dir='output',
returnimg=False):
try:
import matplotlib.pyplot as plt
......@@ -56,8 +57,7 @@ def draw_pose(imgfile,
bboxs = results['bbox']
for idx, rect in enumerate(bboxs):
xmin, ymin, xmax, ymax = rect
cv2.rectangle(img, (xmin, ymin), (xmax, ymax),
colors[idx % len(colors)], 2)
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), colors[0], 1)
canvas = img.copy()
for i in range(17):
......@@ -100,7 +100,8 @@ def draw_pose(imgfile,
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
if returnimg:
return canvas
save_name = 'output/' + os.path.basename(imgfile)[:-4] + '_vis.jpg'
save_name = os.path.join(
save_dir, os.path.splitext(os.path.basename(imgfile))[0] + '_vis.jpg')
plt.imsave(save_name, canvas[:, :, ::-1])
print("keypoint visualize image saved to: " + save_name)
plt.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册