From 8a23df163712ab41a43b35cdcd495383726578fc Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Tue, 22 Sep 2020 11:59:32 +0000 Subject: [PATCH] add transforms in predict/bath_predict in deploy.py --- paddlex/deploy.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/paddlex/deploy.py b/paddlex/deploy.py index b04f46e..cbd78d9 100644 --- a/paddlex/deploy.py +++ b/paddlex/deploy.py @@ -247,13 +247,16 @@ class Predictor: [output_tensor.copy_to_cpu(), output_tensor_lod]) return output_results - def predict(self, image, topk=1): + def predict(self, image, topk=1, transforms=None): """ 图片预测 Args: image(str|np.ndarray): 图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。 - topk(int): 分类预测时使用,表示预测前topk的结果 + topk(int): 分类预测时使用,表示预测前topk的结果。 + transforms (paddlex.cls.transforms): 数据预处理操作。 """ + if transforms is not None: + self.transforms = transforms preprocessed_input = self.preprocess([image]) model_pred = self.raw_predict(preprocessed_input) im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[ @@ -269,15 +272,18 @@ class Predictor: return results[0] - def batch_predict(self, image_list, topk=1): + def batch_predict(self, image_list, topk=1, transforms=None): """ 图片预测 Args: image_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径 也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。 - topk(int): 分类预测时使用,表示预测前topk的结果 + topk(int): 分类预测时使用,表示预测前topk的结果。 + transforms (paddlex.cls.transforms): 数据预处理操作。 """ + if transforms is not None: + self.transforms = transforms preprocessed_input = self.preprocess(image_list, self.thread_pool) model_pred = self.raw_predict(preprocessed_input) im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[ -- GitLab