From c6f47c6bdd7b3091e1d16a2a98624a0788279ffe Mon Sep 17 00:00:00 2001 From: chenjian Date: Tue, 15 Nov 2022 11:24:52 +0800 Subject: [PATCH] [Backend] integrate x2paddle tool (#1148) * integrate x2paddle tool * fix caffe convertion bug * add post support * fix tar file path * add download api * add tips for users --- requirements.txt | 2 + visualdl/component/inference/__init__.py | 14 + .../inference/model_convert_server.py | 136 ++++++++++ visualdl/component/inference/xarfile.py | 252 ++++++++++++++++++ visualdl/server/app.py | 11 + 5 files changed, 415 insertions(+) create mode 100644 visualdl/component/inference/__init__.py create mode 100644 visualdl/component/inference/model_convert_server.py create mode 100644 visualdl/component/inference/xarfile.py diff --git a/requirements.txt b/requirements.txt index c7510004..12588e37 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,5 @@ matplotlib pandas multiprocess packaging +x2paddle +rarfile diff --git a/visualdl/component/inference/__init__.py b/visualdl/component/inference/__init__.py new file mode 100644 index 00000000..9c19f7b8 --- /dev/null +++ b/visualdl/component/inference/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= diff --git a/visualdl/component/inference/model_convert_server.py b/visualdl/component/inference/model_convert_server.py new file mode 100644 index 00000000..6f66dd5c --- /dev/null +++ b/visualdl/component/inference/model_convert_server.py @@ -0,0 +1,136 @@ +# Copyright (c) 2022 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +import base64 +import json +import os +import tempfile +from collections import deque +from threading import Lock + +from flask import request +from x2paddle.convert import caffe2paddle +from x2paddle.convert import onnx2paddle + +from .xarfile import archive +from .xarfile import unarchive +from visualdl.server.api import gen_result +from visualdl.server.api import result + + +class ModelConvertApi(object): + def __init__(self): + self.supported_formats = {'onnx', 'caffe'} + self.lock = Lock() + self.translated_models = deque( + maxlen=5) # used to store user's translated model for download + self.request_id = 0 # used to store user's request + + @result() + def convert_model(self, format): + file_handle = request.files['file'] + data = file_handle.stream.read() + if format not in self.supported_formats: + raise RuntimeError('Model format {} is not supported. \ + Only onnx and caffe models are supported now.'.format(format)) + result = {} + result['from'] = format + result['to'] = 'paddle' + # call x2paddle to convert models + with tempfile.TemporaryDirectory( + suffix='x2paddle_translated_models') as tmpdirname: + with tempfile.NamedTemporaryFile() as fp: + fp.write(data) + fp.flush() + try: + if format == 'onnx': + try: + import onnx # noqa: F401 + except Exception: + raise RuntimeError( + "[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\"." + ) + onnx2paddle(fp.name, tmpdirname) + elif format == 'caffe': + with tempfile.TemporaryDirectory() as unarchivedir: + unarchive(fp.name, unarchivedir) + prototxt_path = None + weight_path = None + for dirname, subdirs, filenames in os.walk( + unarchivedir): + for filename in filenames: + if '.prototxt' in filename: + prototxt_path = os.path.join( + dirname, filename) + if '.caffemodel' in filename: + weight_path = os.path.join( + dirname, filename) + if prototxt_path is None or weight_path is None: + raise RuntimeError( + ".prototxt or .caffemodel file is missing in your archive file, \ + please check files uploaded.") + caffe2paddle(prototxt_path, weight_path, + tmpdirname, None) + except Exception as e: + raise RuntimeError( + "[Convertion error] {}.\n Please open an issue at \ + https://github.com/PaddlePaddle/X2Paddle/issues to report your problem." + .format(e)) + with self.lock: + origin_dir = os.getcwd() + os.chdir(os.path.dirname(tmpdirname)) + archive_path = os.path.join( + os.path.dirname(tmpdirname), + archive(os.path.basename(tmpdirname))) + os.chdir(origin_dir) + result['request_id'] = self.request_id + self.request_id += 1 + with open(archive_path, 'rb') as archive_fp: + self.translated_models.append((result['request_id'], + archive_fp.read())) + with open( + os.path.join(tmpdirname, 'inference_model', + 'model.pdmodel'), 'rb') as model_fp: + model_encoded = base64.b64encode( + model_fp.read()).decode('utf-8') + result['pdmodel'] = model_encoded + if os.path.exists(archive_path): + os.remove(archive_path) + + return result + + @result('application/octet-stream') + def download_model(self, request_id): + for stored_request_id, data in self.translated_models: + if str(stored_request_id) == request_id: + return data + + +def create_model_convert_api_call(): + api = ModelConvertApi() + routes = { + 'convert': (api.convert_model, ['format']), + 'download': (api.download_model, ['request_id']) + } + + def call(path: str, args): + route = routes.get(path) + if not route: + return json.dumps(gen_result( + status=1, msg='api not found')), 'application/json', None + method, call_arg_names = route + call_args = [args.get(name) for name in call_arg_names] + return method(*call_args) + + return call diff --git a/visualdl/component/inference/xarfile.py b/visualdl/component/inference/xarfile.py new file mode 100644 index 00000000..ebf82313 --- /dev/null +++ b/visualdl/component/inference/xarfile.py @@ -0,0 +1,252 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tarfile +import zipfile +from typing import Callable +from typing import Generator +from typing import List + +import rarfile + + +class XarInfo(object): + '''Informational class which holds the details about an archive member given by a XarFile.''' + + def __init__(self, _xarinfo, arctype='tar'): + self._info = _xarinfo + self.arctype = arctype + + @property + def name(self) -> str: + if self.arctype == 'tar': + return self._info.name + return self._info.filename + + @property + def size(self) -> int: + if self.arctype == 'tar': + return self._info.size + return self._info.file_size + + +class XarFile(object): + ''' + The XarFile Class provides an interface to tar/rar/zip archives. + + Args: + name(str) : file or directory name to be archived + mode(str) : specifies the mode in which the file is opened, it must be: + ======== ============================================================================================== + Charater Meaning + -------- ---------------------------------------------------------------------------------------------- + 'r' open for reading + 'w' open for writing, truncating the file first, file will be saved according to the arctype field + 'a' open for writing, appending to the end of the file if it exists + ======== =============================================================================================== + arctype(str) : archive type, support ['tar' 'rar' 'zip' 'tar.gz' 'tar.bz2' 'tar.xz' 'tgz' 'txz'], if + the mode if 'w' or 'a', the default is 'tar', if the mode is 'r', it will be based on actual + archive type of file + ''' + + def __init__(self, name: str, mode: str, arctype: str = 'tar', **kwargs): + # if mode is 'w', adjust mode according to arctype field + if mode == 'w': + if arctype in ['tar.gz', 'tgz']: + mode = 'w:gz' + self.arctype = 'tar' + elif arctype == 'tar.bz2': + mode = 'w:bz2' + self.arctype = 'tar' + elif arctype in ['tar.xz', 'txz']: + mode = 'w:xz' + self.arctype = 'tar' + else: + self.arctype = arctype + # if mode is 'r', adjust mode according to actual archive type of file + elif mode == 'r': + if tarfile.is_tarfile(name): + self.arctype = 'tar' + mode = 'r:*' + elif zipfile.is_zipfile(name): + self.arctype = 'zip' + elif rarfile.is_rarfile(name): + self.arctype = 'rar' + elif mode == 'a': + self.arctype = arctype + else: + raise RuntimeError('Unsupported mode {}'.format(mode)) + + if self.arctype in [ + 'tar.gz', 'tar.bz2', 'tar.xz', 'tar', 'tgz', 'txz' + ]: + self._archive_fp = tarfile.open(name, mode, **kwargs) + elif self.arctype == 'zip': + self._archive_fp = zipfile.ZipFile(name, mode, **kwargs) + elif self.arctype == 'rar': + self._archive_fp = rarfile.RarFile(name, mode, **kwargs) + else: + raise RuntimeError('Unsupported archive type {}'.format( + self.arctype)) + + def __del__(self): + self._archive_fp.close() + + def __enter__(self): + return self + + def __exit__(self, exit_exception, exit_value, exit_traceback): + if exit_exception: + print(exit_traceback) + raise exit_exception(exit_value) + self._archive_fp.close() + return self + + def add(self, + name: str, + arcname: str = None, + recursive: bool = True, + exclude: Callable = None): + ''' + Add the file `name' to the archive. `name' may be any type of file (directory, fifo, symbolic link, etc.). + If given, `arcname' specifies an alternative name for the file in the archive. Directories are added + recursively by default. This can be avoided by setting `recursive' to False. `exclude' is a function that + should return True for each filename to be excluded. + ''' + if self.arctype == 'tar': + self._archive_fp.add(name, arcname, recursive, filter=exclude) + else: + self._archive_fp.write(name) + if not recursive or not os.path.isdir(name): + return + items = [] + for _d, _sub_ds, _files in os.walk(name): + items += [os.path.join(_d, _file) for _file in _files] + items += [os.path.join(_d, _sub_d) for _sub_d in _sub_ds] + + for item in items: + if exclude and not exclude(item): + continue + self._archive_fp.write(item) + + def extract(self, name: str, path: str): + '''Extract a file from the archive to the specified path.''' + return self._archive_fp.extract(name, path) + + def extractall(self, path: str): + '''Extract all files from the archive to the specified path.''' + return self._archive_fp.extractall(path) + + def getnames(self) -> List[str]: + '''Return a list of file names in the archive.''' + if self.arctype == 'tar': + return self._archive_fp.getnames() + return self._archive_fp.namelist() + + def getxarinfo(self, name: str) -> List[XarInfo]: + '''Return the instance of XarInfo given 'name'.''' + if self.arctype == 'tar': + return XarInfo(self._archive_fp.getmember(name), self.arctype) + return XarInfo(self._archive_fp.getinfo(name), self.arctype) + + +def open(name: str, mode: str = 'w', **kwargs) -> XarFile: + ''' + Open a xar archive for reading, writing or appending. Return + an appropriate XarFile class. + ''' + return XarFile(name, mode, **kwargs) + + +def archive(filename: str, + recursive: bool = True, + exclude: Callable = None, + arctype: str = 'tar') -> str: + ''' + Archive a file or directory + + Args: + name(str) : file or directory path to be archived + recursive(bool) : whether to recursively archive directories + exclude(Callable) : function that should return True for each filename to be excluded + arctype(str) : archive type, support ['tar' 'rar' 'zip' 'tar.gz' 'tar.bz2' 'tar.xz' 'tgz' 'txz'] + + Returns: + str: archived file path + + Examples: + .. code-block:: python + + archive_path = '/PATH/TO/FILE' + archive(archive_path, arcname='output.tar.gz', arctype='tar.gz') + ''' + basename = os.path.splitext(os.path.basename(filename))[0] + savename = '{}.{}'.format(basename, arctype) + with open(savename, mode='w', arctype=arctype) as file: + file.add(filename, recursive=recursive, exclude=exclude) + + return savename + + +def unarchive(name: str, path: str): + ''' + Unarchive a file + + Args: + name(str) : file or directory name to be unarchived + path(str) : storage name of archive file + + Examples: + .. code-block:: python + + unarchive_path = '/PATH/TO/FILE' + unarchive(unarchive_path, path='./output') + ''' + with open(name, mode='r') as file: + file.extractall(path) + + +def unarchive_with_progress(name: str, path: str) -> Generator[str, int, int]: + ''' + Unarchive a file and return the unarchiving progress -> Generator[filename, extrace_size, total_size] + + Args: + name(str) : file or directory name to be unarchived + path(str) : storage name of archive file + + Examples: + .. code-block:: python + + unarchive_path = 'test.tar.gz' + for filename, extract_size, total_szie in unarchive_with_progress(unarchive_path, path='./output'): + print(filename, extract_size, total_size) + ''' + with open(name, mode='r') as file: + total_size = extract_size = 0 + for filename in file.getnames(): + total_size += file.getxarinfo(filename).size + + for filename in file.getnames(): + file.extract(filename, path) + extract_size += file.getxarinfo(filename).size + yield filename, extract_size, total_size + + +def is_xarfile(file: str) -> bool: + '''Return True if xarfile supports specific file, otherwise False''' + _x_func = [zipfile.is_zipfile, tarfile.is_tarfile, rarfile.is_rarfile] + for _f in _x_func: + if _f(file): + return True + return False diff --git a/visualdl/server/app.py b/visualdl/server/app.py index 39ebb409..06dff816 100644 --- a/visualdl/server/app.py +++ b/visualdl/server/app.py @@ -32,6 +32,7 @@ from flask_babel import Babel import visualdl.server from visualdl import __version__ +from visualdl.component.inference.model_convert_server import create_model_convert_api_call from visualdl.component.profiler.profiler_server import create_profiler_api_call from visualdl.server.api import create_api_call from visualdl.server.args import parse_args @@ -68,6 +69,7 @@ def create_app(args): # noqa: C901 babel = Babel(app) api_call = create_api_call(args.logdir, args.model, args.cache_timeout) profiler_api_call = create_profiler_api_call(args.logdir) + inference_api_call = create_model_convert_api_call() if args.telemetry: update_util.PbUpdater(args.product).start() @@ -141,6 +143,15 @@ def create_app(args): # noqa: C901 return make_response( Response(data, mimetype=mimetype, headers=headers)) + @app.route(api_path + '/inference/', methods=["GET", "POST"]) + def serve_inference_api(method): + if request.method == 'POST': + data, mimetype, headers = inference_api_call(method, request.form) + else: + data, mimetype, headers = inference_api_call(method, request.args) + return make_response( + Response(data, mimetype=mimetype, headers=headers)) + @app.route(check_live_path) def check_live(): return '', 204 -- GitLab