From 02ce5f94aacdb114066ce39fb9158e06860279b3 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Wed, 8 Jun 2022 14:01:51 +0000 Subject: [PATCH] fix: use cpu instead when gpu is invalid --- paddleclas.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/paddleclas.py b/paddleclas.py index 38b89f05..affc4833 100644 --- a/paddleclas.py +++ b/paddleclas.py @@ -31,6 +31,7 @@ import cv2 import numpy as np from tqdm import tqdm from prettytable import PrettyTable +import paddle from deploy.python.predict_cls import ClsPredictor from deploy.utils.get_image_list import get_image_list @@ -201,8 +202,13 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs): if "batch_size" in kwargs and kwargs["batch_size"]: cfg.Global.batch_size = kwargs["batch_size"] + if "use_gpu" in kwargs and kwargs["use_gpu"]: cfg.Global.use_gpu = kwargs["use_gpu"] + if cfg.Global.use_gpu and not paddle.device.is_compiled_with_cuda(): + msg = "The current running environment does not support the use of GPU. CPU has been used instead." + logger.warning(msg) + cfg.Global.use_gpu = False if "infer_imgs" in kwargs and kwargs["infer_imgs"]: cfg.Global.infer_imgs = kwargs["infer_imgs"] @@ -267,17 +273,25 @@ def args_cfg(): type=str, help="The directory of model files. Valid when model_name not specifed." ) - parser.add_argument("--use_gpu", type=str, help="Whether use GPU.") - parser.add_argument("--gpu_mem", type=int, default=8000, help="") + parser.add_argument("--use_gpu", type=str2bool, help="Whether use GPU.") + parser.add_argument( + "--gpu_mem", + type=int, + help="The memory size of GPU allocated to predict.") parser.add_argument( "--enable_mkldnn", type=str2bool, - default=False, help="Whether use MKLDNN. Valid when use_gpu is False") - parser.add_argument("--cpu_num_threads", type=int, default=1, help="") parser.add_argument( - "--use_tensorrt", type=str2bool, default=False, help="") - parser.add_argument("--use_fp16", type=str2bool, default=False, help="") + "--cpu_num_threads", + type=int, + help="The threads number when predicting on CPU.") + parser.add_argument( + "--use_tensorrt", + type=str2bool, + help="Whether use TensorRT to accelerate. ") + parser.add_argument( + "--use_fp16", type=str2bool, help="Whether use FP16 to predict.") parser.add_argument("--batch_size", type=int, help="Batch size.") parser.add_argument( "--topk", -- GitLab