From c7421e641394f5149ed1e0dbd3d6f9f8eab35c74 Mon Sep 17 00:00:00 2001 From: jm12138 <2286040843@qq.com> Date: Mon, 16 Nov 2020 11:42:54 +0800 Subject: [PATCH] Add a hand_pose_localization module --- .../hand_pose_localization/README.md | 112 +++++++++++++++ .../hand_pose_localization/model.py | 70 ++++++++++ .../hand_pose_localization/module.py | 56 ++++++++ .../hand_pose_localization/processor.py | 127 ++++++++++++++++++ 4 files changed, 365 insertions(+) create mode 100644 hub_module/modules/image/keypoint_detection/hand_pose_localization/README.md create mode 100644 hub_module/modules/image/keypoint_detection/hand_pose_localization/model.py create mode 100644 hub_module/modules/image/keypoint_detection/hand_pose_localization/module.py create mode 100644 hub_module/modules/image/keypoint_detection/hand_pose_localization/processor.py diff --git a/hub_module/modules/image/keypoint_detection/hand_pose_localization/README.md b/hub_module/modules/image/keypoint_detection/hand_pose_localization/README.md new file mode 100644 index 00000000..eb143ada --- /dev/null +++ b/hub_module/modules/image/keypoint_detection/hand_pose_localization/README.md @@ -0,0 +1,112 @@ +## 模型概述 +openpose 手部关键点检测模型 + +模型详情请参考[openpose开源项目](https://github.com/CMU-Perceptual-Computing-Lab/openpose) + +## 模型安装 + +```shell +$hub install hand_pose_localization +``` + +## API 说明 + +```python +def keypoint_detection( + self, + images=None, + paths=None, + batch_size=1, + output_dir='output', + visualization=False +) +``` + +预测API,识别出人体手部关键点。 + +![手部关键点](https://ai-studio-static-online.cdn.bcebos.com/97e1ae7c1e68477d85b37f53ee997fbc4ef0fc12c7634301bc08749bd003cac0) + +**参数** + +* images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\], 默认设为 None; +* paths (list\[str\]): 图片的路径, 默认设为 None; +* batch\_size (int): batch 的大小,默认设为 1; +* visualization (bool): 是否将识别结果保存为图片文件,默认设为 False; +* output\_dir (str): 图片的保存路径,默认设为 output。 + +**返回** + +* res (list[list[list[int]]]): 每张图片识别到的21个手部关键点组成的列表,每个关键点的格式为[x, y],若有关键点未识别到则为None + + +## 预测代码示例 + +```python +import cv2 +import paddlehub as hub + +# use_gpu:是否使用GPU进行预测 +model = hub.Module('hand_pose_localization', use_gpu=False) + +# 调用关键点检测API +result = model.keypoint_detection(images=[cv2.imread('/PATH/TO/IMAGE')]) + +# or +# result = model.keypoint_detection(paths=['/PATH/TO/IMAGE']) + +# 打印预测结果 +print(result) +``` + +## 服务部署 + +PaddleHub Serving可以部署一个在线人体手部关键点检测服务。 + +## 第一步:启动PaddleHub Serving + +运行启动命令: +```shell +$ hub serving start -m hand_pose_localization +``` + +这样就完成了一个人体手部关键点检测的在线服务API的部署,默认端口号为8866。 + +**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。 + +## 第二步:发送预测请求 + +配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果 + +```python +import requests +import json +import cv2 +import base64 + +# 图片Base64编码函数 +def cv2_to_base64(image): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tostring()).decode('utf8') + +# 发送HTTP请求 +data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]} +headers = {"Content-type": "application/json"} +url = "http://127.0.0.1:8866/predict/hand_pose_localization" +r = requests.post(url=url, headers=headers, data=json.dumps(data)) + +# 打印预测结果 +print(r.json()["results"]) +``` + + +## 模型相关信息 + +### 模型代码 + +https://github.com/CMU-Perceptual-Computing-Lab/openpose + +### 依赖 + +paddlepaddle >= 1.8.0 + +paddlehub >= 1.8.0 diff --git a/hub_module/modules/image/keypoint_detection/hand_pose_localization/model.py b/hub_module/modules/image/keypoint_detection/hand_pose_localization/model.py new file mode 100644 index 00000000..c79e8177 --- /dev/null +++ b/hub_module/modules/image/keypoint_detection/hand_pose_localization/model.py @@ -0,0 +1,70 @@ +import os +import numpy as np + +from paddle.fluid.core import AnalysisConfig, create_paddle_predictor + +__all__ = ['Model'] + +class Model(): + # 初始化函数 + def __init__(self, modelpath, use_gpu): + # 加载模型预测器 + self.predictor = self.load_model(modelpath, use_gpu) + + # 获取模型的输入输出 + self.input_names = self.predictor.get_input_names() + self.output_names = self.predictor.get_output_names() + self.input_tensor = self.predictor.get_input_tensor(self.input_names[0]) + self.output_tensor = self.predictor.get_output_tensor(self.output_names[0]) + + # 模型加载函数 + def load_model(self, modelpath, use_gpu): + # 对运行位置进行配置 + if use_gpu: + try: + places = os.environ["CUDA_VISIBLE_DEVICES"] + places = int(places[0]) + except Exception as e: + print('Error: %s. Please set the environment variables "CUDA_VISIBLE_DEVICES".' % e) + use_gpu = False + + # 预训练模型路径 + model = os.path.join(modelpath, "__model__") + params = os.path.join(modelpath, "__params__") + + # 加载模型参数 + config = AnalysisConfig(model, params) + + # 设置参数 + if use_gpu: + config.enable_use_gpu(100, places) + else: + config.disable_gpu() + config.enable_mkldnn() + config.disable_glog_info() + config.switch_ir_optim(True) + config.switch_use_feed_fetch_ops(False) + config.switch_specify_input_names(True) + + # 通过参数加载模型预测器 + predictor = create_paddle_predictor(config) + + # 返回预测器 + return predictor + + # 模型预测函数 + def predict(self, input_datas): + outputs = [] + + # 遍历输入数据进行预测 + for input_data in input_datas: + self.input_tensor.copy_from_cpu(input_data) + self.predictor.zero_copy_run() + output = self.output_tensor.copy_to_cpu() + outputs.append(output) + + # 预测结果合并 + outputs = np.concatenate(outputs, 0) + + # 返回预测结果 + return outputs \ No newline at end of file diff --git a/hub_module/modules/image/keypoint_detection/hand_pose_localization/module.py b/hub_module/modules/image/keypoint_detection/hand_pose_localization/module.py new file mode 100644 index 00000000..74a41253 --- /dev/null +++ b/hub_module/modules/image/keypoint_detection/hand_pose_localization/module.py @@ -0,0 +1,56 @@ +# coding=utf-8 +import os + +from paddlehub import Module +from paddlehub.module.module import moduleinfo, serving + +from hand_pose_localization.model import Model +from hand_pose_localization.processor import base64_to_cv2, Processor + +@moduleinfo( + name="hand_pose_localization", # 模型名称 + type="CV/keypoint_detection", # 模型类型 + author="jm12138", # 作者名称 + author_email="jm12138@qq.com", # 作者邮箱 + summary="hand_pose_localization", # 模型介绍 + version="1.0.0" # 版本号 +) +class Hand_Pose_Localization(Module): + # 初始化函数 + def _initialize(self, use_gpu=False): + # 设置模型路径 + self.model_path = os.path.join(self.directory, "hand_pose_localization") + + # 加载模型 + self.model = Model(self.model_path, use_gpu) + + # 关键点检测函数 + def keypoint_detection( + self, + images=None, + paths=None, + batch_size=1, + output_dir='output', + visualization=False + ): + # 加载数据处理器 + processor = Processor(images, paths, batch_size, output_dir) + + # 模型预测 + outputs = self.model.predict(processor.input_datas) + + # 结果后处理 + results = processor.postprocess(outputs, visualization) + + # 返回结果 + return results + + # Hub Serving + @serving + def serving_method(self, images, **kwargs): + # 获取输入数据 + images_decode = [base64_to_cv2(image) for image in images] + # 关键点检测 + results = self.keypoint_detection(images_decode, **kwargs) + # 返回结果 + return results \ No newline at end of file diff --git a/hub_module/modules/image/keypoint_detection/hand_pose_localization/processor.py b/hub_module/modules/image/keypoint_detection/hand_pose_localization/processor.py new file mode 100644 index 00000000..ace7243c --- /dev/null +++ b/hub_module/modules/image/keypoint_detection/hand_pose_localization/processor.py @@ -0,0 +1,127 @@ +import os +import cv2 +import time +import base64 +import numpy as np + +__all__ = ['base64_to_cv2', 'Processor'] + +def check_dir(dir_path): + # 目录检查函数 + if not os.path.exists(dir_path): + os.makedirs(dir_path) + elif os.path.isfile(dir_path): + os.remove(dir_path) + os.makedirs(dir_path) + +def base64_to_cv2(b64str): + # base64转cv2函数 + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data + +class Processor(): + # 初始化函数 + def __init__(self, images=None, paths=None, batch_size=1, output_dir='output'): + # 变量设置 + self.num_points = 21 + self.inHeight = 368 + self.threshold = 0.1 + self.point_pairs = [[0,1],[1,2],[2,3],[3,4], + [0,5],[5,6],[6,7],[7,8], + [0,9],[9,10],[10,11],[11,12], + [0,13],[13,14],[14,15],[15,16], + [0,17],[17,18],[18,19],[19,20]] + + self.images = images + self.paths = paths + self.batch_size = batch_size + self.output_dir = output_dir + + # 获取原始输入数据 + self.datas = self.load_datas() + + # 对原始输入数据进行预处理 + self.input_datas = self.preprocess() + + # 读取数据函数 + def load_datas(self): + datas = [] + + # 读取数据列表 + if self.paths is not None: + for im_path in self.paths: + assert os.path.isfile(im_path), "The {} isn't a valid file path.".format(im_path) + im = cv2.imread(im_path).astype('float32') + datas.append(im) + + if self.images is not None: + datas = self.images + + # 返回数据列表 + return datas + + # 数据预处理函数 + def preprocess(self): + input_datas = [] + + # 数据预处理 + for i, img in enumerate(self.datas): + img_height, img_width, _ = img.shape + aspect_ratio = img_width / img_height + inWidth = int(((aspect_ratio * self.inHeight) * 8) // 8) + inpBlob = cv2.dnn.blobFromImage(img, 1.0 / 255, (inWidth, self.inHeight), (0, 0, 0), swapRB=False, crop=False) + input_datas.append(inpBlob) + + # 数据按batch_size切分 + input_datas = np.concatenate(input_datas, 0) + split_num = len(self.datas)//self.batch_size+1 if len(self.datas)%self.batch_size!=0 else len(self.datas)//self.batch_size + input_datas = np.array_split(input_datas, split_num) + + # 返回预处理完成的数据 + return input_datas + + # 结果后处理函数 + def postprocess(self, outputs, visualization): + all_points = [] + + # 结果后处理 + for im_id, img in enumerate(self.datas): + points = [] + for idx in range(self.num_points): + probMap = outputs[im_id, idx, :, :] + img_height, img_width, _ = img.shape + probMap = cv2.resize(probMap, (img_width, img_height)) + minVal, prob, minLoc, point = cv2.minMaxLoc(probMap) + + if prob > self.threshold: + points.append([int(point[0]), int(point[1])]) + else: + points.append(None) + + all_points.append(points) + + # 结果可视化 + if visualization: + # 检查输出目录 + check_dir(self.output_dir) + # 结果可视化 + self.vis_pose(img, points, im_id) + + # 返回后处理结果 + return all_points + + # 结果可视化 + def vis_pose(self, img, points, im_id): + # 根据结果绘制关键点到原图像上 + for pair in self.point_pairs: + partA = pair[0] + partB = pair[1] + + if points[partA] and points[partB]: + cv2.line(img, tuple(points[partA]), tuple(points[partB]), (0, 255, 255), 3) + cv2.circle(img, tuple(points[partA]), 8, (0, 0, 255), thickness=-1, lineType=cv2.FILLED) + + # 可视化图像保存 + cv2.imwrite(os.path.join(self.output_dir, '%d_%d.jpg' % (im_id, time.time())), img) -- GitLab