From cbf2e9d7c38a54874aca3c0a564c6d82736f7306 Mon Sep 17 00:00:00 2001 From: tianxin Date: Tue, 9 Jul 2019 13:51:06 +0800 Subject: [PATCH] add check_cuda for ERNIE --- BERT/utils/args.py | 1 + ERNIE/run_classifier.py | 4 ++-- ERNIE/run_sequence_labeling.py | 3 ++- ERNIE/train.py | 3 ++- ERNIE/utils/args.py | 14 ++++++++++++++ 5 files changed, 21 insertions(+), 4 deletions(-) diff --git a/BERT/utils/args.py b/BERT/utils/args.py index 4bb20cc..c89c85d 100644 --- a/BERT/utils/args.py +++ b/BERT/utils/args.py @@ -20,6 +20,7 @@ from __future__ import print_function import six import argparse +import paddle.fluid as fluid def str2bool(v): # because argparse does not support to parse "true, False" as python diff --git a/ERNIE/run_classifier.py b/ERNIE/run_classifier.py index 409a690..61fda8f 100644 --- a/ERNIE/run_classifier.py +++ b/ERNIE/run_classifier.py @@ -27,7 +27,7 @@ import reader.task_reader as task_reader from model.ernie import ErnieConfig from finetune.classifier import create_model, evaluate from optimization import optimization -from utils.args import print_arguments +from utils.args import print_arguments, check_cuda from utils.init import init_pretraining_params, init_checkpoint from finetune_args import parser @@ -272,5 +272,5 @@ def main(args): if __name__ == '__main__': print_arguments(args) - + check_cuda(args.use_cuda) main(args) diff --git a/ERNIE/run_sequence_labeling.py b/ERNIE/run_sequence_labeling.py index feffb8f..1f752bc 100644 --- a/ERNIE/run_sequence_labeling.py +++ b/ERNIE/run_sequence_labeling.py @@ -27,7 +27,7 @@ import reader.task_reader as task_reader from model.ernie import ErnieConfig from optimization import optimization from utils.init import init_pretraining_params, init_checkpoint -from utils.args import print_arguments +from utils.args import print_arguments, check_cuda from finetune.sequence_label import create_model, evaluate from finetune_args import parser @@ -280,4 +280,5 @@ def main(args): if __name__ == '__main__': print_arguments(args) + check_cuda(args.use_cuda) main(args) diff --git a/ERNIE/train.py b/ERNIE/train.py index 4faec96..6cbd7bb 100644 --- a/ERNIE/train.py +++ b/ERNIE/train.py @@ -27,7 +27,7 @@ import paddle.fluid as fluid from reader.pretraining import ErnieDataReader from model.ernie import ErnieModel, ErnieConfig from optimization import optimization -from utils.args import print_arguments +from utils.args import print_arguments, check_cuda from utils.init import init_checkpoint, init_pretraining_params from pretrain_args import parser @@ -351,6 +351,7 @@ def train(args): if __name__ == '__main__': print_arguments(args) + check_cuda(args.use_cuda) if args.do_test: test(args) else: diff --git a/ERNIE/utils/args.py b/ERNIE/utils/args.py index b9be634..ebfb767 100644 --- a/ERNIE/utils/args.py +++ b/ERNIE/utils/args.py @@ -20,6 +20,8 @@ from __future__ import print_function import six import argparse +import paddle.fluid as fluid + def str2bool(v): # because argparse does not support to parse "true, False" as python @@ -46,3 +48,15 @@ def print_arguments(args): for arg, value in sorted(six.iteritems(vars(args))): print('%s: %s' % (arg, value)) print('------------------------------------------------') + + +def check_cuda(use_cuda, err = \ + "\nYou can not set use_cuda = True in the model because you are using paddlepaddle-cpu.\n \ + Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda = False to run models on CPU.\n" + ): + try: + if use_cuda == True and fluid.is_compiled_with_cuda() == False: + print(err) + sys.exit(1) + except Exception as e: + pass -- GitLab