未验证 提交 6b42963d 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update hand_pose_localization (#1967)

* update hand_pose_localization

* add clean func
上级 0a26a1fa
......@@ -3,74 +3,167 @@ import numpy as np
from paddle.inference import create_predictor, Config
__all__ = ['Model']
__all__ = ['InferenceModel']
class Model():
class InferenceModel:
# 初始化函数
def __init__(self, modelpath, use_gpu=False, use_mkldnn=True, combined=True):
# 加载模型预测器
self.predictor = self.load_model(modelpath, use_gpu, use_mkldnn, combined)
def __init__(self,
modelpath,
use_gpu=False,
gpu_id=0,
use_mkldnn=False,
cpu_threads=1):
'''
init the inference model
modelpath: inference model path
use_gpu: use gpu or not
use_mkldnn: use mkldnn or not
'''
# 加载模型配置
self.config = self.load_config(modelpath, use_gpu, gpu_id, use_mkldnn, cpu_threads)
# 获取模型的输入输出
self.input_names = self.predictor.get_input_names()
self.output_names = self.predictor.get_output_names()
self.input_handle = self.predictor.get_input_handle(self.input_names[0])
self.output_handle = self.predictor.get_output_handle(self.output_names[0])
# 打印函数
def __repr__(self):
'''
get the numbers and name of inputs and outputs
'''
return 'input_num: %d\ninput_names: %s\noutput_num: %d\noutput_names: %s' % (
self.input_num,
str(self.input_names),
self.output_num,
str(self.output_names)
)
# 模型加载函数
def load_model(self, modelpath, use_gpu, use_mkldnn, combined):
# 类调用函数
def __call__(self, *input_datas, batch_size=1):
'''
call function
'''
return self.forward(*input_datas, batch_size=batch_size)
# 模型参数加载函数
def load_config(self, modelpath, use_gpu, gpu_id, use_mkldnn, cpu_threads):
'''
load the model config
modelpath: inference model path
use_gpu: use gpu or not
use_mkldnn: use mkldnn or not
'''
# 对运行位置进行配置
if use_gpu:
try:
int(os.environ.get('CUDA_VISIBLE_DEVICES'))
except Exception:
print(
'Error! Unable to use GPU. Please set the environment variables "CUDA_VISIBLE_DEVICES=GPU_id" to use GPU.'
)
'''Error! Unable to use GPU. Please set the environment variables "CUDA_VISIBLE_DEVICES=GPU_id" to use GPU. Now switch to CPU to continue...''')
use_gpu = False
# 加载模型参数
if combined:
model = os.path.join(modelpath, "__model__")
params = os.path.join(modelpath, "__params__")
if os.path.isdir(modelpath):
if os.path.exists(os.path.join(modelpath, "__params__")):
# __model__ + __params__
model = os.path.join(modelpath, "__model__")
params = os.path.join(modelpath, "__params__")
config = Config(model, params)
elif os.path.exists(os.path.join(modelpath, "params")):
# model + params
model = os.path.join(modelpath, "model")
params = os.path.join(modelpath, "params")
config = Config(model, params)
elif os.path.exists(os.path.join(modelpath, "__model__")):
# __model__ + others
config = Config(modelpath)
else:
raise Exception(
"Error! Can\'t find the model in: %s. Please check your model path." % os.path.abspath(modelpath))
elif os.path.exists(modelpath + ".pdmodel"):
# *.pdmodel + *.pdiparams
model = modelpath + ".pdmodel"
params = modelpath + ".pdiparams"
config = Config(model, params)
elif isinstance(modelpath, Config):
config = modelpath
else:
config = Config(modelpath)
raise Exception(
"Error! Can\'t find the model in: %s. Please check your model path." % os.path.abspath(modelpath))
# 设置参数
if use_gpu:
config.enable_use_gpu(100, 0)
config.enable_use_gpu(100, gpu_id)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(cpu_threads)
if use_mkldnn:
config.enable_mkldnn()
config.disable_glog_info()
config.switch_ir_optim(True)
config.enable_memory_optim()
config.switch_use_feed_fetch_ops(False)
config.switch_specify_input_names(True)
# 通过参数加载模型预测器
predictor = create_predictor(config)
# 返回配置
return config
# 返回预测器
return predictor
# 预测器创建函数
def eval(self):
'''
create the model predictor by model config
'''
# 创建预测器
self.predictor = create_predictor(self.config)
# 获取模型的输入输出名称
self.input_names = self.predictor.get_input_names()
self.output_names = self.predictor.get_output_names()
# 模型预测函数
def predict(self, input_datas):
outputs = []
# 获取模型的输入输出节点数量
self.input_num = len(self.input_names)
self.output_num = len(self.output_names)
# 获取输入
self.input_handles = []
for input_name in self.input_names:
self.input_handles.append(
self.predictor.get_input_handle(input_name))
# 获取输出
self.output_handles = []
for output_name in self.output_names:
self.output_handles.append(
self.predictor.get_output_handle(output_name))
# 前向计算函数
def forward(self, *input_datas, batch_size=1):
"""
model inference
batch_size: batch size
*input_datas: x1, x2, ..., xn
"""
# 切分输入数据
datas_num = input_datas[0].shape[0]
split_num = datas_num // batch_size + \
1 if datas_num % batch_size != 0 else datas_num // batch_size
input_datas = [np.array_split(input_data, split_num)
for input_data in input_datas]
# 遍历输入数据进行预测
for input_data in input_datas:
inputs = input_data.copy()
self.input_handle.copy_from_cpu(inputs)
outputs = {}
for step in range(split_num):
for i in range(self.input_num):
input_data = input_datas[i][step].copy()
self.input_handles[i].copy_from_cpu(input_data)
self.predictor.run()
output = self.output_handle.copy_to_cpu()
outputs.append(output)
for i in range(self.output_num):
output = self.output_handles[i].copy_to_cpu()
if i in outputs:
outputs[i].append(output)
else:
outputs[i] = [output]
# 预测结果合并
outputs = np.concatenate(outputs, 0)
for key in outputs.keys():
outputs[key] = np.concatenate(outputs[key], 0)
outputs = [v for v in outputs.values()]
# 返回预测结果
return outputs
return tuple(outputs) if len(outputs) > 1 else outputs[0]
\ No newline at end of file
# coding=utf-8
import os
from paddlehub import Module
import numpy as np
from paddlehub.module.module import moduleinfo, serving
from hand_pose_localization.model import Model
from hand_pose_localization.processor import base64_to_cv2, Processor
from .model import InferenceModel
from .processor import base64_to_cv2, Processor
@moduleinfo(
......@@ -14,16 +14,18 @@ from hand_pose_localization.processor import base64_to_cv2, Processor
author="jm12138", # 作者名称
author_email="jm12138@qq.com", # 作者邮箱
summary="hand_pose_localization", # 模型介绍
version="1.0.2" # 版本号
version="1.1.0" # 版本号
)
class Hand_Pose_Localization(Module):
class Hand_Pose_Localization:
# 初始化函数
def __init__(self, name=None, use_gpu=False):
def __init__(self, use_gpu=False):
# 设置模型路径
self.model_path = os.path.join(self.directory, "hand_pose_localization")
self.model_path = os.path.join(self.directory, "hand_pose_localization", "model")
# 加载模型
self.model = Model(modelpath=self.model_path, use_gpu=use_gpu, use_mkldnn=False, combined=True)
self.model = InferenceModel(modelpath=self.model_path, use_gpu=use_gpu)
self.model.eval()
# 关键点检测函数
def keypoint_detection(self, images=None, paths=None, batch_size=1, output_dir='output', visualization=False):
......@@ -31,7 +33,11 @@ class Hand_Pose_Localization(Module):
processor = Processor(images, paths, batch_size, output_dir)
# 模型预测
outputs = self.model.predict(processor.input_datas)
outputs = []
for input_data in processor.input_datas:
output = self.model(input_data)
outputs.append(output)
outputs = np.concatenate(outputs, 0)
# 结果后处理
results = processor.postprocess(outputs, visualization)
......
......@@ -130,8 +130,10 @@
适配paddlehub 2.0
* 1.1.0
* ```shell
$ hub install hand_pose_localization==1.0.1
$ hub install hand_pose_localization==1.1.0
```
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/8UAUuP97RlY/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjYxODQxMzI1&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="hand_pose_localization")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('output')
def test_keypoint_detection1(self):
results = self.module.keypoint_detection(
paths=['tests/test.jpg'],
visualization=False
)
kps = results[0]
self.assertIsInstance(kps, list)
def test_keypoint_detection2(self):
results = self.module.keypoint_detection(
images=[cv2.imread('tests/test.jpg')],
visualization=False
)
kps = results[0]
self.assertIsInstance(kps, list)
def test_keypoint_detection3(self):
results = self.module.keypoint_detection(
images=[cv2.imread('tests/test.jpg')],
visualization=True
)
kps = results[0]
self.assertIsInstance(kps, list)
def test_keypoint_detection4(self):
self.module = hub.Module(name="hand_pose_localization", use_gpu=True)
results = self.module.keypoint_detection(
images=[cv2.imread('tests/test.jpg')],
visualization=False
)
kps = results[0]
self.assertIsInstance(kps, list)
def test_keypoint_detection5(self):
self.assertRaises(
AssertionError,
self.module.keypoint_detection,
paths=['no.jpg']
)
def test_keypoint_detection6(self):
self.assertRaises(
AttributeError,
self.module.keypoint_detection,
images=['test.jpg']
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册