It will use full-connect layer with softmax activation function to classify texts.
"""
def__init__(self,
feature,
num_classes,
feed_list,
data_reader,
token_feature=None,
feature=None,
network=None,
startup_program=None,
config=None,
hidden_units=None,
metrics_choices="default"):
"""
Args:
num_classes: total labels of the text classification task.
feed_list(list): the variable name that will be feeded to the main program
data_reader(object): data reader for the task. It must be one of ClassifyReader and LACClassifyReader.
token_feature(Variable): the feature will be used to connect the preset net. It must be the token-level feature, shape as [-1, seq_len, emb_size]. Default None.
feature(Variable): the feature will be used to classify texts. It must be the sentence-level feature, shape as [-1, emb_size]. Token_feature and feature couldn't be setted as the same time. One of them must be setted as not None. Default None.
network(str): the preset network. Choices: 'bilstm', 'bow', 'cnn', 'dpcnn', 'gru' and 'lstm'. Default None. If network is setted, then token_feature must be seted and feature must be None.
main_program (object): the customized main_program, default None.
startup_program (object): the customized startup_program, default None.
config (RunConfig): run config for the task, such as batch_size, epoch, learning_rate setting and so on. Default None.
hidden_units(list): the element of hidden_units list is the full-connect layer size. It will add the full-connect layers to the program. Default None.
metrics_choices(list): metrics used to the task, default ["acc"].
"""
if(notfeature)and(nottoken_feature):
logger.error(
'Both token_feature and feature are None, one of them must be setted.'
)
exit(1)
eliffeatureandtoken_feature:
logger.error(
'Both token_feature and feature are setted. One should be setted, the other should be None.'
)
exit(1)
ifnetwork:
assertnetworkin[
'bilstm','bow','cnn','dpcnn','gru','lstm'
],'network choice must be one of bilstm, bow, cnn, dpcnn, gru, lstm!'
asserttoken_featureand(
notfeature
),'If you wanna use network, you must set token_feature ranther than feature for TextClassifierTask!'
assertlen(
token_feature.shape
)==3,'When you use network, the parameter token_feature must be the token-level feature, such as the sequence_output of ERNIE, BERT, RoBERTa and ELECTRA module.'
else:
assertfeatureand(
nottoken_feature
),'If you do not use network, you must set feature ranther than token_feature for TextClassifierTask!'
assertlen(
feature.shape
)==2,'When you do not use network, the parameter feture must be the sentence-level feature, such as the pooled_output of ERNIE, BERT, RoBERTa and ELECTRA module.'
self.network=network
ifmetrics_choices=="default":
metrics_choices=["acc"]
self.network=network
super(TextClassifierTask,self).__init__(
data_reader=data_reader,
feature=feature,
feature=featureiffeatureelsetoken_feature,
num_classes=num_classes,
feed_list=feed_list,
startup_program=startup_program,
config=config,
hidden_units=hidden_units,
metrics_choices=metrics_choices)
ifself.network:
assertself.networkin[
'bilstm','bow','cnn','dpcnn','gru','lstm'
],'network choice must be one of bilstm, bow, cnn, dpcnn, gru, lstm!'
assertlen(
self.feature.shape
)==3,'The sequnece_output must be choosed rather than pooled_output of Transformer Model (ERNIE, BERT, RoBERTa and ELECTRA)!'