提交 0efcf0b6 编写于 作者: G gaotingquan

refactor: adapt to predict_cls

refer to commit: bd586f4a
上级 92ac78d1
......@@ -480,7 +480,7 @@ class PaddleClas(object):
input_data (Union[str, np.array]):
When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet.
When the type is np.array, it is the image data whose channel order is RGB.
print_pred (bool, optional): Whether print the prediction result. Defaults to False. Defaults to False.
print_pred (bool, optional): Whether print the prediction result. Defaults to False.
Raises:
ImageTypeError: Illegal input_data.
......@@ -489,12 +489,11 @@ class PaddleClas(object):
Generator[list, None, None]:
The prediction result(s) of input_data by batch_size. For every one image,
prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
The format is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
"""
if isinstance(input_data, np.ndarray):
outputs = self.cls_predictor.predict(input_data)
yield self.cls_predictor.postprocess(outputs)
yield self.cls_predictor.predict(input_data)
elif isinstance(input_data, str):
if input_data.startswith("http") or input_data.startswith("https"):
image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR)
......@@ -509,7 +508,7 @@ class PaddleClas(object):
image_list = get_image_list(input_data)
batch_size = self._config.Global.get("batch_size", 1)
topk = self._config.PostProcess.get('topk', 1)
topk = self._config.PostProcess.Topk.get('topk', 1)
img_list = []
img_path_list = []
......@@ -527,16 +526,15 @@ class PaddleClas(object):
cnt += 1
if cnt % batch_size == 0 or (idx + 1) == len(image_list):
outputs = self.cls_predictor.predict(img_list)
preds = self.cls_predictor.postprocess(outputs,
img_path_list)
preds = self.cls_predictor.predict(img_list)
if print_pred and preds:
for pred in preds:
filename = pred.pop("file_name")
for idx, pred in enumerate(preds):
pred_str = ", ".join(
[f"{k}: {pred[k]}" for k in pred])
print(
f"filename: {filename}, top-{topk}, {pred_str}")
f"filename: {img_path_list[idx]}, top-{topk}, {pred_str}"
)
img_list = []
img_path_list = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册