提交 cbf2e9d7 编写于 作者: T tianxin

add check_cuda for ERNIE

上级 c0c669ab
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import six import six
import argparse import argparse
import paddle.fluid as fluid
def str2bool(v): def str2bool(v):
# because argparse does not support to parse "true, False" as python # because argparse does not support to parse "true, False" as python
......
...@@ -27,7 +27,7 @@ import reader.task_reader as task_reader ...@@ -27,7 +27,7 @@ import reader.task_reader as task_reader
from model.ernie import ErnieConfig from model.ernie import ErnieConfig
from finetune.classifier import create_model, evaluate from finetune.classifier import create_model, evaluate
from optimization import optimization from optimization import optimization
from utils.args import print_arguments from utils.args import print_arguments, check_cuda
from utils.init import init_pretraining_params, init_checkpoint from utils.init import init_pretraining_params, init_checkpoint
from finetune_args import parser from finetune_args import parser
...@@ -272,5 +272,5 @@ def main(args): ...@@ -272,5 +272,5 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda)
main(args) main(args)
...@@ -27,7 +27,7 @@ import reader.task_reader as task_reader ...@@ -27,7 +27,7 @@ import reader.task_reader as task_reader
from model.ernie import ErnieConfig from model.ernie import ErnieConfig
from optimization import optimization from optimization import optimization
from utils.init import init_pretraining_params, init_checkpoint from utils.init import init_pretraining_params, init_checkpoint
from utils.args import print_arguments from utils.args import print_arguments, check_cuda
from finetune.sequence_label import create_model, evaluate from finetune.sequence_label import create_model, evaluate
from finetune_args import parser from finetune_args import parser
...@@ -280,4 +280,5 @@ def main(args): ...@@ -280,4 +280,5 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda)
main(args) main(args)
...@@ -27,7 +27,7 @@ import paddle.fluid as fluid ...@@ -27,7 +27,7 @@ import paddle.fluid as fluid
from reader.pretraining import ErnieDataReader from reader.pretraining import ErnieDataReader
from model.ernie import ErnieModel, ErnieConfig from model.ernie import ErnieModel, ErnieConfig
from optimization import optimization from optimization import optimization
from utils.args import print_arguments from utils.args import print_arguments, check_cuda
from utils.init import init_checkpoint, init_pretraining_params from utils.init import init_checkpoint, init_pretraining_params
from pretrain_args import parser from pretrain_args import parser
...@@ -351,6 +351,7 @@ def train(args): ...@@ -351,6 +351,7 @@ def train(args):
if __name__ == '__main__': if __name__ == '__main__':
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda)
if args.do_test: if args.do_test:
test(args) test(args)
else: else:
......
...@@ -20,6 +20,8 @@ from __future__ import print_function ...@@ -20,6 +20,8 @@ from __future__ import print_function
import six import six
import argparse import argparse
import paddle.fluid as fluid
def str2bool(v): def str2bool(v):
# because argparse does not support to parse "true, False" as python # because argparse does not support to parse "true, False" as python
...@@ -46,3 +48,15 @@ def print_arguments(args): ...@@ -46,3 +48,15 @@ def print_arguments(args):
for arg, value in sorted(six.iteritems(vars(args))): for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value)) print('%s: %s' % (arg, value))
print('------------------------------------------------') print('------------------------------------------------')
def check_cuda(use_cuda, err = \
"\nYou can not set use_cuda = True in the model because you are using paddlepaddle-cpu.\n \
Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda = False to run models on CPU.\n"
):
try:
if use_cuda == True and fluid.is_compiled_with_cuda() == False:
print(err)
sys.exit(1)
except Exception as e:
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册