未验证 提交 e3801c55 编写于 作者: T Tingquan Gao 提交者: GitHub

Fix the bug about calling create_predictor repeatedly in hubserving (#462)

上级 022a4114
...@@ -25,7 +25,7 @@ import numpy as np ...@@ -25,7 +25,7 @@ import numpy as np
import paddlehub as hub import paddlehub as hub
import tools.infer.predict as paddle_predict import tools.infer.predict as paddle_predict
from tools.infer.utils import Base64ToCV2 from tools.infer.utils import Base64ToCV2, create_paddle_predictor
from deploy.hubserving.clas.params import read_params from deploy.hubserving.clas.params import read_params
...@@ -96,6 +96,7 @@ class ClasSystem(hub.Module): ...@@ -96,6 +96,7 @@ class ClasSystem(hub.Module):
assert predicted_data != [], "There is not any image to be predicted. Please check the input data." assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
predictor = create_paddle_predictor(self.args)
all_results = [] all_results = []
for img in predicted_data: for img in predicted_data:
if img is None: if img is None:
...@@ -106,7 +107,7 @@ class ClasSystem(hub.Module): ...@@ -106,7 +107,7 @@ class ClasSystem(hub.Module):
self.args.image_file = img self.args.image_file = img
self.args.top_k = top_k self.args.top_k = top_k
classes, scores = paddle_predict.main(self.args) classes, scores = paddle_predict.predict(self.args, predictor)
elapse = time.time() - starttime elapse = time.time() - starttime
logger.info("Predict time: {}".format(elapse)) logger.info("Predict time: {}".format(elapse))
......
...@@ -138,7 +138,7 @@ hub serving start -c deploy/hubserving/clas/config.json ...@@ -138,7 +138,7 @@ hub serving start -c deploy/hubserving/clas/config.json
```hub uninstall clas_system``` ```hub uninstall clas_system```
- 4、 安装修改后的新服务包 - 4、 安装修改后的新服务包
```hub install deploy/hubserving/clas_system/``` ```hub install deploy/hubserving/clas/```
- 5、重新启动服务 - 5、重新启动服务
```hub serving start -m clas_system``` ```hub serving start -m clas_system```
...@@ -144,7 +144,7 @@ hub uninstall clas_system ...@@ -144,7 +144,7 @@ hub uninstall clas_system
``` ```
- 4. Install modified service module - 4. Install modified service module
```shell ```shell
hub install deploy/hubserving/clas_system/ hub install deploy/hubserving/clas/
``` ```
- 5. Restart service - 5. Restart service
```shell ```shell
......
...@@ -19,51 +19,8 @@ import numpy as np ...@@ -19,51 +19,8 @@ import numpy as np
import cv2 import cv2
import time import time
from paddle.inference import Config
from paddle.inference import create_predictor
def create_paddle_predictor(args):
config = Config(args.model_file, args.params_file)
if args.use_gpu:
config.enable_use_gpu(args.gpu_mem, 0)
else:
config.disable_gpu()
if args.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.disable_glog_info()
config.switch_ir_optim(args.ir_optim) # default true
if args.use_tensorrt:
config.enable_tensorrt_engine(
precision_mode=Config.Precision.Half
if args.use_fp16 else Config.Precision.Float32,
max_batch_size=args.batch_size)
config.enable_memory_optim()
# use zero copy
config.switch_use_feed_fetch_ops(False)
predictor = create_predictor(config)
return predictor
def main(args):
if not args.enable_benchmark:
assert args.batch_size == 1
assert args.use_fp16 is False
else:
assert args.use_gpu is True
assert args.model is not None
# HALF precission predict only work when using tensorrt
if args.use_fp16 is True:
assert args.use_tensorrt is True
predictor = create_paddle_predictor(args)
def predict(args, predictor):
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
input_tensor = predictor.get_input_handle(input_names[0]) input_tensor = predictor.get_input_handle(input_names[0])
...@@ -91,10 +48,11 @@ def main(args): ...@@ -91,10 +48,11 @@ def main(args):
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
classes, scores = utils.postprocess(output, args) classes, scores = utils.postprocess(output, args)
if args.hubserving:
return classes, scores
print("Current image file: {}".format(args.image_file)) print("Current image file: {}".format(args.image_file))
print("\ttop-1 class: {0}".format(classes[0])) print("\ttop-1 class: {0}".format(classes[0]))
print("\ttop-1 score: {0}".format(scores[0])) print("\ttop-1 score: {0}".format(scores[0]))
else: else:
for i in range(0, test_num + 10): for i in range(0, test_num + 10):
inputs = np.random.rand(args.batch_size, 3, 224, inputs = np.random.rand(args.batch_size, 3, 224,
...@@ -117,6 +75,19 @@ def main(args): ...@@ -117,6 +75,19 @@ def main(args):
/ test_num)) / test_num))
def main(args):
if not args.enable_benchmark:
assert args.batch_size == 1
else:
assert args.model is not None
# HALF precission predict only work when using tensorrt
if args.use_fp16 is True:
assert args.use_tensorrt is True
predictor = utils.create_paddle_predictor(args)
predict(args, predictor)
if __name__ == "__main__": if __name__ == "__main__":
args = utils.parse_args() args = utils.parse_args()
main(args) main(args)
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
import argparse import argparse
import cv2 import cv2
import numpy as np import numpy as np
from paddle.inference import Config
from paddle.inference import create_predictor
def parse_args(): def parse_args():
...@@ -65,6 +67,34 @@ def parse_args(): ...@@ -65,6 +67,34 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def create_paddle_predictor(args):
config = Config(args.model_file, args.params_file)
if args.use_gpu:
config.enable_use_gpu(args.gpu_mem, 0)
else:
config.disable_gpu()
if args.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.disable_glog_info()
config.switch_ir_optim(args.ir_optim) # default true
if args.use_tensorrt:
config.enable_tensorrt_engine(
precision_mode=Config.Precision.Half
if args.use_fp16 else Config.Precision.Float32,
max_batch_size=args.batch_size)
config.enable_memory_optim()
# use zero copy
config.switch_use_feed_fetch_ops(False)
predictor = create_predictor(config)
return predictor
def preprocess(img, args): def preprocess(img, args):
resize_op = ResizeImage(resize_short=args.resize_short) resize_op = ResizeImage(resize_short=args.resize_short)
img = resize_op(img) img = resize_op(img)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册