diff --git a/demo/text_classification/predict.py b/demo/text_classification/predict.py index 2afea928051c3ec7feb863e17c6080f0d0e71312..55da4d304fac2d5b481d23fb2910c2c937f010c0 100644 --- a/demo/text_classification/predict.py +++ b/demo/text_classification/predict.py @@ -60,7 +60,7 @@ if __name__ == '__main__': # Construct transfer learning network # Use "pooled_output" for classification tasks on an entire sentence. # Use "sequence_output" for token-level output. - pooled_output = outputs["sequence_output"] + token_feature = outputs["sequence_output"] # Setup feed list for data feeder # Must feed all the tensor of module need @@ -82,11 +82,11 @@ if __name__ == '__main__': # Define a classfication finetune task by PaddleHub's API # network choice: bilstm, bow, cnn, dpcnn, gru, lstm # If you wanna add network after ERNIE/BERT/RoBERTa/ELECTRA module, - # you must use the outputs["sequence_output"] as the feature of TextClassifierTask - # rather than outputs["pooled_output"] + # you must use the outputs["sequence_output"] as the token_feature of TextClassifierTask, + # rather than outputs["pooled_output"], and feature is None cls_task = hub.TextClassifierTask( data_reader=reader, - feature=pooled_output, + token_feature=token_feature, feed_list=feed_list, network=args.network, num_classes=dataset.num_labels, diff --git a/demo/text_classification/run_classifier.sh b/demo/text_classification/run_classifier.sh index 08dd3bc05182290671bcd8d8fcb0c42ec268b1bf..afb9e712dc30ec22d2805278a8b429db62e9b27f 100644 --- a/demo/text_classification/run_classifier.sh +++ b/demo/text_classification/run_classifier.sh @@ -1,5 +1,5 @@ export FLAGS_eager_delete_tensor_gb=0.0 -export CUDA_VISIBLE_DEVICES=0,1,2,3 +export CUDA_VISIBLE_DEVICES=0 CKPT_DIR="./ckpt_chnsenticorp" diff --git a/demo/text_classification/text_classifier.py b/demo/text_classification/text_classifier.py index 8f2243022c6c8098779f213c39eb6795ce628f24..7f1eb0f4dc2892dc2c4b7c0f45f5af638bd51510 100644 --- a/demo/text_classification/text_classifier.py +++ b/demo/text_classification/text_classifier.py @@ -58,7 +58,7 @@ if __name__ == '__main__': # Construct transfer learning network # Use "pooled_output" for classification tasks on an entire sentence. # Use "sequence_output" for token-level output. - pooled_output = outputs["sequence_output"] + token_feature = outputs["sequence_output"] # Setup feed list for data feeder # Must feed all the tensor of module need @@ -87,11 +87,11 @@ if __name__ == '__main__': # Define a classfication finetune task by PaddleHub's API # network choice: bilstm, bow, cnn, dpcnn, gru, lstm # If you wanna add network after ERNIE/BERT/RoBERTa/ELECTRA module, - # you must use the outputs["sequence_output"] as the feature of TextClassifierTask - # rather than outputs["pooled_output"] + # you must use the outputs["sequence_output"] as the token_feature of TextClassifierTask, + # rather than outputs["pooled_output"], and feature is None cls_task = hub.TextClassifierTask( data_reader=reader, - feature=pooled_output, + token_feature=token_feature, feed_list=feed_list, network=args.network, num_classes=dataset.num_labels, diff --git a/paddlehub/finetune/task/classifier_task.py b/paddlehub/finetune/task/classifier_task.py index 4b965aaf50731caedd0696980bc545e87be1eeca..bdbad9aea1caef07b00cc1abc674a63f88c31cf5 100644 --- a/paddlehub/finetune/task/classifier_task.py +++ b/paddlehub/finetune/task/classifier_task.py @@ -17,15 +17,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import time from collections import OrderedDict import numpy as np import paddle import paddle.fluid as fluid +import time +from paddlehub.common.logger import logger from paddlehub.finetune.evaluate import calculate_f1_np, matthews_corrcoef -from paddlehub.common.utils import version_compare import paddlehub.network as net +from paddlehub.reader.nlp_reader import ClassifyReader + from .base_task import BaseTask @@ -159,60 +161,113 @@ ImageClassifierTask = ClassifierTask class TextClassifierTask(ClassifierTask): + """ + Create a text classification task. + It will use full-connect layer with softmax activation function to classify texts. + """ + def __init__(self, - feature, num_classes, feed_list, data_reader, + token_feature=None, + feature=None, network=None, startup_program=None, config=None, hidden_units=None, metrics_choices="default"): + """ + Args: + num_classes: total labels of the text classification task. + feed_list(list): the variable name that will be feeded to the main program + data_reader(object): data reader for the task. It must be one of ClassifyReader and LACClassifyReader. + token_feature(Variable): the feature will be used to connect the preset net. It must be the token-level feature, shape as [-1, seq_len, emb_size]. Default None. + feature(Variable): the feature will be used to classify texts. It must be the sentence-level feature, shape as [-1, emb_size]. Token_feature and feature couldn't be setted as the same time. One of them must be setted as not None. Default None. + network(str): the preset network. Choices: 'bilstm', 'bow', 'cnn', 'dpcnn', 'gru' and 'lstm'. Default None. If network is setted, then token_feature must be seted and feature must be None. + main_program (object): the customized main_program, default None. + startup_program (object): the customized startup_program, default None. + config (RunConfig): run config for the task, such as batch_size, epoch, learning_rate setting and so on. Default None. + hidden_units(list): the element of hidden_units list is the full-connect layer size. It will add the full-connect layers to the program. Default None. + metrics_choices(list): metrics used to the task, default ["acc"]. + """ + if (not feature) and (not token_feature): + logger.error( + 'Both token_feature and feature are None, one of them must be setted.' + ) + exit(1) + elif feature and token_feature: + logger.error( + 'Both token_feature and feature are setted. One should be setted, the other should be None.' + ) + exit(1) + + if network: + assert network in [ + 'bilstm', 'bow', 'cnn', 'dpcnn', 'gru', 'lstm' + ], 'network choice must be one of bilstm, bow, cnn, dpcnn, gru, lstm!' + assert token_feature and ( + not feature + ), 'If you wanna use network, you must set token_feature ranther than feature for TextClassifierTask!' + assert len( + token_feature.shape + ) == 3, 'When you use network, the parameter token_feature must be the token-level feature, such as the sequence_output of ERNIE, BERT, RoBERTa and ELECTRA module.' + else: + assert feature and ( + not token_feature + ), 'If you do not use network, you must set feature ranther than token_feature for TextClassifierTask!' + assert len( + feature.shape + ) == 2, 'When you do not use network, the parameter feture must be the sentence-level feature, such as the pooled_output of ERNIE, BERT, RoBERTa and ELECTRA module.' + + self.network = network if metrics_choices == "default": metrics_choices = ["acc"] - self.network = network + super(TextClassifierTask, self).__init__( data_reader=data_reader, - feature=feature, + feature=feature if feature else token_feature, num_classes=num_classes, feed_list=feed_list, startup_program=startup_program, config=config, hidden_units=hidden_units, metrics_choices=metrics_choices) - if self.network: - assert self.network in [ - 'bilstm', 'bow', 'cnn', 'dpcnn', 'gru', 'lstm' - ], 'network choice must be one of bilstm, bow, cnn, dpcnn, gru, lstm!' - assert len( - self.feature.shape - ) == 3, 'The sequnece_output must be choosed rather than pooled_output of Transformer Model (ERNIE, BERT, RoBERTa and ELECTRA)!' def _build_net(self): - self.seq_len = fluid.layers.data( - name="seq_len", shape=[1], dtype='int64', lod_level=0) - self.seq_len_used = fluid.layers.squeeze(self.seq_len, axes=[1]) - unpad_feature = fluid.layers.sequence_unpad( - self.feature, length=self.seq_len_used) + if isinstance(self._base_data_reader, ClassifyReader): + # ClassifyReader will return the seqence length of an input text + self.seq_len = fluid.layers.data( + name="seq_len", shape=[1], dtype='int64', lod_level=0) + self.seq_len_used = fluid.layers.squeeze(self.seq_len, axes=[1]) + + # unpad the token_feature + unpad_feature = fluid.layers.sequence_unpad( + self.feature, length=self.seq_len_used) if self.network: + # add preset net net_func = getattr(net.classification, self.network) if self.network == 'dpcnn': - cls_feats = net_func(self.feature) + # deepcnn network is no need to unpad + cls_feats = net_func( + self.feature, emb_dim=self.feature.shape[-1]) else: cls_feats = net_func(unpad_feature) + logger.info( + "%s has been added in the TextClassifierTask!" % self.network) else: + # not use preset net but to use fc net cls_feats = fluid.layers.dropout( x=self.feature, dropout_prob=0.1, dropout_implementation="upscale_in_train") - if self.hidden_units is not None: - for n_hidden in self.hidden_units: - cls_feats = fluid.layers.fc( - input=cls_feats, size=n_hidden, act="relu") + if self.hidden_units is not None: + for n_hidden in self.hidden_units: + cls_feats = fluid.layers.fc( + input=cls_feats, size=n_hidden, act="relu") logits = fluid.layers.fc( input=cls_feats, @@ -231,7 +286,10 @@ class TextClassifierTask(ClassifierTask): @property def feed_list(self): - feed_list = self._base_feed_list + [self.seq_len.name] + feed_list = [varname for varname in self._base_feed_list] + if isinstance(self._base_data_reader, ClassifyReader): + # ClassifyReader will return the seqence length of an input text + feed_list += [self.seq_len.name] if self.is_train_phase or self.is_test_phase: feed_list += [self.labels[0].name] return feed_list @@ -239,11 +297,19 @@ class TextClassifierTask(ClassifierTask): @property def fetch_list(self): if self.is_train_phase or self.is_test_phase: - return [ + fetch_list = [ self.labels[0].name, self.ret_infers.name, self.metrics[0].name, - self.loss.name, self.seq_len.name + self.loss.name ] - return [self.outputs[0].name, self.seq_len.name] + else: + # predict phase + fetch_list = [self.outputs[0].name] + + if isinstance(self._base_data_reader, ClassifyReader): + # to avoid save_inference_model to prune seq_len variable + fetch_list += [self.seq_len.name] + + return fetch_list class MultiLabelClassifierTask(ClassifierTask): diff --git a/paddlehub/finetune/task/sequence_task.py b/paddlehub/finetune/task/sequence_task.py index 372eb3b218b23a578ea80b14c9da856829000598..ac46c990a2cf990f5c9cf7e9bbde3cfc9ae1f270 100644 --- a/paddlehub/finetune/task/sequence_task.py +++ b/paddlehub/finetune/task/sequence_task.py @@ -66,11 +66,7 @@ class SequenceLabelTask(BaseTask): def _build_net(self): self.seq_len = fluid.layers.data( name="seq_len", shape=[1], dtype='int64', lod_level=0) - - if version_compare(paddle.__version__, "1.6"): - self.seq_len_used = fluid.layers.squeeze(self.seq_len, axes=[1]) - else: - self.seq_len_used = self.seq_len + self.seq_len_used = fluid.layers.squeeze(self.seq_len, axes=[1]) if self.add_crf: unpad_feature = fluid.layers.sequence_unpad( diff --git a/paddlehub/network/classification.py b/paddlehub/network/classification.py index 24ebaced95ef80595ffe2da3773b30d72df22bd0..132ae008759bcb78d9187d1c19bab4e6c5e79d32 100644 --- a/paddlehub/network/classification.py +++ b/paddlehub/network/classification.py @@ -21,43 +21,6 @@ import paddle import paddle.fluid as fluid -def bilstm_net(token_embeddings, hid_dim=128, hid_dim2=96): - """ - bilstm net - """ - main_program = token_embeddings.block.program - start_program = fluid.Program() - with fluid.program_guard(main_program, start_program): - with fluid.unique_name.guard('preste_'): - seq_len = fluid.layers.data( - name="seq_len", shape=[1], dtype='int64', lod_level=0) - - if version_compare(paddle.__version__, "1.6"): - seq_len_used = fluid.layers.squeeze(seq_len, axes=[1]) - else: - seq_len_used = seq_len - - unpad_feature = fluid.layers.sequence_unpad( - token_embeddings, length=seq_len_used) - - fc0 = fluid.layers.fc(input=unpad_feature, size=hid_dim * 4) - rfc0 = fluid.layers.fc(input=unpad_feature, size=hid_dim * 4) - lstm_h, c = fluid.layers.dynamic_lstm( - input=fc0, size=hid_dim * 4, is_reverse=False) - rlstm_h, c = fluid.layers.dynamic_lstm( - input=rfc0, size=hid_dim * 4, is_reverse=True) - lstm_last = fluid.layers.sequence_last_step(input=lstm_h) - rlstm_last = fluid.layers.sequence_last_step(input=rlstm_h) - lstm_last_tanh = fluid.layers.tanh(lstm_last) - rlstm_last_tanh = fluid.layers.tanh(rlstm_last) - - # concat layer - lstm_concat = fluid.layers.concat(input=[lstm_last, rlstm_last], axis=1) - # full connect layer - fc = fluid.layers.fc(input=lstm_concat, size=hid_dim2, act='tanh') - return seq_len, fc, start_program - - def bilstm(token_embeddings, hid_dim=128, hid_dim2=96): """ bilstm net