From 5d94fd1b474818472f1e70a5446e248e78219eb2 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sat, 6 Jul 2019 06:54:32 +0000 Subject: [PATCH] Add cuda check for bert --- BERT/run_classifier.py | 5 +++-- BERT/run_squad.py | 3 ++- BERT/train.py | 3 ++- BERT/utils/args.py | 11 +++++++++++ 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/BERT/run_classifier.py b/BERT/run_classifier.py index 4c37b33..473a8d5 100644 --- a/BERT/run_classifier.py +++ b/BERT/run_classifier.py @@ -30,7 +30,7 @@ import reader.cls as reader from model.bert import BertConfig from model.classifier import create_model from optimization import optimization -from utils.args import ArgumentGroup, print_arguments +from utils.args import ArgumentGroup, print_arguments, check_cuda from utils.init import init_pretraining_params, init_checkpoint import dist_utils @@ -281,7 +281,7 @@ def main(args): exec_strategy=exec_strategy, build_strategy = build_strategy, main_program=train_program) - + train_pyreader.decorate_tensor_provider(train_data_generator) else: train_exe = None @@ -415,4 +415,5 @@ def main(args): if __name__ == '__main__': print_arguments(args) + check_cuda(args.use_cuda) main(args) diff --git a/BERT/run_squad.py b/BERT/run_squad.py index f93c7bf..3d4a23f 100644 --- a/BERT/run_squad.py +++ b/BERT/run_squad.py @@ -28,7 +28,7 @@ import paddle.fluid as fluid from reader.squad import DataProcessor, write_predictions from model.bert import BertConfig, BertModel -from utils.args import ArgumentGroup, print_arguments +from utils.args import ArgumentGroup, print_arguments, check_cuda from optimization import optimization from utils.init import init_pretraining_params, init_checkpoint @@ -424,4 +424,5 @@ def train(args): if __name__ == '__main__': print_arguments(args) + check_cuda(args.use_cuda) train(args) diff --git a/BERT/train.py b/BERT/train.py index aaf2a87..2d95568 100644 --- a/BERT/train.py +++ b/BERT/train.py @@ -29,7 +29,7 @@ import paddle.fluid as fluid from reader.pretraining import DataReader from model.bert import BertModel, BertConfig from optimization import optimization -from utils.args import ArgumentGroup, print_arguments +from utils.args import ArgumentGroup, print_arguments, check_cuda from utils.init import init_checkpoint, init_pretraining_params # yapf: disable @@ -418,6 +418,7 @@ def train(args): if __name__ == '__main__': print_arguments(args) + check_cuda(args.use_cuda) if args.do_test: test(args) else: diff --git a/BERT/utils/args.py b/BERT/utils/args.py index b9be634..4bb20cc 100644 --- a/BERT/utils/args.py +++ b/BERT/utils/args.py @@ -46,3 +46,14 @@ def print_arguments(args): for arg, value in sorted(six.iteritems(vars(args))): print('%s: %s' % (arg, value)) print('------------------------------------------------') + +def check_cuda(use_cuda, err = \ + "\nYou can not set use_cuda = True in the model because you are using paddlepaddle-cpu.\n \ + Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda = False to run models on CPU.\n" + ): + try: + if use_cuda == True and fluid.is_compiled_with_cuda() == False: + print(err) + sys.exit(1) + except Exception as e: + pass -- GitLab