未验证 提交 879383e8 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

add hub convert (#773)

上级 c728bd6e
......@@ -27,3 +27,4 @@ from . import config
from . import hub
from . import autofinetune
from . import serving
from . import convert
#coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
import tarfile
import shutil
from string import Template
from paddlehub.common import tmp_dir
from paddlehub.commands.base_command import BaseCommand, ENTRY
INIT_FILE = '__init__.py'
MODULE_FILE = 'module.py'
SERVING_FILE = 'serving_client_demo.py'
TMPL_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'tmpl')
class ConvertCommand(BaseCommand):
name = "convert"
def __init__(self, name):
super(ConvertCommand, self).__init__(name)
self.show_in_help = True
self.description = "Convert model to PaddleHub-Module."
self.parser = argparse.ArgumentParser(
description=self.__class__.__doc__,
prog='%s %s [COMMAND]' % (ENTRY, name),
usage='%(prog)s',
add_help=True)
self.parser.add_argument('command')
self.parser.add_argument('--module_name', '-n')
self.parser.add_argument(
'--module_version', '-v', nargs='?', default='1.0.0')
self.parser.add_argument('--model_dir', '-d')
self.parser.add_argument('--output_dir', '-o')
def create_module_tar(self):
if not os.path.exists(self.dest):
os.makedirs(self.dest)
tar_file = os.path.join(self.dest, '{}.tar.gz'.format(self.module))
with tarfile.open(tar_file, 'w:gz') as tfp:
tfp.add(self.dest, recursive=False, arcname=self.module)
for root, dir, files in os.walk(self.src):
for file in files:
fullpath = os.path.join(root, file)
arcname = os.path.join(self.module, 'assets', file)
tfp.add(fullpath, arcname=arcname)
tfp.add(
self.model_file, arcname=os.path.join(self.module, MODULE_FILE))
tfp.add(
self.serving_file,
arcname=os.path.join(self.module, SERVING_FILE))
tfp.add(
self.init_file, arcname=os.path.join(self.module, INIT_FILE))
def create_module_py(self):
template_file = open(os.path.join(TMPL_DIR, 'x_model.tmpl'), 'r')
tmpl = Template(template_file.read())
lines = []
lines.append(
tmpl.substitute(
NAME="'{}'".format(self.module),
TYPE="'CV'",
AUTHOR="'Baidu'",
SUMMARY="''",
VERSION="'{}'".format(self.version),
EMAIL="''"))
# self.model_file = os.path.join(self.dest, MODULE_FILE)
self.model_file = os.path.join(self._tmp_dir, MODULE_FILE)
if os.path.exists(self.model_file):
raise RuntimeError(
'File `{MODULE_FILE}` is already exists in src dir.'.format(
MODULE_FILE))
with open(self.model_file, 'w') as fp:
fp.writelines(lines)
def create_init_py(self):
# self.init_file = os.path.join(self.dest, INIT_FILE)
self.init_file = os.path.join(self._tmp_dir, INIT_FILE)
if os.path.exists(self.init_file):
return
shutil.copyfile(os.path.join(TMPL_DIR, 'init_py.tmpl'), self.init_file)
def create_serving_demo_py(self):
template_file = open(os.path.join(TMPL_DIR, 'serving_demo.tmpl'), 'r')
tmpl = Template(template_file.read())
lines = []
lines.append(tmpl.substitute(MODULE_NAME=self.module))
# self.serving_file = os.path.join(self.dest, SERVING_FILE)
self.serving_file = os.path.join(self._tmp_dir, SERVING_FILE)
if os.path.exists(self.serving_file):
raise RuntimeError(
'File `{}` is already exists in src dir.'.format(SERVING_FILE))
with open(self.serving_file, 'w') as fp:
fp.writelines(lines)
@staticmethod
def show_help():
str = "convert --module <module> [--version <version>] --dest dest_dir --src srd_dir\n"
str += "\tConvert model to PaddleHub-Module.\n"
str += "--model_dir\n"
str += "\tDir of model you want to export.\n"
str += "--module_name:\n"
str += "\tSet name of module.\n"
str += "--module_version\n"
str += "\tSet version of module, default is `1.0.0`.\n"
str += "--output_dir\n"
str += "\tDir to save PaddleHub-Module after exporting, default is `.`.\n"
print(str)
return
def execute(self, argv):
args = self.parser.parse_args()
if not args.module_name or not args.model_dir:
ConvertCommand.show_help()
return False
self.module = args.module_name
self.version = args.module_version if args.module_version is not None else '1.0.0'
self.src = args.model_dir
self.dest = args.output_dir if args.output_dir is not None else os.path.join(
'{}_{}'.format(self.module, str(time.time())))
os.makedirs(self.dest)
with tmp_dir() as _dir:
self._tmp_dir = _dir
self.create_module_py()
self.create_init_py()
self.create_serving_demo_py()
self.create_module_tar()
print('The converted module is stored in `{}`.'.format(self.dest))
return True
def run(self, module, version, src, dest):
self.module = module
self.version = version
self.src = src
self.dest = dest
os.makedirs(self.dest)
with tmp_dir() as _dir:
self._tmp_dir = _dir
self.create_module_py()
self.create_init_py()
self.create_serving_demo_py()
self.create_module_tar()
return True
command = ConvertCommand.instance()
if __name__ == '__main__':
command.run('test_module_name', '1.1.1', './new_model', './new_module')
# coding: utf8
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
if __name__ == '__main__':
# 获取图片的base64编码格式
img1 = cv2_to_base64(cv2.imread("IMAGE_PATH1"))
img2 = cv2_to_base64(cv2.imread("IMAGE_PATH2"))
data = {'images': [img1, img2]}
# 指定content-type
headers = {"Content-type": "application/json"}
# 发送HTTP请求
url = "http://127.0.0.1:8866/predict/${MODULE_NAME}"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
from __future__ import absolute_import
from __future__ import division
import os
import cv2
import argparse
import base64
import paddlex as pdx
import numpy as np
import paddlehub as hub
from paddlehub.module.module import moduleinfo, runnable, serving
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
def cv2_to_base64(image):
return base64.b64encode(image)
# data = cv2.imencode('.jpg', image)[1]
# return base64.b64encode(data.tostring()).decode('utf8')
def read_images(paths):
images = []
for path in paths:
images.append(cv2.imread(path))
return images
@moduleinfo(
name=${NAME},
type=${TYPE},
author=${AUTHOR},
author_email=${EMAIL},
summary=${SUMMARY},
version=${VERSION})
class MODULE(hub.Module):
def _initialize(self, **kwargs):
self.default_pretrained_model_path = os.path.join(
self.directory, 'assets')
self.model = pdx.deploy.Predictor(self.default_pretrained_model_path,
**kwargs)
def predict(self,
images=None,
paths=None,
data=None,
batch_size=1,
use_gpu=False,
**kwargs):
all_data = images if images is not None else read_images(paths)
total_num = len(all_data)
loop_num = int(np.ceil(total_num / batch_size))
res = []
for iter_id in range(loop_num):
batch_data = list()
handle_id = iter_id * batch_size
for image_id in range(batch_size):
try:
batch_data.append(all_data[handle_id + image_id])
except IndexError:
break
out = self.model.batch_predict(batch_data, **kwargs)
res.extend(out)
return res
@serving
def serving_method(self, images, **kwargs):
"""
Run as a service.
"""
images_decode = [base64_to_cv2(image) for image in images]
results = self.predict(images_decode, **kwargs)
res = []
for result in results:
if isinstance(result, dict):
# result_new = dict()
for key, value in result.items():
if isinstance(value, np.ndarray):
result[key] = cv2_to_base64(value)
elif isinstance(value, np.generic):
result[key] = np.asscalar(value)
elif isinstance(result, list):
for index in range(len(result)):
for key, value in result[index].items():
if isinstance(value, np.ndarray):
result[index][key] = cv2_to_base64(value)
elif isinstance(value, np.generic):
result[index][key] = np.asscalar(value)
else:
raise RuntimeError('The result cannot be used in serving.')
res.append(result)
return res
@runnable
def run_cmd(self, argvs):
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(
description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(
title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options",
description=
"Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
results = self.predict(
paths=[args.input_path],
use_gpu=args.use_gpu)
return results
def add_module_config_arg(self):
"""
Add the command config options.
"""
self.arg_config_group.add_argument(
'--use_gpu',
type=bool,
default=False,
help="whether use GPU or not")
def add_module_input_arg(self):
"""
Add the command input options.
"""
self.arg_input_group.add_argument(
'--input_path', type=str, help="path to image.")
if __name__ == '__main__':
module = MODULE(directory='./new_model')
images = [cv2.imread('./cat.jpg'), cv2.imread('./cat.jpg'), cv2.imread('./cat.jpg')]
res = module.predict(images=images)
......@@ -44,13 +44,23 @@ setup(
'paddlehub/serving/templates': [
'paddlehub/serving/templates/serving_config.json',
'paddlehub/serving/templates/main.html'
],
'paddlehub/command/tmpl': [
'paddlehub/command/tmpl/init_py.tmpl',
'paddlehub/command/tmpl/serving_demo.tmpl',
'paddlehub/command/tmpl/x_model.tmpl'
]
},
include_package_data=True,
data_files=[('paddlehub/serving/templates', [
'paddlehub/serving/templates/serving_config.json',
'paddlehub/serving/templates/main.html'
])],
]),
('paddlehub/commands/tmpl', [
'paddlehub/commands/tmpl/init_py.tmpl',
'paddlehub/commands/tmpl/serving_demo.tmpl',
'paddlehub/commands/tmpl/x_model.tmpl'
])],
include_data_files=True,
# PyPI package information.
classifiers=[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册