提交 5d94fd1b 编写于 作者: Y Yibing Liu

Add cuda check for bert

上级 67edb3e1
...@@ -30,7 +30,7 @@ import reader.cls as reader ...@@ -30,7 +30,7 @@ import reader.cls as reader
from model.bert import BertConfig from model.bert import BertConfig
from model.classifier import create_model from model.classifier import create_model
from optimization import optimization from optimization import optimization
from utils.args import ArgumentGroup, print_arguments from utils.args import ArgumentGroup, print_arguments, check_cuda
from utils.init import init_pretraining_params, init_checkpoint from utils.init import init_pretraining_params, init_checkpoint
import dist_utils import dist_utils
...@@ -415,4 +415,5 @@ def main(args): ...@@ -415,4 +415,5 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda)
main(args) main(args)
...@@ -28,7 +28,7 @@ import paddle.fluid as fluid ...@@ -28,7 +28,7 @@ import paddle.fluid as fluid
from reader.squad import DataProcessor, write_predictions from reader.squad import DataProcessor, write_predictions
from model.bert import BertConfig, BertModel from model.bert import BertConfig, BertModel
from utils.args import ArgumentGroup, print_arguments from utils.args import ArgumentGroup, print_arguments, check_cuda
from optimization import optimization from optimization import optimization
from utils.init import init_pretraining_params, init_checkpoint from utils.init import init_pretraining_params, init_checkpoint
...@@ -424,4 +424,5 @@ def train(args): ...@@ -424,4 +424,5 @@ def train(args):
if __name__ == '__main__': if __name__ == '__main__':
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda)
train(args) train(args)
...@@ -29,7 +29,7 @@ import paddle.fluid as fluid ...@@ -29,7 +29,7 @@ import paddle.fluid as fluid
from reader.pretraining import DataReader from reader.pretraining import DataReader
from model.bert import BertModel, BertConfig from model.bert import BertModel, BertConfig
from optimization import optimization from optimization import optimization
from utils.args import ArgumentGroup, print_arguments from utils.args import ArgumentGroup, print_arguments, check_cuda
from utils.init import init_checkpoint, init_pretraining_params from utils.init import init_checkpoint, init_pretraining_params
# yapf: disable # yapf: disable
...@@ -418,6 +418,7 @@ def train(args): ...@@ -418,6 +418,7 @@ def train(args):
if __name__ == '__main__': if __name__ == '__main__':
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda)
if args.do_test: if args.do_test:
test(args) test(args)
else: else:
......
...@@ -46,3 +46,14 @@ def print_arguments(args): ...@@ -46,3 +46,14 @@ def print_arguments(args):
for arg, value in sorted(six.iteritems(vars(args))): for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value)) print('%s: %s' % (arg, value))
print('------------------------------------------------') print('------------------------------------------------')
def check_cuda(use_cuda, err = \
"\nYou can not set use_cuda = True in the model because you are using paddlepaddle-cpu.\n \
Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda = False to run models on CPU.\n"
):
try:
if use_cuda == True and fluid.is_compiled_with_cuda() == False:
print(err)
sys.exit(1)
except Exception as e:
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册