diff --git a/demo/text-classification/predict.py b/demo/text-classification/predict.py index 71fe91cc0d9e0c7f2504a0704bdd636bb0647ea3..8f69f46eac5d1815a3f6fb78329b04b8d2555de5 100644 --- a/demo/text-classification/predict.py +++ b/demo/text-classification/predict.py @@ -17,10 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import time import argparse +import ast import numpy as np +import os +import time import paddle import paddle.fluid as fluid @@ -30,6 +31,7 @@ import paddlehub as hub parser = argparse.ArgumentParser(__doc__) parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint") parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.") +parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False") args = parser.parse_args() # yapf: enable. @@ -46,7 +48,7 @@ if __name__ == '__main__': vocab_path=module.get_vocab_path(), max_seq_len=args.max_seq_len) - place = fluid.CUDAPlace(0) + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) with fluid.program_guard(program): label = fluid.layers.data(name="label", shape=[1], dtype='int64') diff --git a/demo/text-classification/run_predict.sh b/demo/text-classification/run_predict.sh index 1fe8680f44fea98a743d74507799abaea91757ca..90a6ddfb8d7dfd439ef76982714425924a4dcd33 100644 --- a/demo/text-classification/run_predict.sh +++ b/demo/text-classification/run_predict.sh @@ -1,4 +1,4 @@ export CUDA_VISIBLE_DEVICES=0 -CKPT_DIR="./ckpt_20190414203357/best_model" -python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 +CKPT_DIR="./ckpt_chnsenticorp/best_model" +python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 --use_gpu False