diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 3de00d83a8f9f55af9b89d5d2cd5c877399c5930..5c75e0c480eac6796d6d4b7075d1b38d254380fd 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -101,6 +101,7 @@ class TextDetector(object): if args.benchmark: import auto_log pid = os.getpid() + gpu_id = utility.get_infer_gpuid() self.autolog = auto_log.AutoLogger( model_name="det", model_precision=args.precision, @@ -110,7 +111,7 @@ class TextDetector(object): inference_config=self.config, pids=pid, process_name=None, - gpu_ids=0, + gpu_ids=gpu_id if args.use_gpu else None, time_keys=[ 'preprocess_time', 'inference_time', 'postprocess_time' ], diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index bb4a31706471b9b1745519ac9f390d01b60d5d44..97dfa5214628123d0c9b7edd7d94060a2bfd2a1e 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -68,6 +68,7 @@ class TextRecognizer(object): if args.benchmark: import auto_log pid = os.getpid() + gpu_id = utility.get_infer_gpuid() self.autolog = auto_log.AutoLogger( model_name="rec", model_precision=args.precision, @@ -77,7 +78,7 @@ class TextRecognizer(object): inference_config=self.config, pids=pid, process_name=None, - gpu_ids=0 if args.use_gpu else None, + gpu_ids=gpu_id if args.use_gpu else None, time_keys=[ 'preprocess_time', 'inference_time', 'postprocess_time' ], diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 1c82280099f17f6d3bf848669e47439505f10576..527a811d620efac33ece9cdbd4b6196e18a8497d 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -280,6 +280,20 @@ def create_predictor(args, mode, logger): return predictor, input_tensor, output_tensors, config +def get_infer_gpuid(): + cmd = "nvidia-smi" + res = os.popen(cmd).readlines() + if len(res) == 0: + return None + cmd = "env | grep CUDA_VISIBLE_DEVICES" + env_cuda = os.popen(cmd).readlines() + if len(env_cuda) == 0: + return 0 + else: + gpu_id = env_cuda[0].strip().split("=")[1] + return int(gpu_id[0]) + + def draw_e2e_res(dt_boxes, strs, img_path): src_im = cv2.imread(img_path) for box, str in zip(dt_boxes, strs):