提交 635af633 编写于 作者: Z Zeyu Chen

simplify text-classification

上级 69d6f05b
export CUDA_VISIBLE_DEVICES=5 export CUDA_VISIBLE_DEVICES=5
# User can select senticorp, nlpcc_dbqa, lcqmc for different task # User can select chnsenticorp, nlpcc_dbqa, lcqmc for different task
DATASET="senticorp" DATASET="chnsenticorp"
CKPT_DIR="./ckpt_${DATASET}" CKPT_DIR="./ckpt_${DATASET}"
# Recommending hyper parameters for difference task # Recommending hyper parameters for difference task
# ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5 # ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5
......
export CUDA_VISIBLE_DEVICES=5 export CUDA_VISIBLE_DEVICES=5
CKPT_DIR="./ckpt_sentiment_cls/best_model" CKPT_DIR="./ckpt_chnsenticorp/best_model"
python -u cls_predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128
...@@ -43,7 +43,7 @@ if __name__ == '__main__': ...@@ -43,7 +43,7 @@ if __name__ == '__main__':
# Step2: Download dataset and use ClassifyReader to read dataset # Step2: Download dataset and use ClassifyReader to read dataset
dataset = None dataset = None
if args.dataset.lower() == "senticorp": if args.dataset.lower() == "chnsenticorp":
dataset = hub.dataset.ChnSentiCorp() dataset = hub.dataset.ChnSentiCorp()
elif args.dataset.lower() == "nlpcc_dbqa": elif args.dataset.lower() == "nlpcc_dbqa":
dataset = hub.dataset.NLPCC_DBQA() dataset = hub.dataset.NLPCC_DBQA()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册