提交 bc2c5e71 编写于 作者: S Steffy-zxf 提交者: wuzewu

Add multi label task (#51)

* Modify the coding format of senta_demo.py and lac_demo.py

* Add the toxic dataset and multi-label task

* Add the multi_label classification demo
上级 e6f4a801
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning on classification task """
import argparse
import ast
import paddle.fluid as fluid
import paddlehub as hub
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.")
parser.add_argument("--warmup_proportion", type=float, default=0.1, help="Warmup proportion params for warmup strategy")
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.")
parser.add_argument("--batch_size", type=int, default=1, help="Total examples' number in batch for training.")
args = parser.parse_args()
# yapf: enable.
if __name__ == '__main__':
# Step1: load Paddlehub ERNIE pretrained model
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len)
# Step2: Download dataset and use MultiLabelReader to read dataset
dataset = hub.dataset.Toxic()
reader = hub.reader.MultiLabelClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
pooled_output = outputs["pooled_output"]
# Setup feed list for data feeder
# Must feed all the tensor of ERNIE's module need
feed_list = [
inputs["input_ids"].name,
inputs["position_ids"].name,
inputs["segment_ids"].name,
inputs["input_mask"].name,
]
# Select finetune strategy, setup config and finetune
strategy = hub.AdamWeightDecayStrategy(
weight_decay=args.weight_decay,
learning_rate=args.learning_rate,
lr_scheduler="linear_decay")
# Setup runing config for PaddleHub Finetune API
config = hub.RunConfig(
use_cuda=args.use_gpu,
num_epoch=args.num_epoch,
batch_size=args.batch_size,
checkpoint_dir=args.checkpoint_dir,
strategy=strategy)
# Define a classfication finetune task by PaddleHub's API
multi_label_cls_task = hub.MultiLabelClassifierTask(
data_reader=reader,
feature=pooled_output,
feed_list=feed_list,
num_classes=dataset.num_labels,
config=config)
# Finetune and evaluate by PaddleHub's API
# will finish training, evaluation, testing, save model automatically
multi_label_cls_task.finetune_and_eval()
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning on classification task """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ast
import numpy as np
import os
import time
import paddle
import paddle.fluid as fluid
import paddlehub as hub
import pandas as pd
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--batch_size", type=int, default=1, help="Total examples' number in batch for training.")
parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for finetuning, input should be True or False")
args = parser.parse_args()
# yapf: enable.
if __name__ == '__main__':
# loading Paddlehub ERNIE pretrained model
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
inputs, outputs, program = module.context(max_seq_len=args.max_seq_len)
# Sentence classification dataset reader
dataset = hub.dataset.Toxic()
num_label = len(dataset.get_labels())
reader = hub.reader.MultiLabelClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_output" for token-level output.
pooled_output = outputs["pooled_output"]
# Setup feed list for data feeder
# Must feed all the tensor of ERNIE's module need
feed_list = [
inputs["input_ids"].name,
inputs["position_ids"].name,
inputs["segment_ids"].name,
inputs["input_mask"].name,
]
# Setup runing config for PaddleHub Finetune API
config = hub.RunConfig(
use_data_parallel=False,
use_pyreader=False,
use_cuda=args.use_gpu,
batch_size=args.batch_size,
enable_memory_optim=False,
checkpoint_dir=args.checkpoint_dir,
strategy=hub.finetune.strategy.DefaultFinetuneStrategy())
# Define a classfication finetune task by PaddleHub's API
multi_label_cls_task = hub.MultiLabelClassifierTask(
data_reader=reader,
feature=pooled_output,
feed_list=feed_list,
num_classes=dataset.num_labels,
config=config)
# Data to be prdicted
data = [
[
"Yes you did. And you admitted to doing it. See the Warren Kinsella talk page."
],
[
"I asked you a question. We both know you have my page on your watch list, so are why are you playing games and making me formally ping you? Makin'Bacon"
],
]
index = 0
results = multi_label_cls_task.predict(data=data)
for result in results:
# get predict index
label_ids = []
for i in range(num_label):
label_val = np.argmax(result[i])
label_ids.append(label_val)
print("%s\tpredict=%s" % (data[index][0], label_ids))
index += 1
export CUDA_VISIBLE_DEVICES=0
# User can select chnsenticorp, nlpcc_dbqa, lcqmc for different task
DATASET="toxic"
CKPT_DIR="./ckpt_${DATASET}"
# Recommending hyper parameters for difference task
# ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5
# NLPCC_DBQA: batch_size=8, weight_decay=0.01, num_epoch=3, max_seq_len=512, lr=2e-5
# LCQMC: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=2e-5
python -u multi_label_classifier.py \
--batch_size=32 \
--use_gpu=True \
--checkpoint_dir=${CKPT_DIR} \
--learning_rate=5e-5 \
--weight_decay=0.01 \
--max_seq_len=128 \
--num_epoch=3
export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_toxic"
python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 --use_gpu True
......@@ -47,6 +47,7 @@ from .finetune.task import ClassifierTask
from .finetune.task import TextClassifierTask
from .finetune.task import ImageClassifierTask
from .finetune.task import SequenceLabelTask
from .finetune.task import MultiLabelClassifierTask
from .finetune.config import RunConfig
from .finetune.strategy import AdamWeightDecayStrategy
from .finetune.strategy import DefaultStrategy
......
......@@ -19,6 +19,7 @@ from .chnsenticorp import ChnSentiCorp
from .msra_ner import MSRA_NER
from .nlpcc_dbqa import NLPCC_DBQA
from .lcqmc import LCQMC
from .toxic import Toxic
# CV Dataset
from .dogcat import DogCatDataset as DogCat
......
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import codecs
import os
import pandas as pd
from numpy import nan
from paddlehub.dataset import InputExample, HubDataset
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
_DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/toxic.tar.gz"
class Toxic(HubDataset):
"""
ChnSentiCorp (by Tan Songbo at ICT of Chinese Academy of Sciences, and for
opinion mining)
"""
def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "toxic")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_test_examples()
self._load_dev_examples()
def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir, "train.csv")
self.train_examples = self._read_csv(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.csv")
self.dev_examples = self._read_csv(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.csv")
self.test_examples = self._read_csv(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
return [
'toxic', 'severe_toxic', 'obscene', 'threat', 'insult',
'identity_hate'
]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_csv(self, input_file, quotechar=None):
"""Reads a tab separated value file."""
data = pd.read_csv(input_file, encoding="UTF-8")
examples = []
for index, row in data.iterrows():
guid = row["id"]
text = row["comment_text"]
labels = [int(value) for value in row[2:]]
example = InputExample(guid=guid, label=labels, text_a=text)
examples.append(example)
return examples
if __name__ == "__main__":
ds = Toxic()
for e in ds.get_train_examples():
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
......@@ -22,6 +22,7 @@ import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers.learning_rate_scheduler as lr_scheduler
from paddle.fluid.layers import control_flow
from paddlehub.common.logger import logger
def adam_weight_decay_optimization(loss,
......@@ -31,21 +32,30 @@ def adam_weight_decay_optimization(loss,
main_program,
weight_decay,
scheduler='linear_decay'):
if warmup_steps > 0:
if scheduler == 'noam_decay':
if scheduler == 'noam_decay':
if warmup_steps > 0:
scheduled_lr = fluid.layers.learning_rate_scheduler\
.noam_decay(1/(warmup_steps *(learning_rate ** 2)),
warmup_steps)
elif scheduler == 'linear_decay':
scheduled_lr = linear_warmup_decay(learning_rate, warmup_steps,
main_program)
else:
raise ValueError("Unkown learning rate scheduler, should be "
"'noam_decay' or 'linear_decay'")
optimizer = fluid.optimizer.Adam(learning_rate=scheduled_lr)
logger.warning(
"Noam decay learning rate scheduler should have positive \
warmup steps, using constant learning rate instead!")
scheduled_lr = fluid.layers.create_global_var(
shape=[1],
value=learning_rate,
dtype='float32',
persistable=True,
name="learning_rate")
elif scheduler == 'linear_decay':
scheduled_lr = linear_warmup_decay(learning_rate, num_train_steps,
warmup_steps, main_program)
else:
optimizer = fluid.optimizer.Adam(learning_rate=learning_rate)
scheduled_lr = learning_rate
raise ValueError("Unkown learning rate scheduler, should be "
"'noam_decay' or 'linear_decay'")
optimizer = fluid.optimizer.Adam(learning_rate=scheduled_lr)
clip_norm_thres = 1.0
fluid.clip.set_gradient_clip(
......@@ -81,13 +91,14 @@ def adam_weight_decay_optimization(loss,
return scheduled_lr
def linear_warmup_decay(init_lr, num_warmup_steps, main_program):
def linear_warmup_decay(init_lr, num_train_steps, num_warmup_steps,
main_program):
with main_program._lr_schedule_guard():
global_step = lr_scheduler._decay_step_counter()
lr = fluid.layers.create_global_var(
shape=[1],
value=0.0,
value=init_lr,
dtype='float32',
persistable=True,
name="learning_rate")
......@@ -97,8 +108,12 @@ def linear_warmup_decay(init_lr, num_warmup_steps, main_program):
decayed_lr = init_lr * global_step * 1.0 / num_warmup_steps
fluid.layers.assign(decayed_lr, lr)
with switch.default():
last_value_var = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(init_lr))
fluid.layers.assign(last_value_var, lr)
decayed_lr = lr_scheduler.polynomial_decay(
learning_rate=init_lr,
decay_steps=num_train_steps,
end_learning_rate=0.0,
power=1.0,
cycle=False)
fluid.layers.assign(decayed_lr, lr)
return lr
......@@ -38,7 +38,7 @@ from paddlehub.finetune.config import RunConfig
__all__ = [
"ClassifierTask", "ImageClassifierTask", "TextClassifierTask",
"SequenceLabelTask"
"SequenceLabelTask", "MultiLabelClassifierTask"
]
......@@ -859,10 +859,152 @@ class SequenceLabelTask(BasicTask):
feed_list += [self.seq_len.name]
return feed_list
class MultiLabelClassifierTask(ClassifierTask):
def __init__(self,
data_reader,
feature,
num_classes,
feed_list,
startup_program=None,
config=None,
hidden_units=None):
main_program = feature.block.program
super(MultiLabelClassifierTask, self).__init__(
data_reader=data_reader,
feature=feature,
num_classes=num_classes,
feed_list=feed_list,
startup_program=startup_program,
config=config,
hidden_units=hidden_units)
self.best_avg_auc = -1
def _build_net(self):
cls_feats = fluid.layers.dropout(
x=self.feature,
dropout_prob=0.1,
dropout_implementation="upscale_in_train")
if self.hidden_units is not None:
for n_hidden in self.hidden_units:
cls_feats = fluid.layers.fc(
input=cls_feats, size=n_hidden, act="relu")
probs = []
for i in range(self.num_classes):
probs.append(
fluid.layers.fc(
input=cls_feats,
size=2,
param_attr=fluid.ParamAttr(
name="cls_out_w_%d" % i,
initializer=fluid.initializer.TruncatedNormal(
scale=0.02)),
bias_attr=fluid.ParamAttr(
name="cls_out_b_%d" % i,
initializer=fluid.initializer.Constant(0.)),
act="softmax"))
return probs
def _add_label(self):
label = fluid.layers.data(
name="label", shape=[self.num_classes], dtype='int64')
return label
def _add_loss(self):
label_split = fluid.layers.split(self.label, self.num_classes, dim=-1)
total_loss = fluid.layers.fill_constant(
shape=[1], value=0.0, dtype='float64')
for index, probs in enumerate(self.output):
ce_loss = fluid.layers.cross_entropy(
input=probs, label=label_split[index])
total_loss += fluid.layers.reduce_sum(ce_loss)
loss = fluid.layers.mean(x=total_loss)
return loss
def _add_metrics(self):
label_split = fluid.layers.split(self.label, self.num_classes, dim=-1)
# metrics change to auc of every class
eval_list = []
for index, probs in enumerate(self.output):
current_auc, _, _ = fluid.layers.auc(
input=probs, label=label_split[index])
eval_list.append(current_auc)
return eval_list
def _build_env_end_event(self):
with self.log_writer.mode(self.phase) as logw:
self.env.loss_scalar = logw.scalar(
tag="Loss [{}]".format(self.phase))
self.env.auc_scalar_list = []
for i in range(self.num_classes):
self.env.auc_scalar_list.append(
logw.scalar(tag="AUC_{} [{}]".format(i, "train")))
self.env.avg_auc_scalar = logw.scalar(
tag="Average auc [{}]".format(self.phase))
def _calculate_metrics(self, run_states):
loss_sum = acc_sum = run_examples = 0
run_step = run_time_used = 0
for run_state in run_states:
run_examples += run_state.run_examples
run_step += run_state.run_step
loss_sum += np.mean(
run_state.run_results[-1]) * run_state.run_examples
auc_list = run_states[-1].run_results[:-1]
run_time_used = time.time() - run_states[0].run_time_begin
avg_loss = loss_sum / (run_examples * self.num_classes)
run_speed = run_step / run_time_used
return avg_loss, auc_list, run_speed
def _log_interval_event(self, run_states):
avg_loss, auc_list, run_speed = self._calculate_metrics(run_states)
if self.is_train_phase:
for index, auc_scalar in enumerate(self.env.auc_scalar_list):
auc_scalar.add_record(self.current_step, auc_list[index])
self.env.loss_scalar.add_record(self.current_step, avg_loss)
avg_auc = np.mean(auc_list)
self.env.avg_auc_scalar.add_record(self.current_step, avg_auc)
logger.info("step %d: loss=%.5f avg_auc=%.5f [step/sec: %.2f]" %
(self.current_step, avg_loss, avg_auc, run_speed))
for index, auc in enumerate(auc_list):
logger.info("label_%d_auc = %.5f" % (index, auc_list[index][0]))
def _eval_end_event(self, run_states):
eval_loss, auc_list, run_speed = self._calculate_metrics(run_states)
if self.is_train_phase:
for index, auc_scalar in enumerate(self.env.auc_scalar_list):
auc_scalar.add_record(self.current_step, auc_list[index])
avg_auc = np.mean(auc_list)
logger.info(
"[%s dataset evaluation result] loss=%.5f avg_auc=%.5f [step/sec: %.2f]"
% (self.phase, eval_loss, avg_auc, run_speed))
for index, auc in enumerate(auc_list):
logger.info("label_%d_auc = %.5f" % (index, auc_list[index][0]))
if self.phase in ["dev", "val"] and avg_auc > self.best_avg_auc:
self.env.loss_scalar.add_record(self.current_step, eval_loss)
for index, auc_scalar in enumerate(self.env.auc_scalar_list):
auc_scalar.add_record(self.current_step, auc_list[index])
self.env.avg_auc_scalar.add_record(self.current_step, avg_auc)
self.best_avg_auc = avg_auc
model_saved_dir = os.path.join(self.config.checkpoint_dir,
"best_model")
logger.info("best model saved to %s [best average auc=%.5f]" %
(model_saved_dir, self.best_avg_auc))
save_result = fluid.io.save_persistables(
executor=self.exe,
dirname=model_saved_dir,
main_program=self.main_program)
@property
def fetch_list(self):
if self.is_train_phase or self.is_test_phase:
return [metric.name for metric in self.metrics] + [self.loss.name]
elif self.is_predict_phase:
return [self.ret_infers.name] + [self.seq_len.name]
return [self.output.name]
return self.output
......@@ -16,4 +16,5 @@
from .nlp_reader import ClassifyReader
from .nlp_reader import SequenceLabelReader
from .nlp_reader import LACClassifyReader
from .nlp_reader import MultiLabelClassifyReader
from .cv_reader import ImageClassificationReader
......@@ -553,5 +553,113 @@ class LACClassifyReader(object):
return paddle.batch(_data_reader, batch_size=batch_size)
class MultiLabelClassifyReader(BaseReader):
def _pad_batch_records(self, batch_records, phase=None):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
# padding
padded_token_ids, input_mask = pad_batch_data(
batch_token_ids,
pad_idx=self.pad_id,
max_seq_len=self.max_seq_len,
return_input_mask=True)
padded_text_type_ids = pad_batch_data(
batch_text_type_ids,
max_seq_len=self.max_seq_len,
pad_idx=self.pad_id)
padded_position_ids = pad_batch_data(
batch_position_ids,
max_seq_len=self.max_seq_len,
pad_idx=self.pad_id)
if phase != "predict":
batch_labels_ids = [record.label_ids for record in batch_records]
num_label = len(self.dataset.get_labels())
batch_labels = np.array(batch_labels_ids).astype("int64").reshape(
[-1, num_label])
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_labels
]
else:
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask
]
return return_list
def _convert_example_to_record(self,
example,
max_seq_length,
tokenizer,
phase=None):
"""Converts a single `Example` into a single `Record`."""
text_a = tokenization.convert_to_unicode(example.text_a)
tokens_a = tokenizer.tokenize(text_a)
tokens_b = None
if example.text_b is not None:
#if "text_b" in example._fields:
text_b = tokenization.convert_to_unicode(example.text_b)
tokens_b = tokenizer.tokenize(text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
tokens = []
text_type_ids = []
tokens.append("[CLS]")
text_type_ids.append(0)
for token in tokens_a:
tokens.append(token)
text_type_ids.append(0)
tokens.append("[SEP]")
text_type_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
text_type_ids.append(1)
tokens.append("[SEP]")
text_type_ids.append(1)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
label_ids = []
for label in example.label:
label_ids.append(int(label))
if phase != "predict":
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_ids=label_ids)
else:
Record = namedtuple('Record',
['token_ids', 'text_type_ids', 'position_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids)
return record
if __name__ == '__main__':
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册