diff --git a/demo/qa_classification/classifier.py b/demo/qa_classification/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..1c7427b257042e635fec2803c8870bf1f29d31c6 --- /dev/null +++ b/demo/qa_classification/classifier.py @@ -0,0 +1,93 @@ +#coding:utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Finetuning on classification task """ + +import argparse +import ast + +import paddle.fluid as fluid +import paddlehub as hub + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.") +parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False") +parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.") +parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.") +parser.add_argument("--warmup_proportion", type=float, default=0.0, help="Warmup proportion params for warmup strategy") +parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint") +parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.") +parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.") +parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.") +parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.") +args = parser.parse_args() +# yapf: enable. + +if __name__ == '__main__': + # Load Paddlehub ERNIE pretrained model + module = hub.Module(name="ernie") + # module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12") + inputs, outputs, program = module.context( + trainable=True, max_seq_len=args.max_seq_len) + + # Download dataset and use ClassifyReader to read dataset + dataset = hub.dataset.NLPCC_DBQA() + + reader = hub.reader.ClassifyReader( + dataset=dataset, + vocab_path=module.get_vocab_path(), + max_seq_len=args.max_seq_len) + + # Construct transfer learning network + # Use "pooled_output" for classification tasks on an entire sentence. + # Use "sequence_output" for token-level output. + pooled_output = outputs["pooled_output"] + + # Setup feed list for data feeder + # Must feed all the tensor of ERNIE's module need + feed_list = [ + inputs["input_ids"].name, + inputs["position_ids"].name, + inputs["segment_ids"].name, + inputs["input_mask"].name, + ] + + # Select finetune strategy, setup config and finetune + strategy = hub.AdamWeightDecayStrategy( + weight_decay=args.weight_decay, + learning_rate=args.learning_rate, + lr_scheduler="linear_decay") + + # Setup runing config for PaddleHub Finetune API + config = hub.RunConfig( + use_data_parallel=args.use_data_parallel, + use_pyreader=args.use_pyreader, + use_cuda=args.use_gpu, + num_epoch=args.num_epoch, + batch_size=args.batch_size, + checkpoint_dir=args.checkpoint_dir, + strategy=strategy) + + # Define a classfication finetune task by PaddleHub's API + cls_task = hub.TextClassifierTask( + data_reader=reader, + feature=pooled_output, + feed_list=feed_list, + num_classes=dataset.num_labels, + config=config) + + # Finetune and evaluate by PaddleHub's API + # will finish training, evaluation, testing, save model automatically + cls_task.finetune_and_eval() diff --git a/demo/qa_classification/predict.py b/demo/qa_classification/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..813cc91ed3c3027db16f550287adc555ed18ccaf --- /dev/null +++ b/demo/qa_classification/predict.py @@ -0,0 +1,104 @@ +#coding:utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Finetuning on classification task """ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import ast +import numpy as np +import os +import time + +import paddle +import paddle.fluid as fluid +import paddlehub as hub + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint") +parser.add_argument("--batch_size", type=int, default=1, help="Total examples' number in batch for training.") +parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.") +parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False") +parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.") +args = parser.parse_args() +# yapf: enable. + +if __name__ == '__main__': + # loading Paddlehub ERNIE pretrained model + module = hub.Module(name="ernie") + inputs, outputs, program = module.context(max_seq_len=args.max_seq_len) + + # Sentence classification dataset reader + dataset = hub.dataset.NLPCC_DBQA() + reader = hub.reader.ClassifyReader( + dataset=dataset, + vocab_path=module.get_vocab_path(), + max_seq_len=args.max_seq_len) + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + # Construct transfer learning network + # Use "pooled_output" for classification tasks on an entire sentence. + # Use "sequence_output" for token-level output. + pooled_output = outputs["pooled_output"] + + # Setup feed list for data feeder + # Must feed all the tensor of ERNIE's module need + feed_list = [ + inputs["input_ids"].name, + inputs["position_ids"].name, + inputs["segment_ids"].name, + inputs["input_mask"].name, + ] + + # Setup runing config for PaddleHub Finetune API + config = hub.RunConfig( + use_data_parallel=False, + use_pyreader=args.use_pyreader, + use_cuda=args.use_gpu, + batch_size=args.batch_size, + enable_memory_optim=False, + checkpoint_dir=args.checkpoint_dir, + strategy=hub.finetune.strategy.DefaultFinetuneStrategy()) + + # Define a classfication finetune task by PaddleHub's API + cls_task = hub.TextClassifierTask( + data_reader=reader, + feature=pooled_output, + feed_list=feed_list, + num_classes=dataset.num_labels, + config=config) + + # Data to be prdicted + data = [["北京奥运博物馆的场景效果负责人是谁?", "主要承担奥运文物征集、保管、研究和爱国主义教育基地建设相关工作。"], + ["北京奥运博物馆的场景效果负责人是谁", "于海勃,美国加利福尼亚大学教授 场景效果负责人 总设计师"], + ["北京奥运博物馆的场景效果负责人是谁?", "洪麦恩,清华大学美术学院教授 内容及主展线负责人 总设计师"]] + + index = 0 + run_states = cls_task.predict(data=data) + results = [run_state.run_results for run_state in run_states] + max_probs = 0 + for index, batch_result in enumerate(results): + # get predict index + if max_probs <= batch_result[0][0, 1]: + max_probs = batch_result[0][0, 1] + max_flag = index + + print("question:%s\tthe predict answer:%s\t" % (data[max_flag][0], + data[max_flag][1])) diff --git a/demo/qa_classification/run_classifier.sh b/demo/qa_classification/run_classifier.sh new file mode 100644 index 0000000000000000000000000000000000000000..e41b4a50699e1b2b4a39a95f92e82fed7c2a023b --- /dev/null +++ b/demo/qa_classification/run_classifier.sh @@ -0,0 +1,20 @@ +export FLAGS_eager_delete_tensor_gb=0.0 +export CUDA_VISIBLE_DEVICES=0 + + +CKPT_DIR="./ckpt_qa" +# Recommending hyper parameters for difference task +# ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5 +# NLPCC_DBQA: batch_size=8, weight_decay=0.01, num_epoch=3, max_seq_len=512, lr=2e-5 +# LCQMC: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=2e-5 + +python -u classifier.py \ + --batch_size=24 \ + --use_gpu=True \ + --checkpoint_dir=${CKPT_DIR} \ + --learning_rate=5e-5 \ + --weight_decay=0.01 \ + --max_seq_len=128 \ + --num_epoch=3 \ + --use_pyreader=False \ + --use_data_parallel=False \ diff --git a/demo/qa_classification/run_predict.sh b/demo/qa_classification/run_predict.sh new file mode 100644 index 0000000000000000000000000000000000000000..a308117b9e2550c0734092b4a4631cdf387c1d32 --- /dev/null +++ b/demo/qa_classification/run_predict.sh @@ -0,0 +1,5 @@ +export FLAGS_eager_delete_tensor_gb=0.0 +export CUDA_VISIBLE_DEVICES=0 + +CKPT_DIR="./ckpt_qa" +python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 --use_gpu False