diff --git a/demo/text_classification/run_predict.sh b/demo/text_classification/run_predict.sh index c4c0d6002a932fd48d074672624640df131d961e..f09aa160c61a3354cde1ab81aa7cb14e0cbcbed0 100644 --- a/demo/text_classification/run_predict.sh +++ b/demo/text_classification/run_predict.sh @@ -7,4 +7,4 @@ python -u predict.py \ --checkpoint_dir=$CKPT_DIR \ --max_seq_len=128 \ --use_gpu=True \ - --batch_size=24 + --batch_size=1 diff --git a/demo/text_generation/predict.py b/demo/text_generation/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..808d28c44a99db9d7fe8b3a3f3b91758358dc800 --- /dev/null +++ b/demo/text_generation/predict.py @@ -0,0 +1,84 @@ +#coding:utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fine-tuning on classification task """ + +import argparse +import ast + +import paddlehub as hub + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint") +parser.add_argument("--batch_size", type=int, default=1, help="Total examples' number in batch for training.") +parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.") +parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for fine-tuning, input should be True or False") +parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.") +args = parser.parse_args() +# yapf: enable. + +if __name__ == '__main__': + # Load Paddlehub ERNIE Tiny pretrained model + module = hub.Module(name="ernie_tiny") + inputs, outputs, program = module.context( + trainable=True, max_seq_len=args.max_seq_len) + + # Download dataset and get its label list and label num + # If you just want labels information, you can omit its tokenizer parameter to avoid preprocessing the train set. + dataset = hub.dataset.Couplet() + num_classes = dataset.num_labels + label_list = dataset.get_labels() + + # Setup RunConfig for PaddleHub Fine-tune API + config = hub.RunConfig( + use_data_parallel=args.use_data_parallel, + use_cuda=args.use_gpu, + batch_size=args.batch_size, + checkpoint_dir=args.checkpoint_dir, + strategy=hub.AdamWeightDecayStrategy()) + + # Construct transfer learning network + # Use "pooled_output" for classification tasks on an entire sentence. + # Use "sequence_output" for token-level output. + pooled_output = outputs["pooled_output"] + sequence_output = outputs["sequence_output"] + + # Define a classfication fine-tune task by PaddleHub's API + gen_task = hub.TextGenerationTask( + feature=pooled_output, + token_feature=sequence_output, + max_seq_len=args.max_seq_len, + num_classes=dataset.num_labels, + config=config, + metrics_choices=["bleu"]) + + # Data to be predicted + text_a = ["人增福寿年增岁", "风吹云乱天垂泪", "若有经心风过耳"] + + # Add 0x02 between characters to match the format of training data, + # otherwise the length of prediction results will not match the input string + # if the input string contains non-Chinese characters. + formatted_text_a = list(map("\002".join, text_a)) + + # Use the appropriate tokenizer to preprocess the data + # For ernie_tiny, it use BertTokenizer too. + tokenizer = hub.BertTokenizer(vocab_file=module.get_vocab_path()) + encoded_data = [ + tokenizer.encode(text=text, max_seq_len=args.max_seq_len) + for text in formatted_text_a + ] + print( + gen_task.predict( + data=encoded_data, label_list=label_list, accelerate_mode=False)) diff --git a/demo/text_generation/run_predict.sh b/demo/text_generation/run_predict.sh new file mode 100644 index 0000000000000000000000000000000000000000..cd4325b8f4f742b518ce721112268bf8abd698f7 --- /dev/null +++ b/demo/text_generation/run_predict.sh @@ -0,0 +1,10 @@ +export FLAGS_eager_delete_tensor_gb=0.0 +export CUDA_VISIBLE_DEVICES=0 + +CKPT_DIR="./ckpt_generation" + +python -u predict.py \ + --checkpoint_dir=$CKPT_DIR \ + --max_seq_len=128 \ + --use_gpu=True \ + --batch_size=1 diff --git a/demo/text_generation/run_text_gen.sh b/demo/text_generation/run_text_gen.sh new file mode 100644 index 0000000000000000000000000000000000000000..cc7a09e930f528f3680496578f75e6afe2932518 --- /dev/null +++ b/demo/text_generation/run_text_gen.sh @@ -0,0 +1,12 @@ +export FLAGS_eager_delete_tensor_gb=0.0 +export CUDA_VISIBLE_DEVICES=0 + +CKPT_DIR="./ckpt_generation" +python -u text_gen.py \ + --batch_size 16 \ + --num_epoch 30 \ + --checkpoint_dir $CKPT_DIR \ + --max_seq_len 50 \ + --learning_rate 5e-3 \ + --cut_fraction 0.1 \ + --use_data_parallel True diff --git a/demo/text_generation/text_gen.py b/demo/text_generation/text_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..afaa70adaacd9627f10169c2a834d23d821b5700 --- /dev/null +++ b/demo/text_generation/text_gen.py @@ -0,0 +1,83 @@ +#coding:utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fine-tuning on classification task """ + +import argparse +import ast + +import paddlehub as hub + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.") +parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for fine-tuning, input should be True or False") +parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.") +parser.add_argument("--cut_fraction", type=float, default=0.1, help="Warmup proportion params for warmup strategy") +parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint") +parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.") +parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.") +parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.") +args = parser.parse_args() +# yapf: enable. + +if __name__ == '__main__': + + # Load Paddlehub ERNIE Tiny pretrained model + module = hub.Module(name="ernie_tiny") + inputs, outputs, program = module.context( + trainable=True, max_seq_len=args.max_seq_len) + + # Use the appropriate tokenizer to preprocess the data set + # For ernie_tiny, it use BertTokenizer too. + tokenizer = hub.BertTokenizer(vocab_file=module.get_vocab_path()) + dataset = hub.dataset.Couplet( + tokenizer=tokenizer, max_seq_len=args.max_seq_len) + + # Construct transfer learning network + # Use "pooled_output" for classification tasks on an entire sentence. + # Use "sequence_output" for token-level output. + pooled_output = outputs["pooled_output"] + sequence_output = outputs["sequence_output"] + + # Select fine-tune strategy, setup config and fine-tune + strategy = hub.ULMFiTStrategy( + learning_rate=args.learning_rate, + optimizer_name="adam", + cut_fraction=args.cut_fraction, + dis_params_layer=module.get_params_layer(), + frz_params_layer=module.get_params_layer()) + + # Setup RunConfig for PaddleHub Fine-tune API + config = hub.RunConfig( + use_data_parallel=args.use_data_parallel, + use_cuda=args.use_gpu, + num_epoch=args.num_epoch, + batch_size=args.batch_size, + checkpoint_dir=args.checkpoint_dir, + strategy=strategy) + + # Define a classfication fine-tune task by PaddleHub's API + gen_task = hub.TextGenerationTask( + dataset=dataset, + feature=pooled_output, + token_feature=sequence_output, + max_seq_len=args.max_seq_len, + num_classes=dataset.num_labels, + config=config, + metrics_choices=["bleu"]) + + # Fine-tune and evaluate by PaddleHub's API + # will finish training, evaluation, testing, save model automatically + gen_task.finetune_and_eval() diff --git a/paddlehub/__init__.py b/paddlehub/__init__.py index 6cd198fcc22e42323b007feb3dc51a999047db52..265db1ab0362a41ddd7ef2b568f6c5381e3aa54c 100644 --- a/paddlehub/__init__.py +++ b/paddlehub/__init__.py @@ -60,6 +60,7 @@ from .finetune.task import SequenceLabelTask from .finetune.task import MultiLabelClassifierTask from .finetune.task import RegressionTask from .finetune.task import ReadingComprehensionTask +from .finetune.task import TextGenerationTask from .finetune.config import RunConfig from .finetune.strategy import AdamWeightDecayStrategy from .finetune.strategy import DefaultStrategy diff --git a/paddlehub/common/paddle_helper.py b/paddlehub/common/paddle_helper.py index 5515ea6329ed417b9c6152cc63532dbd888b9c54..ec7f609aa9b6a8507ea5a70b960249a250d41ea7 100644 --- a/paddlehub/common/paddle_helper.py +++ b/paddlehub/common/paddle_helper.py @@ -52,13 +52,19 @@ def get_variable_info(var): var_info = { 'name': var.name, - 'dtype': convert_dtype_to_string(var.dtype), - 'lod_level': var.lod_level, - 'shape': var.shape, 'stop_gradient': var.stop_gradient, 'is_data': var.is_data, - 'error_clip': var.error_clip + 'error_clip': var.error_clip, + 'type': var.type } + + try: + var_info['dtype'] = convert_dtype_to_string(var.dtype) + var_info['lod_level'] = var.lod_level + var_info['shape'] = var.shape + except: + pass + if isinstance(var, fluid.framework.Parameter): var_info['trainable'] = var.trainable var_info['optimize_attr'] = var.optimize_attr @@ -153,17 +159,34 @@ def _copy_vars_and_ops_in_blocks(from_block, to_block): to_block.create_var(**var_info) for op in from_block.ops: + all_attrs = op.all_attrs() + if 'sub_block' in all_attrs: + _sub_block = to_block.program._create_block() + _copy_vars_and_ops_in_blocks(all_attrs['sub_block'], _sub_block) + to_block.program._rollback() + new_attrs = {'sub_block': _sub_block} + for key, value in all_attrs.items(): + if key == 'sub_block': + continue + new_attrs[key] = copy.deepcopy(value) + else: + new_attrs = copy.deepcopy(all_attrs) + op_info = { 'type': op.type, 'inputs': { - input: [to_block.var(var) for var in op.input(input)] + input: + [to_block._find_var_recursive(var) for var in op.input(input)] for input in op.input_names }, 'outputs': { - output: [to_block.var(var) for var in op.output(output)] + output: [ + to_block._find_var_recursive(var) + for var in op.output(output) + ] for output in op.output_names }, - 'attrs': copy.deepcopy(op.all_attrs()) + 'attrs': new_attrs } to_block.append_op(**op_info) diff --git a/paddlehub/dataset/__init__.py b/paddlehub/dataset/__init__.py index 04d823aded9a3946b5ef0fec4c808763ba568fe4..e1b1c2a266dbd0bcf1e620e97ebc1ad593d00203 100644 --- a/paddlehub/dataset/__init__.py +++ b/paddlehub/dataset/__init__.py @@ -30,6 +30,7 @@ from .cmrc2018 import CMRC2018 from .bq import BQ from .iflytek import IFLYTEK from .thucnews import THUCNEWS +from .couplet import Couplet # CV Dataset from .dogcat import DogCatDataset as DogCat diff --git a/paddlehub/dataset/base_nlp_dataset.py b/paddlehub/dataset/base_nlp_dataset.py index ddb47c540492d12cbc3be0b698c3cccbfda12a49..b4fffbd410020951d3b62fa7fdff0cc9bcafc770 100644 --- a/paddlehub/dataset/base_nlp_dataset.py +++ b/paddlehub/dataset/base_nlp_dataset.py @@ -68,7 +68,8 @@ class BaseNLPDataset(BaseDataset): if not self.tokenizer or not examples: return [] logger.info("Processing the train set...") - self._train_records = self._convert_examples_to_records(examples) + self._train_records = self._convert_examples_to_records( + examples, phase="train") return self._train_records @property @@ -78,7 +79,8 @@ class BaseNLPDataset(BaseDataset): if not self.tokenizer or not examples: return [] logger.info("Processing the dev set...") - self._dev_records = self._convert_examples_to_records(examples) + self._dev_records = self._convert_examples_to_records( + examples, phase="dev") return self._dev_records @property @@ -88,7 +90,8 @@ class BaseNLPDataset(BaseDataset): if not self.tokenizer or not examples: return [] logger.info("Processing the test set...") - self._test_records = self._convert_examples_to_records(examples) + self._test_records = self._convert_examples_to_records( + examples, phase="test") return self._test_records @property @@ -98,7 +101,8 @@ class BaseNLPDataset(BaseDataset): if not self.tokenizer or not examples: return [] logger.info("Processing the predict set...") - self._predict_records = self._convert_examples_to_records(examples) + self._predict_records = self._convert_examples_to_records( + examples, phase="predict") return self._predict_records def _read_file(self, input_file, phase=None): @@ -148,12 +152,14 @@ class BaseNLPDataset(BaseDataset): examples.append(example) return examples - def _convert_examples_to_records(self, examples): + def _convert_examples_to_records(self, examples, phase): """ Returns a list[dict] including all the input information what the model need. Args: examples (list): the data example, returned by _read_file. + phase (str): the processing phase, can be "train" "dev" "test" or "predict". + Returns: a list with all the examples record. @@ -166,8 +172,8 @@ class BaseNLPDataset(BaseDataset): text_pair=example.text_b, max_seq_len=self.max_seq_len) if example.label: - record["label"] = self.label_list.index( - example.label) if self.label_list else float(example.label) + record["label"] = self.label_index[ + example.label] if self.label_list else float(example.label) records.append(record) return records @@ -281,12 +287,14 @@ class BaseNLPDataset(BaseDataset): class TextClassificationDataset(BaseNLPDataset): - def _convert_examples_to_records(self, examples): + def _convert_examples_to_records(self, examples, phase): """ Returns a list[dict] including all the input information what the model need. Args: examples (list): the data example, returned by _read_file. + phase (str): the processing phase, can be "train" "dev" "test" or "predict". + Returns: a list with all the examples record. @@ -299,18 +307,19 @@ class TextClassificationDataset(BaseNLPDataset): text_pair=example.text_b, max_seq_len=self.max_seq_len) if example.label: - record["label"] = self.label_list.index(example.label) + record["label"] = self.label_index[example.label] records.append(record) return records class RegressionDataset(BaseNLPDataset): - def _convert_examples_to_records(self, examples): + def _convert_examples_to_records(self, examples, phase): """ Returns a list[dict] including all the input information what the model need. Args: examples (list): the data example, returned by _read_file. + phase (str): the processing phase, can be "train" "dev" "test" or "predict". Returns: a list with all the examples record. @@ -328,6 +337,80 @@ class RegressionDataset(BaseNLPDataset): return records +class GenerationDataset(BaseNLPDataset): + def __init__(self, + base_path, + train_file=None, + dev_file=None, + test_file=None, + predict_file=None, + label_file=None, + label_list=None, + train_file_with_header=False, + dev_file_with_header=False, + test_file_with_header=False, + predict_file_with_header=False, + tokenizer=None, + max_seq_len=128, + split_char="\002", + start_token="", + end_token="", + unk_token=""): + self.split_char = split_char + self.start_token = start_token + self.end_token = end_token + self.unk_token = unk_token + super(GenerationDataset, self).__init__( + base_path=base_path, + train_file=train_file, + dev_file=dev_file, + test_file=test_file, + predict_file=predict_file, + label_file=label_file, + label_list=label_list, + train_file_with_header=train_file_with_header, + dev_file_with_header=dev_file_with_header, + test_file_with_header=test_file_with_header, + predict_file_with_header=predict_file_with_header, + tokenizer=tokenizer, + max_seq_len=max_seq_len) + + def _convert_examples_to_records(self, examples, phase): + """ + Returns a list[dict] including all the input information what the model need. + + Args: + examples (list): the data example, returned by _read_file. + phase (str): the processing phase, can be "train" "dev" "test" or "predict". + + Returns: + a list with all the examples record. + """ + records = [] + for example in examples: + record = self.tokenizer.encode( + text=example.text_a.split(self.split_char), + text_pair=example.text_b.split(self.split_char) + if example.text_b else None, + max_seq_len=self.max_seq_len) + if example.label: + expand_label = [self.start_token] + example.label.split( + self.split_char)[:self.max_seq_len - 2] + [self.end_token] + expand_label_id = [ + self.label_index.get(label, + self.label_index[self.unk_token]) + for label in expand_label + ] + record["label"] = expand_label_id[1:] + [ + self.label_index[self.end_token] + ] * (self.max_seq_len - len(expand_label) + 1) + record["dec_input"] = expand_label_id[:-1] + [ + self.label_index[self.end_token] + ] * (self.max_seq_len - len(expand_label) + 1) + records.append(record) + return records + + class SeqLabelingDataset(BaseNLPDataset): def __init__(self, base_path, @@ -363,12 +446,13 @@ class SeqLabelingDataset(BaseNLPDataset): tokenizer=tokenizer, max_seq_len=max_seq_len) - def _convert_examples_to_records(self, examples): + def _convert_examples_to_records(self, examples, phase): """ Returns a list[dict] including all the input information what the model need. Args: examples (list): the data examples, returned by _read_file. + phase (str): the processing phase, can be "train" "dev" "test" or "predict". Returns: a list with all the examples record. @@ -389,11 +473,11 @@ class SeqLabelingDataset(BaseNLPDataset): if tokens_index < len( tokens) and token == tokens[tokens_index]: record["label"].append( - self.label_list.index(labels[tokens_index])) + self.label_index[labels[tokens_index]]) tokens_index += 1 else: record["label"].append( - self.label_list.index(self.no_entity_label)) + self.label_index[self.no_entity_label]) records.append(record) return records @@ -435,13 +519,13 @@ class SeqLabelingDataset(BaseNLPDataset): class MultiLabelDataset(BaseNLPDataset): - def _convert_examples_to_records(self, examples): + def _convert_examples_to_records(self, examples, phase): """ Returns a list[dict] including all the input information what the model need. Args: examples (list): the data examples, returned by _read_file. - max_seq_len (int): padding to the max sequence length. + phase (str): the processing phase, can be "train" "dev" "test" or "predict". Returns: a list with all the examples record. @@ -631,7 +715,16 @@ class MRCDataset(BaseNLPDataset): return special_tokens_num, special_tokens_num_before_doc def _convert_examples_to_records_and_features(self, examples, phase): - """Loads a data file into a list of `InputBatch`s.""" + """ + Returns a list[dict] including all the input information what the model need. + + Args: + examples (list): the data examples, returned by _read_file. + phase (str): the processing phase, can be "train" "dev" "test" or "predict". + + Returns: + a list with all the examples record. + """ features = [] records = [] unique_id = 1000000000 diff --git a/paddlehub/dataset/couplet.py b/paddlehub/dataset/couplet.py new file mode 100644 index 0000000000000000000000000000000000000000..e7dc34aa48f8560a4c4d2de85d32622f8aa213d8 --- /dev/null +++ b/paddlehub/dataset/couplet.py @@ -0,0 +1,84 @@ +# coding:utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import codecs +import csv + +from paddlehub.dataset import InputExample +from paddlehub.common.dir import DATA_HOME +from paddlehub.dataset.base_nlp_dataset import GenerationDataset + +_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/couplet.tar.gz" + + +class Couplet(GenerationDataset): + """ + An open source couplet dataset, see https://github.com/v-zich/couplet-clean-dataset for details. + """ + + def __init__(self, tokenizer=None, max_seq_len=None): + dataset_dir = os.path.join(DATA_HOME, "couplet") + base_path = self._download_dataset(dataset_dir, url=_DATA_URL) + with open(os.path.join(dataset_dir, "vocab.txt")) as vocab_file: + label_list = [line.strip() for line in vocab_file.readlines()] + super(Couplet, self).__init__( + base_path=base_path, + train_file="train.tsv", + dev_file="dev.tsv", + test_file="test.tsv", + label_list=label_list, + tokenizer=tokenizer, + max_seq_len=max_seq_len) + + def _read_file(self, input_file, phase=None): + """Reads a tab separated value file.""" + with codecs.open(input_file, "r", encoding="UTF-8") as f: + reader = csv.reader(f, delimiter="\t", quotechar=None) + examples = [] + seq_id = 0 + for line in reader: + example = InputExample( + guid=seq_id, label=line[1], text_a=line[0]) + seq_id += 1 + examples.append(example) + + return examples + + +if __name__ == "__main__": + from paddlehub.tokenizer.bert_tokenizer import BertTokenizer + tokenizer = BertTokenizer(vocab_file='vocab.txt') + ds = Couplet(tokenizer=tokenizer, max_seq_len=30) + print("first 10 train") + for e in ds.get_train_examples()[:10]: + print("guid: {}\ttext_a: {}\ttext_b: {}\tlabel: {}".format( + e.guid, e.text_a, e.text_b, e.label)) + print("first 10 dev") + for e in ds.get_dev_examples()[:10]: + print("guid: {}\ttext_a: {}\ttext_b: {}\tlabel: {}".format( + e.guid, e.text_a, e.text_b, e.label)) + print("first 10 test") + for e in ds.get_test_examples()[:10]: + print("guid: {}\ttext_a: {}\ttext_b: {}\tlabel: {}".format( + e.guid, e.text_a, e.text_b, e.label)) + print(ds) + print("first 10 dev records") + for e in ds.get_dev_records()[:10]: + print(e) diff --git a/paddlehub/dataset/dataset.py b/paddlehub/dataset/dataset.py index ac2fda31cf429b7bc6e7aa049020d59380912cb7..9b202bd2007ceea97b0c8b75f5bd56ade065a28e 100644 --- a/paddlehub/dataset/dataset.py +++ b/paddlehub/dataset/dataset.py @@ -106,6 +106,9 @@ class BaseDataset(object): "As label_list has been assigned, label_file is noneffective" ) + self.label_index = dict( + zip(self.label_list, range(len(self.label_list)))) + def get_train_examples(self): return self.train_examples diff --git a/paddlehub/finetune/evaluate.py b/paddlehub/finetune/evaluate.py index 5f7d0457552d453c0222ed854609530dfaf14d86..417b4b745eb32578dad84c06f8f414e82ef7b3fc 100644 --- a/paddlehub/finetune/evaluate.py +++ b/paddlehub/finetune/evaluate.py @@ -12,10 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +import collections +import math import numpy as np @@ -153,7 +151,7 @@ def matthews_corrcoef(preds, labels): def recall_nk(data, n, k, m): - ''' + """ This metric can be used to evaluate whether the model can find the correct response B for question A Note: Only applies to each question A only has one correct response B1. @@ -170,7 +168,7 @@ def recall_nk(data, n, k, m): m: int. For every m examples, there's going to be a positive sample. eg. data= [A1,B1,1], [A1,B2,0], [A1,B3,0], [A2,B1,1], [A2,B2,0], [A2,B3,0] For every 3 examples, there will be one positive sample. so m=3, and n can be 1,2 or 3. - ''' + """ def get_p_at_n_in_m(data, n, k, ind): """ @@ -194,3 +192,94 @@ def recall_nk(data, n, k, m): correct_num += get_p_at_n_in_m(data, n, k, ind) return correct_num / length + + +def _get_ngrams(segment, max_order): + """ + Extracts all n-grams upto a given maximum order from an input segment. + + Args: + segment: text segment from which n-grams will be extracted. + max_order: maximum length in tokens of the n-grams returned by this + methods. + + Returns: + The Counter containing all n-grams upto max_order in segment + with a count of how many times each n-gram occurred. + """ + ngram_counts = collections.Counter() + for order in range(1, max_order + 1): + for i in range(0, len(segment) - order + 1): + ngram = tuple(segment[i:i + order]) + ngram_counts[ngram] += 1 + return ngram_counts + + +def compute_bleu(reference_corpus, + translation_corpus, + max_order=4, + smooth=False): + """ + Computes BLEU score of translated segments against one or more references. + + Args: + reference_corpus: list of lists of references for each translation. Each + reference should be tokenized into a list of tokens. + translation_corpus: list of translations to score. Each translation + should be tokenized into a list of tokens. + max_order: Maximum n-gram order to use when computing BLEU score. + smooth: Whether or not to apply Lin et al. 2004 smoothing. + + Returns: + 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram + precisions and brevity penalty. + """ + matches_by_order = [0] * max_order + possible_matches_by_order = [0] * max_order + reference_length = 0 + translation_length = 0 + for (reference, translation) in zip(reference_corpus, translation_corpus): + reference_length += len(reference) + translation_length += len(translation) + + merged_ref_ngram_counts = collections.Counter() + merged_ref_ngram_counts |= _get_ngrams(reference, max_order) + translation_ngram_counts = _get_ngrams(translation, max_order) + overlap = translation_ngram_counts & merged_ref_ngram_counts + for ngram in overlap: + matches_by_order[len(ngram) - 1] += overlap[ngram] + for order in range(1, max_order + 1): + possible_matches = len(translation) - order + 1 + if possible_matches > 0: + possible_matches_by_order[order - 1] += possible_matches + + precisions = [0] * max_order + for i in range(0, max_order): + if smooth: + precisions[i] = ((matches_by_order[i] + 1.) / + (possible_matches_by_order[i] + 1.)) + else: + if possible_matches_by_order[i] > 0: + precisions[i] = ( + float(matches_by_order[i]) / possible_matches_by_order[i]) + else: + precisions[i] = 0.0 + + if min(precisions) > 0: + p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) + geo_mean = math.exp(p_log_sum) + else: + geo_mean = 0 + + ratio = float(translation_length) / reference_length + + if ratio > 1.0: + bp = 1. + elif ratio > 0.0: + bp = math.exp(1 - 1. / ratio) + else: + bp = 0 + + bleu = geo_mean * bp + + return (bleu, precisions, bp, ratio, translation_length, reference_length) diff --git a/paddlehub/finetune/task/__init__.py b/paddlehub/finetune/task/__init__.py index e4457699bfaa5fe0156d64a22304952821c0a088..7a2031c304880d46c4b259379ecd7120359f1ea1 100644 --- a/paddlehub/finetune/task/__init__.py +++ b/paddlehub/finetune/task/__init__.py @@ -19,3 +19,4 @@ from .detection_task import DetectionTask from .reading_comprehension_task import ReadingComprehensionTask from .regression_task import RegressionTask from .sequence_task import SequenceLabelTask +from .generation_task import TextGenerationTask diff --git a/paddlehub/finetune/task/base_task.py b/paddlehub/finetune/task/base_task.py index 287a25d70162a365a872849ebd3caffd520b322d..a6fa1781fb0446704bc167d944b7f60a4eefe2cd 100644 --- a/paddlehub/finetune/task/base_task.py +++ b/paddlehub/finetune/task/base_task.py @@ -344,6 +344,8 @@ class BaseTask(object): self.dataset = dataset if dataset: self._label_list = dataset.get_labels() + else: + self._label_list = None # Compatible code for usage deprecated in paddlehub v1.8. self._base_data_reader = data_reader self._base_feed_list = feed_list @@ -1099,23 +1101,29 @@ class BaseTask(object): or a plaintext string list when the task is initialized with data_reader param (deprecated in paddlehub v1.8). label_list (list): the label list, used to proprocess the output. load_best_model (bool): load the best model or not - return_result (bool): return a readable result or just the raw run result. Always True when the task is not initialized with data_reader param. + return_result (bool): return a readable result or just the raw run result. Always True when the task is not initialized with data_reader but dataset parameter. accelerate_mode (bool): use high-performance predictor or not Returns: RunState: the running result of predict phase """ - if accelerate_mode and isinstance(self._base_data_reader, - hub.reader.LACClassifyReader): - logger.warning( - "LACClassifyReader does not support predictor, the accelerate_mode is closed now." - ) - accelerate_mode = False + if accelerate_mode: + if isinstance(self._base_data_reader, hub.reader.LACClassifyReader): + logger.warning( + "LACClassifyReader does not support predictor, the accelerate_mode is closed now." + ) + accelerate_mode = False + elif isinstance(self, hub.TextGenerationTask): + logger.warning( + "TextGenerationTask does not support predictor, the accelerate_mode is closed now." + ) + accelerate_mode = False self.accelerate_mode = accelerate_mode with self.phase_guard(phase="predict"): self._predict_data = data - self._label_list = label_list + if label_list: + self._label_list = label_list self._predict_start_event() if load_best_model: diff --git a/paddlehub/finetune/task/generation_task.py b/paddlehub/finetune/task/generation_task.py new file mode 100644 index 0000000000000000000000000000000000000000..efed83868ba90340df2a40332b1bad2128afb20c --- /dev/null +++ b/paddlehub/finetune/task/generation_task.py @@ -0,0 +1,342 @@ +#coding:utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from collections import OrderedDict + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid import ParamAttr +from paddle.fluid.layers import RNNCell, LSTMCell, rnn, BeamSearchDecoder, dynamic_decode +from paddlehub.finetune.evaluate import compute_bleu + +from .base_task import BaseTask + + +class AttentionDecoderCell(RNNCell): + def __init__(self, num_layers, hidden_size, dropout_prob=0., + init_scale=0.1): + self.num_layers = num_layers + self.hidden_size = hidden_size + self.dropout_prob = dropout_prob + self.lstm_cells = [] + self.init_scale = init_scale + param_attr = ParamAttr( + initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale)) + bias_attr = ParamAttr(initializer=fluid.initializer.Constant(0.0)) + for i in range(num_layers): + self.lstm_cells.append(LSTMCell(hidden_size, param_attr, bias_attr)) + + def attention(self, query, enc_output, mask=None): + query = fluid.layers.unsqueeze(query, [1]) + memory = fluid.layers.fc( + enc_output, + self.hidden_size, + num_flatten_dims=2, + param_attr=ParamAttr( + name="dec_memory_w", + initializer=fluid.initializer.UniformInitializer( + low=-self.init_scale, high=self.init_scale))) + attn = fluid.layers.matmul(query, memory, transpose_y=True) + + if mask: + attn = fluid.layers.transpose(attn, [1, 0, 2]) + attn = fluid.layers.elementwise_add(attn, mask * 1000000000, -1) + attn = fluid.layers.transpose(attn, [1, 0, 2]) + weight = fluid.layers.softmax(attn) + weight_memory = fluid.layers.matmul(weight, memory) + + return weight_memory + + def call(self, step_input, states, enc_output, enc_padding_mask=None): + lstm_states, input_feed = states + new_lstm_states = [] + step_input = fluid.layers.concat([step_input, input_feed], 1) + for i in range(self.num_layers): + out, new_lstm_state = self.lstm_cells[i](step_input, lstm_states[i]) + step_input = fluid.layers.dropout( + out, + self.dropout_prob, + dropout_implementation='upscale_in_train' + ) if self.dropout_prob > 0 else out + new_lstm_states.append(new_lstm_state) + dec_att = self.attention(step_input, enc_output, enc_padding_mask) + dec_att = fluid.layers.squeeze(dec_att, [1]) + concat_att_out = fluid.layers.concat([dec_att, step_input], 1) + out = fluid.layers.fc( + concat_att_out, + self.hidden_size, + param_attr=ParamAttr( + name="dec_out_w", + initializer=fluid.initializer.UniformInitializer( + low=-self.init_scale, high=self.init_scale))) + return out, [new_lstm_states, out] + + +class TextGenerationTask(BaseTask): + """ + TextGenerationTask use rnn as decoder and beam search technology when predict. + + Args: + feature(Variable): The sentence-level feature, shape as [-1, emb_size]. + token_feature(Variable): The token-level feature, shape as [-1, seq_len, emb_size]. + max_seq_len(int): the decoder max sequence length. + num_classes(int): total labels of the task. + dataset(GenerationDataset): the dataset containing training set, development set and so on. If you want to finetune the model, you should set it. + Otherwise, if you just want to use the model to predict, you can omit it. Default None + num_layers(int): the decoder rnn layers number. Default 1 + hidden_size(int): the decoder rnn hidden size. Default 128 + dropout(float): the decoder dropout rate. Default 0. + beam_size(int): the beam search size during predict phase. Default 10. + beam_max_step_num(int): the beam search max step number. Default 30. + start_token(str): the beam search start token. Default "" + end_token(str): the beam search end token. Default "" + startup_program(Program): the customized startup_program, default None + config(RunConfig): the config for the task, default None + metrics_choices(list): metrics used to the task, default ["bleu"] + """ + + def __init__( + self, + feature, + token_feature, + max_seq_len, + num_classes, + dataset=None, + num_layers=1, + hidden_size=512, + dropout=0., + beam_size=10, + beam_max_step_num=30, + start_token="", + end_token="", + startup_program=None, + config=None, + metrics_choices="default", + ): + if metrics_choices == "default": + metrics_choices = ["bleu"] + main_program = feature.block.program + super(TextGenerationTask, self).__init__( + dataset=dataset, + main_program=main_program, + startup_program=startup_program, + config=config, + metrics_choices=metrics_choices) + + self.num_layers = num_layers + self.hidden_size = hidden_size + self.dropout = dropout + self.token_feature = token_feature + self.feature = feature + self.max_seq_len = max_seq_len + self.num_classes = num_classes + self.beam_size = beam_size + self.beam_max_step_num = beam_max_step_num + self.start_token = start_token + self.end_token = end_token + + def _add_label(self): + label = fluid.layers.data( + name="label", shape=[self.max_seq_len, 1], dtype='int64') + return [label] + + def _build_net(self): + self.seq_len = fluid.layers.data( + name="seq_len", shape=[1], dtype='int64', lod_level=0) + self.seq_len_used = fluid.layers.squeeze(self.seq_len, axes=[1]) + src_mask = fluid.layers.sequence_mask( + self.seq_len_used, maxlen=self.max_seq_len, dtype='float32') + enc_padding_mask = (src_mask - 1.0) + + # Define decoder and initialize it. + dec_cell = AttentionDecoderCell(self.num_layers, self.hidden_size, + self.dropout) + enc_last_step = fluid.layers.slice( + self.token_feature, + axes=[1], + starts=[-1], + ends=[self.token_feature.shape[1] + 1]) + dec_init_cell = fluid.layers.fc( + input=enc_last_step, + size=self.hidden_size, + num_flatten_dims=1, + param_attr=fluid.ParamAttr( + name="dec_init_cell_w", + initializer=fluid.initializer.TruncatedNormal(scale=0.02)), + bias_attr=fluid.ParamAttr( + name="dec_init_cell_b", + initializer=fluid.initializer.Constant(0.))) + dec_init_hidden = fluid.layers.fc( + input=self.feature, + size=self.hidden_size, + num_flatten_dims=1, + param_attr=fluid.ParamAttr( + name="dec_init_hidden_w", + initializer=fluid.initializer.TruncatedNormal(scale=0.02)), + bias_attr=fluid.ParamAttr( + name="dec_init_hidden_b", + initializer=fluid.initializer.Constant(0.))) + # TODO: maybe dec_init_hidden can use self.feature, and dec_init_cell can be get_initial_states + dec_initial_states = [ + [[dec_init_hidden, dec_init_cell]] * self.num_layers, + dec_cell.get_initial_states( + batch_ref=self.token_feature, shape=[self.hidden_size]) + ] + tar_vocab_size = len(self._label_list) + tar_embeder = lambda x: fluid.embedding( + input=x, + size=[tar_vocab_size, self.hidden_size], + dtype='float32', + is_sparse=False, + param_attr=fluid.ParamAttr( + name='target_embedding', + initializer=fluid.initializer.UniformInitializer( + low=-0.1, high=0.1))) + start_token_id = self._label_list.index(self.start_token) + end_token_id = self._label_list.index(self.end_token) + if not self.is_predict_phase: + self.dec_input = fluid.layers.data( + name="dec_input", shape=[self.max_seq_len], dtype='int64') + tar_emb = tar_embeder(self.dec_input) + dec_output, _ = rnn( + cell=dec_cell, + inputs=tar_emb, + initial_states=dec_initial_states, + sequence_length=None, + enc_output=self.token_feature, + enc_padding_mask=enc_padding_mask) + self.logits = fluid.layers.fc( + dec_output, + size=tar_vocab_size, + num_flatten_dims=len(dec_output.shape) - 1, + param_attr=fluid.ParamAttr( + name="output_w", + initializer=fluid.initializer.UniformInitializer( + low=-0.1, high=0.1))) + self.ret_infers = fluid.layers.reshape( + x=fluid.layers.argmax(self.logits, axis=2), shape=[-1, 1]) + logits = self.logits + logits = fluid.layers.softmax(logits) + return [logits] + else: + output_layer = lambda x: fluid.layers.fc( + x, + size=tar_vocab_size, + num_flatten_dims=len(x.shape) - 1, + param_attr=fluid.ParamAttr(name="output_w")) + beam_search_decoder = BeamSearchDecoder( + dec_cell, + start_token_id, + end_token_id, + self.beam_size, + embedding_fn=tar_embeder, + output_fn=output_layer) + enc_output = beam_search_decoder.tile_beam_merge_with_batch( + self.token_feature, self.beam_size) + enc_padding_mask = beam_search_decoder.tile_beam_merge_with_batch( + enc_padding_mask, self.beam_size) + self.ret_infers, _ = dynamic_decode( + beam_search_decoder, + inits=dec_initial_states, + max_step_num=self.beam_max_step_num, + enc_output=enc_output, + enc_padding_mask=enc_padding_mask) + return self.ret_infers + + def _postprocessing(self, run_states): + results = [] + for batch_states in run_states: + batch_results = batch_states.run_results + batch_infers = batch_results[0].astype(np.int32) + seq_lens = batch_results[1].reshape([-1]).astype(np.int32).tolist() + for i, sample_infers in enumerate(batch_infers): + beam_result = [] + for beam_infer in sample_infers.T: + seq_result = [ + self._label_list[infer] + for infer in beam_infer.tolist()[:seq_lens[i] - 2] + ] + beam_result.append(seq_result) + results.append(beam_result) + return results + + def _add_metrics(self): + self.ret_labels = fluid.layers.reshape(x=self.labels[0], shape=[-1, 1]) + return [self.ret_labels, self.ret_infers, self.seq_len_used] + + def _add_loss(self): + loss = fluid.layers.cross_entropy( + input=self.outputs[0], label=self.labels[0], soft_label=False) + loss = fluid.layers.unsqueeze(loss, axes=[2]) + max_tar_seq_len = fluid.layers.shape(self.dec_input)[1] + tar_sequence_length = fluid.layers.elementwise_sub( + self.seq_len_used, fluid.layers.ones_like(self.seq_len_used)) + tar_mask = fluid.layers.sequence_mask( + tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32') + loss = loss * tar_mask + loss = fluid.layers.reduce_mean(loss, dim=[0]) + loss = fluid.layers.reduce_sum(loss) + return loss + + @property + def fetch_list(self): + if self.is_train_phase or self.is_test_phase: + return [metric.name for metric in self.metrics] + [self.loss.name] + elif self.is_predict_phase: + return [self.ret_infers.name] + [self.seq_len_used.name] + return [output.name for output in self.outputs] + + def _calculate_metrics(self, run_states): + loss_sum = 0 + run_step = run_examples = 0 + labels = [] + results = [] + for run_state in run_states: + loss_sum += np.mean(run_state.run_results[-1]) + np_labels = run_state.run_results[0] + np_infers = run_state.run_results[1] + np_lens = run_state.run_results[2] + batch_size = len(np_lens) + max_len = len(np_labels) // batch_size + for i in range(batch_size): + label = [ + self.dataset.label_list[int(id)] + for id in np_labels[i * max_len:i * max_len + np_lens[i] - + 2] + ] # -2 for CLS and SEP + result = [ + self.dataset.label_list[int(id)] + for id in np_infers[i * max_len:i * max_len + np_lens[i] - + 2] + ] + labels.append(label) + results.append(result) + + run_examples += run_state.run_examples + run_step += run_state.run_step + + run_time_used = time.time() - run_states[0].run_time_begin + run_speed = run_step / run_time_used + avg_loss = loss_sum / run_examples + + # The first key will be used as main metrics to update the best model + scores = OrderedDict() + for metric in self.metrics_choices: + if metric == "bleu": + scores["bleu"] = compute_bleu(labels, results, max_order=1)[0] + else: + raise ValueError("Not Support Metric: \"%s\"" % metric) + return scores, avg_loss, run_speed