From 635af63307719da4522ae62ee5dde1d489033e81 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Sun, 14 Apr 2019 19:19:36 +0800 Subject: [PATCH] simplify text-classification --- demo/text-classification/{cls_predict.py => predict.py} | 0 demo/text-classification/run_classifier.sh | 4 ++-- demo/text-classification/run_predict.sh | 4 ++-- demo/text-classification/text_classifier.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) rename demo/text-classification/{cls_predict.py => predict.py} (100%) diff --git a/demo/text-classification/cls_predict.py b/demo/text-classification/predict.py similarity index 100% rename from demo/text-classification/cls_predict.py rename to demo/text-classification/predict.py diff --git a/demo/text-classification/run_classifier.sh b/demo/text-classification/run_classifier.sh index e666aac7..9aaa61a5 100644 --- a/demo/text-classification/run_classifier.sh +++ b/demo/text-classification/run_classifier.sh @@ -1,7 +1,7 @@ export CUDA_VISIBLE_DEVICES=5 -# User can select senticorp, nlpcc_dbqa, lcqmc for different task -DATASET="senticorp" +# User can select chnsenticorp, nlpcc_dbqa, lcqmc for different task +DATASET="chnsenticorp" CKPT_DIR="./ckpt_${DATASET}" # Recommending hyper parameters for difference task # ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5 diff --git a/demo/text-classification/run_predict.sh b/demo/text-classification/run_predict.sh index d192c340..57b522ab 100644 --- a/demo/text-classification/run_predict.sh +++ b/demo/text-classification/run_predict.sh @@ -1,4 +1,4 @@ export CUDA_VISIBLE_DEVICES=5 -CKPT_DIR="./ckpt_sentiment_cls/best_model" -python -u cls_predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 +CKPT_DIR="./ckpt_chnsenticorp/best_model" +python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 diff --git a/demo/text-classification/text_classifier.py b/demo/text-classification/text_classifier.py index c13bfc4e..2dd7e895 100644 --- a/demo/text-classification/text_classifier.py +++ b/demo/text-classification/text_classifier.py @@ -43,7 +43,7 @@ if __name__ == '__main__': # Step2: Download dataset and use ClassifyReader to read dataset dataset = None - if args.dataset.lower() == "senticorp": + if args.dataset.lower() == "chnsenticorp": dataset = hub.dataset.ChnSentiCorp() elif args.dataset.lower() == "nlpcc_dbqa": dataset = hub.dataset.NLPCC_DBQA() -- GitLab