diff --git a/demo/text-classification/cls_predict.py b/demo/text-classification/predict.py similarity index 100% rename from demo/text-classification/cls_predict.py rename to demo/text-classification/predict.py diff --git a/demo/text-classification/run_classifier.sh b/demo/text-classification/run_classifier.sh index e666aac729c4bfa0bb4bd78ba76401faf86871b1..9aaa61a5de0146926c27779c0cece85ab3eab0d3 100644 --- a/demo/text-classification/run_classifier.sh +++ b/demo/text-classification/run_classifier.sh @@ -1,7 +1,7 @@ export CUDA_VISIBLE_DEVICES=5 -# User can select senticorp, nlpcc_dbqa, lcqmc for different task -DATASET="senticorp" +# User can select chnsenticorp, nlpcc_dbqa, lcqmc for different task +DATASET="chnsenticorp" CKPT_DIR="./ckpt_${DATASET}" # Recommending hyper parameters for difference task # ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5 diff --git a/demo/text-classification/run_predict.sh b/demo/text-classification/run_predict.sh index d192c3400c95144248b66250c45a7ddd45b0206f..57b522ab54247dc21711e561b35532229baf17f7 100644 --- a/demo/text-classification/run_predict.sh +++ b/demo/text-classification/run_predict.sh @@ -1,4 +1,4 @@ export CUDA_VISIBLE_DEVICES=5 -CKPT_DIR="./ckpt_sentiment_cls/best_model" -python -u cls_predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 +CKPT_DIR="./ckpt_chnsenticorp/best_model" +python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 diff --git a/demo/text-classification/text_classifier.py b/demo/text-classification/text_classifier.py index c13bfc4eb751be337319ede5682771a80ee04efb..2dd7e8958c46f0e564900b513f9c26c0c7bacf48 100644 --- a/demo/text-classification/text_classifier.py +++ b/demo/text-classification/text_classifier.py @@ -43,7 +43,7 @@ if __name__ == '__main__': # Step2: Download dataset and use ClassifyReader to read dataset dataset = None - if args.dataset.lower() == "senticorp": + if args.dataset.lower() == "chnsenticorp": dataset = hub.dataset.ChnSentiCorp() elif args.dataset.lower() == "nlpcc_dbqa": dataset = hub.dataset.NLPCC_DBQA()