Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
ERNIE
提交
5d94fd1b
E
ERNIE
项目概览
PaddlePaddle
/
ERNIE
大约 1 年 前同步成功
通知
109
Star
5997
Fork
1270
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
29
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
E
ERNIE
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
29
Issue
29
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5d94fd1b
编写于
7月 06, 2019
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add cuda check for bert
上级
67edb3e1
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
18 addition
and
4 deletion
+18
-4
BERT/run_classifier.py
BERT/run_classifier.py
+3
-2
BERT/run_squad.py
BERT/run_squad.py
+2
-1
BERT/train.py
BERT/train.py
+2
-1
BERT/utils/args.py
BERT/utils/args.py
+11
-0
未找到文件。
BERT/run_classifier.py
浏览文件 @
5d94fd1b
...
...
@@ -30,7 +30,7 @@ import reader.cls as reader
from
model.bert
import
BertConfig
from
model.classifier
import
create_model
from
optimization
import
optimization
from
utils.args
import
ArgumentGroup
,
print_arguments
from
utils.args
import
ArgumentGroup
,
print_arguments
,
check_cuda
from
utils.init
import
init_pretraining_params
,
init_checkpoint
import
dist_utils
...
...
@@ -281,7 +281,7 @@ def main(args):
exec_strategy
=
exec_strategy
,
build_strategy
=
build_strategy
,
main_program
=
train_program
)
train_pyreader
.
decorate_tensor_provider
(
train_data_generator
)
else
:
train_exe
=
None
...
...
@@ -415,4 +415,5 @@ def main(args):
if
__name__
==
'__main__'
:
print_arguments
(
args
)
check_cuda
(
args
.
use_cuda
)
main
(
args
)
BERT/run_squad.py
浏览文件 @
5d94fd1b
...
...
@@ -28,7 +28,7 @@ import paddle.fluid as fluid
from
reader.squad
import
DataProcessor
,
write_predictions
from
model.bert
import
BertConfig
,
BertModel
from
utils.args
import
ArgumentGroup
,
print_arguments
from
utils.args
import
ArgumentGroup
,
print_arguments
,
check_cuda
from
optimization
import
optimization
from
utils.init
import
init_pretraining_params
,
init_checkpoint
...
...
@@ -424,4 +424,5 @@ def train(args):
if
__name__
==
'__main__'
:
print_arguments
(
args
)
check_cuda
(
args
.
use_cuda
)
train
(
args
)
BERT/train.py
浏览文件 @
5d94fd1b
...
...
@@ -29,7 +29,7 @@ import paddle.fluid as fluid
from
reader.pretraining
import
DataReader
from
model.bert
import
BertModel
,
BertConfig
from
optimization
import
optimization
from
utils.args
import
ArgumentGroup
,
print_arguments
from
utils.args
import
ArgumentGroup
,
print_arguments
,
check_cuda
from
utils.init
import
init_checkpoint
,
init_pretraining_params
# yapf: disable
...
...
@@ -418,6 +418,7 @@ def train(args):
if
__name__
==
'__main__'
:
print_arguments
(
args
)
check_cuda
(
args
.
use_cuda
)
if
args
.
do_test
:
test
(
args
)
else
:
...
...
BERT/utils/args.py
浏览文件 @
5d94fd1b
...
...
@@ -46,3 +46,14 @@ def print_arguments(args):
for
arg
,
value
in
sorted
(
six
.
iteritems
(
vars
(
args
))):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
def
check_cuda
(
use_cuda
,
err
=
\
"
\n
You can not set use_cuda = True in the model because you are using paddlepaddle-cpu.
\n
\
Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda = False to run models on CPU.
\n
"
):
try
:
if
use_cuda
==
True
and
fluid
.
is_compiled_with_cuda
()
==
False
:
print
(
err
)
sys
.
exit
(
1
)
except
Exception
as
e
:
pass
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录