未验证 提交 17f205f8 编写于 作者: F freshield.eth 提交者: GitHub

add python example for face detection part (#7543)

* add python example for face detection part

为了方便二次开发,这里添加了直接使用Python脚本的形式进行人脸检测的示例。

* change the demo image
上级 e4eb82f4
......@@ -111,6 +111,58 @@ legend_name = 'Paddle-BlazeFace';
matlab -nodesktop -nosplash -nojvm -r "run wider_eval.m;quit;"
```
### Python脚本预测
为了支持二次开发,这里提供通过Python脚本使用Paddle Detection whl包来进行预测的示例。
```python
import cv2
import paddle
import numpy as np
from ppdet.core.workspace import load_config
from ppdet.engine import Trainer
from ppdet.metrics import get_infer_results
from ppdet.data.transform.operators import NormalizeImage, Permute
if __name__ == '__main__':
# 准备基础的参数
config_path = 'PaddleDetection/configs/face_detection/blazeface_1000e.yml'
cfg = load_config(config_path)
weight_path = 'PaddleDetection/output/blazeface_1000e.pdparams'
infer_img_path = 'PaddleDetection/demo/hrnet_demo.jpg'
cfg.weights = weight_path
bbox_thre = 0.8
paddle.set_device('gpu')
# 创建所需的类
trainer = Trainer(cfg, mode='test')
trainer.load_weights(cfg.weights)
trainer.model.eval()
normaler = NormalizeImage(mean=[123, 117, 104], std=[127.502231, 127.502231, 127.502231], is_scale=False)
permuter = Permute()
# 进行图片读取
im = cv2.imread(infer_img_path)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
# 准备数据字典
data_dict = {'image': im}
data_dict = normaler(data_dict)
data_dict = permuter(data_dict)
h, w, c = im.shape
data_dict['im_id'] = paddle.Tensor(np.array([[0]]))
data_dict['im_shape'] = paddle.Tensor(np.array([[h, w]], dtype=np.float32))
data_dict['scale_factor'] = paddle.Tensor(np.array([[1., 1.]], dtype=np.float32))
data_dict['image'] = paddle.Tensor(data_dict['image'].reshape((1, c, h, w)))
data_dict['curr_iter'] = paddle.Tensor(np.array([0]))
# 进行预测
outs = trainer.model(data_dict)
# 对预测的数据进行后处理得到最终的bbox信息
for key in ['im_shape', 'scale_factor', 'im_id']:
outs[key] = data_dict[key]
for key, value in outs.items():
outs[key] = value.numpy()
clsid2catid, catid2name = {0: 'face'}, {0: 0}
batch_res = get_infer_results(outs, clsid2catid)
bbox = [sub_dict for sub_dict in batch_res['bbox'] if sub_dict['score'] > bbox_thre]
print(bbox)
```
## Citations
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册