提交 74ec0647 编写于 作者: Z Zeyu Chen

reorg finetune and reader

上级 810cdd3a
...@@ -49,6 +49,7 @@ if __name__ == '__main__': ...@@ -49,6 +49,7 @@ if __name__ == '__main__':
# Setup runing config for PaddleHub Finetune API # Setup runing config for PaddleHub Finetune API
config = hub.RunConfig( config = hub.RunConfig(
eval_interval=10,
use_cuda=True, use_cuda=True,
num_epoch=args.num_epoch, num_epoch=args.num_epoch,
batch_size=args.batch_size, batch_size=args.batch_size,
......
...@@ -40,6 +40,3 @@ from .finetune.finetune import finetune_and_eval ...@@ -40,6 +40,3 @@ from .finetune.finetune import finetune_and_eval
from .finetune.config import RunConfig from .finetune.config import RunConfig
from .finetune.strategy import BERTFinetuneStrategy from .finetune.strategy import BERTFinetuneStrategy
from .finetune.strategy import DefaultStrategy from .finetune.strategy import DefaultStrategy
from .reader import BERTTokenizeReader
from .reader.cv_reader import ImageClassificationReader
...@@ -13,3 +13,4 @@ ...@@ -13,3 +13,4 @@
# limitations under the License. # limitations under the License.
from . import utils from . import utils
from .utils import get_running_device_info
...@@ -17,6 +17,8 @@ from __future__ import division ...@@ -17,6 +17,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import time
import multiprocessing
import hashlib import hashlib
import paddle import paddle
...@@ -185,6 +187,17 @@ def is_yaml_file(file_path): ...@@ -185,6 +187,17 @@ def is_yaml_file(file_path):
return get_file_ext(file_path) == ".yml" return get_file_ext(file_path) == ".yml"
def get_running_device_info(config):
if config.use_cuda:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
return place, dev_count
if __name__ == "__main__": if __name__ == "__main__":
print(is_yaml_file("test.yml")) print(is_yaml_file("test.yml"))
print(is_csv_file("test.yml")) print(is_csv_file("test.yml"))
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from datetime import datetime
import time import time
from datetime import datetime
from paddlehub.finetune.strategy import DefaultStrategy from paddlehub.finetune.strategy import DefaultStrategy
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
......
...@@ -12,6 +12,14 @@ ...@@ -12,6 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
import paddlehub as hub
from paddlehub.common.logger import logger
def evaluate_cls_task(task, data_reader, feed_list, phase="test", config=None): def evaluate_cls_task(task, data_reader, feed_list, phase="test", config=None):
logger.info("Evaluation on {} dataset start".format(phase)) logger.info("Evaluation on {} dataset start".format(phase))
...@@ -20,7 +28,7 @@ def evaluate_cls_task(task, data_reader, feed_list, phase="test", config=None): ...@@ -20,7 +28,7 @@ def evaluate_cls_task(task, data_reader, feed_list, phase="test", config=None):
loss = task.variable("loss") loss = task.variable("loss")
accuracy = task.variable("accuracy") accuracy = task.variable("accuracy")
batch_size = config.batch_size batch_size = config.batch_size
place, dev_count = _get_running_device_info(config) place, dev_count = hub.common.get_running_device_info(config)
exe = fluid.Executor(place=place) exe = fluid.Executor(place=place)
with fluid.program_guard(inference_program): with fluid.program_guard(inference_program):
data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place) data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
...@@ -64,7 +72,7 @@ def evaluate_seq_labeling_task(task, ...@@ -64,7 +72,7 @@ def evaluate_seq_labeling_task(task,
logger.info("Evaluation on {} dataset start".format(phase)) logger.info("Evaluation on {} dataset start".format(phase))
inference_program = task.inference_program() inference_program = task.inference_program()
batch_size = config.batch_size batch_size = config.batch_size
place, dev_count = _get_running_device_info(config) place, dev_count = hub.common.get_running_device_info(config)
exe = fluid.Executor(place=place) exe = fluid.Executor(place=place)
num_labels = len(data_reader.get_labels()) num_labels = len(data_reader.get_labels())
with fluid.program_guard(inference_program): with fluid.program_guard(inference_program):
......
...@@ -18,10 +18,10 @@ from __future__ import print_function ...@@ -18,10 +18,10 @@ from __future__ import print_function
import os import os
import time import time
import multiprocessing
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddlehub as hub
import numpy as np import numpy as np
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
...@@ -29,18 +29,6 @@ from paddlehub.finetune.strategy import BERTFinetuneStrategy, DefaultStrategy ...@@ -29,18 +29,6 @@ from paddlehub.finetune.strategy import BERTFinetuneStrategy, DefaultStrategy
from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint
from paddlehub.finetune.evaluate import evaluate_cls_task, evaluate_seq_labeling_task from paddlehub.finetune.evaluate import evaluate_cls_task, evaluate_seq_labeling_task
from visualdl import LogWriter from visualdl import LogWriter
import paddlehub as hub
def _get_running_device_info(config):
if config.use_cuda:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
return place, dev_count
def _do_memory_optimization(task, config): def _do_memory_optimization(task, config):
...@@ -80,7 +68,7 @@ def _finetune_seq_label_task(task, ...@@ -80,7 +68,7 @@ def _finetune_seq_label_task(task,
num_epoch = config.num_epoch num_epoch = config.num_epoch
batch_size = config.batch_size batch_size = config.batch_size
place, dev_count = _get_running_device_info(config) place, dev_count = hub.common.get_running_device_info(config)
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
exe = fluid.Executor(place=place) exe = fluid.Executor(place=place)
data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place) data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
...@@ -177,7 +165,7 @@ def _finetune_cls_task(task, data_reader, feed_list, config=None, ...@@ -177,7 +165,7 @@ def _finetune_cls_task(task, data_reader, feed_list, config=None,
log_writter = LogWriter( log_writter = LogWriter(
os.path.join(config.checkpoint_dir, "vdllog"), sync_cycle=10) os.path.join(config.checkpoint_dir, "vdllog"), sync_cycle=10)
place, dev_count = _get_running_device_info(config) place, dev_count = hub.common.get_running_device_info(config)
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
exe = fluid.Executor(place=place) exe = fluid.Executor(place=place)
data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place) data_feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
......
...@@ -12,6 +12,6 @@ ...@@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .nlp_reader import BERTTokenizeReader from .nlp_reader import ClassifyReader
from .task_reader import ClassifyReader from .nlp_reader import SequenceLabelReader
from .task_reader import SequenceLabelReader from .cv_reader import ImageClassificationReader
...@@ -12,33 +12,52 @@ ...@@ -12,33 +12,52 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import types
import csv import csv
import json
import numpy as np import numpy as np
from collections import namedtuple
#from paddlehub import dataset
from paddlehub.reader import tokenization from paddlehub.reader import tokenization
from paddlehub.reader.batching import prepare_batch_data from .batching import pad_batch_data
class BERTTokenizeReader(object): class BaseReader(object):
"""Base class for data converters for sequence classification data sets."""
def __init__(self, def __init__(self,
dataset, dataset,
vocab_path, vocab_path,
max_seq_len, label_map_config=None,
max_seq_len=512,
do_lower_case=True, do_lower_case=True,
in_tokens=False,
random_seed=None): random_seed=None):
self.dataset = dataset
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.tokenizer = tokenization.FullTokenizer( self.tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path, do_lower_case=do_lower_case) vocab_file=vocab_path, do_lower_case=do_lower_case)
self.vocab = self.tokenizer.vocab self.vocab = self.tokenizer.vocab
self.dataset = dataset
self.pad_id = self.vocab["[PAD]"]
self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"]
self.in_tokens = in_tokens
np.random.seed(random_seed) np.random.seed(random_seed)
# generate label map
self.label_map = {}
for index, label in enumerate(self.dataset.get_labels()):
self.label_map[label] = index
print("Dataset label map = {}".format(self.label_map))
self.current_example = 0
self.current_epoch = 0
self.num_examples = 0
# if label_map_config:
# with open(label_map_config) as f:
# self.label_map = json.load(f)
# else:
# self.label_map = None
self.num_examples = {'train': -1, 'dev': -1, 'test': -1} self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
def get_train_examples(self): def get_train_examples(self):
...@@ -61,116 +80,11 @@ class BERTTokenizeReader(object): ...@@ -61,116 +80,11 @@ class BERTTokenizeReader(object):
"""Gets the list of labels for this data set.""" """Gets the list of labels for this data set."""
return self.dataset.get_labels() return self.dataset.get_labels()
def convert_example(self, index, example, labels, max_seq_len, tokenizer): def get_train_progress(self):
"""Converts a single `InputExample` into a single `InputFeatures`.""" """Gets progress for training phase."""
feature = convert_single_example(index, example, labels, max_seq_len, return self.current_example, self.current_epoch
tokenizer)
return feature
def generate_instance(self, feature):
"""
generate instance with given feature
Args:
feature: InputFeatures(object). A single set of features of data.
"""
position_ids = list(range(len(feature.input_ids)))
return [
feature.input_ids, feature.segment_ids, position_ids,
feature.label_id
]
def generate_batch_data(self,
batch_data,
total_token_num,
return_input_mask=True,
return_max_len=False,
return_num_token=False):
return prepare_batch_data(
batch_data,
total_token_num,
max_seq_len=self.max_seq_len,
pad_id=self.vocab["[PAD]"],
cls_id=self.vocab["[CLS]"],
sep_id=self.vocab["[SEP]"],
return_input_mask=return_input_mask,
return_max_len=return_max_len,
return_num_token=return_num_token)
def get_num_examples(self, phase):
"""Get number of examples for train, dev or test."""
if phase not in ['train', 'val', 'dev', 'test']:
raise ValueError(
"Unknown phase, which should be in ['train', 'val'/'dev', 'test']."
)
return self.num_examples[phase]
def data_generator(self, batch_size, phase='train', shuffle=True):
"""
Generate data for train, dev/val or test.
Args:
batch_size: int. The batch size of generated data.
phase: string. The phase for which to generate data.
epoch: int. Total epoches to generate data.
shuffle: bool. Whether to shuffle examples.
"""
if phase == 'train':
examples = self.get_train_examples()
self.num_examples['train'] = len(examples)
elif phase == 'val' or phase == 'dev':
examples = self.get_dev_examples()
self.num_examples['dev'] = len(examples)
elif phase == 'test':
examples = self.get_test_examples()
self.num_examples['test'] = len(examples)
else:
raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].")
def instance_reader():
"""
convert a single instance to BERT input feature
"""
if shuffle:
np.random.shuffle(examples)
for (index, example) in enumerate(examples):
feature = self.convert_example(index, example,
self.get_labels(),
self.max_seq_len, self.tokenizer)
instance = self.generate_instance(feature)
yield instance
def batch_reader(reader, batch_size):
batch, total_token_num, max_len = [], 0, 0
for instance in reader():
token_ids, sent_ids, pos_ids, label = instance[:4]
max_len = max(max_len, len(token_ids))
batch.append(instance)
total_token_num += len(token_ids)
if len(batch) == batch_size:
yield batch, total_token_num
batch, total_token_num, max_len = [], 0, 0
if len(batch) > 0:
yield batch, total_token_num
def wrapper():
for batch_data, total_token_num in batch_reader(
instance_reader, batch_size):
batch_data = self.generate_batch_data(
batch_data,
total_token_num,
return_input_mask=True,
return_max_len=True,
return_num_token=False)
yield [batch_data]
return wrapper
def _truncate_seq_pair(tokens_a, tokens_b, max_length): def _truncate_seq_pair(self, tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length.""" """Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence # This is a simple heuristic which will always truncate the longer sequence
...@@ -186,47 +100,28 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): ...@@ -186,47 +100,28 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
else: else:
tokens_b.pop() tokens_b.pop()
def _convert_example_to_record(self, example, max_seq_length, tokenizer):
"""Converts a single `Example` into a single `Record`."""
class InputFeatures(object): text_a = tokenization.convert_to_unicode(example.text_a)
"""A single set of features of data.""" tokens_a = tokenizer.tokenize(text_a)
def __init__(self, input_ids, input_mask, segment_ids, label_id):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
def convert_single_example_to_unicode(guid, single_example):
text_a = tokenization.convert_to_unicode(single_example[0])
text_b = tokenization.convert_to_unicode(single_example[1])
label = tokenization.convert_to_unicode(single_example[2])
return InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`."""
label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None tokens_b = None
if example.text_b: if example.text_b is not None:
tokens_b = tokenizer.tokenize(example.text_b) #if "text_b" in example._fields:
text_b = tokenization.convert_to_unicode(example.text_b)
tokens_b = tokenizer.tokenize(text_b)
if tokens_b: if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total # Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length. # length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3" # Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else: else:
# Account for [CLS] and [SEP] with "- 2" # Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2: if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)] tokens_a = tokens_a[0:(max_seq_length - 2)]
# The convention in BERT is: # The convention in BERT/ERNIE is:
# (a) For sequence pairs: # (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
...@@ -245,52 +140,262 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -245,52 +140,262 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
# used as as the "sentence vector". Note that this only makes sense because # used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned. # the entire model is fine-tuned.
tokens = [] tokens = []
segment_ids = [] text_type_ids = []
tokens.append("[CLS]") tokens.append("[CLS]")
segment_ids.append(0) text_type_ids.append(0)
for token in tokens_a: for token in tokens_a:
tokens.append(token) tokens.append(token)
segment_ids.append(0) text_type_ids.append(0)
tokens.append("[SEP]") tokens.append("[SEP]")
segment_ids.append(0) text_type_ids.append(0)
if tokens_b: if tokens_b:
for token in tokens_b: for token in tokens_b:
tokens.append(token) tokens.append(token)
segment_ids.append(1) text_type_ids.append(1)
tokens.append("[SEP]") tokens.append("[SEP]")
segment_ids.append(1) text_type_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens) token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
# The mask has 1 for real tokens and 0 for padding tokens. Only real if self.label_map:
# tokens are attended to. label_id = self.label_map[example.label]
input_mask = [1] * len(input_ids) else:
label_id = example.label
# Record = namedtuple(
# 'Record',
# ['token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid'])
# qid = None
# if "qid" in example._fields:
# qid = example.qid
# record = Record(
# token_ids=token_ids,
# text_type_ids=text_type_ids,
# position_ids=position_ids,
# label_id=label_id,
# qid=qid)
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_id'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_id=label_id)
return record
def _prepare_batch_data(self, examples, batch_size, phase=None):
"""generate batch records"""
batch_records, max_len = [], 0
for index, example in enumerate(examples):
if phase == "train":
self.current_example = index
record = self._convert_example_to_record(example, self.max_seq_len,
self.tokenizer)
max_len = max(max_len, len(record.token_ids))
if self.in_tokens:
to_append = (len(batch_records) + 1) * max_len <= batch_size
else:
to_append = len(batch_records) < batch_size
if to_append:
batch_records.append(record)
else:
yield self._pad_batch_records(batch_records)
batch_records, max_len = [record], len(record.token_ids)
label_id = label_map[example.label] if batch_records:
yield self._pad_batch_records(batch_records)
feature = InputFeatures( # def get_num_examples(self, input_file):
input_ids=input_ids, # examples = self._read_tsv(input_file)
input_mask=input_mask, # return len(examples)
segment_ids=segment_ids,
label_id=label_id)
return feature
def get_num_examples(self, phase):
"""Get number of examples for train, dev or test."""
if phase not in ['train', 'val', 'dev', 'test']:
raise ValueError(
"Unknown phase, which should be in ['train', 'val'/'dev', 'test']."
)
return self.num_examples[phase]
def data_generator(self, batch_size, phase='train', shuffle=True):
def convert_examples_to_features(examples, label_list, max_seq_length, if phase == 'train':
tokenizer): examples = self.get_train_examples()
"""Convert a set of `InputExample`s to a list of `InputFeatures`.""" self.num_examples['train'] = len(examples)
elif phase == 'val' or phase == 'dev':
examples = self.get_dev_examples()
self.num_examples['dev'] = len(examples)
elif phase == 'test':
examples = self.get_test_examples()
self.num_examples['test'] = len(examples)
else:
raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].")
features = [] def wrapper():
for (ex_index, example) in enumerate(examples): if shuffle:
if ex_index % 10000 == 0: np.random.shuffle(examples)
print("Writing example %d of %d" % (ex_index, len(examples)))
feature = convert_single_example(ex_index, example, label_list, for batch_data in self._prepare_batch_data(
max_seq_length, tokenizer) examples, batch_size, phase=phase):
yield [batch_data]
return wrapper
class ClassifyReader(BaseReader):
def _pad_batch_records(self, batch_records):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
batch_labels = [record.label_id for record in batch_records]
batch_labels = np.array(batch_labels).astype("int64").reshape([-1, 1])
# if batch_records[0].qid:
# batch_qids = [record.qid for record in batch_records]
# batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1])
# else:
# batch_qids = np.array([]).astype("int64").reshape([-1, 1])
# padding
padded_token_ids, input_mask = pad_batch_data(
batch_token_ids,
max_seq_len=self.max_seq_len,
pad_idx=self.pad_id,
return_input_mask=True)
padded_text_type_ids = pad_batch_data(
batch_text_type_ids,
max_seq_len=self.max_seq_len,
pad_idx=self.pad_id)
padded_position_ids = pad_batch_data(
batch_position_ids,
max_seq_len=self.max_seq_len,
pad_idx=self.pad_id)
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_labels
]
return return_list
class SequenceLabelReader(BaseReader):
def _pad_batch_records(self, batch_records):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
batch_label_ids = [record.label_ids for record in batch_records]
# padding
padded_token_ids, input_mask, batch_seq_lens = pad_batch_data(
batch_token_ids,
pad_idx=self.pad_id,
max_seq_len=self.max_seq_len,
return_input_mask=True,
return_seq_lens=True)
padded_text_type_ids = pad_batch_data(
batch_text_type_ids,
max_seq_len=self.max_seq_len,
pad_idx=self.pad_id)
padded_position_ids = pad_batch_data(
batch_position_ids,
max_seq_len=self.max_seq_len,
pad_idx=self.pad_id)
padded_label_ids = pad_batch_data(
batch_label_ids,
max_seq_len=self.max_seq_len,
pad_idx=len(self.label_map) - 1)
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_label_ids, batch_seq_lens
]
return return_list
def _reseg_token_label(self, tokens, labels, tokenizer):
assert len(tokens) == len(labels)
ret_tokens = []
ret_labels = []
for token, label in zip(tokens, labels):
sub_token = tokenizer.tokenize(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
ret_labels.append(label)
if len(sub_token) < 2:
continue
sub_label = label
if label.startswith("B-"):
sub_label = "I-" + label[2:]
ret_labels.extend([sub_label] * (len(sub_token) - 1))
assert len(ret_tokens) == len(ret_labels)
return ret_tokens, ret_labels
def _convert_example_to_record(self, example, max_seq_length, tokenizer):
tokens = tokenization.convert_to_unicode(example.text_a).split(u"")
labels = tokenization.convert_to_unicode(example.label).split(u"")
tokens, labels = self._reseg_token_label(tokens, labels, tokenizer)
if len(tokens) > max_seq_length - 2:
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
text_type_ids = [0] * len(token_ids)
no_entity_id = len(self.label_map) - 1
label_ids = [no_entity_id
] + [self.label_map[label]
for label in labels] + [no_entity_id]
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_ids=label_ids)
return record
class ExtractEmbeddingReader(BaseReader):
def _pad_batch_records(self, batch_records):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
# padding
padded_token_ids, input_mask, seq_lens = pad_batch_data(
batch_token_ids,
pad_idx=self.pad_id,
max_seq_len=self.max_seq_len,
return_input_mask=True,
return_seq_lens=True)
padded_text_type_ids = pad_batch_data(
batch_text_type_ids,
pad_idx=self.pad_id,
max_seq_len=self.max_seq_len)
padded_position_ids = pad_batch_data(
batch_position_ids,
pad_idx=self.pad_id,
max_seq_len=self.max_seq_len)
return_list = [
padded_token_ids, padded_text_type_ids, padded_position_ids,
input_mask, seq_lens
]
features.append(feature) return return_list
return features
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import json
import numpy as np
from collections import namedtuple
from paddlehub.reader import tokenization
from .batching import pad_batch_data
class BaseReader(object):
def __init__(self,
dataset,
vocab_path,
label_map_config=None,
max_seq_len=512,
do_lower_case=True,
in_tokens=False,
random_seed=None):
self.max_seq_len = max_seq_len
self.tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path, do_lower_case=do_lower_case)
self.vocab = self.tokenizer.vocab
self.dataset = dataset
self.pad_id = self.vocab["[PAD]"]
self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"]
self.in_tokens = in_tokens
np.random.seed(random_seed)
# generate label map
self.label_map = {}
for index, label in enumerate(self.dataset.get_labels()):
self.label_map[label] = index
print("Dataset label map = {}".format(self.label_map))
self.current_example = 0
self.current_epoch = 0
self.num_examples = 0
# if label_map_config:
# with open(label_map_config) as f:
# self.label_map = json.load(f)
# else:
# self.label_map = None
self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
def get_train_examples(self):
"""Gets a collection of `InputExample`s for the train set."""
return self.dataset.get_train_examples()
def get_dev_examples(self):
"""Gets a collection of `InputExample`s for the dev set."""
return self.dataset.get_dev_examples()
def get_val_examples(self):
"""Gets a collection of `InputExample`s for the val set."""
return self.dataset.get_val_examples()
def get_test_examples(self):
"""Gets a collection of `InputExample`s for prediction."""
return self.dataset.get_test_examples()
def get_labels(self):
"""Gets the list of labels for this data set."""
return self.dataset.get_labels()
def get_train_progress(self):
"""Gets progress for training phase."""
return self.current_example, self.current_epoch
def _truncate_seq_pair(self, tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def _convert_example_to_record(self, example, max_seq_length, tokenizer):
"""Converts a single `Example` into a single `Record`."""
text_a = tokenization.convert_to_unicode(example.text_a)
tokens_a = tokenizer.tokenize(text_a)
tokens_b = None
if example.text_b is not None:
#if "text_b" in example._fields:
text_b = tokenization.convert_to_unicode(example.text_b)
tokens_b = tokenizer.tokenize(text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
# The convention in BERT/ERNIE is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
text_type_ids = []
tokens.append("[CLS]")
text_type_ids.append(0)
for token in tokens_a:
tokens.append(token)
text_type_ids.append(0)
tokens.append("[SEP]")
text_type_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
text_type_ids.append(1)
tokens.append("[SEP]")
text_type_ids.append(1)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
if self.label_map:
label_id = self.label_map[example.label]
else:
label_id = example.label
# Record = namedtuple(
# 'Record',
# ['token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid'])
# qid = None
# if "qid" in example._fields:
# qid = example.qid
# record = Record(
# token_ids=token_ids,
# text_type_ids=text_type_ids,
# position_ids=position_ids,
# label_id=label_id,
# qid=qid)
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_id'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_id=label_id)
return record
def _prepare_batch_data(self, examples, batch_size, phase=None):
"""generate batch records"""
batch_records, max_len = [], 0
for index, example in enumerate(examples):
if phase == "train":
self.current_example = index
record = self._convert_example_to_record(example, self.max_seq_len,
self.tokenizer)
max_len = max(max_len, len(record.token_ids))
if self.in_tokens:
to_append = (len(batch_records) + 1) * max_len <= batch_size
else:
to_append = len(batch_records) < batch_size
if to_append:
batch_records.append(record)
else:
yield self._pad_batch_records(batch_records)
batch_records, max_len = [record], len(record.token_ids)
if batch_records:
yield self._pad_batch_records(batch_records)
# def get_num_examples(self, input_file):
# examples = self._read_tsv(input_file)
# return len(examples)
def get_num_examples(self, phase):
"""Get number of examples for train, dev or test."""
if phase not in ['train', 'val', 'dev', 'test']:
raise ValueError(
"Unknown phase, which should be in ['train', 'val'/'dev', 'test']."
)
return self.num_examples[phase]
def data_generator(self, batch_size, phase='train', shuffle=True):
if phase == 'train':
examples = self.get_train_examples()
self.num_examples['train'] = len(examples)
elif phase == 'val' or phase == 'dev':
examples = self.get_dev_examples()
self.num_examples['dev'] = len(examples)
elif phase == 'test':
examples = self.get_test_examples()
self.num_examples['test'] = len(examples)
else:
raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].")
def wrapper():
if shuffle:
np.random.shuffle(examples)
for batch_data in self._prepare_batch_data(
examples, batch_size, phase=phase):
yield [batch_data]
return wrapper
class ClassifyReader(BaseReader):
def _pad_batch_records(self, batch_records):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
batch_labels = [record.label_id for record in batch_records]
batch_labels = np.array(batch_labels).astype("int64").reshape([-1, 1])
# if batch_records[0].qid:
# batch_qids = [record.qid for record in batch_records]
# batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1])
# else:
# batch_qids = np.array([]).astype("int64").reshape([-1, 1])
# padding
padded_token_ids, input_mask = pad_batch_data(
batch_token_ids,
max_seq_len=self.max_seq_len,
pad_idx=self.pad_id,
return_input_mask=True)
padded_text_type_ids = pad_batch_data(
batch_text_type_ids,
max_seq_len=self.max_seq_len,
pad_idx=self.pad_id)
padded_position_ids = pad_batch_data(
batch_position_ids,
max_seq_len=self.max_seq_len,
pad_idx=self.pad_id)
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_labels
]
return return_list
class SequenceLabelReader(BaseReader):
def _pad_batch_records(self, batch_records):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
batch_label_ids = [record.label_ids for record in batch_records]
# padding
padded_token_ids, input_mask, batch_seq_lens = pad_batch_data(
batch_token_ids,
pad_idx=self.pad_id,
max_seq_len=self.max_seq_len,
return_input_mask=True,
return_seq_lens=True)
padded_text_type_ids = pad_batch_data(
batch_text_type_ids,
max_seq_len=self.max_seq_len,
pad_idx=self.pad_id)
padded_position_ids = pad_batch_data(
batch_position_ids,
max_seq_len=self.max_seq_len,
pad_idx=self.pad_id)
padded_label_ids = pad_batch_data(
batch_label_ids,
max_seq_len=self.max_seq_len,
pad_idx=len(self.label_map) - 1)
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_label_ids, batch_seq_lens
]
return return_list
def _reseg_token_label(self, tokens, labels, tokenizer):
assert len(tokens) == len(labels)
ret_tokens = []
ret_labels = []
for token, label in zip(tokens, labels):
sub_token = tokenizer.tokenize(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
ret_labels.append(label)
if len(sub_token) < 2:
continue
sub_label = label
if label.startswith("B-"):
sub_label = "I-" + label[2:]
ret_labels.extend([sub_label] * (len(sub_token) - 1))
assert len(ret_tokens) == len(ret_labels)
return ret_tokens, ret_labels
def _convert_example_to_record(self, example, max_seq_length, tokenizer):
tokens = tokenization.convert_to_unicode(example.text_a).split(u"")
labels = tokenization.convert_to_unicode(example.label).split(u"")
tokens, labels = self._reseg_token_label(tokens, labels, tokenizer)
if len(tokens) > max_seq_length - 2:
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
text_type_ids = [0] * len(token_ids)
no_entity_id = len(self.label_map) - 1
label_ids = [no_entity_id
] + [self.label_map[label]
for label in labels] + [no_entity_id]
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_ids=label_ids)
return record
class ExtractEmbeddingReader(BaseReader):
def _pad_batch_records(self, batch_records):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
# padding
padded_token_ids, input_mask, seq_lens = pad_batch_data(
batch_token_ids,
pad_idx=self.pad_id,
max_seq_len=self.max_seq_len,
return_input_mask=True,
return_seq_lens=True)
padded_text_type_ids = pad_batch_data(
batch_text_type_ids,
pad_idx=self.pad_id,
max_seq_len=self.max_seq_len)
padded_position_ids = pad_batch_data(
batch_position_ids,
pad_idx=self.pad_id,
max_seq_len=self.max_seq_len)
return_list = [
padded_token_ids, padded_text_type_ids, padded_position_ids,
input_mask, seq_lens
]
return return_list
if __name__ == '__main__':
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册