未验证 提交 092c0274 编写于 作者: C chenjian 提交者: GitHub

[New Feature] Add paddle2onnx component (#1228)

* add paddle2onnx component

* add comments

* fix

* supplement failure judgement

* fix format

* fix format
上级 e368bb7b
# flake8: noqa
# Copyright (c) 2022 VisualDL Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -13,20 +14,18 @@
# limitations under the License.
# =======================================================================
import base64
import glob
import hashlib
import json
import os
import shutil
import tempfile
from threading import Lock
import paddle2onnx
from flask import request
from x2paddle.convert import caffe2paddle
from x2paddle.convert import onnx2paddle
from .xarfile import archive
from .xarfile import unarchive
from visualdl.io.bfile import BosFileSystem
from visualdl.server.api import gen_result
from visualdl.server.api import result
from visualdl.utils.dir import X2PADDLE_CACHE_PATH
......@@ -35,126 +34,252 @@ _max_cache_numbers = 200
class ModelConvertApi(object):
'''!
Integrate multiple model convertion tools, and provide convertion service for users.
When user uploads a model to this server, convert model and upload the results to VDL Bos.
When user downloads the model, we get the data from Bos and send it to client.
Maybe users can download from bos directy if frontend can achieve it.
'''
def __init__(self):
self.supported_formats = {'onnx', 'caffe'}
self.lock = Lock()
self.server_count = 0 # we use this variable to count requests handled,
# and check the number of files every 100 requests.
# If more than _max_cache_numbers files in cache, we delete the last recent used 50 files.
'''
Initialize a object to provide service. Need a BosFileSystem client to write data.
'''
try:
self.bos_client = BosFileSystem()
self.bucket_name = os.getenv("BOS_BUCKET_NAME")
except Exception:
# When BOS_HOST, BOS_AK, BOS_SK, BOS_STS are not set in the environment variables.
# We use VDL BOS by default
self.bos_client = BosFileSystem(write_flag=False)
self.bos_client.renew_bos_client_from_server()
self.bucket_name = 'visualdl-server'
@result()
def convert_model(self, format):
file_handle = request.files['file']
data = file_handle.stream.read()
if format not in self.supported_formats:
raise RuntimeError('Model format {} is not supported. "\
"Only onnx and caffe models are supported now.'.format(format))
def onnx2paddle_model_convert(self, convert_to_lite, lite_valid_places,
lite_model_type): # noqa:C901
'''
Convert onnx model to paddle model.
'''
model_handle = request.files['model']
data = model_handle.stream.read()
result = {}
result['from'] = format
result['to'] = 'paddle'
# Do a simple data verification
if convert_to_lite in ['true', 'True', 'yes', 'Yes', 'y']:
convert_to_lite = True
else:
convert_to_lite = False
if lite_valid_places not in [
'arm', 'opencl', 'x86', 'metal', 'xpu', 'bm', 'mlu',
'intel_fpga', 'huawei_ascend_npu', 'imagination_nna',
'rockchip_npu', 'mediatek_apu', 'huawei_kirin_npu',
'amlogic_npu'
]:
lite_valid_places = 'arm'
if lite_model_type not in ['protobuf', 'naive_buffer']:
lite_model_type = 'naive_buffer'
# call x2paddle to convert models
hl = hashlib.md5()
hl.update(data)
identity = hl.hexdigest()
result['request_id'] = identity
target_path = os.path.join(X2PADDLE_CACHE_PATH, identity)
if os.path.exists(target_path):
if os.path.exists(
os.path.join(target_path, 'inference_model',
'model.pdmodel')): # if data in cache
with open(
os.path.join(target_path, 'inference_model',
'model.pdmodel'), 'rb') as model_fp:
model_encoded = base64.b64encode(
model_fp.read()).decode('utf-8')
result['pdmodel'] = model_encoded
# check whether model has been transfromed before
# if model has been transformed before, data is stored at bos
pdmodel_filename = 'bos://{}/onnx2paddle/{}/model.pdmodel'.format(
self.bucket_name, identity)
if self.bos_client.exists(pdmodel_filename):
remote_data = self.bos_client.read_file(pdmodel_filename)
if remote_data: # we should check data is not empty,
# in case convertion failed but empty data is still uploaded before due to unknown reasons
model_encoded = base64.b64encode(remote_data).decode('utf-8')
result['model'] = model_encoded
return result
else:
target_path = os.path.join(X2PADDLE_CACHE_PATH, 'onnx2paddle',
identity)
if not os.path.exists(target_path):
os.makedirs(target_path, exist_ok=True)
with tempfile.NamedTemporaryFile() as fp:
fp.write(data)
fp.flush()
try:
if format == 'onnx':
try:
import onnx # noqa: F401
except Exception:
raise RuntimeError(
"[ERROR] onnx is not installed, use \"pip install onnx>=1.6.0\"."
)
onnx2paddle(fp.name, target_path)
elif format == 'caffe':
with tempfile.TemporaryDirectory() as unarchivedir:
unarchive(fp.name, unarchivedir)
prototxt_path = None
weight_path = None
for dirname, subdirs, filenames in os.walk(
unarchivedir):
for filename in filenames:
if '.prototxt' in filename:
prototxt_path = os.path.join(
dirname, filename)
if '.caffemodel' in filename:
weight_path = os.path.join(
dirname, filename)
if prototxt_path is None or weight_path is None:
raise RuntimeError(
".prototxt or .caffemodel file is missing in your archive file, "
"please check files uploaded.")
caffe2paddle(prototxt_path, weight_path, target_path,
None)
try:
if convert_to_lite is False:
onnx2paddle(
fp.name, target_path, convert_to_lite=convert_to_lite)
else:
onnx2paddle(
fp.name,
target_path,
convert_to_lite=convert_to_lite,
lite_valid_places=lite_valid_places,
lite_model_type=lite_model_type)
except Exception as e:
raise RuntimeError(
"[Convertion error] {}.\n Please open an issue at "
"https://github.com/PaddlePaddle/X2Paddle/issues to report your problem."
.format(e))
with self.lock: # we need to enter dirname(target_path) to archive,
# in case unneccessary directory added in archive.
origin_dir = os.getcwd()
os.chdir(os.path.dirname(target_path))
archive(os.path.basename(target_path))
os.chdir(origin_dir)
self.server_count += 1
with open(
os.path.join(X2PADDLE_CACHE_PATH, 'onnx2paddle',
'{}.tar'.format(identity)), 'rb') as f:
# upload archived transformed model to vdl bos
data = f.read()
filename = 'bos://{}/onnx2paddle/{}.tar'.format(
self.bucket_name, identity)
try:
self.bos_client.write(filename, data)
except Exception as e:
print(
"Exception: Write file {}.tar to bos failed, due to {}"
.format(identity, e))
with open(
os.path.join(target_path, 'inference_model', 'model.pdmodel'),
'rb') as model_fp:
model_encoded = base64.b64encode(model_fp.read()).decode('utf-8')
result['pdmodel'] = model_encoded
# upload pdmodel file to bos, if some model has been transformed before, we can directly download from bos
filename = 'bos://{}/onnx2paddle/{}/model.pdmodel'.format(
self.bucket_name, identity)
data = model_fp.read()
try:
self.bos_client.write(filename, data)
except Exception as e:
print(
"Exception: Write file {}/model.pdmodel to bos failed, due to {}"
.format(identity, e))
# return transformed pdmodel file to frontend to show model structure graph
model_encoded = base64.b64encode(data).decode('utf-8')
# delete target_path
shutil.rmtree(target_path)
result['model'] = model_encoded
return result
@result('application/octet-stream')
def download_model(self, request_id):
if os.path.exists(
os.path.join(X2PADDLE_CACHE_PATH,
'{}.tar'.format(request_id))):
with open(
os.path.join(X2PADDLE_CACHE_PATH,
'{}.tar'.format(request_id)), 'rb') as f:
data = f.read()
if self.server_count % 100 == 0: # we check number of files every 100 request
file_paths = glob.glob(
os.path.join(X2PADDLE_CACHE_PATH, '*.tar'))
if len(file_paths) >= _max_cache_numbers:
file_paths = sorted(
file_paths, key=os.path.getctime, reverse=True)
for file_path in file_paths:
def onnx2paddle_model_download(self, request_id):
'''
Download converted paddle model from bos.
'''
filename = 'bos://{}/onnx2paddle/{}.tar'.format(
self.bucket_name, request_id)
data = None
if self.bos_client.exists(filename):
data = self.bos_client.read_file(filename)
if not data:
raise RuntimeError(
"The requested model can not be downloaded due to not existing or convertion failed."
)
return data
@result()
def paddle2onnx_convert(self, opset_version, deploy_backend):
'''
Convert paddle model to onnx model.
'''
model_handle = request.files['model']
params_handle = request.files['param']
model_data = model_handle.stream.read()
param_data = params_handle.stream.read()
result = {}
# Do a simple data verification
try:
os.remove(file_path)
shutil.rmtree(
os.path.join(
os.path.dirname(file_path),
os.path.splitext(
os.path.basename(file_path))[0]))
opset_version = int(opset_version)
except Exception:
pass
opset_version = 11
if deploy_backend not in ['onnxruntime', 'tensorrt', 'others']:
deploy_backend = 'onnxruntime'
# call paddle2onnx to convert models
hl = hashlib.md5()
hl.update(model_data + param_data)
identity = hl.hexdigest()
result['request_id'] = identity
# check whether model has been transfromed before
# if model has been transformed before, data is stored at bos
model_filename = 'bos://{}/paddle2onnx/{}/model.onnx'.format(
self.bucket_name, identity)
if self.bos_client.exists(model_filename):
remote_data = self.bos_client.read_file(model_filename)
if remote_data: # we should check data is not empty,
# in case convertion failed but empty data is still uploaded before due to unknown reasons
model_encoded = base64.b64encode(remote_data).decode('utf-8')
result['model'] = model_encoded
return result
with tempfile.NamedTemporaryFile() as model_fp:
with tempfile.NamedTemporaryFile() as param_fp:
model_fp.write(model_data)
param_fp.write(param_data)
model_fp.flush()
param_fp.flush()
try:
onnx_model = paddle2onnx.export(
model_fp.name,
param_fp.name,
opset_version=opset_version,
deploy_backend=deploy_backend)
except Exception as e:
raise RuntimeError(
"[Convertion error] {}.\n Please open an issue at "
"https://github.com/PaddlePaddle/Paddle2ONNX/issues to report your problem."
.format(e))
if not onnx_model:
raise RuntimeError(
"[Convertion error] Please check your input model and param files."
)
# upload transformed model to vdl bos
filename = 'bos://{}/paddle2onnx/{}/model.onnx'.format(
self.bucket_name, identity)
model_encoded = None
if onnx_model:
try:
self.bos_client.write(filename, onnx_model)
except Exception as e:
print(
"Exception: Write file {}/model.onnx to bos failed, due to {}"
.format(identity, e))
model_encoded = base64.b64encode(onnx_model).decode(
'utf-8')
result['model'] = model_encoded
return result
@result('application/octet-stream')
def paddle2onnx_download(self, request_id):
'''
Download converted onnx model from bos.
'''
filename = 'bos://{}/paddle2onnx/{}/model.onnx'.format(
self.bucket_name, request_id)
data = None
if self.bos_client.exists(filename):
data = self.bos_client.read_file(filename)
if not data:
raise RuntimeError(
"The requested model can not be downloaded due to not existing or convertion failed."
)
return data
def create_model_convert_api_call():
api = ModelConvertApi()
routes = {
'convert': (api.convert_model, ['format']),
'download': (api.download_model, ['request_id'])
'paddle2onnx/convert': (api.paddle2onnx_convert,
['opset_version', 'deploy_backend']),
'paddle2onnx/download': (api.paddle2onnx_download, ['request_id']),
'onnx2paddle/convert':
(api.onnx2paddle_model_convert,
['convert_to_lite', 'lite_valid_places', 'lite_model_type']),
'onnx2paddle/download': (api.onnx2paddle_model_download,
['request_id'])
}
def call(path: str, args):
......
......@@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =======================================================================
import base64
import hashlib
import os
import tempfile
import hashlib
import base64
import time
try:
......@@ -251,6 +250,13 @@ class BosConfigClient(object):
return result
def upload_object_from_file(self, path, filename):
"""!
Upload a local file to baidu bos filesystem. The path can de divided as bucket name and prefix directory.
The file would be uploaded in bucket at path `join(prefix directory in path, filename)`
@param self object
@param path(str) bos directory path to store file, which consists of bucket_name + prefix directory.
@param filename(str) local file path to upload
"""
if not self.exists(path):
self.makedirs(path)
bucket_name, object_key = get_object_info(path)
......@@ -265,18 +271,21 @@ class BosConfigClient(object):
class BosFileSystem(object):
def __init__(self, write_flag=True):
if write_flag:
self.max_contents_count = 1
self.max_contents_time = 1
self.get_bos_config()
self.bos_client = BosClient(self.config)
self.file_length_map = {}
self._file_contents_to_add = b''
self._file_contents_count = 0
self._start_append_time = time.time()
if write_flag:
self.get_bos_config()
self.bos_client = BosClient(self.config)
def get_bos_config(self):
'''
Get Bos configuration from environment variables.
'''
bos_host = os.getenv("BOS_HOST")
if not bos_host:
raise KeyError('${BOS_HOST} is not found.')
......@@ -296,6 +305,9 @@ class BosFileSystem(object):
def set_bos_config(self, bos_ak, bos_sk, bos_sts,
bos_host="bj.bcebos.com"):
'''
Set Bos configuration and get bos client according to parameters.
'''
self.config = BceClientConfiguration(
credentials=BceCredentials(bos_ak, bos_sk),
endpoint=bos_host,
......@@ -303,6 +315,9 @@ class BosFileSystem(object):
self.bos_client = BosClient(self.config)
def renew_bos_client_from_server(self):
'''
Get bos client by visualdl provided ak, sk, and sts token
'''
import requests
import json
from visualdl.utils.dir import CONFIG_PATH
......@@ -407,12 +422,14 @@ class BosFileSystem(object):
data=init_data,
content_md5=content_md5(init_data),
content_length=len(init_data))
except (exception.BceServerError, exception.BceHttpClientError) as e:
except (exception.BceServerError,
exception.BceHttpClientError) as e:
if bucket_name == 'visualdl-server': # only sts token from visualdl-server, we can renew automatically
self.renew_bos_client_from_server()
# we should add a judgement for case 2
try:
self.bos_client.get_object_meta_data(bucket_name, object_key)
self.bos_client.get_object_meta_data(
bucket_name, object_key)
except exception.BceError:
# the file not exists, then create the file
self.bos_client.append_object(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册