提交 4519cfa9 编写于 作者: S ShiningZhang

set trt dynamic shape for ocr

上级 37840afb
...@@ -65,6 +65,10 @@ message EngineDesc { ...@@ -65,6 +65,10 @@ message EngineDesc {
optional int32 batch_infer_size = 31 [ default = 32 ]; optional int32 batch_infer_size = 31 [ default = 32 ];
optional bool enable_overrun = 32 [ default = false ]; optional bool enable_overrun = 32 [ default = false ];
optional bool allow_split_request = 33 [ default = true ]; optional bool allow_split_request = 33 [ default = true ];
optional int32 min_subgraph_size = 34 [ default = 3 ];
map<string,string> min_input_shape = 35;
map<string,string> max_input_shape = 36;
map<string,string> opt_input_shape = 37;
}; };
// model_toolkit conf // model_toolkit conf
......
...@@ -244,7 +244,7 @@ int GeneralDetectionOp::inference() { ...@@ -244,7 +244,7 @@ int GeneralDetectionOp::inference() {
databuf_char_out = reinterpret_cast<char*>(databuf_data_out); databuf_char_out = reinterpret_cast<char*>(databuf_data_out);
paddle::PaddleBuf paddleBuf(databuf_char_out, databuf_size_out); paddle::PaddleBuf paddleBuf(databuf_char_out, databuf_size_out);
paddle::PaddleTensor tensor_out; paddle::PaddleTensor tensor_out;
tensor_out.name = "image"; tensor_out.name = "x";
tensor_out.dtype = paddle::PaddleDType::FLOAT32; tensor_out.dtype = paddle::PaddleDType::FLOAT32;
tensor_out.shape = output_shape; tensor_out.shape = output_shape;
tensor_out.data = paddleBuf; tensor_out.data = paddleBuf;
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
## Get Model ## Get Model
``` ```
wget https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/OCR/ocr_rec.tar.gz python3 -m paddle_serving_app.package --get_model ocr_rec
tar -xzvf ocr_rec.tar.gz tar -xzvf ocr_rec.tar.gz
wget https://paddle-serving.bj.bcebos.com/ocr/ocr_det.tar.gz python3 -m paddle_serving_app.package --get_model ocr_det
tar -xzvf ocr_det.tar.gz tar -xzvf ocr_det.tar.gz
``` ```
...@@ -108,7 +108,7 @@ python3 rec_web_client.py ...@@ -108,7 +108,7 @@ python3 rec_web_client.py
When a service starts the concatenation of two models, it only needs to pass in the relative path of the model folder in order after `--model`, and the custom C++ OP class name after `--op`. The order of the model after `--model` and the class name after `--OP` needs to correspond. Here, it is assumed that we have defined the two OPs as GeneralDetectionOp and GeneralRecOp respectively, The script code is as follows: When a service starts the concatenation of two models, it only needs to pass in the relative path of the model folder in order after `--model`, and the custom C++ OP class name after `--op`. The order of the model after `--model` and the class name after `--OP` needs to correspond. Here, it is assumed that we have defined the two OPs as GeneralDetectionOp and GeneralRecOp respectively, The script code is as follows:
```python ```python
#One service starts the concatenation of two models #One service starts the concatenation of two models
python3 -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --op GeneralDetectionOp GeneralRecOp --port 9293 python3 -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --op GeneralDetectionOp GeneralInferOp --port 9293
#ocr_det_model correspond to GeneralDetectionOp, ocr_rec_model correspond to GeneralRecOp #ocr_det_model correspond to GeneralDetectionOp, ocr_rec_model correspond to GeneralRecOp
``` ```
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
## 获取模型 ## 获取模型
``` ```
wget https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/OCR/ocr_rec.tar.gz python3 -m paddle_serving_app.package --get_model ocr_rec
tar -xzvf ocr_rec.tar.gz tar -xzvf ocr_rec.tar.gz
wget https://paddle-serving.bj.bcebos.com/ocr/ocr_det.tar.gz python3 -m paddle_serving_app.package --get_model ocr_det
tar -xzvf ocr_det.tar.gz tar -xzvf ocr_det.tar.gz
``` ```
## 获取数据集(可选) ## 获取数据集(可选)
...@@ -106,7 +106,7 @@ python3 rec_web_client.py ...@@ -106,7 +106,7 @@ python3 rec_web_client.py
一个服务启动两个模型串联,只需要在`--model后依次按顺序传入模型文件夹的相对路径`,且需要在`--op后依次传入自定义C++OP类名称`,其中--model后面的模型与--op后面的类名称的顺序需要对应,`这里假设我们已经定义好了两个OP分别为GeneralDetectionOp和GeneralRecOp`,则脚本代码如下: 一个服务启动两个模型串联,只需要在`--model后依次按顺序传入模型文件夹的相对路径`,且需要在`--op后依次传入自定义C++OP类名称`,其中--model后面的模型与--op后面的类名称的顺序需要对应,`这里假设我们已经定义好了两个OP分别为GeneralDetectionOp和GeneralRecOp`,则脚本代码如下:
```python ```python
#一个服务启动多模型串联 #一个服务启动多模型串联
python3 -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --op GeneralDetectionOp GeneralRecOp --port 9293 python3 -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --op GeneralDetectionOp GeneralInferOp --port 9293
#多模型串联 ocr_det_model对应GeneralDetectionOp ocr_rec_model对应GeneralRecOp #多模型串联 ocr_det_model对应GeneralDetectionOp ocr_rec_model对应GeneralRecOp
``` ```
......
...@@ -47,18 +47,18 @@ class OCRService(WebService): ...@@ -47,18 +47,18 @@ class OCRService(WebService):
}) })
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
data = base64.b64decode(feed[0]["image"].encode('utf8')) data = base64.b64decode(feed[0]["x"].encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR) im = cv2.imdecode(data, cv2.IMREAD_COLOR)
self.ori_h, self.ori_w, _ = im.shape self.ori_h, self.ori_w, _ = im.shape
det_img = self.det_preprocess(im) det_img = self.det_preprocess(im)
_, self.new_h, self.new_w = det_img.shape _, self.new_h, self.new_w = det_img.shape
return { return {
"image": det_img[np.newaxis, :].copy() "x": det_img[np.newaxis, :].copy()
}, ["concat_1.tmp_0"], True }, ["save_infer_model/scale_0.tmp_1"], True
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
det_out = fetch_map["concat_1.tmp_0"] det_out = fetch_map["save_infer_model/scale_0.tmp_1"]
ratio_list = [ ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
] ]
......
...@@ -47,17 +47,17 @@ class OCRService(WebService): ...@@ -47,17 +47,17 @@ class OCRService(WebService):
}) })
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
data = base64.b64decode(feed[0]["image"].encode('utf8')) data = base64.b64decode(feed[0]["x"].encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR) im = cv2.imdecode(data, cv2.IMREAD_COLOR)
self.ori_h, self.ori_w, _ = im.shape self.ori_h, self.ori_w, _ = im.shape
det_img = self.det_preprocess(im) det_img = self.det_preprocess(im)
_, self.new_h, self.new_w = det_img.shape _, self.new_h, self.new_w = det_img.shape
print(det_img) print(det_img)
return {"image": det_img}, ["concat_1.tmp_0"], False return {"x": det_img}, ["save_infer_model/scale_0.tmp_1"], False
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
det_out = fetch_map["concat_1.tmp_0"] det_out = fetch_map["save_infer_model/scale_0.tmp_1"]
ratio_list = [ ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
] ]
......
...@@ -42,13 +42,11 @@ for img_file in os.listdir(test_img_dir): ...@@ -42,13 +42,11 @@ for img_file in os.listdir(test_img_dir):
image_data = file.read() image_data = file.read()
image = cv2_to_base64(image_data) image = cv2_to_base64(image_data)
fetch_map = client.predict( fetch_map = client.predict(
feed={"image": image}, feed={"x": image},
fetch=["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"], fetch=["save_infer_model/scale_0.tmp_1"],
batch=True) batch=True)
result = {} result = {}
result["score"] = fetch_map["softmax_0.tmp_0"] rec_res = OCRReader().postprocess_ocrv2(fetch_map, with_score=False)
del fetch_map["softmax_0.tmp_0"]
rec_res = OCRReader().postprocess(fetch_map, with_score=False)
res_lst = [] res_lst = []
for res in rec_res: for res in rec_res:
res_lst.append(res[0]) res_lst.append(res[0])
......
...@@ -48,7 +48,7 @@ class OCRService(WebService): ...@@ -48,7 +48,7 @@ class OCRService(WebService):
self.ocr_reader = OCRReader() self.ocr_reader = OCRReader()
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
data = base64.b64decode(feed[0]["image"].encode('utf8')) data = base64.b64decode(feed[0]["x"].encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR) im = cv2.imdecode(data, cv2.IMREAD_COLOR)
ori_h, ori_w, _ = im.shape ori_h, ori_w, _ = im.shape
...@@ -57,7 +57,7 @@ class OCRService(WebService): ...@@ -57,7 +57,7 @@ class OCRService(WebService):
det_img = det_img[np.newaxis, :] det_img = det_img[np.newaxis, :]
det_img = det_img.copy() det_img = det_img.copy()
det_out = self.det_client.predict( det_out = self.det_client.predict(
feed={"image": det_img}, fetch=["concat_1.tmp_0"], batch=True) feed={"x": det_img}, fetch=["save_infer_model/scale_0.tmp_1"], batch=True)
filter_func = FilterBoxes(10, 10) filter_func = FilterBoxes(10, 10)
post_func = DBPostProcess({ post_func = DBPostProcess({
"thresh": 0.3, "thresh": 0.3,
...@@ -68,7 +68,7 @@ class OCRService(WebService): ...@@ -68,7 +68,7 @@ class OCRService(WebService):
}) })
sorted_boxes = SortedBoxes() sorted_boxes = SortedBoxes()
ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w] ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w]
dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list]) dt_boxes_list = post_func(det_out["save_infer_model/scale_0.tmp_1"], [ratio_list])
dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w]) dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w])
dt_boxes = sorted_boxes(dt_boxes) dt_boxes = sorted_boxes(dt_boxes)
get_rotate_crop_image = GetRotateCropImage() get_rotate_crop_image = GetRotateCropImage()
...@@ -88,12 +88,12 @@ class OCRService(WebService): ...@@ -88,12 +88,12 @@ class OCRService(WebService):
for id, img in enumerate(img_list): for id, img in enumerate(img_list):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img imgs[id] = norm_img
feed = {"image": imgs.copy()} feed = {"x": imgs.copy()}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] fetch = ["save_infer_model/scale_0.tmp_1"]
return feed, fetch, True return feed, fetch, True
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) rec_res = self.ocr_reader.postprocess_ocrv2(fetch_map, with_score=True)
res_lst = [] res_lst = []
for res in rec_res: for res in rec_res:
res_lst.append(res[0]) res_lst.append(res[0])
......
...@@ -34,7 +34,7 @@ for img_file in os.listdir(test_img_dir): ...@@ -34,7 +34,7 @@ for img_file in os.listdir(test_img_dir):
with open(os.path.join(test_img_dir, img_file), 'rb') as file: with open(os.path.join(test_img_dir, img_file), 'rb') as file:
image_data1 = file.read() image_data1 = file.read()
image = cv2_to_base64(image_data1) image = cv2_to_base64(image_data1)
data = {"feed": [{"image": image}], "fetch": ["res"]} data = {"feed": [{"x": image}], "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data)) r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r) print(r)
print(r.json()) print(r.json())
...@@ -44,13 +44,13 @@ class OCRService(WebService): ...@@ -44,13 +44,13 @@ class OCRService(WebService):
self.ocr_reader = OCRReader() self.ocr_reader = OCRReader()
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
data = base64.b64decode(feed[0]["image"].encode('utf8')) data = base64.b64decode(feed[0]["x"].encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR) im = cv2.imdecode(data, cv2.IMREAD_COLOR)
ori_h, ori_w, _ = im.shape ori_h, ori_w, _ = im.shape
det_img = self.det_preprocess(im) det_img = self.det_preprocess(im)
det_out = self.det_client.predict( det_out = self.det_client.predict(
feed={"image": det_img}, fetch=["concat_1.tmp_0"], batch=False) feed={"x": det_img}, fetch=["save_infer_model/scale_0.tmp_1"], batch=False)
_, new_h, new_w = det_img.shape _, new_h, new_w = det_img.shape
filter_func = FilterBoxes(10, 10) filter_func = FilterBoxes(10, 10)
post_func = DBPostProcess({ post_func = DBPostProcess({
...@@ -62,7 +62,7 @@ class OCRService(WebService): ...@@ -62,7 +62,7 @@ class OCRService(WebService):
}) })
sorted_boxes = SortedBoxes() sorted_boxes = SortedBoxes()
ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w] ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w]
dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list]) dt_boxes_list = post_func(det_out["save_infer_model/scale_0.tmp_1"], [ratio_list])
dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w]) dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w])
dt_boxes = sorted_boxes(dt_boxes) dt_boxes = sorted_boxes(dt_boxes)
get_rotate_crop_image = GetRotateCropImage() get_rotate_crop_image = GetRotateCropImage()
...@@ -78,12 +78,12 @@ class OCRService(WebService): ...@@ -78,12 +78,12 @@ class OCRService(WebService):
for img in img_list: for img in img_list:
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
feed_list.append(norm_img[np.newaxis, :]) feed_list.append(norm_img[np.newaxis, :])
feed_batch = {"image": np.concatenate(feed_list, axis=0)} feed_batch = {"x": np.concatenate(feed_list, axis=0)}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] fetch = ["save_infer_model/scale_0.tmp_1"]
return feed_batch, fetch, True return feed_batch, fetch, True
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) rec_res = self.ocr_reader.postprocess_ocrv2(fetch_map, with_score=True)
res_lst = [] res_lst = []
for res in rec_res: for res in rec_res:
res_lst.append(res[0]) res_lst.append(res[0])
......
...@@ -38,7 +38,7 @@ class OCRService(WebService): ...@@ -38,7 +38,7 @@ class OCRService(WebService):
def preprocess(self, feed=[], fetch=[]): def preprocess(self, feed=[], fetch=[]):
img_list = [] img_list = []
for feed_data in feed: for feed_data in feed:
data = base64.b64decode(feed_data["image"].encode('utf8')) data = base64.b64decode(feed_data["x"].encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR) im = cv2.imdecode(data, cv2.IMREAD_COLOR)
img_list.append(im) img_list.append(im)
...@@ -53,12 +53,12 @@ class OCRService(WebService): ...@@ -53,12 +53,12 @@ class OCRService(WebService):
for i, img in enumerate(img_list): for i, img in enumerate(img_list):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[i] = norm_img imgs[i] = norm_img
feed = {"image": imgs.copy()} feed = {"x": imgs.copy()}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] fetch = ["save_infer_model/scale_0.tmp_1"]
return feed, fetch, True return feed, fetch, True
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) rec_res = self.ocr_reader.postprocess_ocrv2(fetch_map, with_score=True)
res_lst = [] res_lst = []
for res in rec_res: for res in rec_res:
res_lst.append(res[0]) res_lst.append(res[0])
......
...@@ -36,6 +36,6 @@ for img_file in os.listdir(test_img_dir): ...@@ -36,6 +36,6 @@ for img_file in os.listdir(test_img_dir):
image_data1 = file.read() image_data1 = file.read()
image = cv2_to_base64(image_data1) image = cv2_to_base64(image_data1)
#data = {"feed": [{"image": image}], "fetch": ["res"]} #data = {"feed": [{"image": image}], "fetch": ["res"]}
data = {"feed": [{"image": image}] * 3, "fetch": ["res"]} data = {"feed": [{"x": image}] * 3, "fetch": ["res"]}
r = requests.post(url=url, headers=headers, data=json.dumps(data)) r = requests.post(url=url, headers=headers, data=json.dumps(data))
print(r.json()) print(r.json())
...@@ -39,7 +39,7 @@ class OCRService(WebService): ...@@ -39,7 +39,7 @@ class OCRService(WebService):
# TODO: to handle batch rec images # TODO: to handle batch rec images
img_list = [] img_list = []
for feed_data in feed: for feed_data in feed:
data = base64.b64decode(feed_data["image"].encode('utf8')) data = base64.b64decode(feed_data["x"].encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR) im = cv2.imdecode(data, cv2.IMREAD_COLOR)
img_list.append(im) img_list.append(im)
...@@ -55,12 +55,12 @@ class OCRService(WebService): ...@@ -55,12 +55,12 @@ class OCRService(WebService):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[i] = norm_img imgs[i] = norm_img
feed = {"image": imgs.copy()} feed = {"x": imgs.copy()}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] fetch = ["save_infer_model/scale_0.tmp_1"]
return feed, fetch, True return feed, fetch, True
def postprocess(self, feed={}, fetch=[], fetch_map=None): def postprocess(self, feed={}, fetch=[], fetch_map=None):
rec_res = self.ocr_reader.postprocess(fetch_map, with_score=True) rec_res = self.ocr_reader.postprocess_ocrv2(fetch_map, with_score=True)
res_lst = [] res_lst = []
for res in rec_res: for res in rec_res:
res_lst.append(res[0]) res_lst.append(res[0])
......
...@@ -225,6 +225,12 @@ class PaddleInferenceEngine : public EngineCore { ...@@ -225,6 +225,12 @@ class PaddleInferenceEngine : public EngineCore {
config.SwitchIrOptim(true); config.SwitchIrOptim(true);
} }
int local_min_subgraph_size = min_subgraph_size;
if (engine_conf.has_min_subgraph_size()) {
local_min_subgraph_size = engine_conf.min_subgraph_size();
LOG(INFO) << "local_min_subgraph_size=" << local_min_subgraph_size;
}
if (engine_conf.has_use_trt() && engine_conf.use_trt()) { if (engine_conf.has_use_trt() && engine_conf.use_trt()) {
config.SwitchIrOptim(true); config.SwitchIrOptim(true);
if (!engine_conf.has_use_gpu() || !engine_conf.use_gpu()) { if (!engine_conf.has_use_gpu() || !engine_conf.use_gpu()) {
...@@ -236,10 +242,55 @@ class PaddleInferenceEngine : public EngineCore { ...@@ -236,10 +242,55 @@ class PaddleInferenceEngine : public EngineCore {
} }
config.EnableTensorRtEngine(1 << 20, config.EnableTensorRtEngine(1 << 20,
max_batch, max_batch,
min_subgraph_size, local_min_subgraph_size,
precision_type, precision_type,
false, false,
FLAGS_use_calib); FLAGS_use_calib);
std::map<std::string, std::vector<int>> min_input_shape;
std::map<std::string, std::vector<int>> max_input_shape;
std::map<std::string, std::vector<int>> optim_input_shape;
if (engine_conf.min_input_shape_size() > 0) {
for (auto& iter : engine_conf.min_input_shape()) {
std::string key = iter.first;
std::string value = iter.second;
std::istringstream ss(value);
std::string word;
std::vector<int> arr;
while(ss >> word) {
arr.push_back(std::stoi(word));
}
min_input_shape[key] = arr;
}
}
if (engine_conf.max_input_shape_size() > 0) {
for (auto& iter : engine_conf.max_input_shape()) {
std::string key = iter.first;
std::string value = iter.second;
std::istringstream ss(value);
std::string word;
std::vector<int> arr;
while(ss >> word) {
arr.push_back(std::stoi(word));
}
max_input_shape[key] = arr;
}
}
if (engine_conf.opt_input_shape_size() > 0) {
for (auto& iter : engine_conf.opt_input_shape()) {
std::string key = iter.first;
std::string value = iter.second;
std::istringstream ss(value);
std::string word;
std::vector<int> arr;
while(ss >> word) {
arr.push_back(std::stoi(word));
}
optim_input_shape[key] = arr;
}
}
config.SetTRTDynamicShapeInfo(min_input_shape,
max_input_shape,
optim_input_shape);
LOG(INFO) << "create TensorRT predictor"; LOG(INFO) << "create TensorRT predictor";
} }
......
...@@ -222,6 +222,8 @@ def serve_args(): ...@@ -222,6 +222,8 @@ def serve_args():
"--prometheus_port", type=int, default=19393, help="Port of the Prometheus") "--prometheus_port", type=int, default=19393, help="Port of the Prometheus")
parser.add_argument( parser.add_argument(
"--request_cache_size", type=int, default=0, help="Port of the Prometheus") "--request_cache_size", type=int, default=0, help="Port of the Prometheus")
parser.add_argument(
"--min_subgraph_size", type=int, default="", nargs="+", help="gpu ids")
return parser.parse_args() return parser.parse_args()
...@@ -272,11 +274,14 @@ def start_gpu_card_model(gpu_mode, port, args): # pylint: disable=doc-string-mi ...@@ -272,11 +274,14 @@ def start_gpu_card_model(gpu_mode, port, args): # pylint: disable=doc-string-mi
read_op = op_maker.create('GeneralReaderOp') read_op = op_maker.create('GeneralReaderOp')
op_seq_maker.add_op(read_op) op_seq_maker.add_op(read_op)
is_ocr = False
#如果dag_list_op不是空,那么证明通过--op 传入了自定义OP或自定义的DAG串联关系。 #如果dag_list_op不是空,那么证明通过--op 传入了自定义OP或自定义的DAG串联关系。
#此时,根据--op 传入的顺序去组DAG串联关系 #此时,根据--op 传入的顺序去组DAG串联关系
if len(dag_list_op) > 0: if len(dag_list_op) > 0:
for single_op in dag_list_op: for single_op in dag_list_op:
op_seq_maker.add_op(op_maker.create(single_op)) op_seq_maker.add_op(op_maker.create(single_op))
if single_op == "GeneralDetectionOp":
is_ocr = True
#否则,仍然按照原有方式根虎--model去串联。 #否则,仍然按照原有方式根虎--model去串联。
else: else:
for idx, single_model in enumerate(model): for idx, single_model in enumerate(model):
...@@ -287,6 +292,7 @@ def start_gpu_card_model(gpu_mode, port, args): # pylint: disable=doc-string-mi ...@@ -287,6 +292,7 @@ def start_gpu_card_model(gpu_mode, port, args): # pylint: disable=doc-string-mi
# 以后可能考虑不用python脚本来生成配置 # 以后可能考虑不用python脚本来生成配置
if len(model) == 2 and idx == 0 and single_model == "ocr_det_model": if len(model) == 2 and idx == 0 and single_model == "ocr_det_model":
infer_op_name = "GeneralDetectionOp" infer_op_name = "GeneralDetectionOp"
is_ocr = True
else: else:
infer_op_name = "GeneralInferOp" infer_op_name = "GeneralInferOp"
general_infer_op = op_maker.create(infer_op_name) general_infer_op = op_maker.create(infer_op_name)
...@@ -306,10 +312,14 @@ def start_gpu_card_model(gpu_mode, port, args): # pylint: disable=doc-string-mi ...@@ -306,10 +312,14 @@ def start_gpu_card_model(gpu_mode, port, args): # pylint: disable=doc-string-mi
server.set_enable_prometheus(args.enable_prometheus) server.set_enable_prometheus(args.enable_prometheus)
server.set_prometheus_port(args.prometheus_port) server.set_prometheus_port(args.prometheus_port)
server.set_request_cache_size(args.request_cache_size) server.set_request_cache_size(args.request_cache_size)
server.set_min_subgraph_size(args.min_subgraph_size)
if args.use_trt and device == "gpu": if args.use_trt and device == "gpu":
server.set_trt() server.set_trt()
server.set_ir_optimize(True) server.set_ir_optimize(True)
if is_ocr:
info = set_ocr_dynamic_shape_info()
server.set_trt_dynamic_shape_info(info)
if args.gpu_multi_stream and device == "gpu": if args.gpu_multi_stream and device == "gpu":
server.set_gpu_multi_stream() server.set_gpu_multi_stream()
...@@ -344,6 +354,51 @@ def start_gpu_card_model(gpu_mode, port, args): # pylint: disable=doc-string-mi ...@@ -344,6 +354,51 @@ def start_gpu_card_model(gpu_mode, port, args): # pylint: disable=doc-string-mi
use_encryption_model=args.use_encryption_model) use_encryption_model=args.use_encryption_model)
server.run_server() server.run_server()
def set_ocr_dynamic_shape_info():
info = []
min_input_shape = {
"x": [1, 3, 50, 50],
"conv2d_182.tmp_0": [1, 1, 20, 20],
"nearest_interp_v2_2.tmp_0": [1, 1, 20, 20],
"nearest_interp_v2_3.tmp_0": [1, 1, 20, 20],
"nearest_interp_v2_4.tmp_0": [1, 1, 20, 20],
"nearest_interp_v2_5.tmp_0": [1, 1, 20, 20]
}
max_input_shape = {
"x": [1, 3, 1536, 1536],
"conv2d_182.tmp_0": [20, 200, 960, 960],
"nearest_interp_v2_2.tmp_0": [20, 200, 960, 960],
"nearest_interp_v2_3.tmp_0": [20, 200, 960, 960],
"nearest_interp_v2_4.tmp_0": [20, 200, 960, 960],
"nearest_interp_v2_5.tmp_0": [20, 200, 960, 960],
}
opt_input_shape = {
"x": [1, 3, 960, 960],
"conv2d_182.tmp_0": [3, 96, 240, 240],
"nearest_interp_v2_2.tmp_0": [3, 96, 240, 240],
"nearest_interp_v2_3.tmp_0": [3, 24, 240, 240],
"nearest_interp_v2_4.tmp_0": [3, 24, 240, 240],
"nearest_interp_v2_5.tmp_0": [3, 24, 240, 240],
}
det_info = {
"min_input_shape": min_input_shape,
"max_input_shape": max_input_shape,
"opt_input_shape": opt_input_shape,
}
info.append(det_info)
min_input_shape = {"x": [1, 3, 32, 10], "lstm_1.tmp_0": [1, 1, 128]}
max_input_shape = {
"x": [50, 3, 32, 1000],
"lstm_1.tmp_0": [500, 50, 128]
}
opt_input_shape = {"x": [6, 3, 32, 100], "lstm_1.tmp_0": [25, 5, 128]}
rec_info = {
"min_input_shape": min_input_shape,
"max_input_shape": max_input_shape,
"opt_input_shape": opt_input_shape,
}
info.append(rec_info)
return info
def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-missing def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-missing
......
...@@ -101,6 +101,8 @@ class Server(object): ...@@ -101,6 +101,8 @@ class Server(object):
self.enable_prometheus = False self.enable_prometheus = False
self.prometheus_port = 19393 self.prometheus_port = 19393
self.request_cache_size = 0 self.request_cache_size = 0
self.min_subgraph_size = []
self.trt_dynamic_shape_info = []
def get_fetch_list(self, infer_node_idx=-1): def get_fetch_list(self, infer_node_idx=-1):
fetch_names = [ fetch_names = [
...@@ -211,6 +213,13 @@ class Server(object): ...@@ -211,6 +213,13 @@ class Server(object):
def set_request_cache_size(self, request_cache_size): def set_request_cache_size(self, request_cache_size):
self.request_cache_size = request_cache_size self.request_cache_size = request_cache_size
def set_min_subgraph_size(self, min_subgraph_size):
if isinstance(min_subgraph_size, list):
self.min_subgraph_size = list(map(int, min_subgraph_size))
def set_trt_dynamic_shape_info(self, info):
self.trt_dynamic_shape_info = info
def _prepare_engine(self, model_config_paths, device, use_encryption_model): def _prepare_engine(self, model_config_paths, device, use_encryption_model):
self.device = device self.device = device
if self.model_toolkit_conf == None: if self.model_toolkit_conf == None:
...@@ -292,6 +301,25 @@ class Server(object): ...@@ -292,6 +301,25 @@ class Server(object):
if use_encryption_model: if use_encryption_model:
engine.encrypted_model = True engine.encrypted_model = True
engine.type = "PADDLE_INFER" engine.type = "PADDLE_INFER"
if len(self.min_subgraph_size) > index:
engine.min_subgraph_size = self.min_subgraph_size[index]
if len(self.trt_dynamic_shape_info) > index:
dynamic_shape_info = self.trt_dynamic_shape_info[index]
try:
for key,value in dynamic_shape_info.items():
shape_type = key
if shape_type == "min_input_shape":
local_map = engine.min_input_shape
if shape_type == "max_input_shape":
local_map = engine.max_input_shape
if shape_type == "opt_input_shape":
local_map = engine.opt_input_shape
for name,shape in value.items():
local_value = ' '.join(str(i) for i in shape)
local_map[name] = local_value
except:
raise ValueError("Set TRT dynamic shape info error!")
self.model_toolkit_conf.append(server_sdk.ModelToolkitConf()) self.model_toolkit_conf.append(server_sdk.ModelToolkitConf())
self.model_toolkit_conf[-1].engines.extend([engine]) self.model_toolkit_conf[-1].engines.extend([engine])
index = index + 1 index = index + 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册