未验证 提交 092c0274 编写于 作者: C chenjian 提交者: GitHub

[New Feature] Add paddle2onnx component (#1228)

* add paddle2onnx component

* add comments

* fix

* supplement failure judgement

* fix format

* fix format
上级 e368bb7b
...@@ -9,7 +9,8 @@ six >= 1.14.0 ...@@ -9,7 +9,8 @@ six >= 1.14.0
matplotlib matplotlib
pandas pandas
packaging packaging
x2paddle x2paddle >= 1.4.0
paddle2onnx >= 1.0.5
rarfile rarfile
gradio gradio
tritonclient[all] tritonclient[all]
......
# flake8: noqa
# Copyright (c) 2022 VisualDL Authors. All Rights Reserve. # Copyright (c) 2022 VisualDL Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -13,20 +14,18 @@ ...@@ -13,20 +14,18 @@
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
import base64 import base64
import glob
import hashlib import hashlib
import json import json
import os import os
import shutil import shutil
import tempfile import tempfile
from threading import Lock
import paddle2onnx
from flask import request from flask import request
from x2paddle.convert import caffe2paddle
from x2paddle.convert import onnx2paddle from x2paddle.convert import onnx2paddle
from .xarfile import archive from .xarfile import archive
from .xarfile import unarchive from visualdl.io.bfile import BosFileSystem
from visualdl.server.api import gen_result from visualdl.server.api import gen_result
from visualdl.server.api import result from visualdl.server.api import result
from visualdl.utils.dir import X2PADDLE_CACHE_PATH from visualdl.utils.dir import X2PADDLE_CACHE_PATH
...@@ -35,126 +34,252 @@ _max_cache_numbers = 200 ...@@ -35,126 +34,252 @@ _max_cache_numbers = 200
class ModelConvertApi(object): class ModelConvertApi(object):
'''!
Integrate multiple model convertion tools, and provide convertion service for users.
When user uploads a model to this server, convert model and upload the results to VDL Bos.
When user downloads the model, we get the data from Bos and send it to client.
Maybe users can download from bos directy if frontend can achieve it.
'''
def __init__(self): def __init__(self):
self.supported_formats = {'onnx', 'caffe'} '''
self.lock = Lock() Initialize a object to provide service. Need a BosFileSystem client to write data.
self.server_count = 0 # we use this variable to count requests handled, '''
# and check the number of files every 100 requests. try:
# If more than _max_cache_numbers files in cache, we delete the last recent used 50 files. self.bos_client = BosFileSystem()
self.bucket_name = os.getenv("BOS_BUCKET_NAME")
except Exception:
# When BOS_HOST, BOS_AK, BOS_SK, BOS_STS are not set in the environment variables.
# We use VDL BOS by default
self.bos_client = BosFileSystem(write_flag=False)
self.bos_client.renew_bos_client_from_server()
self.bucket_name = 'visualdl-server'
@result() @result()
def convert_model(self, format): def onnx2paddle_model_convert(self, convert_to_lite, lite_valid_places,
file_handle = request.files['file'] lite_model_type): # noqa:C901
data = file_handle.stream.read() '''
if format not in self.supported_formats: Convert onnx model to paddle model.
raise RuntimeError('Model format {} is not supported. "\ '''
"Only onnx and caffe models are supported now.'.format(format)) model_handle = request.files['model']
data = model_handle.stream.read()
result = {} result = {}
result['from'] = format # Do a simple data verification
result['to'] = 'paddle' if convert_to_lite in ['true', 'True', 'yes', 'Yes', 'y']:
convert_to_lite = True
else:
convert_to_lite = False
if lite_valid_places not in [
'arm', 'opencl', 'x86', 'metal', 'xpu', 'bm', 'mlu',
'intel_fpga', 'huawei_ascend_npu', 'imagination_nna',
'rockchip_npu', 'mediatek_apu', 'huawei_kirin_npu',
'amlogic_npu'
]:
lite_valid_places = 'arm'
if lite_model_type not in ['protobuf', 'naive_buffer']:
lite_model_type = 'naive_buffer'
# call x2paddle to convert models # call x2paddle to convert models
hl = hashlib.md5() hl = hashlib.md5()
hl.update(data) hl.update(data)
identity = hl.hexdigest() identity = hl.hexdigest()
result['request_id'] = identity result['request_id'] = identity
target_path = os.path.join(X2PADDLE_CACHE_PATH, identity) # check whether model has been transfromed before
if os.path.exists(target_path): # if model has been transformed before, data is stored at bos
if os.path.exists( pdmodel_filename = 'bos://{}/onnx2paddle/{}/model.pdmodel'.format(
os.path.join(target_path, 'inference_model', self.bucket_name, identity)
'model.pdmodel')): # if data in cache if self.bos_client.exists(pdmodel_filename):
with open( remote_data = self.bos_client.read_file(pdmodel_filename)
os.path.join(target_path, 'inference_model', if remote_data: # we should check data is not empty,
'model.pdmodel'), 'rb') as model_fp: # in case convertion failed but empty data is still uploaded before due to unknown reasons
model_encoded = base64.b64encode( model_encoded = base64.b64encode(remote_data).decode('utf-8')
model_fp.read()).decode('utf-8') result['model'] = model_encoded
result['pdmodel'] = model_encoded
return result return result
else: target_path = os.path.join(X2PADDLE_CACHE_PATH, 'onnx2paddle',
identity)
if not os.path.exists(target_path):
os.makedirs(target_path, exist_ok=True) os.makedirs(target_path, exist_ok=True)
with tempfile.NamedTemporaryFile() as fp: with tempfile.NamedTemporaryFile() as fp:
fp.write(data) fp.write(data)
fp.flush() fp.flush()
try: try:
if format == 'onnx': import onnx # noqa: F401
try: except Exception:
import onnx # noqa: F401 raise RuntimeError(
except Exception: "[ERROR] onnx is not installed, use \"pip install onnx>=1.6.0\"."
raise RuntimeError( )
"[ERROR] onnx is not installed, use \"pip install onnx>=1.6.0\"." try:
) if convert_to_lite is False:
onnx2paddle(fp.name, target_path) onnx2paddle(
elif format == 'caffe': fp.name, target_path, convert_to_lite=convert_to_lite)
with tempfile.TemporaryDirectory() as unarchivedir: else:
unarchive(fp.name, unarchivedir) onnx2paddle(
prototxt_path = None fp.name,
weight_path = None target_path,
for dirname, subdirs, filenames in os.walk( convert_to_lite=convert_to_lite,
unarchivedir): lite_valid_places=lite_valid_places,
for filename in filenames: lite_model_type=lite_model_type)
if '.prototxt' in filename:
prototxt_path = os.path.join(
dirname, filename)
if '.caffemodel' in filename:
weight_path = os.path.join(
dirname, filename)
if prototxt_path is None or weight_path is None:
raise RuntimeError(
".prototxt or .caffemodel file is missing in your archive file, "
"please check files uploaded.")
caffe2paddle(prototxt_path, weight_path, target_path,
None)
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
"[Convertion error] {}.\n Please open an issue at " "[Convertion error] {}.\n Please open an issue at "
"https://github.com/PaddlePaddle/X2Paddle/issues to report your problem." "https://github.com/PaddlePaddle/X2Paddle/issues to report your problem."
.format(e)) .format(e))
with self.lock: # we need to enter dirname(target_path) to archive,
# in case unneccessary directory added in archive.
origin_dir = os.getcwd() origin_dir = os.getcwd()
os.chdir(os.path.dirname(target_path)) os.chdir(os.path.dirname(target_path))
archive(os.path.basename(target_path)) archive(os.path.basename(target_path))
os.chdir(origin_dir) os.chdir(origin_dir)
self.server_count += 1 with open(
os.path.join(X2PADDLE_CACHE_PATH, 'onnx2paddle',
'{}.tar'.format(identity)), 'rb') as f:
# upload archived transformed model to vdl bos
data = f.read()
filename = 'bos://{}/onnx2paddle/{}.tar'.format(
self.bucket_name, identity)
try:
self.bos_client.write(filename, data)
except Exception as e:
print(
"Exception: Write file {}.tar to bos failed, due to {}"
.format(identity, e))
with open( with open(
os.path.join(target_path, 'inference_model', 'model.pdmodel'), os.path.join(target_path, 'inference_model', 'model.pdmodel'),
'rb') as model_fp: 'rb') as model_fp:
model_encoded = base64.b64encode(model_fp.read()).decode('utf-8') # upload pdmodel file to bos, if some model has been transformed before, we can directly download from bos
result['pdmodel'] = model_encoded filename = 'bos://{}/onnx2paddle/{}/model.pdmodel'.format(
self.bucket_name, identity)
data = model_fp.read()
try:
self.bos_client.write(filename, data)
except Exception as e:
print(
"Exception: Write file {}/model.pdmodel to bos failed, due to {}"
.format(identity, e))
# return transformed pdmodel file to frontend to show model structure graph
model_encoded = base64.b64encode(data).decode('utf-8')
# delete target_path
shutil.rmtree(target_path)
result['model'] = model_encoded
return result return result
@result('application/octet-stream') @result('application/octet-stream')
def download_model(self, request_id): def onnx2paddle_model_download(self, request_id):
if os.path.exists( '''
os.path.join(X2PADDLE_CACHE_PATH, Download converted paddle model from bos.
'{}.tar'.format(request_id))): '''
with open( filename = 'bos://{}/onnx2paddle/{}.tar'.format(
os.path.join(X2PADDLE_CACHE_PATH, self.bucket_name, request_id)
'{}.tar'.format(request_id)), 'rb') as f: data = None
data = f.read() if self.bos_client.exists(filename):
if self.server_count % 100 == 0: # we check number of files every 100 request data = self.bos_client.read_file(filename)
file_paths = glob.glob( if not data:
os.path.join(X2PADDLE_CACHE_PATH, '*.tar')) raise RuntimeError(
if len(file_paths) >= _max_cache_numbers: "The requested model can not be downloaded due to not existing or convertion failed."
file_paths = sorted( )
file_paths, key=os.path.getctime, reverse=True) return data
for file_path in file_paths:
try: @result()
os.remove(file_path) def paddle2onnx_convert(self, opset_version, deploy_backend):
shutil.rmtree( '''
os.path.join( Convert paddle model to onnx model.
os.path.dirname(file_path), '''
os.path.splitext( model_handle = request.files['model']
os.path.basename(file_path))[0])) params_handle = request.files['param']
except Exception: model_data = model_handle.stream.read()
pass param_data = params_handle.stream.read()
return data result = {}
# Do a simple data verification
try:
opset_version = int(opset_version)
except Exception:
opset_version = 11
if deploy_backend not in ['onnxruntime', 'tensorrt', 'others']:
deploy_backend = 'onnxruntime'
# call paddle2onnx to convert models
hl = hashlib.md5()
hl.update(model_data + param_data)
identity = hl.hexdigest()
result['request_id'] = identity
# check whether model has been transfromed before
# if model has been transformed before, data is stored at bos
model_filename = 'bos://{}/paddle2onnx/{}/model.onnx'.format(
self.bucket_name, identity)
if self.bos_client.exists(model_filename):
remote_data = self.bos_client.read_file(model_filename)
if remote_data: # we should check data is not empty,
# in case convertion failed but empty data is still uploaded before due to unknown reasons
model_encoded = base64.b64encode(remote_data).decode('utf-8')
result['model'] = model_encoded
return result
with tempfile.NamedTemporaryFile() as model_fp:
with tempfile.NamedTemporaryFile() as param_fp:
model_fp.write(model_data)
param_fp.write(param_data)
model_fp.flush()
param_fp.flush()
try:
onnx_model = paddle2onnx.export(
model_fp.name,
param_fp.name,
opset_version=opset_version,
deploy_backend=deploy_backend)
except Exception as e:
raise RuntimeError(
"[Convertion error] {}.\n Please open an issue at "
"https://github.com/PaddlePaddle/Paddle2ONNX/issues to report your problem."
.format(e))
if not onnx_model:
raise RuntimeError(
"[Convertion error] Please check your input model and param files."
)
# upload transformed model to vdl bos
filename = 'bos://{}/paddle2onnx/{}/model.onnx'.format(
self.bucket_name, identity)
model_encoded = None
if onnx_model:
try:
self.bos_client.write(filename, onnx_model)
except Exception as e:
print(
"Exception: Write file {}/model.onnx to bos failed, due to {}"
.format(identity, e))
model_encoded = base64.b64encode(onnx_model).decode(
'utf-8')
result['model'] = model_encoded
return result
@result('application/octet-stream')
def paddle2onnx_download(self, request_id):
'''
Download converted onnx model from bos.
'''
filename = 'bos://{}/paddle2onnx/{}/model.onnx'.format(
self.bucket_name, request_id)
data = None
if self.bos_client.exists(filename):
data = self.bos_client.read_file(filename)
if not data:
raise RuntimeError(
"The requested model can not be downloaded due to not existing or convertion failed."
)
return data
def create_model_convert_api_call(): def create_model_convert_api_call():
api = ModelConvertApi() api = ModelConvertApi()
routes = { routes = {
'convert': (api.convert_model, ['format']), 'paddle2onnx/convert': (api.paddle2onnx_convert,
'download': (api.download_model, ['request_id']) ['opset_version', 'deploy_backend']),
'paddle2onnx/download': (api.paddle2onnx_download, ['request_id']),
'onnx2paddle/convert':
(api.onnx2paddle_model_convert,
['convert_to_lite', 'lite_valid_places', 'lite_model_type']),
'onnx2paddle/download': (api.onnx2paddle_model_download,
['request_id'])
} }
def call(path: str, args): def call(path: str, args):
......
...@@ -12,11 +12,10 @@ ...@@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ======================================================================= # =======================================================================
import base64
import hashlib
import os import os
import tempfile import tempfile
import hashlib
import base64
import time import time
try: try:
...@@ -251,6 +250,13 @@ class BosConfigClient(object): ...@@ -251,6 +250,13 @@ class BosConfigClient(object):
return result return result
def upload_object_from_file(self, path, filename): def upload_object_from_file(self, path, filename):
"""!
Upload a local file to baidu bos filesystem. The path can de divided as bucket name and prefix directory.
The file would be uploaded in bucket at path `join(prefix directory in path, filename)`
@param self object
@param path(str) bos directory path to store file, which consists of bucket_name + prefix directory.
@param filename(str) local file path to upload
"""
if not self.exists(path): if not self.exists(path):
self.makedirs(path) self.makedirs(path)
bucket_name, object_key = get_object_info(path) bucket_name, object_key = get_object_info(path)
...@@ -265,18 +271,21 @@ class BosConfigClient(object): ...@@ -265,18 +271,21 @@ class BosConfigClient(object):
class BosFileSystem(object): class BosFileSystem(object):
def __init__(self, write_flag=True): def __init__(self, write_flag=True):
self.max_contents_count = 1
self.max_contents_time = 1
self.file_length_map = {}
self._file_contents_to_add = b''
self._file_contents_count = 0
self._start_append_time = time.time()
if write_flag: if write_flag:
self.max_contents_count = 1
self.max_contents_time = 1
self.get_bos_config() self.get_bos_config()
self.bos_client = BosClient(self.config) self.bos_client = BosClient(self.config)
self.file_length_map = {}
self._file_contents_to_add = b''
self._file_contents_count = 0
self._start_append_time = time.time()
def get_bos_config(self): def get_bos_config(self):
'''
Get Bos configuration from environment variables.
'''
bos_host = os.getenv("BOS_HOST") bos_host = os.getenv("BOS_HOST")
if not bos_host: if not bos_host:
raise KeyError('${BOS_HOST} is not found.') raise KeyError('${BOS_HOST} is not found.')
...@@ -296,6 +305,9 @@ class BosFileSystem(object): ...@@ -296,6 +305,9 @@ class BosFileSystem(object):
def set_bos_config(self, bos_ak, bos_sk, bos_sts, def set_bos_config(self, bos_ak, bos_sk, bos_sts,
bos_host="bj.bcebos.com"): bos_host="bj.bcebos.com"):
'''
Set Bos configuration and get bos client according to parameters.
'''
self.config = BceClientConfiguration( self.config = BceClientConfiguration(
credentials=BceCredentials(bos_ak, bos_sk), credentials=BceCredentials(bos_ak, bos_sk),
endpoint=bos_host, endpoint=bos_host,
...@@ -303,6 +315,9 @@ class BosFileSystem(object): ...@@ -303,6 +315,9 @@ class BosFileSystem(object):
self.bos_client = BosClient(self.config) self.bos_client = BosClient(self.config)
def renew_bos_client_from_server(self): def renew_bos_client_from_server(self):
'''
Get bos client by visualdl provided ak, sk, and sts token
'''
import requests import requests
import json import json
from visualdl.utils.dir import CONFIG_PATH from visualdl.utils.dir import CONFIG_PATH
...@@ -407,12 +422,14 @@ class BosFileSystem(object): ...@@ -407,12 +422,14 @@ class BosFileSystem(object):
data=init_data, data=init_data,
content_md5=content_md5(init_data), content_md5=content_md5(init_data),
content_length=len(init_data)) content_length=len(init_data))
except (exception.BceServerError, exception.BceHttpClientError) as e: except (exception.BceServerError,
exception.BceHttpClientError) as e:
if bucket_name == 'visualdl-server': # only sts token from visualdl-server, we can renew automatically if bucket_name == 'visualdl-server': # only sts token from visualdl-server, we can renew automatically
self.renew_bos_client_from_server() self.renew_bos_client_from_server()
# we should add a judgement for case 2 # we should add a judgement for case 2
try: try:
self.bos_client.get_object_meta_data(bucket_name, object_key) self.bos_client.get_object_meta_data(
bucket_name, object_key)
except exception.BceError: except exception.BceError:
# the file not exists, then create the file # the file not exists, then create the file
self.bos_client.append_object( self.bos_client.append_object(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册