diff --git a/demo/ernie-classification/finetune_with_hub.py b/demo/ernie-classification/finetune_with_hub.py index 0e6a284cd698472500f89a861c56d53859940979..8d6737aead8db524824acf209cf569a6cf5eaa85 100644 --- a/demo/ernie-classification/finetune_with_hub.py +++ b/demo/ernie-classification/finetune_with_hub.py @@ -79,7 +79,7 @@ if __name__ == '__main__': label.name ] # Define a classfication finetune task by PaddleHub's API - cls_task = hub.append_mlp_classifier( + cls_task = hub.create_text_classification_task( pooled_output, label, num_classes=num_labels) # Finetune and evaluate by PaddleHub's API diff --git a/demo/ernie-seq-label/finetune_with_hub.py b/demo/ernie-seq-label/finetune_with_hub.py index dbab65c0744190134927367d18753604c7a54d85..0559be40558ec9c8dfda334753b3b806e259eced 100644 --- a/demo/ernie-seq-label/finetune_with_hub.py +++ b/demo/ernie-seq-label/finetune_with_hub.py @@ -82,13 +82,13 @@ if __name__ == '__main__': label.name, seq_len ] # Define a classfication finetune task by PaddleHub's API - seq_label_task = hub.append_sequence_labeler( + seq_label_task = hub.create_seq_labeling_task( feature=sequence_output, labels=label, seq_len=seq_len, num_classes=num_labels) - # Finetune and evaluate by PaddleHub's API + # Finetune and evaluate model by PaddleHub's API # will finish training, evaluation, testing, save model automatically hub.finetune_and_eval( task=seq_label_task, diff --git a/paddlehub/__init__.py b/paddlehub/__init__.py index 8fce7887c925ba2b09d84aa88e7db304a3dc7026..89fb3cb99b9683168c7526d0838b3eb2c01f569e 100644 --- a/paddlehub/__init__.py +++ b/paddlehub/__init__.py @@ -32,11 +32,12 @@ from .module.manager import default_module_manager from .io.type import DataType -from .finetune.network import append_mlp_classifier -from .finetune.network import append_sequence_labeler +from .finetune.task import Task +from .finetune.task import create_seq_labeling_task +from .finetune.task import create_text_classification_task +from .finetune.task import create_img_classification_task from .finetune.finetune import finetune_and_eval from .finetune.config import RunConfig -from .finetune.task import Task from .finetune.strategy import BERTFinetuneStrategy from .finetune.strategy import DefaultStrategy diff --git a/paddlehub/finetune/network.py b/paddlehub/finetune/network.py deleted file mode 100644 index 5513063f046a987931411344651596538c10bcc7..0000000000000000000000000000000000000000 --- a/paddlehub/finetune/network.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import collections -import time -import multiprocessing - -import numpy as np -import paddle.fluid as fluid - -from paddlehub.finetune.task import Task - -__all__ = ['append_mlp_classifier'] - - -def append_mlp_classifier(feature, label, num_classes=2, hidden_units=None): - """ - Append a multi-layer perceptron classifier for binary classification base - on input feature - """ - cls_feats = fluid.layers.dropout( - x=feature, dropout_prob=0.1, dropout_implementation="upscale_in_train") - - # append fully connected layer according to hidden_units - if hidden_units is not None: - for n_hidden in hidden_units: - cls_feats = fluid.layers.fc(input=cls_feats, size=n_hidden) - - logits = fluid.layers.fc( - input=cls_feats, - size=num_classes, - param_attr=fluid.ParamAttr( - name="cls_out_w", - initializer=fluid.initializer.TruncatedNormal(scale=0.02)), - bias_attr=fluid.ParamAttr( - name="cls_out_b", initializer=fluid.initializer.Constant(0.))) - - ce_loss, probs = fluid.layers.softmax_with_cross_entropy( - logits=logits, label=label, return_softmax=True) - loss = fluid.layers.mean(x=ce_loss) - - num_example = fluid.layers.create_tensor(dtype='int64') - accuracy = fluid.layers.accuracy( - input=probs, label=label, total=num_example) - - graph_var_dict = { - "loss": loss, - "probs": probs, - "accuracy": accuracy, - "num_example": num_example - } - - task = Task("text_classification", graph_var_dict, - fluid.default_main_program(), fluid.default_startup_program()) - - return task - - -def append_mlp_multi_classifier(feature, - label, - num_classes, - hidden_units=None, - act=None): - pass - - -def append_sequence_labeler(feature, labels, seq_len, num_classes=None): - logits = fluid.layers.fc( - input=feature, - size=num_classes, - num_flatten_dims=2, - param_attr=fluid.ParamAttr( - name="cls_seq_label_out_w", - initializer=fluid.initializer.TruncatedNormal(scale=0.02)), - bias_attr=fluid.ParamAttr( - name="cls_seq_label_out_b", - initializer=fluid.initializer.Constant(0.))) - - ret_labels = fluid.layers.reshape(x=labels, shape=[-1, 1]) - ret_infers = fluid.layers.reshape( - x=fluid.layers.argmax(logits, axis=2), shape=[-1, 1]) - - labels = fluid.layers.flatten(labels, axis=2) - ce_loss, probs = fluid.layers.softmax_with_cross_entropy( - logits=fluid.layers.flatten(logits, axis=2), - label=labels, - return_softmax=True) - loss = fluid.layers.mean(x=ce_loss) - # accuracy = fluid.layers.accuracy( - # input=probs, label=labels, total=num_example) - - graph_var_dict = { - "loss": loss, - "probs": probs, - "labels": ret_labels, - "infers": ret_infers, - "seq_len": seq_len - } - - task = Task("sequence_labeling", graph_var_dict, - fluid.default_main_program(), fluid.default_startup_program()) - - return task diff --git a/paddlehub/finetune/task.py b/paddlehub/finetune/task.py index 24d27c52f96b30465a5b9b8b2d16f41a6bef2ace..3603e6bcd338fdb1115617e24b512bedd931cc0c 100644 --- a/paddlehub/finetune/task.py +++ b/paddlehub/finetune/task.py @@ -22,6 +22,11 @@ import paddle.fluid as fluid class Task(object): + """ + A simple transfer learning task definition, + including Paddle's main_program, startup_program and inference program + """ + def __init__(self, task_type, graph_var_dict, main_program, startup_program): self.task_type = task_type @@ -51,3 +56,130 @@ class Task(object): metric_variable_names.append(var_name) return metric_variable_names + + +def create_text_classification_task(feature, + label, + num_classes, + hidden_units=None): + """ + Append a multi-layer perceptron classifier for binary classification base + on input feature + """ + cls_feats = fluid.layers.dropout( + x=feature, dropout_prob=0.1, dropout_implementation="upscale_in_train") + + # append fully connected layer according to hidden_units + if hidden_units is not None: + for n_hidden in hidden_units: + cls_feats = fluid.layers.fc(input=cls_feats, size=n_hidden) + + logits = fluid.layers.fc( + input=cls_feats, + size=num_classes, + param_attr=fluid.ParamAttr( + name="cls_out_w", + initializer=fluid.initializer.TruncatedNormal(scale=0.02)), + bias_attr=fluid.ParamAttr( + name="cls_out_b", initializer=fluid.initializer.Constant(0.))) + + ce_loss, probs = fluid.layers.softmax_with_cross_entropy( + logits=logits, label=label, return_softmax=True) + loss = fluid.layers.mean(x=ce_loss) + + num_example = fluid.layers.create_tensor(dtype='int64') + accuracy = fluid.layers.accuracy( + input=probs, label=label, total=num_example) + + graph_var_dict = { + "loss": loss, + "probs": probs, + "accuracy": accuracy, + "num_example": num_example + } + + task = Task("text_classification", graph_var_dict, + fluid.default_main_program(), fluid.default_startup_program()) + + return task + + +def create_img_classification_task(feature, + label, + num_classes, + hidden_units=None): + """ + Append a multi-layer perceptron classifier for binary classification base + on input feature + """ + cls_feats = feature + # append fully connected layer according to hidden_units + if hidden_units is not None: + for n_hidden in hidden_units: + cls_feats = fluid.layers.fc(input=cls_feats, size=n_hidden) + + logits = fluid.layers.fc( + input=cls_feats, + size=num_classes, + param_attr=fluid.ParamAttr( + name="cls_out_w", + initializer=fluid.initializer.TruncatedNormal(scale=0.02)), + bias_attr=fluid.ParamAttr( + name="cls_out_b", initializer=fluid.initializer.Constant(0.))) + + ce_loss, probs = fluid.layers.softmax_with_cross_entropy( + logits=logits, label=label, return_softmax=True) + loss = fluid.layers.mean(x=ce_loss) + + num_example = fluid.layers.create_tensor(dtype='int64') + accuracy = fluid.layers.accuracy( + input=probs, label=label, total=num_example) + + graph_var_dict = { + "loss": loss, + "probs": probs, + "accuracy": accuracy, + "num_example": num_example + } + + task = Task("text_classification", graph_var_dict, + fluid.default_main_program(), fluid.default_startup_program()) + + return task + + +def create_seq_labeling_task(feature, labels, seq_len, num_classes=None): + logits = fluid.layers.fc( + input=feature, + size=num_classes, + num_flatten_dims=2, + param_attr=fluid.ParamAttr( + name="cls_seq_label_out_w", + initializer=fluid.initializer.TruncatedNormal(scale=0.02)), + bias_attr=fluid.ParamAttr( + name="cls_seq_label_out_b", + initializer=fluid.initializer.Constant(0.))) + + ret_labels = fluid.layers.reshape(x=labels, shape=[-1, 1]) + ret_infers = fluid.layers.reshape( + x=fluid.layers.argmax(logits, axis=2), shape=[-1, 1]) + + labels = fluid.layers.flatten(labels, axis=2) + ce_loss, probs = fluid.layers.softmax_with_cross_entropy( + logits=fluid.layers.flatten(logits, axis=2), + label=labels, + return_softmax=True) + loss = fluid.layers.mean(x=ce_loss) + + graph_var_dict = { + "loss": loss, + "probs": probs, + "labels": ret_labels, + "infers": ret_infers, + "seq_len": seq_len + } + + task = Task("sequence_labeling", graph_var_dict, + fluid.default_main_program(), fluid.default_startup_program()) + + return task diff --git a/paddlehub/reader/task_reader.py b/paddlehub/reader/task_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..5f501cd721086c98ce41e27786fd549b15267c4d --- /dev/null +++ b/paddlehub/reader/task_reader.py @@ -0,0 +1,398 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import json +import numpy as np +from collections import namedtuple + +from paddlehub.reader import tokenization +from .batching import pad_batch_data + + +class BaseReader(object): + def __init__(self, + dataset, + vocab_path, + label_map_config=None, + max_seq_len=512, + do_lower_case=True, + in_tokens=False, + random_seed=None): + self.max_seq_len = max_seq_len + self.tokenizer = tokenization.FullTokenizer( + vocab_file=vocab_path, do_lower_case=do_lower_case) + self.vocab = self.tokenizer.vocab + self.dataset = dataset + self.pad_id = self.vocab["[PAD]"] + self.cls_id = self.vocab["[CLS]"] + self.sep_id = self.vocab["[SEP]"] + self.in_tokens = in_tokens + + np.random.seed(random_seed) + + self.label_map = self.dataset.get_label_map() + + self.current_example = 0 + self.current_epoch = 0 + self.num_examples = 0 + + # if label_map_config: + # with open(label_map_config) as f: + # self.label_map = json.load(f) + # else: + # self.label_map = None + + self.num_examples = {'train': -1, 'dev': -1, 'test': -1} + + def get_train_examples(self): + """Gets a collection of `InputExample`s for the train set.""" + return self.dataset.get_train_examples() + + def get_dev_examples(self): + """Gets a collection of `InputExample`s for the dev set.""" + return self.dataset.get_dev_examples() + + def get_val_examples(self): + """Gets a collection of `InputExample`s for the val set.""" + return self.dataset.get_val_examples() + + def get_test_examples(self): + """Gets a collection of `InputExample`s for prediction.""" + return self.dataset.get_test_examples() + + def get_labels(self): + """Gets the list of labels for this data set.""" + return self.dataset.get_labels() + + def get_train_progress(self): + """Gets progress for training phase.""" + return self.current_example, self.current_epoch + + def _truncate_seq_pair(self, tokens_a, tokens_b, max_length): + """Truncates a sequence pair in place to the maximum length.""" + + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_length: + break + if len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + + def _convert_example_to_record(self, example, max_seq_length, tokenizer): + """Converts a single `Example` into a single `Record`.""" + + text_a = tokenization.convert_to_unicode(example.text_a) + tokens_a = tokenizer.tokenize(text_a) + tokens_b = None + if example.text_b is not None: + #if "text_b" in example._fields: + text_b = tokenization.convert_to_unicode(example.text_b) + tokens_b = tokenizer.tokenize(text_b) + + if tokens_b: + # Modifies `tokens_a` and `tokens_b` in place so that the total + # length is less than the specified length. + # Account for [CLS], [SEP], [SEP] with "- 3" + self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) + else: + # Account for [CLS] and [SEP] with "- 2" + if len(tokens_a) > max_seq_length - 2: + tokens_a = tokens_a[0:(max_seq_length - 2)] + + # The convention in BERT/ERNIE is: + # (a) For sequence pairs: + # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] + # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 + # (b) For single sequences: + # tokens: [CLS] the dog is hairy . [SEP] + # type_ids: 0 0 0 0 0 0 0 + # + # Where "type_ids" are used to indicate whether this is the first + # sequence or the second sequence. The embedding vectors for `type=0` and + # `type=1` were learned during pre-training and are added to the wordpiece + # embedding vector (and position vector). This is not *strictly* necessary + # since the [SEP] token unambiguously separates the sequences, but it makes + # it easier for the model to learn the concept of sequences. + # + # For classification tasks, the first vector (corresponding to [CLS]) is + # used as as the "sentence vector". Note that this only makes sense because + # the entire model is fine-tuned. + tokens = [] + text_type_ids = [] + tokens.append("[CLS]") + text_type_ids.append(0) + for token in tokens_a: + tokens.append(token) + text_type_ids.append(0) + tokens.append("[SEP]") + text_type_ids.append(0) + + if tokens_b: + for token in tokens_b: + tokens.append(token) + text_type_ids.append(1) + tokens.append("[SEP]") + text_type_ids.append(1) + + token_ids = tokenizer.convert_tokens_to_ids(tokens) + position_ids = list(range(len(token_ids))) + + if self.label_map: + label_id = self.label_map[example.label] + else: + label_id = example.label + + # Record = namedtuple( + # 'Record', + # ['token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid']) + + # qid = None + # if "qid" in example._fields: + # qid = example.qid + + # record = Record( + # token_ids=token_ids, + # text_type_ids=text_type_ids, + # position_ids=position_ids, + # label_id=label_id, + # qid=qid) + Record = namedtuple( + 'Record', + ['token_ids', 'text_type_ids', 'position_ids', 'label_id']) + + record = Record( + token_ids=token_ids, + text_type_ids=text_type_ids, + position_ids=position_ids, + label_id=label_id) + return record + + def _prepare_batch_data(self, examples, batch_size, phase=None): + """generate batch records""" + batch_records, max_len = [], 0 + for index, example in enumerate(examples): + if phase == "train": + self.current_example = index + record = self._convert_example_to_record(example, self.max_seq_len, + self.tokenizer) + max_len = max(max_len, len(record.token_ids)) + if self.in_tokens: + to_append = (len(batch_records) + 1) * max_len <= batch_size + else: + to_append = len(batch_records) < batch_size + if to_append: + batch_records.append(record) + else: + yield self._pad_batch_records(batch_records) + batch_records, max_len = [record], len(record.token_ids) + + if batch_records: + yield self._pad_batch_records(batch_records) + + # def get_num_examples(self, input_file): + # examples = self._read_tsv(input_file) + # return len(examples) + + def get_num_examples(self, phase): + """Get number of examples for train, dev or test.""" + if phase not in ['train', 'val', 'dev', 'test']: + raise ValueError( + "Unknown phase, which should be in ['train', 'val'/'dev', 'test']." + ) + return self.num_examples[phase] + + def data_generator(self, batch_size, phase='train', shuffle=True): + + if phase == 'train': + examples = self.get_train_examples() + self.num_examples['train'] = len(examples) + elif phase == 'val' or phase == 'dev': + examples = self.get_dev_examples() + self.num_examples['dev'] = len(examples) + elif phase == 'test': + examples = self.get_test_examples() + self.num_examples['test'] = len(examples) + else: + raise ValueError( + "Unknown phase, which should be in ['train', 'dev', 'test'].") + + def wrapper(): + if shuffle: + np.random.shuffle(examples) + + for batch_data in self._prepare_batch_data( + examples, batch_size, phase=phase): + yield [batch_data] + + return wrapper + + +class ClassifyReader(BaseReader): + def _pad_batch_records(self, batch_records): + batch_token_ids = [record.token_ids for record in batch_records] + batch_text_type_ids = [record.text_type_ids for record in batch_records] + batch_position_ids = [record.position_ids for record in batch_records] + batch_labels = [record.label_id for record in batch_records] + batch_labels = np.array(batch_labels).astype("int64").reshape([-1, 1]) + + # if batch_records[0].qid: + # batch_qids = [record.qid for record in batch_records] + # batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1]) + # else: + # batch_qids = np.array([]).astype("int64").reshape([-1, 1]) + + # padding + padded_token_ids, input_mask = pad_batch_data( + batch_token_ids, + max_seq_len=self.max_seq_len, + pad_idx=self.pad_id, + return_input_mask=True) + padded_text_type_ids = pad_batch_data( + batch_text_type_ids, + max_seq_len=self.max_seq_len, + pad_idx=self.pad_id) + padded_position_ids = pad_batch_data( + batch_position_ids, + max_seq_len=self.max_seq_len, + pad_idx=self.pad_id) + + return_list = [ + padded_token_ids, padded_position_ids, padded_text_type_ids, + input_mask, batch_labels + ] + + return return_list + + +class SequenceLabelReader(BaseReader): + def _pad_batch_records(self, batch_records): + batch_token_ids = [record.token_ids for record in batch_records] + batch_text_type_ids = [record.text_type_ids for record in batch_records] + batch_position_ids = [record.position_ids for record in batch_records] + batch_label_ids = [record.label_ids for record in batch_records] + + # padding + padded_token_ids, input_mask, batch_seq_lens = pad_batch_data( + batch_token_ids, + pad_idx=self.pad_id, + max_seq_len=self.max_seq_len, + return_input_mask=True, + return_seq_lens=True) + padded_text_type_ids = pad_batch_data( + batch_text_type_ids, + max_seq_len=self.max_seq_len, + pad_idx=self.pad_id) + padded_position_ids = pad_batch_data( + batch_position_ids, + max_seq_len=self.max_seq_len, + pad_idx=self.pad_id) + padded_label_ids = pad_batch_data( + batch_label_ids, + max_seq_len=self.max_seq_len, + pad_idx=len(self.label_map) - 1) + + return_list = [ + padded_token_ids, padded_position_ids, padded_text_type_ids, + input_mask, padded_label_ids, batch_seq_lens + ] + return return_list + + def _reseg_token_label(self, tokens, labels, tokenizer): + assert len(tokens) == len(labels) + ret_tokens = [] + ret_labels = [] + for token, label in zip(tokens, labels): + sub_token = tokenizer.tokenize(token) + if len(sub_token) == 0: + continue + ret_tokens.extend(sub_token) + ret_labels.append(label) + if len(sub_token) < 2: + continue + sub_label = label + if label.startswith("B-"): + sub_label = "I-" + label[2:] + ret_labels.extend([sub_label] * (len(sub_token) - 1)) + + assert len(ret_tokens) == len(ret_labels) + return ret_tokens, ret_labels + + def _convert_example_to_record(self, example, max_seq_length, tokenizer): + tokens = tokenization.convert_to_unicode(example.text_a).split(u"") + labels = tokenization.convert_to_unicode(example.label).split(u"") + tokens, labels = self._reseg_token_label(tokens, labels, tokenizer) + + if len(tokens) > max_seq_length - 2: + tokens = tokens[0:(max_seq_length - 2)] + labels = labels[0:(max_seq_length - 2)] + + tokens = ["[CLS]"] + tokens + ["[SEP]"] + token_ids = tokenizer.convert_tokens_to_ids(tokens) + position_ids = list(range(len(token_ids))) + text_type_ids = [0] * len(token_ids) + no_entity_id = len(self.label_map) - 1 + label_ids = [no_entity_id + ] + [self.label_map[label] + for label in labels] + [no_entity_id] + + Record = namedtuple( + 'Record', + ['token_ids', 'text_type_ids', 'position_ids', 'label_ids']) + record = Record( + token_ids=token_ids, + text_type_ids=text_type_ids, + position_ids=position_ids, + label_ids=label_ids) + return record + + +class ExtractEmbeddingReader(BaseReader): + def _pad_batch_records(self, batch_records): + batch_token_ids = [record.token_ids for record in batch_records] + batch_text_type_ids = [record.text_type_ids for record in batch_records] + batch_position_ids = [record.position_ids for record in batch_records] + + # padding + padded_token_ids, input_mask, seq_lens = pad_batch_data( + batch_token_ids, + pad_idx=self.pad_id, + max_seq_len=self.max_seq_len, + return_input_mask=True, + return_seq_lens=True) + padded_text_type_ids = pad_batch_data( + batch_text_type_ids, + pad_idx=self.pad_id, + max_seq_len=self.max_seq_len) + padded_position_ids = pad_batch_data( + batch_position_ids, + pad_idx=self.pad_id, + max_seq_len=self.max_seq_len) + + return_list = [ + padded_token_ids, padded_text_type_ids, padded_position_ids, + input_mask, seq_lens + ] + + return return_list + + +if __name__ == '__main__': + pass