提交 7dff4f85 编写于 作者: Z zhangxuefei

Fix the bug that user couldn't choose to use gpu whether or not

上级 42aa06ef
......@@ -17,10 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import argparse
import ast
import numpy as np
import os
import time
import paddle
import paddle.fluid as fluid
......@@ -30,6 +31,7 @@ import paddlehub as hub
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False")
args = parser.parse_args()
# yapf: enable.
......@@ -46,7 +48,7 @@ if __name__ == '__main__':
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
place = fluid.CUDAPlace(0)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
with fluid.program_guard(program):
label = fluid.layers.data(name="label", shape=[1], dtype='int64')
......
export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_20190414203357/best_model"
python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128
CKPT_DIR="./ckpt_chnsenticorp/best_model"
python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 --use_gpu False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册