From 7c5c9be0402d4f1f873102b0365aaecaab5e6188 Mon Sep 17 00:00:00 2001 From: zhongpu <2013000149@qq.com> Date: Tue, 4 Feb 2020 18:17:10 +0800 Subject: [PATCH] fix ptb_lm for support cpu training (#4234) * fix ptb_lm for support cpu, test=develop * fix bug for arg_parse, test=develop * update readme.me, test=develop --- dygraph/ptb_lm/README.md | 30 ++++++++++++++++-------------- dygraph/ptb_lm/args.py | 9 ++++++++- dygraph/ptb_lm/ptb_dy.py | 7 ++++++- 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/dygraph/ptb_lm/README.md b/dygraph/ptb_lm/README.md index 3e13af5c..cd3f06d2 100644 --- a/dygraph/ptb_lm/README.md +++ b/dygraph/ptb_lm/README.md @@ -12,20 +12,22 @@ ## 2. 效果说明 在small meidum large三个不同配置情况的ppl对比: -| small config | train | valid | test | -| :------------- | :---------: | :--------: | :----------: | -| paddle静态图 | 40.962 | 118.111 | 112.617 | -| paddle动态图 | | | | - -| medium config | train | valid | test | -| :------------- | :---------: | :--------: | :----------: | -| paddle静态图 | 45.620 | 87.398 | 83.682 | -| paddle动态图 | | | | - -| large config | train | valid | test | -| :------------- | :---------: | :--------: | :----------: | -| paddle静态图 | 37.221 | 82.358 | 78.137 | -| paddle动态图 | | | | +单卡V100,CUDA10 cudnn7,Python 3.7,CentOS release 6.3 + +| small config | train | valid | test | 训练速度 | +| :------------- | :---------: | :--------: | :----------: | :------------: | +| paddle静态图 | 40.962 | 118.111 | 112.617 | 41s/epoch | +| paddle动态图 | 40.566 | 119.541 | 115.300 | 93s/epoch | + +| medium config | train | valid | test | 训练速度 | +| :------------- | :---------: | :--------: | :----------: | :------------: | +| paddle静态图 | 45.620 | 87.398 | 83.682 | 53s/epoch | +| paddle动态图 | 45.738 | 87.428 | 83.810 | 104s/epoch | + +| large config | train | valid | test | 训练速度 | +| :------------- | :---------: | :--------: | :----------: | :------------: | +| paddle静态图 | 37.221 | 82.358 | 78.137 | 77s/epoch | +| paddle动态图 | 37.468 | 82.273 | 78.912 | 145s/epoch | ## 3. 数据集 diff --git a/dygraph/ptb_lm/args.py b/dygraph/ptb_lm/args.py index ad33ea1a..6449b274 100644 --- a/dygraph/ptb_lm/args.py +++ b/dygraph/ptb_lm/args.py @@ -19,6 +19,13 @@ from __future__ import print_function import argparse import distutils.util +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Unsupported value encountered.') def parse_args(): parser = argparse.ArgumentParser(description=__doc__) @@ -36,7 +43,7 @@ def parse_args(): "--data_path", type=str, help="all the data for train,valid,test") parser.add_argument('--para_init', action='store_true') parser.add_argument( - '--use_gpu', type=bool, default=False, help='whether using gpu') + '--use_gpu', type=str2bool, default=True, help='whether using gpu') parser.add_argument( '--log_path', help='path of the log file. If not set, logs are printed to console') diff --git a/dygraph/ptb_lm/ptb_dy.py b/dygraph/ptb_lm/ptb_dy.py index befa4ad6..86411a02 100644 --- a/dygraph/ptb_lm/ptb_dy.py +++ b/dygraph/ptb_lm/ptb_dy.py @@ -215,6 +215,11 @@ def train_ptb_lm(): # check if set use_gpu=True in paddlepaddle cpu version model_check.check_cuda(args.use_gpu) + + place = core.CPUPlace() + if args.use_gpu == True: + place = core.CUDAPlace(0) + # check if paddlepaddle version is satisfied model_check.check_version() @@ -273,7 +278,7 @@ def train_ptb_lm(): print("model type not support") return - with fluid.dygraph.guard(core.CUDAPlace(0)): + with fluid.dygraph.guard(place): if args.ce: print("ce mode") seed = 33 -- GitLab