diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 21bbee098ef19456d05165969a9ad400400f1264..33ed62125c0b59b5f23b72b5b8f6ecb3b0835cf3 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -313,6 +313,11 @@ def create_predictor(args, mode, logger): def get_infer_gpuid(): + if os.name == 'nt': + try: + return int(os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0]) + except KeyError: + return 0 if not paddle.fluid.core.is_compiled_with_rocm(): cmd = "env | grep CUDA_VISIBLE_DEVICES" else: