提交 fa68672a 编写于 作者: Z zhangxuefei

Update text cls demo to adapted to ernie v2

上级 5edc2a6f
......@@ -29,12 +29,14 @@ import paddlehub as hub
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--checkpoint_dir", type=str, default="ckpt_20190802182531", help="Directory to model checkpoint")
parser.add_argument("--batch_size", type=int, default=1, help="Total examples' number in batch for training.")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.")
parser.add_argument("--dataset", type=str, default="chnsenticorp", help="Directory to dataset")
parser.add_argument("--dataset", type=str, default="chnsenticorp", help="The choice of dataset")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.")
parser.add_argument("--use_taskid", type=ast.literal_eval, default=False, help="Whether to use taskid ,if yes to use ernie v2.")
args = parser.parse_args()
# yapf: enable.
......@@ -52,25 +54,46 @@ if __name__ == '__main__':
module = hub.Module(name="ernie")
elif args.dataset.lower() == "mrpc":
dataset = hub.dataset.GLUE("MRPC")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "qqp":
dataset = hub.dataset.GLUE("QQP")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "sst-2":
dataset = hub.dataset.GLUE("SST-2")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "cola":
dataset = hub.dataset.GLUE("CoLA")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "qnli":
dataset = hub.dataset.GLUE("QNLI")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "rte":
dataset = hub.dataset.GLUE("RTE")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "mnli":
dataset = hub.dataset.GLUE("MNLI")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower().startswith("xnli"):
dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:])
module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12")
......@@ -82,7 +105,8 @@ if __name__ == '__main__':
reader = hub.reader.ClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
max_seq_len=args.max_seq_len,
use_task_id=args.use_taskid)
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
......@@ -98,6 +122,15 @@ if __name__ == '__main__':
inputs["input_mask"].name,
]
if args.use_taskid:
feed_list = [
inputs["input_ids"].name,
inputs["position_ids"].name,
inputs["segment_ids"].name,
inputs["input_mask"].name,
inputs["task_ids"].name,
]
# Setup runing config for PaddleHub Finetune API
config = hub.RunConfig(
use_data_parallel=False,
......
......@@ -26,3 +26,4 @@ python -u text_classifier.py \
--num_epoch=3 \
--use_pyreader=True \
--use_data_parallel=True \
--use_taskid=False \
......@@ -15,4 +15,4 @@ CKPT_DIR="./ckpt_${DATASET}"
# sw: Swahili th: Thai tr: Turkish
# ur: Urdu vi: Vietnamese zh: Chinese (Simplified)
python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 --use_gpu False --dataset=${DATASET}
python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 --use_gpu False --dataset=${DATASET} ----use_taskid False
......@@ -22,17 +22,18 @@ import paddlehub as hub
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--dataset", type=str, default="chnsenticorp", help="Directory to dataset")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--dataset", type=str, default="chnsenticorp", help="The choice of dataset")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.")
parser.add_argument("--warmup_proportion", type=float, default=0.0, help="Warmup proportion params for warmup strategy")
parser.add_argument("--warmup_proportion", type=float, default=0.1, help="Warmup proportion params for warmup strategy")
parser.add_argument("--data_dir", type=str, default=None, help="Path to training data.")
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.")
parser.add_argument("--use_taskid", type=ast.literal_eval, default=False, help="Whether to use taskid ,if yes to use ernie v2.")
args = parser.parse_args()
# yapf: enable.
......@@ -50,25 +51,46 @@ if __name__ == '__main__':
module = hub.Module(name="ernie")
elif args.dataset.lower() == "mrpc":
dataset = hub.dataset.GLUE("MRPC")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "qqp":
dataset = hub.dataset.GLUE("QQP")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "sst-2":
dataset = hub.dataset.GLUE("SST-2")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "cola":
dataset = hub.dataset.GLUE("CoLA")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "qnli":
dataset = hub.dataset.GLUE("QNLI")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "rte":
dataset = hub.dataset.GLUE("RTE")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "mnli":
dataset = hub.dataset.GLUE("MNLI")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
if args.use_taskid:
module = hub.Module(name="ernie_v2_eng_base")
else:
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower().startswith("xnli"):
dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:])
module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12")
......@@ -80,7 +102,8 @@ if __name__ == '__main__':
reader = hub.reader.ClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
max_seq_len=args.max_seq_len,
use_task_id=args.use_taskid)
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
......@@ -96,6 +119,14 @@ if __name__ == '__main__':
inputs["input_mask"].name,
]
if args.use_taskid:
feed_list = [
inputs["input_ids"].name,
inputs["position_ids"].name,
inputs["segment_ids"].name,
inputs["input_mask"].name,
inputs["task_ids"].name,
]
# Select finetune strategy, setup config and finetune
strategy = hub.AdamWeightDecayStrategy(
weight_decay=args.weight_decay,
......
......@@ -332,6 +332,8 @@ class ClassifyReader(BaseReader):
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册