未验证 提交 755a59e1 编写于 作者: Z Zeyu Chen 提交者: GitHub

Merge pull request #17 from Steffy-zxf/fix-bug-use-gpu-demo-text-classification-prediction

Fix the bug that user couldn't choose to use gpu whether or not
...@@ -17,10 +17,11 @@ from __future__ import absolute_import ...@@ -17,10 +17,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import time
import argparse import argparse
import ast
import numpy as np import numpy as np
import os
import time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -30,6 +31,7 @@ import paddlehub as hub ...@@ -30,6 +31,7 @@ import paddlehub as hub
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint") parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.") parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable. # yapf: enable.
...@@ -46,7 +48,7 @@ if __name__ == '__main__': ...@@ -46,7 +48,7 @@ if __name__ == '__main__':
vocab_path=module.get_vocab_path(), vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len) max_seq_len=args.max_seq_len)
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
with fluid.program_guard(program): with fluid.program_guard(program):
label = fluid.layers.data(name="label", shape=[1], dtype='int64') label = fluid.layers.data(name="label", shape=[1], dtype='int64')
......
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_20190414203357/best_model" CKPT_DIR="./ckpt_chnsenticorp/best_model"
python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 --use_gpu False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册