diff --git a/ppcls/utils/config.py b/ppcls/utils/config.py index 25a918afe0e3fcf55dce846be0199a349dfd3bd0..3bde9dfbfe525070efa56c3b274e82fbb4004a6e 100644 --- a/ppcls/utils/config.py +++ b/ppcls/utils/config.py @@ -98,8 +98,9 @@ def check_config(config): """ check.check_version() - mode = config.get('mode', 'train') - check.check_gpu() + use_gpu = config.get('use_gpu', True) + if use_gpu: + check.check_gpu() architecture = config.get('ARCHITECTURE') check.check_architecture(architecture) @@ -110,6 +111,7 @@ def check_config(config): classes_num = config.get('classes_num') check.check_classes_num(classes_num) + mode = config.get('mode', 'train') if mode.lower() == 'train': check.check_function_params(config, 'LEARNING_RATE') check.check_function_params(config, 'OPTIMIZER')