From fa68672a0c65ff61c974b66d61b63b6b4a00afdc Mon Sep 17 00:00:00 2001 From: zhangxuefei Date: Fri, 2 Aug 2019 19:39:11 +0800 Subject: [PATCH] Update text cls demo to adapted to ernie v2 --- demo/text-classification/predict.py | 53 +++++++++++++++++---- demo/text-classification/run_classifier.sh | 1 + demo/text-classification/run_predict.sh | 2 +- demo/text-classification/text_classifier.py | 53 ++++++++++++++++----- paddlehub/reader/nlp_reader.py | 2 + 5 files changed, 89 insertions(+), 22 deletions(-) diff --git a/demo/text-classification/predict.py b/demo/text-classification/predict.py index ca05892e..bc54aff9 100644 --- a/demo/text-classification/predict.py +++ b/demo/text-classification/predict.py @@ -29,12 +29,14 @@ import paddlehub as hub # yapf: disable parser = argparse.ArgumentParser(__doc__) -parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint") +parser.add_argument("--checkpoint_dir", type=str, default="ckpt_20190802182531", help="Directory to model checkpoint") parser.add_argument("--batch_size", type=int, default=1, help="Total examples' number in batch for training.") parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.") parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False") parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.") -parser.add_argument("--dataset", type=str, default="chnsenticorp", help="Directory to dataset") +parser.add_argument("--dataset", type=str, default="chnsenticorp", help="The choice of dataset") +parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.") +parser.add_argument("--use_taskid", type=ast.literal_eval, default=False, help="Whether to use taskid ,if yes to use ernie v2.") args = parser.parse_args() # yapf: enable. @@ -52,25 +54,46 @@ if __name__ == '__main__': module = hub.Module(name="ernie") elif args.dataset.lower() == "mrpc": dataset = hub.dataset.GLUE("MRPC") - module = hub.Module(name="bert_uncased_L-12_H-768_A-12") + if args.use_taskid: + module = hub.Module(name="ernie_v2_eng_base") + else: + module = hub.Module(name="bert_uncased_L-12_H-768_A-12") elif args.dataset.lower() == "qqp": dataset = hub.dataset.GLUE("QQP") - module = hub.Module(name="bert_uncased_L-12_H-768_A-12") + if args.use_taskid: + module = hub.Module(name="ernie_v2_eng_base") + else: + module = hub.Module(name="bert_uncased_L-12_H-768_A-12") elif args.dataset.lower() == "sst-2": dataset = hub.dataset.GLUE("SST-2") - module = hub.Module(name="bert_uncased_L-12_H-768_A-12") + if args.use_taskid: + module = hub.Module(name="ernie_v2_eng_base") + else: + module = hub.Module(name="bert_uncased_L-12_H-768_A-12") elif args.dataset.lower() == "cola": dataset = hub.dataset.GLUE("CoLA") - module = hub.Module(name="bert_uncased_L-12_H-768_A-12") + if args.use_taskid: + module = hub.Module(name="ernie_v2_eng_base") + else: + module = hub.Module(name="bert_uncased_L-12_H-768_A-12") elif args.dataset.lower() == "qnli": dataset = hub.dataset.GLUE("QNLI") - module = hub.Module(name="bert_uncased_L-12_H-768_A-12") + if args.use_taskid: + module = hub.Module(name="ernie_v2_eng_base") + else: + module = hub.Module(name="bert_uncased_L-12_H-768_A-12") elif args.dataset.lower() == "rte": dataset = hub.dataset.GLUE("RTE") - module = hub.Module(name="bert_uncased_L-12_H-768_A-12") + if args.use_taskid: + module = hub.Module(name="ernie_v2_eng_base") + else: + module = hub.Module(name="bert_uncased_L-12_H-768_A-12") elif args.dataset.lower() == "mnli": dataset = hub.dataset.GLUE("MNLI") - module = hub.Module(name="bert_uncased_L-12_H-768_A-12") + if args.use_taskid: + module = hub.Module(name="ernie_v2_eng_base") + else: + module = hub.Module(name="bert_uncased_L-12_H-768_A-12") elif args.dataset.lower().startswith("xnli"): dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:]) module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12") @@ -82,7 +105,8 @@ if __name__ == '__main__': reader = hub.reader.ClassifyReader( dataset=dataset, vocab_path=module.get_vocab_path(), - max_seq_len=args.max_seq_len) + max_seq_len=args.max_seq_len, + use_task_id=args.use_taskid) # Construct transfer learning network # Use "pooled_output" for classification tasks on an entire sentence. @@ -98,6 +122,15 @@ if __name__ == '__main__': inputs["input_mask"].name, ] + if args.use_taskid: + feed_list = [ + inputs["input_ids"].name, + inputs["position_ids"].name, + inputs["segment_ids"].name, + inputs["input_mask"].name, + inputs["task_ids"].name, + ] + # Setup runing config for PaddleHub Finetune API config = hub.RunConfig( use_data_parallel=False, diff --git a/demo/text-classification/run_classifier.sh b/demo/text-classification/run_classifier.sh index 3497e7a3..8016b2a3 100644 --- a/demo/text-classification/run_classifier.sh +++ b/demo/text-classification/run_classifier.sh @@ -26,3 +26,4 @@ python -u text_classifier.py \ --num_epoch=3 \ --use_pyreader=True \ --use_data_parallel=True \ + --use_taskid=False \ diff --git a/demo/text-classification/run_predict.sh b/demo/text-classification/run_predict.sh index 6283a376..5c272f7d 100644 --- a/demo/text-classification/run_predict.sh +++ b/demo/text-classification/run_predict.sh @@ -15,4 +15,4 @@ CKPT_DIR="./ckpt_${DATASET}" # sw: Swahili th: Thai tr: Turkish # ur: Urdu vi: Vietnamese zh: Chinese (Simplified) -python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 --use_gpu False --dataset=${DATASET} +python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 --use_gpu False --dataset=${DATASET} ----use_taskid False diff --git a/demo/text-classification/text_classifier.py b/demo/text-classification/text_classifier.py index 24f2e43d..3e1e87fa 100644 --- a/demo/text-classification/text_classifier.py +++ b/demo/text-classification/text_classifier.py @@ -22,17 +22,18 @@ import paddlehub as hub # yapf: disable parser = argparse.ArgumentParser(__doc__) parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.") -parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False") -parser.add_argument("--dataset", type=str, default="chnsenticorp", help="Directory to dataset") +parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for finetuning, input should be True or False") +parser.add_argument("--dataset", type=str, default="chnsenticorp", help="The choice of dataset") parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.") -parser.add_argument("--warmup_proportion", type=float, default=0.0, help="Warmup proportion params for warmup strategy") +parser.add_argument("--warmup_proportion", type=float, default=0.1, help="Warmup proportion params for warmup strategy") parser.add_argument("--data_dir", type=str, default=None, help="Path to training data.") parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint") parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.") parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.") parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.") parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.") +parser.add_argument("--use_taskid", type=ast.literal_eval, default=False, help="Whether to use taskid ,if yes to use ernie v2.") args = parser.parse_args() # yapf: enable. @@ -50,25 +51,46 @@ if __name__ == '__main__': module = hub.Module(name="ernie") elif args.dataset.lower() == "mrpc": dataset = hub.dataset.GLUE("MRPC") - module = hub.Module(name="bert_uncased_L-12_H-768_A-12") + if args.use_taskid: + module = hub.Module(name="ernie_v2_eng_base") + else: + module = hub.Module(name="bert_uncased_L-12_H-768_A-12") elif args.dataset.lower() == "qqp": dataset = hub.dataset.GLUE("QQP") - module = hub.Module(name="bert_uncased_L-12_H-768_A-12") + if args.use_taskid: + module = hub.Module(name="ernie_v2_eng_base") + else: + module = hub.Module(name="bert_uncased_L-12_H-768_A-12") elif args.dataset.lower() == "sst-2": dataset = hub.dataset.GLUE("SST-2") - module = hub.Module(name="bert_uncased_L-12_H-768_A-12") + if args.use_taskid: + module = hub.Module(name="ernie_v2_eng_base") + else: + module = hub.Module(name="bert_uncased_L-12_H-768_A-12") elif args.dataset.lower() == "cola": dataset = hub.dataset.GLUE("CoLA") - module = hub.Module(name="bert_uncased_L-12_H-768_A-12") + if args.use_taskid: + module = hub.Module(name="ernie_v2_eng_base") + else: + module = hub.Module(name="bert_uncased_L-12_H-768_A-12") elif args.dataset.lower() == "qnli": dataset = hub.dataset.GLUE("QNLI") - module = hub.Module(name="bert_uncased_L-12_H-768_A-12") + if args.use_taskid: + module = hub.Module(name="ernie_v2_eng_base") + else: + module = hub.Module(name="bert_uncased_L-12_H-768_A-12") elif args.dataset.lower() == "rte": dataset = hub.dataset.GLUE("RTE") - module = hub.Module(name="bert_uncased_L-12_H-768_A-12") + if args.use_taskid: + module = hub.Module(name="ernie_v2_eng_base") + else: + module = hub.Module(name="bert_uncased_L-12_H-768_A-12") elif args.dataset.lower() == "mnli": dataset = hub.dataset.GLUE("MNLI") - module = hub.Module(name="bert_uncased_L-12_H-768_A-12") + if args.use_taskid: + module = hub.Module(name="ernie_v2_eng_base") + else: + module = hub.Module(name="bert_uncased_L-12_H-768_A-12") elif args.dataset.lower().startswith("xnli"): dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:]) module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12") @@ -80,7 +102,8 @@ if __name__ == '__main__': reader = hub.reader.ClassifyReader( dataset=dataset, vocab_path=module.get_vocab_path(), - max_seq_len=args.max_seq_len) + max_seq_len=args.max_seq_len, + use_task_id=args.use_taskid) # Construct transfer learning network # Use "pooled_output" for classification tasks on an entire sentence. @@ -96,6 +119,14 @@ if __name__ == '__main__': inputs["input_mask"].name, ] + if args.use_taskid: + feed_list = [ + inputs["input_ids"].name, + inputs["position_ids"].name, + inputs["segment_ids"].name, + inputs["input_mask"].name, + inputs["task_ids"].name, + ] # Select finetune strategy, setup config and finetune strategy = hub.AdamWeightDecayStrategy( weight_decay=args.weight_decay, diff --git a/paddlehub/reader/nlp_reader.py b/paddlehub/reader/nlp_reader.py index ccc1ea34..4fdbaf2d 100644 --- a/paddlehub/reader/nlp_reader.py +++ b/paddlehub/reader/nlp_reader.py @@ -332,6 +332,8 @@ class ClassifyReader(BaseReader): ] if self.use_task_id: + padded_task_ids = np.ones_like( + padded_token_ids, dtype="int64") * self.task_id return_list = [ padded_token_ids, padded_position_ids, padded_text_type_ids, input_mask, padded_task_ids -- GitLab