提交 a3dbba0d 编写于 作者: Z Zeyu Chen

add inference program clone and update README.md

上级 dbb29416
......@@ -120,5 +120,5 @@ python -u cls_predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128
参数配置正确后,请执行脚本`sh run_predict.sh`,即可看到以下文本分类预测结果。如需了解更多预测步骤,请参考`cls_predict.py`
```
text=风扇确实够响的,尤其是到晚上周围安静下来。风扇频频开启,发热量有些惊人 label=0 predict=[0.99244046 0.00755955]
text=键盘缝隙大进灰,装系统自己不会装~~屏幕有点窄玩游戏人物有点变形 label=0 predict=0
```
......@@ -57,6 +57,11 @@ if __name__ == '__main__':
# Setup feed list for data feeder
# Must feed all the tensor of ERNIE's module need
feed_list = [
input_dict["input_ids"].name, input_dict["position_ids"].name,
input_dict["segment_ids"].name, input_dict["input_mask"].name,
label.name
]
# Define a classfication finetune task by PaddleHub's API
cls_task = hub.create_text_classification_task(
......@@ -65,19 +70,26 @@ if __name__ == '__main__':
# classificatin probability tensor
probs = cls_task.variable("probs")
pred = fluid.layers.argmax(probs, axis=1)
# load best model checkpoint
fluid.io.load_persistables(exe, args.checkpoint_dir)
feed_list = [
input_dict["input_ids"].name, input_dict["position_ids"].name,
input_dict["segment_ids"].name, input_dict["input_mask"].name,
label.name
]
inference_program = program.clone(for_test=True)
data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
test_reader = reader.data_generator(phase='test', shuffle=False)
test_examples = dataset.get_test_examples()
total = 0
correct = 0
for index, batch in enumerate(test_reader()):
probs_v = exe.run(
feed=data_feeder.feed(batch), fetch_list=[probs.name])
print("%s\tpredict=%s" % (test_examples[index], probs_v[0][0]))
pred_v = exe.run(
feed=data_feeder.feed(batch),
fetch_list=[pred.name],
program=inference_program)
total += 1
if (pred_v[0][0] == int(test_examples[index].label)):
correct += 1
acc = 1.0 * correct / total
print("%s\tpredict=%s" % (test_examples[index], pred_v[0][0]))
print("accuracy = %f" % acc)
export CUDA_VISIBLE_DEVICES=1
export CUDA_VISIBLE_DEVICES=5
CKPT_DIR="./ckpt_sentiment_cls/best_model"
python -u cls_predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册