未验证 提交 7c5c9be0 编写于 作者: Z zhongpu 提交者: GitHub

fix ptb_lm for support cpu training (#4234)

* fix ptb_lm for support cpu, test=develop

* fix bug for arg_parse, test=develop

* update readme.me, test=develop
上级 8e4eebfc
......@@ -12,20 +12,22 @@
## 2. 效果说明
在small meidum large三个不同配置情况的ppl对比:
| small config | train | valid | test |
| :------------- | :---------: | :--------: | :----------: |
| paddle静态图 | 40.962 | 118.111 | 112.617 |
| paddle动态图 | | | |
| medium config | train | valid | test |
| :------------- | :---------: | :--------: | :----------: |
| paddle静态图 | 45.620 | 87.398 | 83.682 |
| paddle动态图 | | | |
| large config | train | valid | test |
| :------------- | :---------: | :--------: | :----------: |
| paddle静态图 | 37.221 | 82.358 | 78.137 |
| paddle动态图 | | | |
单卡V100,CUDA10 cudnn7,Python 3.7,CentOS release 6.3
| small config | train | valid | test | 训练速度 |
| :------------- | :---------: | :--------: | :----------: | :------------: |
| paddle静态图 | 40.962 | 118.111 | 112.617 | 41s/epoch |
| paddle动态图 | 40.566 | 119.541 | 115.300 | 93s/epoch |
| medium config | train | valid | test | 训练速度 |
| :------------- | :---------: | :--------: | :----------: | :------------: |
| paddle静态图 | 45.620 | 87.398 | 83.682 | 53s/epoch |
| paddle动态图 | 45.738 | 87.428 | 83.810 | 104s/epoch |
| large config | train | valid | test | 训练速度 |
| :------------- | :---------: | :--------: | :----------: | :------------: |
| paddle静态图 | 37.221 | 82.358 | 78.137 | 77s/epoch |
| paddle动态图 | 37.468 | 82.273 | 78.912 | 145s/epoch |
## 3. 数据集
......
......@@ -19,6 +19,13 @@ from __future__ import print_function
import argparse
import distutils.util
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
......@@ -36,7 +43,7 @@ def parse_args():
"--data_path", type=str, help="all the data for train,valid,test")
parser.add_argument('--para_init', action='store_true')
parser.add_argument(
'--use_gpu', type=bool, default=False, help='whether using gpu')
'--use_gpu', type=str2bool, default=True, help='whether using gpu')
parser.add_argument(
'--log_path',
help='path of the log file. If not set, logs are printed to console')
......
......@@ -215,6 +215,11 @@ def train_ptb_lm():
# check if set use_gpu=True in paddlepaddle cpu version
model_check.check_cuda(args.use_gpu)
place = core.CPUPlace()
if args.use_gpu == True:
place = core.CUDAPlace(0)
# check if paddlepaddle version is satisfied
model_check.check_version()
......@@ -273,7 +278,7 @@ def train_ptb_lm():
print("model type not support")
return
with fluid.dygraph.guard(core.CUDAPlace(0)):
with fluid.dygraph.guard(place):
if args.ce:
print("ce mode")
seed = 33
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册