提交 cbf2e9d7 编写于 作者: T tianxin

add check_cuda for ERNIE

上级 c0c669ab
......@@ -20,6 +20,7 @@ from __future__ import print_function
import six
import argparse
import paddle.fluid as fluid
def str2bool(v):
# because argparse does not support to parse "true, False" as python
......
......@@ -27,7 +27,7 @@ import reader.task_reader as task_reader
from model.ernie import ErnieConfig
from finetune.classifier import create_model, evaluate
from optimization import optimization
from utils.args import print_arguments
from utils.args import print_arguments, check_cuda
from utils.init import init_pretraining_params, init_checkpoint
from finetune_args import parser
......@@ -272,5 +272,5 @@ def main(args):
if __name__ == '__main__':
print_arguments(args)
check_cuda(args.use_cuda)
main(args)
......@@ -27,7 +27,7 @@ import reader.task_reader as task_reader
from model.ernie import ErnieConfig
from optimization import optimization
from utils.init import init_pretraining_params, init_checkpoint
from utils.args import print_arguments
from utils.args import print_arguments, check_cuda
from finetune.sequence_label import create_model, evaluate
from finetune_args import parser
......@@ -280,4 +280,5 @@ def main(args):
if __name__ == '__main__':
print_arguments(args)
check_cuda(args.use_cuda)
main(args)
......@@ -27,7 +27,7 @@ import paddle.fluid as fluid
from reader.pretraining import ErnieDataReader
from model.ernie import ErnieModel, ErnieConfig
from optimization import optimization
from utils.args import print_arguments
from utils.args import print_arguments, check_cuda
from utils.init import init_checkpoint, init_pretraining_params
from pretrain_args import parser
......@@ -351,6 +351,7 @@ def train(args):
if __name__ == '__main__':
print_arguments(args)
check_cuda(args.use_cuda)
if args.do_test:
test(args)
else:
......
......@@ -20,6 +20,8 @@ from __future__ import print_function
import six
import argparse
import paddle.fluid as fluid
def str2bool(v):
# because argparse does not support to parse "true, False" as python
......@@ -46,3 +48,15 @@ def print_arguments(args):
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def check_cuda(use_cuda, err = \
"\nYou can not set use_cuda = True in the model because you are using paddlepaddle-cpu.\n \
Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda = False to run models on CPU.\n"
):
try:
if use_cuda == True and fluid.is_compiled_with_cuda() == False:
print(err)
sys.exit(1)
except Exception as e:
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册