未验证 提交 abc05deb 编写于 作者: Z Zeyu Chen 提交者: GitHub

Add predefined networks for text classification

Add predefined networks for text classification
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning on classification task """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ast
import numpy as np
import os
import time
import paddle
import paddle.fluid as fluid
import paddlehub as hub
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--batch_size", type=int, default=1, help="Total examples' number in batch for training.")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.")
parser.add_argument("--network", type=str, default='bilstm', help="Pre-defined network which was connected after Transformer model, such as ERNIE, BERT ,RoBERTa and ELECTRA.")
args = parser.parse_args()
# yapf: enable.
if __name__ == '__main__':
# Load Paddlehub ERNIE Tiny pretrained model
module = hub.Module(name="ernie_tiny")
inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len)
# Download dataset and use accuracy as metrics
# Choose dataset: GLUE/XNLI/ChinesesGLUE/NLPCC-DBQA/LCQMC
dataset = hub.dataset.ChnSentiCorp()
# For ernie_tiny, it use sub-word to tokenize chinese sentence
# If not ernie tiny, sp_model_path and word_dict_path should be set None
reader = hub.reader.ClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len,
sp_model_path=module.get_spm_path(),
word_dict_path=module.get_word_dict_path())
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_output" for token-level output.
token_feature = outputs["sequence_output"]
# Setup feed list for data feeder
# Must feed all the tensor of module need
feed_list = [
inputs["input_ids"].name,
inputs["position_ids"].name,
inputs["segment_ids"].name,
inputs["input_mask"].name,
]
# Setup runing config for PaddleHub Finetune API
config = hub.RunConfig(
use_data_parallel=args.use_data_parallel,
use_cuda=args.use_gpu,
batch_size=args.batch_size,
checkpoint_dir=args.checkpoint_dir,
strategy=hub.AdamWeightDecayStrategy())
# Define a classfication finetune task by PaddleHub's API
# network choice: bilstm, bow, cnn, dpcnn, gru, lstm (PaddleHub pre-defined network)
# If you wanna add network after ERNIE/BERT/RoBERTa/ELECTRA module,
# you must use the outputs["sequence_output"] as the token_feature of TextClassifierTask,
# rather than outputs["pooled_output"], and feature is None
cls_task = hub.TextClassifierTask(
data_reader=reader,
token_feature=token_feature,
feed_list=feed_list,
network=args.network,
num_classes=dataset.num_labels,
config=config)
# Data to be prdicted
data = [["这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般"], ["交通方便;环境很好;服务态度很好 房间较小"],
["19天硬盘就罢工了~~~算上运来的一周都没用上15天~~~可就是不能换了~~~唉~~~~你说这算什么事呀~~~"]]
print(cls_task.predict(data=data, return_result=True))
export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_chnsenticorp"
python -u text_cls.py \
--batch_size=24 \
--use_gpu=True \
--checkpoint_dir=${CKPT_DIR} \
--learning_rate=5e-5 \
--weight_decay=0.01 \
--max_seq_len=128 \
--warmup_proportion=0.1 \
--num_epoch=3 \
--use_data_parallel=True
# The sugguested hyper parameters for difference task
# for ChineseGLUE:
# TNews: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
# LCQMC: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
# XNLI_zh: batch_size=32, weight_decay=0, num_epoch=2, max_seq_len=128, lr=5e-5
# INEWS: batch_size=4, weight_decay=0, num_epoch=3, max_seq_len=512, lr=5e-5
# DRCD: see demo: reading-comprehension
# CMRC2018: see demo: reading-comprehension
# BQ: batch_size=32, weight_decay=0, num_epoch=2, max_seq_len=100, lr=1e-5
# MSRANER: see demo: sequence-labeling
# THUCNEWS: batch_size=8, weight_decay=0, num_epoch=2, max_seq_len=512, lr=5e-5
# IFLYTEKDATA: batch_size=16, weight_decay=0, num_epoch=5, max_seq_len=256, lr=1e-5
# for other tasks:
# ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5
# NLPCC_DBQA: batch_size=8, weight_decay=0.01, num_epoch=3, max_seq_len=512, lr=2e-5
# LCQMC: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=2e-5
# QQP: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
# QNLI: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
# SST-2: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
# CoLA: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
# MRPC: batch_size=32, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5
# RTE: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=3e-5
# MNLI: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
# Specify the matched/mismatched dev and test dataset with an underscore.
# mnli_m or mnli: dev and test in matched dataset.
# mnli_mm: dev and test in mismatched dataset.
# The difference can be seen in https://www.nyu.edu/projects/bowman/multinli/paper.pdf.
# If you are not sure which one to pick, just use mnli or mnli_m.
# XNLI: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
# Specify the language with an underscore like xnli_zh.
# ar- Arabic bg- Bulgarian de- German
# el- Greek en- English es- Spanish
# fr- French hi- Hindi ru- Russian
# sw- Swahili th- Thai tr- Turkish
# ur- Urdu vi- Vietnamese zh- Chinese (Simplified)
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_chnsenticorp" CKPT_DIR="./ckpt_chnsenticorp_predefine_net"
python -u text_classifier.py \ python -u text_cls_predefine_net.py \
--batch_size=24 \ --batch_size=24 \
--use_gpu=True \ --use_gpu=True \
--checkpoint_dir=${CKPT_DIR} \ --checkpoint_dir=${CKPT_DIR} \
...@@ -12,7 +12,8 @@ python -u text_classifier.py \ ...@@ -12,7 +12,8 @@ python -u text_classifier.py \
--max_seq_len=128 \ --max_seq_len=128 \
--warmup_proportion=0.1 \ --warmup_proportion=0.1 \
--num_epoch=3 \ --num_epoch=3 \
--use_data_parallel=True --use_data_parallel=True \
--network=bilstm
# The sugguested hyper parameters for difference task # The sugguested hyper parameters for difference task
# for ChineseGLUE: # for ChineseGLUE:
......
...@@ -3,7 +3,8 @@ export CUDA_VISIBLE_DEVICES=0 ...@@ -3,7 +3,8 @@ export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_chnsenticorp" CKPT_DIR="./ckpt_chnsenticorp"
python -u predict.py --checkpoint_dir=$CKPT_DIR \ python -u predict.py \
--checkpoint_dir=$CKPT_DIR \
--max_seq_len=128 \ --max_seq_len=128 \
--use_gpu=True \ --use_gpu=True \
--batch_size=24 \ --batch_size=24
export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_chnsenticorp_predefine_net"
python -u predict_predefine_net.py \
--checkpoint_dir=$CKPT_DIR \
--max_seq_len=128 \
--use_gpu=True \
--batch_size=24 \
--network=bilstm
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning on classification task """
import argparse
import ast
import paddlehub as hub
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.")
parser.add_argument("--warmup_proportion", type=float, default=0.1, help="Warmup proportion params for warmup strategy")
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
parser.add_argument("--network", type=str, default='bilstm', help="Pre-defined network which was connected after Transformer model, such as ERNIE, BERT ,RoBERTa and ELECTRA.")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.")
args = parser.parse_args()
# yapf: enable.
if __name__ == '__main__':
# Load Paddlehub ERNIE Tiny pretrained model
module = hub.Module(name="ernie_tiny")
inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len)
# Download dataset and use accuracy as metrics
# Choose dataset: GLUE/XNLI/ChinesesGLUE/NLPCC-DBQA/LCQMC
# metric should be acc, f1 or matthews
dataset = hub.dataset.ChnSentiCorp()
metrics_choices = ["acc"]
# For ernie_tiny, it use sub-word to tokenize chinese sentence
# If not ernie tiny, sp_model_path and word_dict_path should be set None
reader = hub.reader.ClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len,
sp_model_path=module.get_spm_path(),
word_dict_path=module.get_word_dict_path())
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_output" for token-level output.
token_feature = outputs["sequence_output"]
# Setup feed list for data feeder
# Must feed all the tensor of module need
feed_list = [
inputs["input_ids"].name,
inputs["position_ids"].name,
inputs["segment_ids"].name,
inputs["input_mask"].name,
]
# Select finetune strategy, setup config and finetune
strategy = hub.AdamWeightDecayStrategy(
warmup_proportion=args.warmup_proportion,
weight_decay=args.weight_decay,
learning_rate=args.learning_rate)
# Setup runing config for PaddleHub Finetune API
config = hub.RunConfig(
use_data_parallel=args.use_data_parallel,
use_cuda=args.use_gpu,
num_epoch=args.num_epoch,
batch_size=args.batch_size,
checkpoint_dir=args.checkpoint_dir,
strategy=strategy)
# Define a classfication finetune task by PaddleHub's API
# network choice: bilstm, bow, cnn, dpcnn, gru, lstm (PaddleHub pre-defined network)
# If you wanna add network after ERNIE/BERT/RoBERTa/ELECTRA module,
# you must use the outputs["sequence_output"] as the token_feature of TextClassifierTask,
# rather than outputs["pooled_output"], and feature is None
cls_task = hub.TextClassifierTask(
data_reader=reader,
token_feature=token_feature,
feed_list=feed_list,
network=args.network,
num_classes=dataset.num_labels,
config=config,
metrics_choices=metrics_choices)
# Finetune and evaluate by PaddleHub's API
# will finish training, evaluation, testing, save model automatically
cls_task.finetune_and_eval()
...@@ -28,6 +28,7 @@ from . import io ...@@ -28,6 +28,7 @@ from . import io
from . import dataset from . import dataset
from . import finetune from . import finetune
from . import reader from . import reader
from . import network
from .common.dir import USER_HOME from .common.dir import USER_HOME
from .common.dir import HUB_HOME from .common.dir import HUB_HOME
......
#coding:utf-8 # coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
......
...@@ -998,11 +998,6 @@ class BaseTask(object): ...@@ -998,11 +998,6 @@ class BaseTask(object):
Returns: Returns:
RunState: the running result of predict phase RunState: the running result of predict phase
""" """
if not version_compare(paddle.__version__, "1.6.2") and accelerate_mode:
logger.warning(
"Fail to open predict accelerate mode as it does not support paddle < 1.6.2. Please update PaddlePaddle."
)
accelerate_mode = False
self.accelerate_mode = accelerate_mode self.accelerate_mode = accelerate_mode
with self.phase_guard(phase="predict"): with self.phase_guard(phase="predict"):
......
...@@ -17,12 +17,17 @@ from __future__ import absolute_import ...@@ -17,12 +17,17 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import time
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import time
from paddlehub.common.logger import logger
from paddlehub.finetune.evaluate import calculate_f1_np, matthews_corrcoef from paddlehub.finetune.evaluate import calculate_f1_np, matthews_corrcoef
from paddlehub.reader.nlp_reader import ClassifyReader
import paddlehub.network as net
from .base_task import BaseTask from .base_task import BaseTask
...@@ -104,7 +109,7 @@ class ClassifierTask(BaseTask): ...@@ -104,7 +109,7 @@ class ClassifierTask(BaseTask):
run_examples += run_state.run_examples run_examples += run_state.run_examples
run_step += run_state.run_step run_step += run_state.run_step
loss_sum += np.mean( loss_sum += np.mean(
run_state.run_results[-1]) * run_state.run_examples run_state.run_results[-2]) * run_state.run_examples
acc_sum += np.mean( acc_sum += np.mean(
run_state.run_results[2]) * run_state.run_examples run_state.run_results[2]) * run_state.run_examples
np_labels = run_state.run_results[0] np_labels = run_state.run_results[0]
...@@ -147,7 +152,7 @@ class ClassifierTask(BaseTask): ...@@ -147,7 +152,7 @@ class ClassifierTask(BaseTask):
results = [] results = []
for batch_state in run_states: for batch_state in run_states:
batch_result = batch_state.run_results batch_result = batch_state.run_results
batch_infer = np.argmax(batch_result, axis=2)[0] batch_infer = np.argmax(batch_result[0], axis=1)
results += [id2label[sample_infer] for sample_infer in batch_infer] results += [id2label[sample_infer] for sample_infer in batch_infer]
return results return results
...@@ -156,21 +161,73 @@ ImageClassifierTask = ClassifierTask ...@@ -156,21 +161,73 @@ ImageClassifierTask = ClassifierTask
class TextClassifierTask(ClassifierTask): class TextClassifierTask(ClassifierTask):
"""
Create a text classification task.
It will use full-connect layer with softmax activation function to classify texts.
"""
def __init__(self, def __init__(self,
feature,
num_classes, num_classes,
feed_list, feed_list,
data_reader, data_reader,
feature=None,
token_feature=None,
network=None,
startup_program=None, startup_program=None,
config=None, config=None,
hidden_units=None, hidden_units=None,
metrics_choices="default"): metrics_choices="default"):
"""
Args:
num_classes: total labels of the text classification task.
feed_list(list): the variable name that will be feeded to the main program
data_reader(object): data reader for the task. It must be one of ClassifyReader and LACClassifyReader.
feature(Variable): the `feature` will be used to classify texts. It must be the sentence-level feature, shape as [-1, emb_size]. `Token_feature` and `feature` couldn't be setted at the same time. One of them must be setted as not None. Default None.
token_feature(Variable): the `feature` will be used to connect the pre-defined network. It must be the token-level feature, shape as [-1, seq_len, emb_size]. Default None.
network(str): the pre-defined network. Choices: 'bilstm', 'bow', 'cnn', 'dpcnn', 'gru' and 'lstm'. Default None. If network is setted, then `token_feature` must be setted and `feature` must be None.
main_program (object): the customized main program, default None.
startup_program (object): the customized startup program, default None.
config (RunConfig): run config for the task, such as batch_size, epoch, learning_rate setting and so on. Default None.
hidden_units(list): the element of `hidden_units` list is the full-connect layer size. It will add the full-connect layers to the program. Default None.
metrics_choices(list): metrics used to the task, default ["acc"].
"""
if (not feature) and (not token_feature):
logger.error(
'Both token_feature and feature are None, one of them must be setted.'
)
exit(1)
elif feature and token_feature:
logger.error(
'Both token_feature and feature are setted. One should be setted, the other should be None.'
)
exit(1)
if network:
assert network in [
'bilstm', 'bow', 'cnn', 'dpcnn', 'gru', 'lstm'
], 'network choice must be one of bilstm, bow, cnn, dpcnn, gru, lstm!'
assert token_feature and (
not feature
), 'If you wanna use network, you must set token_feature ranther than feature for TextClassifierTask!'
assert len(
token_feature.shape
) == 3, 'When you use network, the parameter token_feature must be the token-level feature, such as the sequence_output of ERNIE, BERT, RoBERTa and ELECTRA module.'
else:
assert feature and (
not token_feature
), 'If you do not use network, you must set feature ranther than token_feature for TextClassifierTask!'
assert len(
feature.shape
) == 2, 'When you do not use network, the parameter feture must be the sentence-level feature, such as the pooled_output of ERNIE, BERT, RoBERTa and ELECTRA module.'
self.network = network
if metrics_choices == "default": if metrics_choices == "default":
metrics_choices = ["acc"] metrics_choices = ["acc"]
super(TextClassifierTask, self).__init__( super(TextClassifierTask, self).__init__(
data_reader=data_reader, data_reader=data_reader,
feature=feature, feature=feature if feature else token_feature,
num_classes=num_classes, num_classes=num_classes,
feed_list=feed_list, feed_list=feed_list,
startup_program=startup_program, startup_program=startup_program,
...@@ -179,6 +236,29 @@ class TextClassifierTask(ClassifierTask): ...@@ -179,6 +236,29 @@ class TextClassifierTask(ClassifierTask):
metrics_choices=metrics_choices) metrics_choices=metrics_choices)
def _build_net(self): def _build_net(self):
if isinstance(self._base_data_reader, ClassifyReader):
# ClassifyReader will return the seqence length of an input text
self.seq_len = fluid.layers.data(
name="seq_len", shape=[1], dtype='int64', lod_level=0)
self.seq_len_used = fluid.layers.squeeze(self.seq_len, axes=[1])
# unpad the token_feature
unpad_feature = fluid.layers.sequence_unpad(
self.feature, length=self.seq_len_used)
if self.network:
# add pre-defined net
net_func = getattr(net.classification, self.network)
if self.network == 'dpcnn':
# deepcnn network is no need to unpad
cls_feats = net_func(
self.feature, emb_dim=self.feature.shape[-1])
else:
cls_feats = net_func(unpad_feature)
logger.info(
"%s has been added in the TextClassifierTask!" % self.network)
else:
# not use pre-defined net but to use fc net
cls_feats = fluid.layers.dropout( cls_feats = fluid.layers.dropout(
x=self.feature, x=self.feature,
dropout_prob=0.1, dropout_prob=0.1,
...@@ -204,6 +284,33 @@ class TextClassifierTask(ClassifierTask): ...@@ -204,6 +284,33 @@ class TextClassifierTask(ClassifierTask):
return [logits] return [logits]
@property
def feed_list(self):
feed_list = [varname for varname in self._base_feed_list]
if isinstance(self._base_data_reader, ClassifyReader):
# ClassifyReader will return the seqence length of an input text
feed_list += [self.seq_len.name]
if self.is_train_phase or self.is_test_phase:
feed_list += [self.labels[0].name]
return feed_list
@property
def fetch_list(self):
if self.is_train_phase or self.is_test_phase:
fetch_list = [
self.labels[0].name, self.ret_infers.name, self.metrics[0].name,
self.loss.name
]
else:
# predict phase
fetch_list = [self.outputs[0].name]
if isinstance(self._base_data_reader, ClassifyReader):
# to avoid save_inference_model to prune seq_len variable
fetch_list += [self.seq_len.name]
return fetch_list
class MultiLabelClassifierTask(ClassifierTask): class MultiLabelClassifierTask(ClassifierTask):
def __init__(self, def __init__(self,
......
...@@ -66,11 +66,7 @@ class SequenceLabelTask(BaseTask): ...@@ -66,11 +66,7 @@ class SequenceLabelTask(BaseTask):
def _build_net(self): def _build_net(self):
self.seq_len = fluid.layers.data( self.seq_len = fluid.layers.data(
name="seq_len", shape=[1], dtype='int64', lod_level=0) name="seq_len", shape=[1], dtype='int64', lod_level=0)
if version_compare(paddle.__version__, "1.6"):
self.seq_len_used = fluid.layers.squeeze(self.seq_len, axes=[1]) self.seq_len_used = fluid.layers.squeeze(self.seq_len, axes=[1])
else:
self.seq_len_used = self.seq_len
if self.add_crf: if self.add_crf:
unpad_feature = fluid.layers.sequence_unpad( unpad_feature = fluid.layers.sequence_unpad(
......
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import classification
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provide nets for text classification
"""
import paddle
import paddle.fluid as fluid
def bilstm(token_embeddings, hid_dim=128, hid_dim2=96):
"""
bilstm net
"""
fc0 = fluid.layers.fc(input=token_embeddings, size=hid_dim * 4)
rfc0 = fluid.layers.fc(input=token_embeddings, size=hid_dim * 4)
lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0, size=hid_dim * 4, is_reverse=False)
rlstm_h, c = fluid.layers.dynamic_lstm(
input=rfc0, size=hid_dim * 4, is_reverse=True)
lstm_last = fluid.layers.sequence_last_step(input=lstm_h)
rlstm_last = fluid.layers.sequence_last_step(input=rlstm_h)
lstm_last_tanh = fluid.layers.tanh(lstm_last)
rlstm_last_tanh = fluid.layers.tanh(rlstm_last)
# concat layer
lstm_concat = fluid.layers.concat(input=[lstm_last, rlstm_last], axis=1)
# full connect layer
fc = fluid.layers.fc(input=lstm_concat, size=hid_dim2, act='tanh')
return fc
def bow(token_embeddings, hid_dim=128, hid_dim2=96):
"""
bow net
"""
# bow layer
bow = fluid.layers.sequence_pool(input=token_embeddings, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow)
# full connect layer
fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
return fc_2
def cnn(token_embeddings, hid_dim=128, win_size=3):
"""
cnn net
"""
# cnn layer
conv = fluid.nets.sequence_conv_pool(
input=token_embeddings,
num_filters=hid_dim,
filter_size=win_size,
act="tanh",
pool_type="max")
# full connect layer
fc_1 = fluid.layers.fc(input=conv, size=hid_dim)
return fc_1
def dpcnn(token_embeddings,
hid_dim=128,
channel_size=250,
emb_dim=1024,
blocks=6):
"""
deepcnn net
"""
def _block(x):
x = fluid.layers.relu(x)
x = fluid.layers.conv2d(x, channel_size, (3, 1), padding=(1, 0))
x = fluid.layers.relu(x)
x = fluid.layers.conv2d(x, channel_size, (3, 1), padding=(1, 0))
return x
emb = fluid.layers.unsqueeze(token_embeddings, axes=[1])
region_embedding = fluid.layers.conv2d(
emb, channel_size, (3, emb_dim), padding=(1, 0))
conv_features = _block(region_embedding)
conv_features = conv_features + region_embedding
# multi-cnn layer
for i in range(blocks):
block_features = fluid.layers.pool2d(
conv_features,
pool_size=(3, 1),
pool_stride=(2, 1),
pool_padding=(1, 0))
conv_features = _block(block_features)
conv_features = block_features + conv_features
features = fluid.layers.pool2d(conv_features, global_pooling=True)
features = fluid.layers.squeeze(features, axes=[2, 3])
# full connect layer
fc_1 = fluid.layers.fc(input=features, size=hid_dim, act="tanh")
return fc_1
def gru(token_embeddings, hid_dim=128, hid_dim2=96):
"""
gru net
"""
fc0 = fluid.layers.fc(input=token_embeddings, size=hid_dim * 3)
gru_h = fluid.layers.dynamic_gru(input=fc0, size=hid_dim, is_reverse=False)
gru_max = fluid.layers.sequence_pool(input=gru_h, pool_type='max')
gru_max_tanh = fluid.layers.tanh(gru_max)
fc1 = fluid.layers.fc(input=gru_max_tanh, size=hid_dim2, act='tanh')
return fc1
def lstm(token_embeddings, hid_dim=128, hid_dim2=96):
"""
lstm net
"""
# lstm layer
fc0 = fluid.layers.fc(input=token_embeddings, size=hid_dim * 4)
lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0, size=hid_dim * 4, is_reverse=False)
# max pooling layer
lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max')
lstm_max_tanh = fluid.layers.tanh(lstm_max)
# full connect layer
fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh')
return fc1
...@@ -65,7 +65,6 @@ class BaseNLPReader(BaseReader): ...@@ -65,7 +65,6 @@ class BaseNLPReader(BaseReader):
logger.warning( logger.warning(
"use_task_id has been de discarded since PaddleHub v1.4.0, it's no necessary to feed task_ids now." "use_task_id has been de discarded since PaddleHub v1.4.0, it's no necessary to feed task_ids now."
) )
self.task_id = 0
self.Record_With_Label_Id = namedtuple( self.Record_With_Label_Id = namedtuple(
'Record', 'Record',
...@@ -272,11 +271,12 @@ class ClassifyReader(BaseNLPReader): ...@@ -272,11 +271,12 @@ class ClassifyReader(BaseNLPReader):
batch_text_type_ids = [record.text_type_ids for record in batch_records] batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records] batch_position_ids = [record.position_ids for record in batch_records]
padded_token_ids, input_mask = pad_batch_data( padded_token_ids, input_mask, batch_seq_lens = pad_batch_data(
batch_token_ids, batch_token_ids,
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
pad_idx=self.pad_id, pad_idx=self.pad_id,
return_input_mask=True) return_input_mask=True,
return_seq_lens=True)
padded_text_type_ids = pad_batch_data( padded_text_type_ids = pad_batch_data(
batch_text_type_ids, batch_text_type_ids,
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
...@@ -286,36 +286,16 @@ class ClassifyReader(BaseNLPReader): ...@@ -286,36 +286,16 @@ class ClassifyReader(BaseNLPReader):
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
pad_idx=self.pad_id) pad_idx=self.pad_id)
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_seq_lens
]
if phase != "predict": if phase != "predict":
batch_labels = [record.label_id for record in batch_records] batch_labels = [record.label_id for record in batch_records]
batch_labels = np.array(batch_labels).astype("int64").reshape( batch_labels = np.array(batch_labels).astype("int64").reshape(
[-1, 1]) [-1, 1])
return_list += [batch_labels]
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_labels
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids, batch_labels
]
else:
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids
]
return return_list return return_list
...@@ -369,40 +349,20 @@ class SequenceLabelReader(BaseNLPReader): ...@@ -369,40 +349,20 @@ class SequenceLabelReader(BaseNLPReader):
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
pad_idx=self.pad_id) pad_idx=self.pad_id)
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask
]
if phase != "predict": if phase != "predict":
batch_label_ids = [record.label_id for record in batch_records] batch_label_ids = [record.label_id for record in batch_records]
padded_label_ids = pad_batch_data( padded_label_ids = pad_batch_data(
batch_label_ids, batch_label_ids,
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
pad_idx=len(self.label_map) - 1) pad_idx=len(self.label_map) - 1)
return_list += [padded_label_ids, batch_seq_lens]
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_label_ids, batch_seq_lens
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids, padded_label_ids,
batch_seq_lens
]
else: else:
return_list = [ return_list += [batch_seq_lens]
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_seq_lens
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids, batch_seq_lens
]
return return_list return return_list
...@@ -514,37 +474,18 @@ class MultiLabelClassifyReader(BaseNLPReader): ...@@ -514,37 +474,18 @@ class MultiLabelClassifyReader(BaseNLPReader):
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
pad_idx=self.pad_id) pad_idx=self.pad_id)
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask
]
if phase != "predict": if phase != "predict":
batch_labels_ids = [record.label_id for record in batch_records] batch_labels_ids = [record.label_id for record in batch_records]
num_label = len(self.dataset.get_labels()) num_label = len(self.dataset.get_labels())
batch_labels = np.array(batch_labels_ids).astype("int64").reshape( batch_labels = np.array(batch_labels_ids).astype("int64").reshape(
[-1, num_label]) [-1, num_label])
return_list = [ return_list += [batch_labels]
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_labels
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids, batch_labels
]
else:
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids
]
return return_list return return_list
def _convert_example_to_record(self, def _convert_example_to_record(self,
...@@ -634,37 +575,17 @@ class RegressionReader(BaseNLPReader): ...@@ -634,37 +575,17 @@ class RegressionReader(BaseNLPReader):
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
pad_idx=self.pad_id) pad_idx=self.pad_id)
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask
]
if phase != "predict": if phase != "predict":
batch_labels = [record.label_id for record in batch_records] batch_labels = [record.label_id for record in batch_records]
# the only diff with ClassifyReader: astype("float32") # the only diff with ClassifyReader: astype("float32")
batch_labels = np.array(batch_labels).astype("float32").reshape( batch_labels = np.array(batch_labels).astype("float32").reshape(
[-1, 1]) [-1, 1])
return_list = [ return_list += [batch_labels]
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_labels
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids, batch_labels
]
else:
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids
]
return return_list return return_list
...@@ -831,6 +752,10 @@ class ReadingComprehensionReader(BaseNLPReader): ...@@ -831,6 +752,10 @@ class ReadingComprehensionReader(BaseNLPReader):
pad_idx=self.pad_id, pad_idx=self.pad_id,
max_seq_len=self.max_seq_len) max_seq_len=self.max_seq_len)
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_unique_ids
]
if phase != "predict": if phase != "predict":
batch_start_position = [ batch_start_position = [
record.start_position for record in batch_records record.start_position for record in batch_records
...@@ -843,33 +768,8 @@ class ReadingComprehensionReader(BaseNLPReader): ...@@ -843,33 +768,8 @@ class ReadingComprehensionReader(BaseNLPReader):
batch_end_position = np.array(batch_end_position).astype( batch_end_position = np.array(batch_end_position).astype(
"int64").reshape([-1, 1]) "int64").reshape([-1, 1])
return_list = [ return_list += [batch_start_position, batch_end_position]
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_unique_ids, batch_start_position,
batch_end_position
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids, batch_unique_ids,
batch_start_position, batch_end_position
]
else:
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_unique_ids
]
if self.use_task_id:
padded_task_ids = np.ones_like(
padded_token_ids, dtype="int64") * self.task_id
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_task_ids, batch_unique_ids
]
return return_list return return_list
def _prepare_batch_data(self, records, batch_size, phase=None): def _prepare_batch_data(self, records, batch_size, phase=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册