提交 9b582337 编写于 作者: L LiuHao 提交者: pkpk

add cuda_check (#2761)

更新文件readme,run_classifier.py, run_ernie_classifier.py,新增支持cuda_check部分代码
上级 06dd8fe8
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
python版本依赖python 2.7 python版本依赖python 2.7
注意:该模型同时支持cpu和gpu训练和预测,用户可以根据自身需求,选择安装对应的paddlepaddle-gpu或paddlepaddle版本。
#### 安装代码 #### 安装代码
克隆数据集代码库到本地 克隆数据集代码库到本地
......
...@@ -13,12 +13,14 @@ import numpy as np ...@@ -13,12 +13,14 @@ import numpy as np
import multiprocessing import multiprocessing
import sys import sys
sys.path.append("../models/classification/") sys.path.append("../models/classification/")
sys.path.append("../")
from nets import bow_net from nets import bow_net
from nets import lstm_net from nets import lstm_net
from nets import cnn_net from nets import cnn_net
from nets import bilstm_net from nets import bilstm_net
from nets import gru_net from nets import gru_net
from models.model_check import check_cuda
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -370,4 +372,5 @@ def main(args): ...@@ -370,4 +372,5 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda)
main(args) main(args)
...@@ -31,6 +31,7 @@ from preprocess.ernie import task_reader ...@@ -31,6 +31,7 @@ from preprocess.ernie import task_reader
from models.representation.ernie import ErnieConfig from models.representation.ernie import ErnieConfig
from models.representation.ernie import ernie_encoder, ernie_encoder_with_paddle_hub from models.representation.ernie import ernie_encoder, ernie_encoder_with_paddle_hub
from models.representation.ernie import ernie_pyreader from models.representation.ernie import ernie_pyreader
from models.model_check import check_cuda
from utils import ArgumentGroup from utils import ArgumentGroup
from utils import print_arguments from utils import print_arguments
from utils import init_checkpoint from utils import init_checkpoint
...@@ -425,4 +426,5 @@ def main(args): ...@@ -425,4 +426,5 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
print_arguments(args) print_arguments(args)
check_cuda(args.use_cuda)
main(args) main(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册