diff --git a/demo/ernie-classification/cls_predict.py b/demo/ernie-classification/cls_predict.py new file mode 100644 index 0000000000000000000000000000000000000000..08f538d410bf10b10a969f56cde63d7538ad1da5 --- /dev/null +++ b/demo/ernie-classification/cls_predict.py @@ -0,0 +1,83 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Finetuning on classification task """ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time +import argparse +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddlehub as hub + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint") +parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.") +args = parser.parse_args() +# yapf: enable. + +if __name__ == '__main__': + # loading Paddlehub ERNIE pretrained model + module = hub.Module(name="ernie") + input_dict, output_dict, program = module.context( + max_seq_len=args.max_seq_len) + + # Sentence classification dataset reader + dataset = hub.dataset.ChnSentiCorp() + reader = hub.reader.ClassifyReader( + dataset=dataset, + vocab_path=module.get_vocab_path(), + max_seq_len=args.max_seq_len) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + with fluid.program_guard(program): + label = fluid.layers.data(name="label", shape=[1], dtype='int64') + + # Use "pooled_output" for classification tasks on an entire sentence. + # Use "sequence_outputs" for token-level output. + pooled_output = output_dict["pooled_output"] + + # Setup feed list for data feeder + # Must feed all the tensor of ERNIE's module need + + # Define a classfication finetune task by PaddleHub's API + cls_task = hub.create_text_classification_task( + feature=pooled_output, label=label, num_classes=dataset.num_labels) + + # classificatin probability tensor + probs = cls_task.variable("probs") + + # load best model checkpoint + fluid.io.load_persistables(exe, args.checkpoint_dir) + + feed_list = [ + input_dict["input_ids"].name, input_dict["position_ids"].name, + input_dict["segment_ids"].name, input_dict["input_mask"].name, + label.name + ] + + data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place) + test_reader = reader.data_generator(phase='test', shuffle=False) + test_examples = dataset.get_test_examples() + for index, batch in enumerate(test_reader()): + probs_v = exe.run( + feed=data_feeder.feed(batch), fetch_list=[probs.name]) + print(test_examples[index], probs_v[0][0]) diff --git a/demo/ernie-classification/ernie_tiny_demo.py b/demo/ernie-classification/ernie_tiny_demo.py index a8c24a10d10e712a15db0f39b97675b5410fa90e..e8cf88f8f5b33916ffbb79253141afaceff203ff 100644 --- a/demo/ernie-classification/ernie_tiny_demo.py +++ b/demo/ernie-classification/ernie_tiny_demo.py @@ -6,10 +6,9 @@ module = hub.Module(name="ernie") inputs, outputs, program = module.context(trainable=True, max_seq_len=128) # Step2 +dataset = hub.dataset.ChnSentiCorp() reader = hub.reader.ClassifyReader( - dataset=hub.dataset.ChnSentiCorp(), - vocab_path=module.get_vocab_path(), - max_seq_len=128) + dataset=dataset, vocab_path=module.get_vocab_path(), max_seq_len=128) # Step3 with fluid.program_guard(program): @@ -18,7 +17,7 @@ with fluid.program_guard(program): pooled_output = outputs["pooled_output"] cls_task = hub.create_text_classification_task( - feature=pooled_output, label=label, num_classes=reader.get_num_labels()) + feature=pooled_output, label=label, num_classes=dataset.num_labels) # Step4 strategy = hub.AdamWeightDecayStrategy( diff --git a/demo/ernie-classification/question_answering.py b/demo/ernie-classification/question_answering.py index 65406c45fcefb716dabef5c807e366567e5e730f..0190a28c64887e8b6a52c6e1199710cfc547cc36 100644 --- a/demo/ernie-classification/question_answering.py +++ b/demo/ernie-classification/question_answering.py @@ -37,11 +37,11 @@ if __name__ == '__main__': trainable=True, max_seq_len=args.max_seq_len) # Step2: Download dataset and use ClassifyReader to read dataset + dataset = hub.dataset.NLPCC_DBQA() reader = hub.reader.ClassifyReader( - dataset=hub.dataset.NLPCC_DBQA(), + dataset=dataset, vocab_path=module.get_vocab_path(), max_seq_len=args.max_seq_len) - num_labels = len(reader.get_labels()) # Step3: construct transfer learning network with fluid.program_guard(program): @@ -59,7 +59,7 @@ if __name__ == '__main__': ] # Define a classfication finetune task by PaddleHub's API cls_task = hub.create_text_classification_task( - pooled_output, label, num_classes=num_labels) + pooled_output, label, num_classes=dataset.num_labels) # Step4: Select finetune strategy, setup config and finetune strategy = hub.AdamWeightDecayStrategy( diff --git a/demo/ernie-classification/question_matching.py b/demo/ernie-classification/question_matching.py index 351922c19c79c32458055213d2e95699d5cbb351..acb26a89c401a2c9c98e4a4d505bbc9cf7fa0a0a 100644 --- a/demo/ernie-classification/question_matching.py +++ b/demo/ernie-classification/question_matching.py @@ -37,11 +37,11 @@ if __name__ == '__main__': trainable=True, max_seq_len=args.max_seq_len) # Step2: Download dataset and use ClassifyReader to read dataset + dataset = hub.dataset.LCQMC() reader = hub.reader.ClassifyReader( - dataset=hub.dataset.LCQMC(), + dataset=dataset, vocab_path=module.get_vocab_path(), max_seq_len=args.max_seq_len) - num_labels = len(reader.get_labels()) # Step3: construct transfer learning network with fluid.program_guard(program): @@ -59,7 +59,7 @@ if __name__ == '__main__': ] # Define a classfication finetune task by PaddleHub's API cls_task = hub.create_text_classification_task( - pooled_output, label, num_classes=num_labels) + pooled_output, label, num_classes=dataset.num_labels) # Step4: Select finetune strategy, setup config and finetune strategy = hub.AdamWeightDecayStrategy( diff --git a/demo/ernie-classification/run_predict.sh b/demo/ernie-classification/run_predict.sh new file mode 100644 index 0000000000000000000000000000000000000000..e3b46f49cea7d29404a6ff9d63e41e69e7ba6a97 --- /dev/null +++ b/demo/ernie-classification/run_predict.sh @@ -0,0 +1,4 @@ +export CUDA_VISIBLE_DEVICES=1 + +CKPT_DIR="./ckpt_sentiment_cls/best_model" +python -u cls_predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 diff --git a/demo/ernie-classification/sentiment_cls.py b/demo/ernie-classification/sentiment_cls.py index fddc8c3ab8cdda3a8abcd4461dff047bff8122d3..a08d1a7b09efd3d9d8dcda1e18feac90e4aae9f1 100644 --- a/demo/ernie-classification/sentiment_cls.py +++ b/demo/ernie-classification/sentiment_cls.py @@ -37,8 +37,9 @@ if __name__ == '__main__': trainable=True, max_seq_len=args.max_seq_len) # Step2: Download dataset and use ClassifyReader to read dataset + dataset = hub.dataset.ChnSentiCorp() reader = hub.reader.ClassifyReader( - dataset=hub.dataset.ChnSentiCorp(), + dataset=dataset, vocab_path=module.get_vocab_path(), max_seq_len=args.max_seq_len) @@ -58,7 +59,9 @@ if __name__ == '__main__': ] # Define a classfication finetune task by PaddleHub's API cls_task = hub.create_text_classification_task( - pooled_output, label, num_classes=reader.get_num_labels()) + feature=pooled_output, + label=label, + num_classes=dataset.num_labels()) # Step4: Select finetune strategy, setup config and finetune strategy = hub.AdamWeightDecayStrategy( diff --git a/paddlehub/dataset/chnsenticorp.py b/paddlehub/dataset/chnsenticorp.py index c237a575f9a18511ca2018af60be048b4d41e3d8..af3beb863f900bb3aa8f229957a362b91998db0e 100644 --- a/paddlehub/dataset/chnsenticorp.py +++ b/paddlehub/dataset/chnsenticorp.py @@ -70,6 +70,13 @@ class ChnSentiCorp(HubDataset): def get_labels(self): return ["0", "1"] + @property + def num_labels(self): + """ + Return the number of labels in the dataset. + """ + return len(self.get_labels()) + def _read_tsv(self, input_file, quotechar=None): """Reads a tab separated value file.""" with open(input_file, "r") as f: diff --git a/paddlehub/dataset/dataset.py b/paddlehub/dataset/dataset.py index 8ce5046f00a7a6901c1bdb0212997c2040f3db83..aca777476628acfd126194e53d34bed34ab19a19 100644 --- a/paddlehub/dataset/dataset.py +++ b/paddlehub/dataset/dataset.py @@ -40,6 +40,13 @@ class InputExample(object): self.text_b = text_b self.label = label + def __str__(self): + if self.text_b is None: + return "text={}\tlabel={}".format(self.text_a, self.label) + else: + return "text_a={}\ttext_b{},label={}".format( + self.text_a, self.text_b, label) + class HubDataset(object): def get_train_examples(self): @@ -56,3 +63,6 @@ class HubDataset(object): def get_labels(self): raise NotImplementedError() + + def num_labels(self): + raise NotImplementedError() diff --git a/paddlehub/dataset/lcqmc.py b/paddlehub/dataset/lcqmc.py index 169f44b93aef75abb595c687a840476a22326765..06654fa14828a6d30ec9defb5d2faca7c8ce0ebf 100644 --- a/paddlehub/dataset/lcqmc.py +++ b/paddlehub/dataset/lcqmc.py @@ -66,6 +66,13 @@ class LCQMC(HubDataset): """See base class.""" return ["0", "1"] + @property + def num_labels(self): + """ + Return the number of labels in the dataset. + """ + return len(self.get_labels()) + def _read_tsv(self, input_file, quotechar=None): """Reads a tab separated value file.""" with open(input_file, "r") as f: diff --git a/paddlehub/dataset/msra_ner.py b/paddlehub/dataset/msra_ner.py index aeade65ca7c6f95cbbdf5b17f25774f857539dc0..460a2fb21f7034d77d359605c414a4b57506582a 100644 --- a/paddlehub/dataset/msra_ner.py +++ b/paddlehub/dataset/msra_ner.py @@ -79,6 +79,13 @@ class MSRA_NER(HubDataset): def get_labels(self): return ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] + @property + def num_labels(self): + """ + Return the number of labels in the dataset. + """ + return len(self.get_labels()) + def get_label_map(self): return self.label_map diff --git a/paddlehub/dataset/nlpcc_dbqa.py b/paddlehub/dataset/nlpcc_dbqa.py index 5ec57cbab07d17d844189ede810b356493783536..beedba95a6e9d20027baa6efc44edc4dba8b07e7 100644 --- a/paddlehub/dataset/nlpcc_dbqa.py +++ b/paddlehub/dataset/nlpcc_dbqa.py @@ -72,6 +72,13 @@ class NLPCC_DBQA(HubDataset): """See base class.""" return ["0", "1"] + @property + def num_labels(self): + """ + Return the number of labels in the dataset. + """ + return len(self.get_labels()) + def _read_tsv(self, input_file, quotechar=None): """Reads a tab separated value file.""" with open(input_file, "r") as f: diff --git a/paddlehub/reader/nlp_reader.py b/paddlehub/reader/nlp_reader.py index b605886163927370170778b8541b215684920a0d..eaca1339932706c2f97bb5abf60c6f5767307860 100644 --- a/paddlehub/reader/nlp_reader.py +++ b/paddlehub/reader/nlp_reader.py @@ -80,9 +80,6 @@ class BaseReader(object): """Gets the list of labels for this data set.""" return self.dataset.get_labels() - def get_num_labels(self): - return len(self.dataset.get_labels()) - def get_train_progress(self): """Gets progress for training phase.""" return self.current_example, self.current_epoch @@ -211,7 +208,7 @@ class BaseReader(object): ) return self.num_examples[phase] - def data_generator(self, batch_size, phase='train', shuffle=True): + def data_generator(self, batch_size=1, phase='train', shuffle=True): if phase == 'train': examples = self.get_train_examples()