From 5708d4e13f6c1f2bd1145c1c6ba4dbd9c5264d4a Mon Sep 17 00:00:00 2001 From: LiuChiachi <709153940@qq.com> Date: Fri, 21 Jan 2022 06:56:51 +0000 Subject: [PATCH] supports predict of string list input --- python/paddle_serving_app/local_predict.py | 26 +++++++++++++--------- python/pipeline/channel.py | 4 +++- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/python/paddle_serving_app/local_predict.py b/python/paddle_serving_app/local_predict.py index 9081b8b3..5f922a28 100644 --- a/python/paddle_serving_app/local_predict.py +++ b/python/paddle_serving_app/local_predict.py @@ -160,12 +160,12 @@ class LocalPredictor(object): "use_trt:{}, use_lite:{}, use_xpu:{}, precision:{}, use_calib:{}, " "use_mkldnn:{}, mkldnn_cache_capacity:{}, mkldnn_op_list:{}, " "mkldnn_bf16_op_list:{}, use_feed_fetch_ops:{}, " - "use_ascend_cl:{}, min_subgraph_size:{}, dynamic_shape_info:{}".format( - model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim, - ir_optim, use_trt, use_lite, use_xpu, precision, use_calib, - use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list, - mkldnn_bf16_op_list, use_feed_fetch_ops, use_ascend_cl, - min_subgraph_size, dynamic_shape_info)) + "use_ascend_cl:{}, min_subgraph_size:{}, dynamic_shape_info:{}". + format(model_path, use_gpu, gpu_id, use_profile, thread_num, + mem_optim, ir_optim, use_trt, use_lite, use_xpu, precision, + use_calib, use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list, + mkldnn_bf16_op_list, use_feed_fetch_ops, use_ascend_cl, + min_subgraph_size, dynamic_shape_info)) self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] @@ -236,10 +236,10 @@ class LocalPredictor(object): kill_stop_process_by_pid("kill", os.getpgid(os.getpid())) if len(dynamic_shape_info): - config.set_trt_dynamic_shape_info( - dynamic_shape_info['min_input_shape'], - dynamic_shape_info['max_input_shape'], - dynamic_shape_info['opt_input_shape']) + config.set_trt_dynamic_shape_info( + dynamic_shape_info['min_input_shape'], + dynamic_shape_info['max_input_shape'], + dynamic_shape_info['opt_input_shape']) # set lite if use_lite: config.enable_lite_engine( @@ -338,7 +338,8 @@ class LocalPredictor(object): # Assemble the input data of paddle predictor, and filter invalid inputs. input_names = self.predictor.get_input_names() for name in input_names: - if isinstance(feed[name], list): + if isinstance(feed[name], list) and not isinstance(feed[name][0], + str): feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[ name]) if self.feed_types_[name] == 0: @@ -365,6 +366,9 @@ class LocalPredictor(object): feed[name] = feed[name].astype("complex64") elif self.feed_types_[name] == 11: feed[name] = feed[name].astype("complex128") + elif isinstance(feed[name], list) and isinstance(feed[name][0], + str): + pass else: raise ValueError("local predictor receives wrong data type") diff --git a/python/pipeline/channel.py b/python/pipeline/channel.py index 9ef1c09c..8df5a625 100644 --- a/python/pipeline/channel.py +++ b/python/pipeline/channel.py @@ -34,6 +34,7 @@ from .error_catch import CustomExceptionCode as ChannelDataErrcode _LOGGER = logging.getLogger(__name__) + class ChannelDataType(enum.Enum): """ Channel data type @@ -167,7 +168,8 @@ class ChannelData(object): elif isinstance(npdata, dict): # batch_size = 1 for _, value in npdata.items(): - if not isinstance(value, np.ndarray): + if not isinstance(value, np.ndarray) and not (isinstance( + value, list) and isinstance(value[0], str)): error_code = ChannelDataErrcode.TYPE_ERROR.value error_info = "Failed to check data: the value " \ "of data must be np.ndarray, but get {}.".format( -- GitLab