# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning on classification task """
from__future__importabsolute_import
from__future__importdivision
from__future__importprint_function
importargparse
importast
importnumpyasnp
importos
importtime
importpaddle
importpaddle.fluidasfluid
importpaddlehubashub
# yapf: disable
parser=argparse.ArgumentParser(__doc__)
parser.add_argument("--checkpoint_dir",type=str,default=None,help="Directory to model checkpoint")
parser.add_argument("--batch_size",type=int,default=1,help="Total examples' number in batch for training.")
parser.add_argument("--max_seq_len",type=int,default=512,help="Number of words of the longest seqence.")
parser.add_argument("--use_gpu",type=ast.literal_eval,default=False,help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--use_data_parallel",type=ast.literal_eval,default=False,help="Whether use data parallel.")
parser.add_argument("--network",type=str,default='bilstm',help="Pre-defined network which was connected after Transformer model, such as ERNIE, BERT ,RoBERTa and ELECTRA.")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning on classification task """
importargparse
importast
importpaddlehubashub
# yapf: disable
parser=argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch",type=int,default=3,help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu",type=ast.literal_eval,default=True,help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--learning_rate",type=float,default=5e-5,help="Learning rate used to train with warmup.")
parser.add_argument("--weight_decay",type=float,default=0.01,help="Weight decay rate for L2 regularizer.")
parser.add_argument("--warmup_proportion",type=float,default=0.1,help="Warmup proportion params for warmup strategy")
parser.add_argument("--checkpoint_dir",type=str,default=None,help="Directory to model checkpoint")
parser.add_argument("--max_seq_len",type=int,default=512,help="Number of words of the longest seqence.")
parser.add_argument("--batch_size",type=int,default=32,help="Total examples' number in batch for training.")
parser.add_argument("--network",type=str,default='bilstm',help="Pre-defined network which was connected after Transformer model, such as ERNIE, BERT ,RoBERTa and ELECTRA.")
parser.add_argument("--use_data_parallel",type=ast.literal_eval,default=False,help="Whether use data parallel.")
It will use full-connect layer with softmax activation function to classify texts.
"""
def__init__(self,
feature,
num_classes,
feed_list,
data_reader,
feature=None,
token_feature=None,
network=None,
startup_program=None,
config=None,
hidden_units=None,
metrics_choices="default"):
"""
Args:
num_classes: total labels of the text classification task.
feed_list(list): the variable name that will be feeded to the main program
data_reader(object): data reader for the task. It must be one of ClassifyReader and LACClassifyReader.
feature(Variable): the `feature` will be used to classify texts. It must be the sentence-level feature, shape as [-1, emb_size]. `Token_feature` and `feature` couldn't be setted at the same time. One of them must be setted as not None. Default None.
token_feature(Variable): the `feature` will be used to connect the pre-defined network. It must be the token-level feature, shape as [-1, seq_len, emb_size]. Default None.
network(str): the pre-defined network. Choices: 'bilstm', 'bow', 'cnn', 'dpcnn', 'gru' and 'lstm'. Default None. If network is setted, then `token_feature` must be setted and `feature` must be None.
main_program (object): the customized main program, default None.
startup_program (object): the customized startup program, default None.
config (RunConfig): run config for the task, such as batch_size, epoch, learning_rate setting and so on. Default None.
hidden_units(list): the element of `hidden_units` list is the full-connect layer size. It will add the full-connect layers to the program. Default None.
metrics_choices(list): metrics used to the task, default ["acc"].
"""
if(notfeature)and(nottoken_feature):
logger.error(
'Both token_feature and feature are None, one of them must be setted.'
)
exit(1)
eliffeatureandtoken_feature:
logger.error(
'Both token_feature and feature are setted. One should be setted, the other should be None.'
)
exit(1)
ifnetwork:
assertnetworkin[
'bilstm','bow','cnn','dpcnn','gru','lstm'
],'network choice must be one of bilstm, bow, cnn, dpcnn, gru, lstm!'
asserttoken_featureand(
notfeature
),'If you wanna use network, you must set token_feature ranther than feature for TextClassifierTask!'
assertlen(
token_feature.shape
)==3,'When you use network, the parameter token_feature must be the token-level feature, such as the sequence_output of ERNIE, BERT, RoBERTa and ELECTRA module.'
else:
assertfeatureand(
nottoken_feature
),'If you do not use network, you must set feature ranther than token_feature for TextClassifierTask!'
assertlen(
feature.shape
)==2,'When you do not use network, the parameter feture must be the sentence-level feature, such as the pooled_output of ERNIE, BERT, RoBERTa and ELECTRA module.'
self.network=network
ifmetrics_choices=="default":
metrics_choices=["acc"]
super(TextClassifierTask,self).__init__(
data_reader=data_reader,
feature=feature,
feature=featureiffeatureelsetoken_feature,
num_classes=num_classes,
feed_list=feed_list,
startup_program=startup_program,
...
...
@@ -179,6 +236,29 @@ class TextClassifierTask(ClassifierTask):