未验证 提交 c0c669ab 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #193 from PaddlePaddle/model_check_for_bert

Add cuda check for bert
...@@ -30,7 +30,7 @@ import reader.cls as reader ...@@ -30,7 +30,7 @@ import reader.cls as reader
from model.bert import BertConfig from model.bert import BertConfig
from model.classifier import create_model from model.classifier import create_model
from optimization import optimization from optimization import optimization
from utils.args import ArgumentGroup, print_arguments from utils.args import ArgumentGroup, print_arguments, check_cuda
from utils.init import init_pretraining_params, init_checkpoint from utils.init import init_pretraining_params, init_checkpoint
import dist_utils import dist_utils
...@@ -415,4 +415,5 @@ def main(args): ...@@ -415,4 +415,5 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda)
main(args) main(args)
...@@ -28,7 +28,7 @@ import paddle.fluid as fluid ...@@ -28,7 +28,7 @@ import paddle.fluid as fluid
from reader.squad import DataProcessor, write_predictions from reader.squad import DataProcessor, write_predictions
from model.bert import BertConfig, BertModel from model.bert import BertConfig, BertModel
from utils.args import ArgumentGroup, print_arguments from utils.args import ArgumentGroup, print_arguments, check_cuda
from optimization import optimization from optimization import optimization
from utils.init import init_pretraining_params, init_checkpoint from utils.init import init_pretraining_params, init_checkpoint
...@@ -424,4 +424,5 @@ def train(args): ...@@ -424,4 +424,5 @@ def train(args):
if __name__ == '__main__': if __name__ == '__main__':
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda)
train(args) train(args)
...@@ -29,7 +29,7 @@ import paddle.fluid as fluid ...@@ -29,7 +29,7 @@ import paddle.fluid as fluid
from reader.pretraining import DataReader from reader.pretraining import DataReader
from model.bert import BertModel, BertConfig from model.bert import BertModel, BertConfig
from optimization import optimization from optimization import optimization
from utils.args import ArgumentGroup, print_arguments from utils.args import ArgumentGroup, print_arguments, check_cuda
from utils.init import init_checkpoint, init_pretraining_params from utils.init import init_checkpoint, init_pretraining_params
# yapf: disable # yapf: disable
...@@ -418,6 +418,7 @@ def train(args): ...@@ -418,6 +418,7 @@ def train(args):
if __name__ == '__main__': if __name__ == '__main__':
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda)
if args.do_test: if args.do_test:
test(args) test(args)
else: else:
......
...@@ -46,3 +46,14 @@ def print_arguments(args): ...@@ -46,3 +46,14 @@ def print_arguments(args):
for arg, value in sorted(six.iteritems(vars(args))): for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value)) print('%s: %s' % (arg, value))
print('------------------------------------------------') print('------------------------------------------------')
def check_cuda(use_cuda, err = \
"\nYou can not set use_cuda = True in the model because you are using paddlepaddle-cpu.\n \
Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda = False to run models on CPU.\n"
):
try:
if use_cuda == True and fluid.is_compiled_with_cuda() == False:
print(err)
sys.exit(1)
except Exception as e:
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册