diff --git a/demo/text_classification/README.md b/demo/text_classification/README.md index 2de5b98c06c9e97e0819ffb2e7f9be660d94e8d0..930989474a665b60af251141e28551f1aac83d68 100644 --- a/demo/text_classification/README.md +++ b/demo/text_classification/README.md @@ -162,9 +162,9 @@ model = hub.Module( task='seq-cls', load_checkpoint='./test_ernie_text_cls/best_model/model.pdparams', label_map=label_map) -results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False) +results, probs = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False, return_prob=True) for idx, text in enumerate(data): - print('Data: {} \t Lable: {}'.format(text[0], results[idx])) + print('Data: {} \t Lable: {} \t Prob: {}'.format(text[0], results[idx], probs[idx])) ``` 参数配置正确后,请执行脚本`python predict.py`, 加载模型具体可参见[加载](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc/api/paddle/framework/io/load_cn.html#load)。 diff --git a/demo/text_classification/predict.py b/demo/text_classification/predict.py index 48a5688bfe48bc4d3728d06c9cdc78281013b9d0..3b6facf9c7c731052c9ac64d7b6391d18d74f98e 100644 --- a/demo/text_classification/predict.py +++ b/demo/text_classification/predict.py @@ -30,4 +30,4 @@ if __name__ == '__main__': label_map=label_map) results, probs = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False, return_prob=True) for idx, text in enumerate(data): - print('Data: {} \t Lable: {}'.format(text[0], results[idx])) + print('Data: {} \t Lable: {} \t Prob: {}'.format(text[0], results[idx], probs[idx]))