未验证 提交 e7a63d72 编写于 作者: Y Yibing Liu 提交者: GitHub

Add paddle version check (#4399)

上级 d7bdc908
......@@ -72,7 +72,7 @@
```
## 安装
本项目依赖于 Paddle Fluid **1.6.0** 及以上版本,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装。如果需要进行 TensorFlow 模型到 Paddle Fluid 参数的转换,则需要同时安装 TensorFlow 1.12。
本项目依赖于 Paddle Fluid **1.7.1** 及以上版本,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装。如果需要进行 TensorFlow 模型到 Paddle Fluid 参数的转换,则需要同时安装 TensorFlow 1.12。
## 预训练
......
......@@ -38,7 +38,7 @@ import reader.cls as reader
from model.bert import BertConfig
from model.classifier import create_model
from optimization import optimization
from utils.args import ArgumentGroup, print_arguments, check_cuda
from utils.args import ArgumentGroup, print_arguments, check_cuda, check_version
from utils.init import init_pretraining_params, init_checkpoint
from utils.cards import get_cards
import dist_utils
......@@ -447,4 +447,5 @@ def main(args):
if __name__ == '__main__':
print_arguments(args)
check_cuda(args.use_cuda)
check_version()
main(args)
......@@ -34,7 +34,7 @@ import paddle.fluid as fluid
from reader.squad import DataProcessor, write_predictions
from model.bert import BertConfig, BertModel
from utils.args import ArgumentGroup, print_arguments, check_cuda
from utils.args import ArgumentGroup, print_arguments, check_cuda, check_version
from optimization import optimization
from utils.init import init_pretraining_params, init_checkpoint
......@@ -424,4 +424,5 @@ def train(args):
if __name__ == '__main__':
print_arguments(args)
check_cuda(args.use_cuda)
check_version()
train(args)
......@@ -35,7 +35,7 @@ import paddle.fluid as fluid
from reader.pretraining import DataReader
from model.bert import BertModel, BertConfig
from optimization import optimization
from utils.args import ArgumentGroup, print_arguments, check_cuda
from utils.args import ArgumentGroup, print_arguments, check_cuda, check_version
from utils.init import init_checkpoint, init_pretraining_params
# yapf: disable
......@@ -433,6 +433,7 @@ def train(args):
if __name__ == '__main__':
print_arguments(args)
check_cuda(args.use_cuda)
check_version()
if args.do_test:
test(args)
else:
......
......@@ -59,3 +59,19 @@ def check_cuda(use_cuda, err = \
sys.exit(1)
except Exception as e:
pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.7.1 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.7.1')
except Exception as e:
print(err)
sys.exit(1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册