提交 fdffdbd6 编写于 作者: F felixhjh

add ErrorCatch and ParamCheck for dynamic_shape_info

上级 6c362d20
doc/images/wechat_group_1.jpeg

334.9 KB | W: | H:

doc/images/wechat_group_1.jpeg

241.1 KB | W: | H:

doc/images/wechat_group_1.jpeg
doc/images/wechat_group_1.jpeg
doc/images/wechat_group_1.jpeg
doc/images/wechat_group_1.jpeg
  • 2-up
  • Swipe
  • Onion skin
......@@ -23,10 +23,13 @@ from .proto import general_model_config_pb2 as m_config
import paddle.inference as paddle_infer
import logging
import glob
from paddle_serving_server.pipeline.error_catch import ErrorCatch, CustomException, CustomExceptionCode, ParamChecker, ParamVerify
check_dynamic_shape_info=ParamVerify.check_dynamic_shape_info
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("LocalPredictor")
logger.setLevel(logging.INFO)
from paddle_serving_server.util import kill_stop_process_by_pid
precision_map = {
'int8': paddle_infer.PrecisionType.Int8,
......@@ -223,6 +226,15 @@ class LocalPredictor(object):
use_static=False,
use_calib_mode=use_calib)
@ErrorCatch
@ParamChecker
def dynamic_shape_info_helper(dynamic_shape_info:lambda dynamic_shape_info: check_dynamic_shape_info(dynamic_shape_info)):
pass
_, resp = dynamic_shape_info_helper(dynamic_shape_info)
if resp.err_no != CustomExceptionCode.OK.value:
print("dynamic_shape_info configure error, it should contain [min_input_shape', 'max_input_shape', 'opt_input_shape' {}".format(resp.err_msg))
kill_stop_process_by_pid("kill", os.getpgid(os.getpid()))
if len(dynamic_shape_info):
config.set_trt_dynamic_shape_info(
dynamic_shape_info['min_input_shape'],
......@@ -269,7 +281,18 @@ class LocalPredictor(object):
if mkldnn_bf16_op_list is not None:
config.set_bfloat16_op(mkldnn_bf16_op_list)
self.predictor = paddle_infer.create_predictor(config)
@ErrorCatch
def create_predictor_check(config):
predictor = paddle_infer.create_predictor(config)
return predictor
predictor, resp = create_predictor_check(config)
if resp.err_no != CustomExceptionCode.OK.value:
logger.critical(
"failed to create predictor: {}".format(resp.err_msg),
exc_info=False)
print("failed to create predictor: {}".format(resp.err_msg))
kill_stop_process_by_pid("kill", os.getpgid(os.getpid()))
self.predictor = predictor
def predict(self, feed=None, fetch=None, batch=False, log_id=0):
"""
......
......@@ -227,5 +227,17 @@ class ParamVerify(object):
if key not in right_fetch_list:
return False
return True
@staticmethod
def check_dynamic_shape_info(dynamic_shape_info):
if not isinstance(dynamic_shape_info, dict):
return False
if len(dynamic_shape_info) == 0:
return True
shape_info_keys = ["min_input_shape", "max_input_shape", "opt_input_shape"]
if all(key in dynamic_shape_info for key in shape_info_keys):
return True
else:
return False
ErrorCatch = ErrorCatch()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册