未验证 提交 8c88895d 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #1590 from PaddlePaddle/dynamic_tensorrt

configure dynamic shape tensorrt
......@@ -4,9 +4,9 @@
## Get Model
```
python3 -m paddle_serving_app.package --get_model ocr_rec
wget https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/OCR/ocr_rec.tar.gz
tar -xzvf ocr_rec.tar.gz
python3 -m paddle_serving_app.package --get_model ocr_det
wget https://paddle-serving.bj.bcebos.com/ocr/ocr_det.tar.gz
tar -xzvf ocr_det.tar.gz
```
......
......@@ -4,9 +4,9 @@
## 获取模型
```
python3 -m paddle_serving_app.package --get_model ocr_rec
wget https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/OCR/ocr_rec.tar.gz
tar -xzvf ocr_rec.tar.gz
python3 -m paddle_serving_app.package --get_model ocr_det
wget https://paddle-serving.bj.bcebos.com/ocr/ocr_det.tar.gz
tar -xzvf ocr_det.tar.gz
```
## 获取数据集(可选)
......
......@@ -37,7 +37,7 @@ op:
model_config: ocr_det_model
#Fetch结果列表,以client_config中fetch_var的alias_name为准
fetch_list: ["concat_1.tmp_0"]
fetch_list: ["save_infer_model/scale_0.tmp_1"]
# device_type, 0=cpu, 1=gpu, 2=tensorRT, 3=arm cpu, 4=kunlun xpu
device_type: 0
......@@ -53,6 +53,9 @@ op:
#ir_optim
ir_optim: True
#开启tensorrt后,进行优化的子图包含的最少节点数
#min_subgraph_size: 13
rec:
#并发数,is_thread_op=True时,为线程并发;否则为进程并发
concurrency: 3
......@@ -73,7 +76,7 @@ op:
model_config: ocr_rec_model
#Fetch结果列表,以client_config中fetch_var的alias_name为准
fetch_list: ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
fetch_list: ["save_infer_model/scale_0.tmp_1"]
# device_type, 0=cpu, 1=gpu, 2=tensorRT, 3=arm cpu, 4=kunlun xpu
device_type: 0
......@@ -88,3 +91,6 @@ op:
#ir_optim
ir_optim: True
#开启tensorrt后,进行优化的子图包含的最少节点数
#min_subgraph_size: 3
......@@ -39,6 +39,42 @@ class DetOp(Op):
"unclip_ratio": 1.5,
"min_size": 3
})
"""
when opening tensorrt(configure in config.yml) and each time the input shape
for inferring is different, using this method for configuring tensorrt
dynamic shape to infer in each op model
"""
def set_dynamic_shape_info(self):
min_input_shape = {
"x": [1, 3, 50, 50],
"conv2d_182.tmp_0": [1, 1, 20, 20],
"nearest_interp_v2_2.tmp_0": [1, 1, 20, 20],
"nearest_interp_v2_3.tmp_0": [1, 1, 20, 20],
"nearest_interp_v2_4.tmp_0": [1, 1, 20, 20],
"nearest_interp_v2_5.tmp_0": [1, 1, 20, 20]
}
max_input_shape = {
"x": [1, 3, 1536, 1536],
"conv2d_182.tmp_0": [20, 200, 960, 960],
"nearest_interp_v2_2.tmp_0": [20, 200, 960, 960],
"nearest_interp_v2_3.tmp_0": [20, 200, 960, 960],
"nearest_interp_v2_4.tmp_0": [20, 200, 960, 960],
"nearest_interp_v2_5.tmp_0": [20, 200, 960, 960],
}
opt_input_shape = {
"x": [1, 3, 960, 960],
"conv2d_182.tmp_0": [3, 96, 240, 240],
"nearest_interp_v2_2.tmp_0": [3, 96, 240, 240],
"nearest_interp_v2_3.tmp_0": [3, 24, 240, 240],
"nearest_interp_v2_4.tmp_0": [3, 24, 240, 240],
"nearest_interp_v2_5.tmp_0": [3, 24, 240, 240],
}
self.dynamic_shape_info = {
"min_input_shape": min_input_shape,
"max_input_shape": max_input_shape,
"opt_input_shape": opt_input_shape,
}
def preprocess(self, input_dicts, data_id, log_id):
(_, input_dict), = input_dicts.items()
......@@ -52,11 +88,11 @@ class DetOp(Op):
det_img = self.det_preprocess(self.im)
_, self.new_h, self.new_w = det_img.shape
imgs.append(det_img[np.newaxis, :].copy())
return {"image": np.concatenate(imgs, axis=0)}, False, None, ""
return {"x": np.concatenate(imgs, axis=0)}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
# print(fetch_dict)
det_out = fetch_dict["concat_1.tmp_0"]
det_out = fetch_dict["save_infer_model/scale_0.tmp_1"]
ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
]
......@@ -71,6 +107,30 @@ class RecOp(Op):
self.ocr_reader = OCRReader()
self.get_rotate_crop_image = GetRotateCropImage()
self.sorted_boxes = SortedBoxes()
"""
when opening tensorrt(configure in config.yml) and each time the input shape
for inferring is different, using this method for configuring tensorrt
dynamic shape to infer in each op model
"""
def set_dynamic_shape_info(self):
min_input_shape = {
"x": [1, 3, 32, 10],
"lstm_1.tmp_0": [1, 1, 128]
}
max_input_shape = {
"x": [50, 3, 32, 1000],
"lstm_1.tmp_0": [500, 50, 128]
}
opt_input_shape = {
"x": [6, 3, 32, 100],
"lstm_1.tmp_0": [25, 5, 128]
}
self.dynamic_shape_info = {
"min_input_shape": min_input_shape,
"max_input_shape": max_input_shape,
"opt_input_shape": opt_input_shape,
}
def preprocess(self, input_dicts, data_id, log_id):
(_, input_dict), = input_dicts.items()
......@@ -143,7 +203,7 @@ class RecOp(Op):
for id, img in enumerate(img_list):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img
feed = {"image": imgs.copy()}
feed = {"x": imgs.copy()}
feed_list.append(feed)
#_LOGGER.info("feed_list : {}".format(feed_list))
......@@ -153,13 +213,13 @@ class RecOp(Op):
res_list = []
if isinstance(fetch_data, dict):
if len(fetch_data) > 0:
rec_batch_res = self.ocr_reader.postprocess(
rec_batch_res = self.ocr_reader.postprocess_ocrv2(
fetch_data, with_score=True)
for res in rec_batch_res:
res_list.append(res[0])
elif isinstance(fetch_data, list):
for one_batch in fetch_data:
one_batch_res = self.ocr_reader.postprocess(
one_batch_res = self.ocr_reader.postprocess_ocrv2(
one_batch, with_score=True)
for res in one_batch_res:
res_list.append(res[0])
......
......@@ -88,7 +88,9 @@ class LocalPredictor(object):
mkldnn_op_list=None,
mkldnn_bf16_op_list=None,
use_feed_fetch_ops=False,
use_ascend_cl=False):
use_ascend_cl=False,
min_subgraph_size=3,
dynamic_shape_info={}):
"""
Load model configs and create the paddle predictor by Paddle Inference API.
......@@ -102,6 +104,9 @@ class LocalPredictor(object):
ir_optim: open calculation chart optimization, False default.
use_trt: use nvidia TensorRT optimization, False default
use_lite: use Paddle-Lite Engint, False default
ir_optim: open calculation chart optimization, False default.
use_trt: use nvidia TensorRT optimization, False default
use_lite: use Paddle-Lite Engint, False default
use_xpu: run predict on Baidu Kunlun, False default
precision: precision mode, "fp32" default
use_calib: use TensorRT calibration, False default
......@@ -111,6 +116,8 @@ class LocalPredictor(object):
mkldnn_bf16_op_list: op list accelerated using MKLDNN bf16, None default.
use_feed_fetch_ops: use feed/fetch ops, False default.
use_ascend_cl: run predict on Huawei Ascend, False default
min_subgraph_size: the minimal subgraph size for opening tensorrt to optimize, 3 default
dynamic_shape_info: dict including min_input_shape,max_input_shape, opt_input_shape, {} default
"""
gpu_id = int(gpu_id)
client_config = "{}/serving_server_conf.prototxt".format(model_path)
......@@ -150,11 +157,12 @@ class LocalPredictor(object):
"use_trt:{}, use_lite:{}, use_xpu:{}, precision:{}, use_calib:{}, "
"use_mkldnn:{}, mkldnn_cache_capacity:{}, mkldnn_op_list:{}, "
"mkldnn_bf16_op_list:{}, use_feed_fetch_ops:{}, "
"use_ascend_cl:{} ".format(
"use_ascend_cl:{}, min_subgraph_size:{}, dynamic_shape_info:{}".format(
model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim,
ir_optim, use_trt, use_lite, use_xpu, precision, use_calib,
use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list,
mkldnn_bf16_op_list, use_feed_fetch_ops, use_ascend_cl))
mkldnn_bf16_op_list, use_feed_fetch_ops, use_ascend_cl,
min_subgraph_size, dynamic_shape_info))
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
......@@ -211,9 +219,15 @@ class LocalPredictor(object):
precision_mode=precision_type,
workspace_size=1 << 20,
max_batch_size=32,
min_subgraph_size=3,
min_subgraph_size=min_subgraph_size,
use_static=False,
use_calib_mode=False)
if len(dynamic_shape_info):
config.set_trt_dynamic_shape_info(
dynamic_shape_info['min_input_shape'],
dynamic_shape_info['max_input_shape'],
dynamic_shape_info['opt_input_shape'])
# set lite
if use_lite:
config.enable_lite_engine(
......
......@@ -32,17 +32,18 @@ class ServingModels(object):
self.model_dict["ImageClassification"] = [
"resnet_v2_50_imagenet", "mobilenet_v2_imagenet"
]
self.model_dict["TextDetection"] = ["ocr_det"]
self.model_dict["OCR"] = ["ocr_rec"]
#self.model_dict["TextDetection"] = ["ocr_det"]
self.model_dict["OCR"] = ["ocr_rec", "ocr_det"]
image_class_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageClassification/"
image_seg_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageSegmentation/"
object_detection_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ObjectDetection/"
ocr_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/OCR/"
#ocr_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/OCR/"
ocr_url = "https://paddle-serving.bj.bcebos.com/ocr_v2/"
senta_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SentimentAnalysis/"
semantic_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SemanticModel/"
wordseg_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/LexicalAnalysis/"
ocr_det_url = "https://paddle-serving.bj.bcebos.com/ocr/"
#ocr_det_url = "https://paddle-serving.bj.bcebos.com/ocr/"
self.url_dict = {}
......@@ -58,7 +59,7 @@ class ServingModels(object):
pack_url(self.model_dict, "ImageSegmentation", image_seg_url)
pack_url(self.model_dict, "ImageClassification", image_class_url)
pack_url(self.model_dict, "OCR", ocr_url)
pack_url(self.model_dict, "TextDetection", ocr_det_url)
#pack_url(self.model_dict, "TextDetection", ocr_det_url)
def get_model_list(self):
return self.model_dict
......
......@@ -118,6 +118,111 @@ class CharacterOps(object):
% (self.loss_type)
assert False, err
class BaseRecLabelDecode(object):
""" Convert between text-label and text-index """
def __init__(self, config):
support_character_type = [
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
'ne', 'EN'
]
character_type = config['character_type']
character_dict_path = config['character_dict_path']
use_space_char = True
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
self.beg_str = "sos"
self.end_str = "eos"
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif character_type == "EN_symbol":
# same with ASTER setting (use 94 char).
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = ""
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str += line
if use_space_char:
self.character_str += " "
dict_character = list(self.character_str)
else:
raise NotImplementedError
self.character_type = character_type
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
return result_list
def get_ignored_tokens(self):
return [0] # for ctc blank
class CTCLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(
self,
config,
#character_dict_path=None,
#character_type='ch',
#use_space_char=False,
**kwargs):
super(CTCLabelDecode, self).__init__(config)
def __call__(self, preds, label=None, *args, **kwargs):
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None:
return text
label = self.decode(label)
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character
class OCRReader(object):
def __init__(self,
......@@ -134,6 +239,7 @@ class OCRReader(object):
char_ops_params["character_dict_path"] = char_dict_path
char_ops_params['loss_type'] = 'ctc'
self.char_ops = CharacterOps(char_ops_params)
self.label_ops = CTCLabelDecode(char_ops_params)
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
......@@ -202,3 +308,15 @@ class OCRReader(object):
else:
rec_res.append([preds_text])
return rec_res
def postprocess_ocrv2(self, outputs, with_score=False):
preds = outputs["save_infer_model/scale_0.tmp_1"]
try:
preds = preds.numpy()
except:
pass
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.label_ops.decode(
preds_idx, preds_prob, is_remove_duplicate=True)
return text
......@@ -50,7 +50,9 @@ class LocalServiceHandler(object):
use_mkldnn=False,
mkldnn_cache_capacity=0,
mkldnn_op_list=None,
mkldnn_bf16_op_list=None):
mkldnn_bf16_op_list=None,
min_subgraph_size=3,
dynamic_shape_info={}):
"""
Initialization of localservicehandler
......@@ -92,6 +94,8 @@ class LocalServiceHandler(object):
self._mkldnn_cache_capacity = 0
self._mkldnn_op_list = None
self._mkldnn_bf16_op_list = None
self.min_subgraph_size = 3
self.dynamic_shape_info = {}
if device_type == -1:
# device_type is not set, determined by `devices`,
......@@ -120,6 +124,8 @@ class LocalServiceHandler(object):
self._use_gpu = True
devices = [int(x) for x in devices.split(",")]
self._use_trt = True
self.min_subgraph_size = min_subgraph_size
self.dynamic_shape_info = dynamic_shape_info
elif device_type == 3:
# ARM CPU
self._device_name = "arm"
......@@ -176,14 +182,16 @@ class LocalServiceHandler(object):
"mem_optim:{}, ir_optim:{}, use_profile:{}, thread_num:{}, "
"client_type:{}, fetch_names:{}, precision:{}, use_mkldnn:{}, "
"mkldnn_cache_capacity:{}, mkldnn_op_list:{}, "
"mkldnn_bf16_op_list:{}, use_ascend_cl:{}".format(
"mkldnn_bf16_op_list:{}, use_ascend_cl:{}, min_subgraph_size:{},"
"is_set_dynamic_shape_info:{}".format(
model_config, self._device_name, self._use_gpu, self._use_trt,
self._use_lite, self._use_xpu, device_type, self._devices,
self._mem_optim, self._ir_optim, self._use_profile,
self._thread_num, self._client_type, self._fetch_names,
self._precision, self._use_mkldnn, self._mkldnn_cache_capacity,
self._mkldnn_op_list, self._mkldnn_bf16_op_list,
self._use_ascend_cl))
self._use_ascend_cl, self.min_subgraph_size,
bool(len(self.dynamic_shape_info))))
def get_fetch_list(self):
return self._fetch_names
......@@ -240,7 +248,9 @@ class LocalServiceHandler(object):
mkldnn_cache_capacity=self._mkldnn_cache_capacity,
mkldnn_op_list=self._mkldnn_op_list,
mkldnn_bf16_op_list=self._mkldnn_bf16_op_list,
use_ascend_cl=self._use_ascend_cl)
use_ascend_cl=self._use_ascend_cl,
min_subgraph_size=self.min_subgraph_size,
dynamic_shape_info=self.dynamic_shape_info)
return self._local_predictor_client
def get_client_config(self):
......
......@@ -116,6 +116,16 @@ class Op(object):
self._for_close_op_lock = threading.Lock()
self._succ_init_op = False
self._succ_close_op = False
self.dynamic_shape_info = {}
self.set_dynamic_shape_info()
def set_dynamic_shape_info(self):
"""
when opening tensorrt(configure in config.yml) and each time the input shape
for inferring is different, using this method for configuring tensorrt
dynamic shape to infer in each op model
"""
pass
# for feed/fetch dict cehck
@staticmethod
......@@ -182,6 +192,7 @@ class Op(object):
self.mkldnn_cache_capacity = 0
self.mkldnn_op_list = None
self.mkldnn_bf16_op_list = None
self.min_subgraph_size = 3
if self._server_endpoints is None:
server_endpoints = conf.get("server_endpoints", [])
......@@ -212,6 +223,8 @@ class Op(object):
"mkldnn_op_list")
self.mkldnn_bf16_op_list = local_service_conf.get(
"mkldnn_bf16_op_list")
self.min_subgraph_size = local_service_conf.get(
"min_subgraph_size")
if self.model_config is None:
self.with_serving = False
......@@ -233,7 +246,9 @@ class Op(object):
mkldnn_cache_capacity=self.
mkldnn_cache_capacity,
mkldnn_op_list=self.mkldnn_bf16_op_list,
mkldnn_bf16_op_list=self.mkldnn_bf16_op_list)
mkldnn_bf16_op_list=self.mkldnn_bf16_op_list,
min_subgraph_size=self.min_subgraph_size,
dynamic_shape_info=self.dynamic_shape_info)
service_handler.prepare_server() # get fetch_list
serivce_ports = service_handler.get_port_list()
self._server_endpoints = [
......@@ -261,7 +276,9 @@ class Op(object):
mkldnn_cache_capacity=self.
mkldnn_cache_capacity,
mkldnn_op_list=self.mkldnn_op_list,
mkldnn_bf16_op_list=self.mkldnn_bf16_op_list)
mkldnn_bf16_op_list=self.mkldnn_bf16_op_list,
min_subgraph_size=self.min_subgraph_size,
dynamic_shape_info=self.dynamic_shape_info)
if self._client_config is None:
self._client_config = service_handler.get_client_config(
)
......@@ -766,7 +783,9 @@ class Op(object):
self.ir_optim, self.precision, self.use_mkldnn,
self.mkldnn_cache_capacity, self.mkldnn_op_list,
self.mkldnn_bf16_op_list, self.is_jump_op(),
self.get_output_channels_of_jump_ops()))
self.get_output_channels_of_jump_ops(),
self.min_subgraph_size,
self.dynamic_shape_info))
p.daemon = True
p.start()
process.append(p)
......@@ -803,7 +822,9 @@ class Op(object):
self.ir_optim, self.precision, self.use_mkldnn,
self.mkldnn_cache_capacity, self.mkldnn_op_list,
self.mkldnn_bf16_op_list, self.is_jump_op(),
self.get_output_channels_of_jump_ops()))
self.get_output_channels_of_jump_ops(),
self.min_subgraph_size,
self.dynamic_shape_info))
# When a process exits, it attempts to terminate
# all of its daemonic child processes.
t.daemon = True
......@@ -1264,7 +1285,7 @@ class Op(object):
is_thread_op, trace_buffer, model_config, workdir, thread_num,
device_type, devices, mem_optim, ir_optim, precision, use_mkldnn,
mkldnn_cache_capacity, mkldnn_op_list, mkldnn_bf16_op_list,
is_jump_op, output_channels_of_jump_ops):
is_jump_op, output_channels_of_jump_ops, min_subgraph_size, dynamic_shape_info):
"""
_run() is the entry function of OP process / thread model.When client
type is local_predictor in process mode, the CUDA environment needs to
......@@ -1316,7 +1337,9 @@ class Op(object):
use_mkldnn=use_mkldnn,
mkldnn_cache_capacity=mkldnn_cache_capacity,
mkldnn_op_list=mkldnn_op_list,
mkldnn_bf16_op_list=mkldnn_bf16_op_list)
mkldnn_bf16_op_list=mkldnn_bf16_op_list,
min_subgraph_size=min_subgraph_size,
dynamic_shape_info=dynamic_shape_info)
_LOGGER.info("Init cuda env in process {}".format(
concurrency_idx))
......
......@@ -260,6 +260,7 @@ class PipelineServer(object):
"use_calib": False,
"use_mkldnn": False,
"mkldnn_cache_capacity": 0,
"min_subgraph_size": 3,
},
}
for op in self._used_op:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册