From a3dbba0df5f583fda640f8b8ad588d38284c8861 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Sun, 14 Apr 2019 12:17:35 +0800 Subject: [PATCH] add inference program clone and update README.md --- demo/ernie-classification/README.md | 2 +- demo/ernie-classification/cls_predict.py | 28 +++++++++++++++++------- demo/ernie-classification/run_predict.sh | 2 +- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/demo/ernie-classification/README.md b/demo/ernie-classification/README.md index 2527b588..24b6b62e 100644 --- a/demo/ernie-classification/README.md +++ b/demo/ernie-classification/README.md @@ -120,5 +120,5 @@ python -u cls_predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 参数配置正确后,请执行脚本`sh run_predict.sh`,即可看到以下文本分类预测结果。如需了解更多预测步骤,请参考`cls_predict.py` ``` -text=风扇确实够响的,尤其是到晚上周围安静下来。风扇频频开启,发热量有些惊人 label=0 predict=[0.99244046 0.00755955] +text=键盘缝隙大进灰,装系统自己不会装~~屏幕有点窄玩游戏人物有点变形 label=0 predict=0 ``` diff --git a/demo/ernie-classification/cls_predict.py b/demo/ernie-classification/cls_predict.py index 6da7889d..9c9fef01 100644 --- a/demo/ernie-classification/cls_predict.py +++ b/demo/ernie-classification/cls_predict.py @@ -57,6 +57,11 @@ if __name__ == '__main__': # Setup feed list for data feeder # Must feed all the tensor of ERNIE's module need + feed_list = [ + input_dict["input_ids"].name, input_dict["position_ids"].name, + input_dict["segment_ids"].name, input_dict["input_mask"].name, + label.name + ] # Define a classfication finetune task by PaddleHub's API cls_task = hub.create_text_classification_task( @@ -65,19 +70,26 @@ if __name__ == '__main__': # classificatin probability tensor probs = cls_task.variable("probs") + pred = fluid.layers.argmax(probs, axis=1) + # load best model checkpoint fluid.io.load_persistables(exe, args.checkpoint_dir) - feed_list = [ - input_dict["input_ids"].name, input_dict["position_ids"].name, - input_dict["segment_ids"].name, input_dict["input_mask"].name, - label.name - ] + inference_program = program.clone(for_test=True) data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place) test_reader = reader.data_generator(phase='test', shuffle=False) test_examples = dataset.get_test_examples() + total = 0 + correct = 0 for index, batch in enumerate(test_reader()): - probs_v = exe.run( - feed=data_feeder.feed(batch), fetch_list=[probs.name]) - print("%s\tpredict=%s" % (test_examples[index], probs_v[0][0])) + pred_v = exe.run( + feed=data_feeder.feed(batch), + fetch_list=[pred.name], + program=inference_program) + total += 1 + if (pred_v[0][0] == int(test_examples[index].label)): + correct += 1 + acc = 1.0 * correct / total + print("%s\tpredict=%s" % (test_examples[index], pred_v[0][0])) + print("accuracy = %f" % acc) diff --git a/demo/ernie-classification/run_predict.sh b/demo/ernie-classification/run_predict.sh index e3b46f49..d192c340 100644 --- a/demo/ernie-classification/run_predict.sh +++ b/demo/ernie-classification/run_predict.sh @@ -1,4 +1,4 @@ -export CUDA_VISIBLE_DEVICES=1 +export CUDA_VISIBLE_DEVICES=5 CKPT_DIR="./ckpt_sentiment_cls/best_model" python -u cls_predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 -- GitLab