未验证 提交 3f787a82 编写于 作者: H haoyuying 提交者: GitHub

add convert and config

上级 12be5b91
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import re
import hashlib
import uuid
import time
from paddlehub.env import CONF_HOME
from paddlehub.commands import register
from paddlehub.utils.utils import md5
default_server_config = {
"server_url": ["http://paddlepaddle.org.cn/paddlehub"],
"resource_storage_server_url": "https://bj.bcebos.com/paddlehub-data/",
"debug": False,
"log_level": "DEBUG",
"hub_name": md5(str(uuid.uuid1())[-12:]) + "-" + str(int(time.time()))
}
@register(name='hub.config', description='Configure PaddleHub.')
class ConfigCommand:
@staticmethod
def show_config():
print("The current configuration is shown below.")
with open(os.path.join(CONF_HOME, "config.json"), "r") as fp:
print(json.dumps(json.load(fp), indent=4))
@staticmethod
def set_server_url(server_url):
with open(os.path.join(CONF_HOME, "config.json"), "r") as fp:
config = json.load(fp)
re_str = r"^(?:http(s)?:\/\/)?[\w.-]+(?:\.[\w\.-]+)+[\w\-\._~:/?#[\]@!\$&'\*\+,;=.]+$"
if re.match(re_str, server_url) is not None:
config["server_url"] = list([server_url])
ConfigCommand.set_config(config)
else:
print("The format of the input url is invalid.")
@staticmethod
def set_config(config):
with open(os.path.join(CONF_HOME, "config.json"), "w") as fp:
fp.write(json.dumps(config))
print("Set success! The current configuration is shown below.")
print(json.dumps(config, indent=4))
@staticmethod
def set_log_level(level):
level = str(level).upper()
if level not in ["NOLOG", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
print("Allowed values include: " "NOLOG, DEBUG, INFO, WARNING, ERROR, CRITICAL")
return
with open(os.path.join(CONF_HOME, "config.json"), "r") as fp:
current_config = json.load(fp)
with open(os.path.join(CONF_HOME, "config.json"), "w") as fp:
current_config["log_level"] = level
fp.write(json.dumps(current_config))
print("Set success! The current configuration is shown below.")
print(json.dumps(current_config, indent=4))
@staticmethod
def show_help():
str = "config <option>\n"
str += "\tShow PaddleHub config without any option.\n"
str += "option:\n"
str += "reset\n"
str += "\tReset config as default.\n"
str += "server==[URL]\n"
str += "\tSet PaddleHub Server url as [URL].\n"
str += "log==[LEVEL]\n"
str += "\tSet log level as [LEVEL:NOLOG, DEBUG, INFO, WARNING, ERROR, CRITICAL].\n"
print(str)
def execute(self, argv):
if not argv:
ConfigCommand.show_config()
for arg in argv:
if arg == "reset":
ConfigCommand.set_config(default_server_config)
elif arg.startswith("server=="):
ConfigCommand.set_server_url(arg.split("==")[1])
elif arg.startswith("log=="):
ConfigCommand.set_log_level(arg.split("==")[1])
else:
ConfigCommand.show_help()
return True
#coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
import tarfile
import shutil
from string import Template
from paddlehub.env import TMP_HOME as tmp_dir
from paddlehub.commands import register
from paddlehub.utils.xarfile import XarFile
INIT_FILE = '__init__.py'
MODULE_FILE = 'module.py'
SERVING_FILE = 'serving_client_demo.py'
TMPL_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'tmpl')
@register(name='hub.convert', description='Convert model to PaddleHub-Module.')
class ConvertCommand:
def __init__(self):
super(ConvertCommand, self).__init__()
self.parser = argparse.ArgumentParser()
self.parser.add_argument('command')
self.parser.add_argument('--module_name', '-n')
self.parser.add_argument('--module_version', '-v', nargs='?', default='1.0.0')
self.parser.add_argument('--model_dir', '-d')
self.parser.add_argument('--output_dir', '-o')
def create_module_tar(self):
if not os.path.exists(self.dest):
os.makedirs(self.dest)
tar_file = os.path.join(self.dest, '{}.tar.gz'.format(self.module))
tfp = XarFile(tar_file, 'w', 'tar.gz')
tfp.add(self.dest, self.module, False)
for root, dir, files in os.walk(self.src):
for file in files:
fullpath = os.path.join(root, file)
arcname = os.path.join(self.module, 'assets', file)
tfp.add(fullpath, arcname=arcname)
tfp.add(name=self.model_file, arcname=os.path.join(self.module, MODULE_FILE))
tfp.add(name=self.serving_file, arcname=os.path.join(self.module, SERVING_FILE))
tfp.add(name=self.init_file, arcname=os.path.join(self.module, INIT_FILE))
def create_module_py(self):
template_file = open(os.path.join(TMPL_DIR, 'x_model.tmpl'), 'r', encoding='utf-8')
tmpl = Template(template_file.read())
lines = []
lines.append(
tmpl.substitute(
NAME="'{}'".format(self.module),
TYPE="'CV'",
AUTHOR="'Baidu'",
SUMMARY="''",
VERSION="'{}'".format(self.version),
EMAIL="''"))
self.model_file = os.path.join(self._tmp_dir, MODULE_FILE)
with open(self.model_file, 'w', encoding='utf-8') as fp:
fp.writelines(lines)
def create_init_py(self):
self.init_file = os.path.join(self._tmp_dir, INIT_FILE)
if os.path.exists(self.init_file):
return
shutil.copyfile(os.path.join(TMPL_DIR, 'init_py.tmpl'), self.init_file)
def create_serving_demo_py(self):
template_file = open(os.path.join(TMPL_DIR, 'serving_demo.tmpl'), 'r', encoding='utf-8')
tmpl = Template(template_file.read())
lines = []
lines.append(tmpl.substitute(MODULE_NAME=self.module))
self.serving_file = os.path.join(self._tmp_dir, SERVING_FILE)
with open(self.serving_file, 'w', encoding='utf-8') as fp:
fp.writelines(lines)
@staticmethod
def show_help():
str = "convert --module <module> [--version <version>] --dest dest_dir --src srd_dir\n"
str += "\tConvert model to PaddleHub-Module.\n"
str += "--model_dir\n"
str += "\tDir of model you want to export.\n"
str += "--module_name:\n"
str += "\tSet name of module.\n"
str += "--module_version\n"
str += "\tSet version of module, default is `1.0.0`.\n"
str += "--output_dir\n"
str += "\tDir to save PaddleHub-Module after exporting, default is `.`.\n"
print(str)
return
def execute(self, argv):
args = self.parser.parse_args()
if not args.module_name or not args.model_dir:
ConvertCommand.show_help()
return False
self.module = args.module_name
self.version = args.module_version if args.module_version is not None else '1.0.0'
self.src = args.model_dir
if not os.path.isdir(self.src):
print('`{}` is not exists or not a directory path'.format(self.src))
return False
self.dest = args.output_dir if args.output_dir is not None else os.path.join('{}_{}'.format(
self.module, str(time.time())))
os.makedirs(self.dest)
self._tmp_dir = tmp_dir
self.create_module_py()
self.create_init_py()
self.create_serving_demo_py()
self.create_module_tar()
print('The converted module is stored in `{}`.'.format(self.dest))
return True
# coding: utf8
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
if __name__ == '__main__':
# 获取图片的base64编码格式
img1 = cv2_to_base64(cv2.imread("IMAGE_PATH1"))
img2 = cv2_to_base64(cv2.imread("IMAGE_PATH2"))
data = {'images': [img1, img2]}
# 指定content-type
headers = {"Content-type": "application/json"}
# 发送HTTP请求
url = "http://127.0.0.1:8866/predict/${MODULE_NAME}"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
from __future__ import absolute_import
from __future__ import division
import os
import cv2
import argparse
import base64
import paddlex as pdx
import numpy as np
import paddlehub as hub
from paddlehub.module.module import moduleinfo, runnable, serving
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
def cv2_to_base64(image):
# return base64.b64encode(image)
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
def read_images(paths):
images = []
for path in paths:
images.append(cv2.imread(path))
return images
@moduleinfo(
name=${NAME},
type=${TYPE},
author=${AUTHOR},
author_email=${EMAIL},
summary=${SUMMARY},
version=${VERSION})
class MODULE(hub.Module):
def _initialize(self, **kwargs):
self.default_pretrained_model_path = os.path.join(
self.directory, 'assets')
self.model = pdx.deploy.Predictor(self.default_pretrained_model_path,
**kwargs)
def predict(self,
images=None,
paths=None,
data=None,
batch_size=1,
use_gpu=False,
**kwargs):
all_data = images if images is not None else read_images(paths)
total_num = len(all_data)
loop_num = int(np.ceil(total_num / batch_size))
res = []
for iter_id in range(loop_num):
batch_data = list()
handle_id = iter_id * batch_size
for image_id in range(batch_size):
try:
batch_data.append(all_data[handle_id + image_id])
except IndexError:
break
out = self.model.batch_predict(batch_data, **kwargs)
res.extend(out)
return res
@serving
def serving_method(self, images, **kwargs):
"""
Run as a service.
"""
images_decode = [base64_to_cv2(image) for image in images]
results = self.predict(images_decode, **kwargs)
res = []
for result in results:
if isinstance(result, dict):
# result_new = dict()
for key, value in result.items():
if isinstance(value, np.ndarray):
result[key] = cv2_to_base64(value)
elif isinstance(value, np.generic):
result[key] = np.asscalar(value)
elif isinstance(result, list):
for index in range(len(result)):
for key, value in result[index].items():
if isinstance(value, np.ndarray):
result[index][key] = cv2_to_base64(value)
elif isinstance(value, np.generic):
result[index][key] = np.asscalar(value)
else:
raise RuntimeError('The result cannot be used in serving.')
res.append(result)
return res
@runnable
def run_cmd(self, argvs):
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(
description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(
title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options",
description=
"Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
results = self.predict(
paths=[args.input_path],
use_gpu=args.use_gpu)
return results
def add_module_config_arg(self):
"""
Add the command config options.
"""
self.arg_config_group.add_argument(
'--use_gpu',
type=bool,
default=False,
help="whether use GPU or not")
def add_module_input_arg(self):
"""
Add the command input options.
"""
self.arg_input_group.add_argument(
'--input_path', type=str, help="path to image.")
if __name__ == '__main__':
module = MODULE(directory='./new_model')
images = [cv2.imread('./cat.jpg'), cv2.imread('./cat.jpg'), cv2.imread('./cat.jpg')]
res = module.predict(images=images)
...@@ -41,16 +41,12 @@ class Version(packaging.version.Version): ...@@ -41,16 +41,12 @@ class Version(packaging.version.Version):
def match(self, condition: str) -> bool: def match(self, condition: str) -> bool:
''' '''
Determine whether the given condition are met Determine whether the given condition are met
Args: Args:
condition(str) : conditions for judgment condition(str) : conditions for judgment
Returns: Returns:
bool: True if the given version condition are met, else False bool: True if the given version condition are met, else False
Examples: Examples:
.. code-block:: python .. code-block:: python
Version('1.2.0').match('>=1.2.0a') Version('1.2.0').match('>=1.2.0a')
''' '''
if not condition: if not condition:
...@@ -184,14 +180,11 @@ def generate_tempdir(directory: str = None, **kwargs): ...@@ -184,14 +180,11 @@ def generate_tempdir(directory: str = None, **kwargs):
def download(url: str, path: str = None) -> str: def download(url: str, path: str = None) -> str:
''' '''
Download a file Download a file
Args: Args:
url (str) : url to be downloaded url (str) : url to be downloaded
path (str, optional) : path to store downloaded products, default is current work directory path (str, optional) : path to store downloaded products, default is current work directory
Examples: Examples:
.. code-block:: python .. code-block:: python
url = 'https://xxxxx.xx/xx.tar.gz' url = 'https://xxxxx.xx/xx.tar.gz'
download(url, path='./output') download(url, path='./output')
''' '''
...@@ -203,14 +196,11 @@ def download(url: str, path: str = None) -> str: ...@@ -203,14 +196,11 @@ def download(url: str, path: str = None) -> str:
def download_with_progress(url: str, path: str = None) -> Generator[str, int, int]: def download_with_progress(url: str, path: str = None) -> Generator[str, int, int]:
''' '''
Download a file and return the downloading progress -> Generator[filename, download_size, total_size] Download a file and return the downloading progress -> Generator[filename, download_size, total_size]
Args: Args:
url (str) : url to be downloaded url (str) : url to be downloaded
path (str, optional) : path to store downloaded products, default is current work directory path (str, optional) : path to store downloaded products, default is current work directory
Examples: Examples:
.. code-block:: python .. code-block:: python
url = 'https://xxxxx.xx/xx.tar.gz' url = 'https://xxxxx.xx/xx.tar.gz'
for filename, download_size, total_szie in download_with_progress(url, path='./output'): for filename, download_size, total_szie in download_with_progress(url, path='./output'):
print(filename, download_size, total_size) print(filename, download_size, total_size)
...@@ -236,7 +226,6 @@ def download_with_progress(url: str, path: str = None) -> Generator[str, int, in ...@@ -236,7 +226,6 @@ def download_with_progress(url: str, path: str = None) -> Generator[str, int, in
def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType: def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType:
''' '''
Load the specified python module. Load the specified python module.
Args: Args:
python_path(str) : The directory where the python module is located python_path(str) : The directory where the python module is located
py_module_name(str) : Module name to be loaded py_module_name(str) : Module name to be loaded
......
...@@ -118,7 +118,7 @@ class XarFile(object): ...@@ -118,7 +118,7 @@ class XarFile(object):
should return True for each filename to be excluded. should return True for each filename to be excluded.
''' '''
if self.arctype == 'tar': if self.arctype == 'tar':
self._archive_fp.add(name, arcname, recursive, exclude) self._archive_fp.add(name, arcname, recursive, filter=exclude)
else: else:
self._archive_fp.write(name) self._archive_fp.write(name)
if not recursive or not os.path.isdir(name): if not recursive or not os.path.isdir(name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册