提交 44aa7b51 编写于 作者: Z Zeyu Chen

remove useless file and organize bert reader

上级 4b35d202
# PaddleHub
[![Build Status](https://travis-ci.org/PaddlePaddle/PaddleHub.svg?branch=master)](https://travis-ci.org/PaddlePaddle/PaddleHub)
[![Build Status](https://travis-ci.org/PaddlePaddle/PaddleHub.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/PaddleHub)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
......@@ -26,56 +26,42 @@ import paddle
import paddle.fluid as fluid
import paddle_hub as hub
import reader.cls as reader
import reader.task_reader as task_reader
from utils.args import ArgumentGroup, print_arguments
from paddle_hub.finetune.config import FinetuneConfig
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 3, "Number of epoches for fine-tuning.")
train_g.add_arg("learning_rate", float, 5e-5, "Learning rate used to train with warmup.")
train_g.add_arg("hub_module_dir", str, None, "PaddleHub module directory")
train_g.add_arg("lr_scheduler", str, "linear_warmup_decay", "scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay'])
train_g.add_arg("weight_decay", float, 0.01, "Weight decay rate for L2 regularizer.")
train_g.add_arg("warmup_proportion", float, 0.1,
"Proportion of training steps to perform linear learning rate warmup for.")
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("data_dir", str, None, "Path to training data.")
data_g.add_arg("checkpoint_dir", str, None, "Directory to model checkpoint")
data_g.add_arg("vocab_path", str, None, "Vocabulary path.")
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.")
data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training. see also --in_tokens.")
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--hub_module_dir", type=str, default=None, help="PaddleHub module directory")
parser.add_argument("--lr_scheduler", type=str, default="linear_warmup_decay",
help="scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay'])
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.")
parser.add_argument("--data_dir", type=str, default=None, help="Path to training data.")
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
args = parser.parse_args()
# yapf: enable.
if __name__ == '__main__':
print_arguments(args)
config = FinetuneConfig(
config = hub.FinetuneConfig(
log_interval=10,
eval_interval=100,
save_ckpt_interval=50,
use_cuda=True,
save_ckpt_interval=200,
checkpoint_dir=args.checkpoint_dir,
learning_rate=args.learning_rate,
num_epoch=args.epoch,
num_epoch=args.num_epoch,
batch_size=args.batch_size,
max_seq_len=args.max_seq_len,
weight_decay=args.weight_decay,
finetune_strategy="bert_finetune",
enable_memory_optim=True,
optimizer=None,
warmup_proportion=args.warmup_proportion)
finetune_strategy="bert_finetune")
# loading Paddlehub BERT
module = hub.Module(module_dir=args.hub_module_dir)
reader = reader.BERTClassifyReader(
data_dir=args.data_dir,
# Use BERTTokenizeReader to tokenize the dataset according to model's
# vocabulary
reader = hub.reader.BERTTokenizeReader(
dataset=hub.dataset.ChnSentiCorp(), # download chnsenticorp dataset
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
......
此差异已折叠。
......@@ -8,12 +8,11 @@ HUB_MODULE_DIR="./hub_module/bert_chinese_L-12_H-768_A-12.hub_module"
CKPT_DIR="./ckpt"
#rm -rf $CKPT_DIR
python -u finetune_with_hub.py \
--batch_size 128 \
--batch_size 32 \
--hub_module_dir=$HUB_MODULE_DIR \
--data_dir ${DATA_PATH} \
--weight_decay 0.01 \
--checkpoint_dir $CKPT_DIR \
--warmup_proportion 0.0 \
--epoch 2 \
--max_seq_len 16 \
--num_epoch 3 \
--max_seq_len 128 \
--learning_rate 5e-5
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Arguments for configuration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import argparse
def str2bool(v):
# because argparse does not support to parse "true, False" as python
# boolean directly
return v.lower() in ("true", "t", "1")
class ArgumentGroup(object):
def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, type, default, help, **kwargs):
type = str2bool if type == bool else type
self._group.add_argument(
"--" + name,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def print_arguments(args):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
def cast_fp16_to_fp32(i, o, prog):
prog.global_block().append_op(
type="cast",
inputs={"X": i},
outputs={"Out": o},
attrs={
"in_dtype": fluid.core.VarDesc.VarType.FP16,
"out_dtype": fluid.core.VarDesc.VarType.FP32
})
def cast_fp32_to_fp16(i, o, prog):
prog.global_block().append_op(
type="cast",
inputs={"X": i},
outputs={"Out": o},
attrs={
"in_dtype": fluid.core.VarDesc.VarType.FP32,
"out_dtype": fluid.core.VarDesc.VarType.FP16
})
def copy_to_master_param(p, block):
v = block.vars.get(p.name, None)
if v is None:
raise ValueError("no param name %s found!" % p.name)
new_p = fluid.framework.Parameter(
block=block,
shape=v.shape,
dtype=fluid.core.VarDesc.VarType.FP32,
type=v.type,
lod_level=v.lod_level,
stop_gradient=p.stop_gradient,
trainable=p.trainable,
optimize_attr=p.optimize_attr,
regularizer=p.regularizer,
gradient_clip_attr=p.gradient_clip_attr,
error_clip=p.error_clip,
name=v.name + ".master")
return new_p
def create_master_params_grads(params_grads, main_prog, startup_prog,
loss_scaling):
master_params_grads = []
tmp_role = main_prog._current_role
OpRole = fluid.core.op_proto_and_checker_maker.OpRole
main_prog._current_role = OpRole.Backward
for p, g in params_grads:
# create master parameters
master_param = copy_to_master_param(p, main_prog.global_block())
startup_master_param = startup_prog.global_block()._clone_variable(
master_param)
startup_p = startup_prog.global_block().var(p.name)
cast_fp16_to_fp32(startup_p, startup_master_param, startup_prog)
# cast fp16 gradients to fp32 before apply gradients
if g.name.find("layer_norm") > -1:
if loss_scaling > 1:
scaled_g = g / float(loss_scaling)
else:
scaled_g = g
master_params_grads.append([p, scaled_g])
continue
master_grad = fluid.layers.cast(g, "float32")
if loss_scaling > 1:
master_grad = master_grad / float(loss_scaling)
master_params_grads.append([master_param, master_grad])
main_prog._current_role = tmp_role
return master_params_grads
def master_param_to_train_param(master_params_grads, params_grads, main_prog):
for idx, m_p_g in enumerate(master_params_grads):
train_p, _ = params_grads[idx]
if train_p.name.find("layer_norm") > -1:
continue
with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]):
cast_fp32_to_fp16(m_p_g[0], train_p, main_prog)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import six
import ast
import copy
import numpy as np
import paddle.fluid as fluid
def cast_fp32_to_fp16(exe, main_program):
print("Cast parameters to float16 data format.")
for param in main_program.global_block().all_parameters():
if not param.name.endswith(".master"):
param_t = fluid.global_scope().find_var(param.name).get_tensor()
data = np.array(param_t)
if param.name.find("layer_norm") == -1:
param_t.set(np.float16(data).view(np.uint16), exe.place)
master_param_var = fluid.global_scope().find_var(param.name +
".master")
if master_param_var is not None:
master_param_var.get_tensor().set(data, exe.place)
def init_checkpoint(exe, init_checkpoint_path, main_program, use_fp16=False):
assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
def existed_persitables(var):
if not fluid.io.is_persistable(var):
return False
return os.path.exists(os.path.join(init_checkpoint_path, var.name))
fluid.io.load_vars(
exe,
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path))
if use_fp16:
cast_fp32_to_fp16(exe, main_program)
def init_pretraining_params(exe,
pretraining_params_path,
main_program,
use_fp16=False):
assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path
def existed_params(var):
if not isinstance(var, fluid.framework.Parameter):
print("param {} not exsist!".format(var.name))
return False
return os.path.exists(os.path.join(pretraining_params_path, var.name))
fluid.io.load_vars(
exe,
pretraining_params_path,
main_program=main_program,
predicate=existed_params)
print(
"Load pretraining parameters from {}.".format(pretraining_params_path))
if use_fp16:
cast_fp32_to_fp16(exe, main_program)
......@@ -14,6 +14,7 @@
from . import module
from . import common
from . import io
from . import dataset
from .common.dir import USER_HOME
from .common.dir import HUB_HOME
......@@ -34,3 +35,5 @@ from .finetune.network import append_mlp_classifier
from .finetune.finetune import finetune_and_eval
from .finetune.config import FinetuneConfig
from .finetune.task import Task
from .reader import BERTTokenizeReader
......@@ -14,6 +14,7 @@
import os
# TODO: Change dir.py's filename, this naming rule is not qualified
USER_HOME = os.path.expanduser('~')
HUB_HOME = os.path.join(USER_HOME, ".hub")
MODULE_HOME = os.path.join(HUB_HOME, "modules")
......
......@@ -88,7 +88,7 @@ class Downloader:
done = int(50 * dl / total_length)
if time.time() - starttime >= FLUSH_INTERVAL:
sys.stdout.write(
"\r%s : [%-50s]%.2f%%" %
"\r%s : [%-50s] %.2f%%" %
(save_name, '=' * done,
float(dl / total_length * 100)))
starttime = time.time()
......
......@@ -19,8 +19,7 @@ import logging
import math
class Logger:
class Logger(object):
PLACEHOLDER = '%'
NOLOG = "NOLOG"
......@@ -29,7 +28,7 @@ class Logger:
format='[%(asctime)-15s] [%(levelname)8s] - %(message)s')
if not name:
name = "paddle-hub"
name = "PaddleHub"
self.logger = logging.getLogger(name)
self.logLevel = "DEBUG"
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dataset import InputExample, HubDataset
from .chnsenticorp import ChnSentiCorp
from .msra_ner import MSRA_NER
......@@ -12,31 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_hub.tools.downloader import default_downloader
from paddle_hub.dir import DATA_HOME
from paddle_hub.common.downloader import default_downloader
from paddle_hub.common.dir import DATA_HOME
import os
import csv
from paddle_hub.dataset import InputExample
from paddle_hub.dataset import HubDataset
from collections import namedtuple
DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/chnsenticorp_data.tar.gz"
class HubDataset(object):
def get_train_examples(self):
raise NotImplementedError()
def get_dev_examples(self):
raise NotImplementedError()
def get_test_examples(self):
raise NotImplementedError()
def get_val_examples(self):
return self.get_dev_examples()
class ChnSentiCorp(HubDataset):
"""
ChnSentiCorp (by Tan Songbo at ICT of Chinese Academy of Sciences, and for
opinion mining)
"""
def __init__(self):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=DATA_URL, save_path=DATA_HOME, print_progress=True)
......@@ -66,15 +60,20 @@ class ChnSentiCorp(HubDataset):
def get_test_examples(self):
return self.test_examples
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
Example = namedtuple('Example', ["label", "text_a"])
examples = []
seq_id = 0
for line in reader:
example = Example(*line)
example = InputExample(
guid=seq_id, label=line[0], text_a=line[1])
seq_id += 1
examples.append(example)
return examples
......@@ -82,5 +81,5 @@ class ChnSentiCorp(HubDataset):
if __name__ == "__main__":
ds = ChnSentiCorp()
for e in ds.get_train_example():
for e in ds.get_train_examples():
print(e)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class InputExample(object):
"""
Input data structure of BERT/ERNIE, can satisfy single sequence task like
text classification, sequence lableing; Sequence pair task like dialog
task.
"""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class HubDataset(object):
def get_train_examples(self):
raise NotImplementedError()
def get_dev_examples(self):
raise NotImplementedError()
def get_test_examples(self):
raise NotImplementedError()
def get_val_examples(self):
return self.get_dev_examples()
def get_labels(self):
raise NotImplementedError()
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_hub.tools.downloader import default_downloader
from paddle_hub.dir import DATA_HOME
from paddle_hub.common.downloader import default_downloader
from paddle_hub.common.dir import DATA_HOME
import os
import csv
......@@ -28,7 +28,6 @@ class MSRA_NER(object):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=DATA_URL, save_path=DATA_HOME, print_progress=True)
print(self.dataset_dir)
self._load_label_map()
self._load_train_examples()
......@@ -44,6 +43,10 @@ class MSRA_NER(object):
def get_train_examples(self):
return self.train_examples
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import csv
import json
import numpy as np
from collections import namedtuple
import tokenization
from batching import pad_batch_data
class BaseReader(object):
def __init__(self,
vocab_path,
label_map_config=None,
max_seq_len=512,
do_lower_case=True,
in_tokens=False,
random_seed=None):
self.max_seq_len = max_seq_len
self.tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path, do_lower_case=do_lower_case)
self.vocab = self.tokenizer.vocab
self.pad_id = self.vocab["[PAD]"]
self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"]
self.in_tokens = in_tokens
np.random.seed(random_seed)
self.current_example = 0
self.current_epoch = 0
self.num_examples = 0
if label_map_config:
with open(label_map_config) as f:
self.label_map = json.load(f)
else:
self.label_map = None
pass
def get_train_progress(self):
"""Gets progress for training phase."""
return self.current_example, self.current_epoch
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
headers = next(reader)
Example = namedtuple('Example', headers)
examples = []
for line in reader:
example = Example(*line)
examples.append(example)
return examples
def _truncate_seq_pair(self, tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def _convert_example_to_record(self, example, max_seq_length, tokenizer):
"""Converts a single `Example` into a single `Record`."""
text_a = tokenization.convert_to_unicode(example.text_a)
tokens_a = tokenizer.tokenize(text_a)
tokens_b = None
if "text_b" in example._fields:
text_b = tokenization.convert_to_unicode(example.text_b)
tokens_b = tokenizer.tokenize(text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
# The convention in BERT/ERNIE is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
text_type_ids = []
tokens.append("[CLS]")
text_type_ids.append(0)
for token in tokens_a:
tokens.append(token)
text_type_ids.append(0)
tokens.append("[SEP]")
text_type_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
text_type_ids.append(1)
tokens.append("[SEP]")
text_type_ids.append(1)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
if self.label_map:
label_id = self.label_map[example.label]
else:
label_id = example.label
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid'])
qid = None
if "qid" in example._fields:
qid = example.qid
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_id=label_id,
qid=qid)
return record
def _prepare_batch_data(self, examples, batch_size, phase=None):
"""generate batch records"""
batch_records, max_len = [], 0
for index, example in enumerate(examples):
if phase == "train":
self.current_example = index
record = self._convert_example_to_record(example, self.max_seq_len,
self.tokenizer)
max_len = max(max_len, len(record.token_ids))
if self.in_tokens:
to_append = (len(batch_records) + 1) * max_len <= batch_size
else:
to_append = len(batch_records) < batch_size
if to_append:
batch_records.append(record)
else:
yield self._pad_batch_records(batch_records)
batch_records, max_len = [record], len(record.token_ids)
if len(batch_records) > 0:
yield self._pad_batch_records(batch_records)
def get_num_examples(self, input_file):
examples = self._read_tsv(input_file)
return len(examples)
def data_generator(self,
input_file,
batch_size,
epoch,
shuffle=True,
phase=None):
examples = self._read_tsv(input_file)
def wrapper():
for epoch_index in range(epoch):
if phase == "train":
self.current_example = 0
self.current_epoch = epoch_index
if shuffle:
np.random.shuffle(examples)
for batch_data in self._prepare_batch_data(
examples, batch_size, phase=phase):
yield batch_data
return wrapper
class ClassifyReader(BaseReader):
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
headers = next(reader)
text_indices = [
index for index, h in enumerate(headers) if h != "label"
]
Example = namedtuple('Example', headers)
examples = []
for line in reader:
for index, text in enumerate(line):
if index in text_indices:
line[index] = text.replace(' ', '')
example = Example(*line)
examples.append(example)
return examples
def _pad_batch_records(self, batch_records):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
batch_labels = [record.label_id for record in batch_records]
batch_labels = np.array(batch_labels).astype("int64").reshape([-1, 1])
if batch_records[0].qid:
batch_qids = [record.qid for record in batch_records]
batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1])
else:
batch_qids = np.array([]).astype("int64").reshape([-1, 1])
# padding
padded_token_ids, next_sent_index, self_attn_bias = pad_batch_data(
batch_token_ids,
pad_idx=self.pad_id,
return_next_sent_pos=True,
return_attn_bias=True)
padded_text_type_ids = pad_batch_data(
batch_text_type_ids, pad_idx=self.pad_id)
padded_position_ids = pad_batch_data(
batch_position_ids, pad_idx=self.pad_id)
return_list = [
padded_token_ids, padded_text_type_ids, padded_position_ids,
self_attn_bias, batch_labels, next_sent_index, batch_qids
]
return return_list
class SequenceLabelReader(BaseReader):
def _pad_batch_records(self, batch_records):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
batch_label_ids = [record.label_ids for record in batch_records]
batch_seq_lens = [len(record.token_ids) for record in batch_records]
# padding
padded_token_ids, self_attn_bias = pad_batch_data(
batch_token_ids,
pad_idx=self.pad_id,
return_next_sent_pos=False,
return_attn_bias=True)
padded_text_type_ids = pad_batch_data(
batch_text_type_ids, pad_idx=self.pad_id)
padded_position_ids = pad_batch_data(
batch_position_ids, pad_idx=self.pad_id)
padded_label_ids = pad_batch_data(
batch_label_ids, pad_idx=len(self.label_map) - 1)
batch_seq_lens = np.array(batch_seq_lens).astype("int64").reshape(
[-1, 1])
return_list = [
padded_token_ids, padded_text_type_ids, padded_position_ids,
self_attn_bias, padded_label_ids, batch_seq_lens
]
return return_list
def _reseg_token_label(self, tokens, labels, tokenizer):
assert len(tokens) == len(labels)
ret_tokens = []
ret_labels = []
for token, label in zip(tokens, labels):
sub_token = tokenizer.tokenize(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
ret_labels.append(label)
if len(sub_token) < 2:
continue
sub_label = label
if label.startswith("B-"):
sub_label = "I-" + label[2:]
ret_labels.extend([sub_label] * (len(sub_token) - 1))
assert len(ret_tokens) == len(ret_labels)
return ret_tokens, ret_labels
def _convert_example_to_record(self, example, max_seq_length, tokenizer):
tokens = tokenization.convert_to_unicode(example.text_a).split(u"")
labels = tokenization.convert_to_unicode(example.label).split(u"")
tokens, labels = self._reseg_token_label(tokens, labels, tokenizer)
if len(tokens) > max_seq_length - 2:
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
text_type_ids = [0] * len(token_ids)
no_entity_id = len(self.label_map) - 1
label_ids = [no_entity_id
] + [self.label_map[label]
for label in labels] + [no_entity_id]
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_ids=label_ids)
return record
if __name__ == '__main__':
pass
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import unicodedata
import six
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
fin = open(vocab_file)
for num, line in enumerate(fin):
items = convert_to_unicode(line.strip()).split("\t")
if len(items) > 2:
break
token = items[0]
index = items[1] if len(items) == 2 else num
token = token.strip()
vocab[token] = int(index)
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a peice of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
class CharTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in text.lower().split(" "):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64)
or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
......@@ -36,14 +36,17 @@ def load_checkpoint(checkpoint_dir, exe):
fluid.io.load_persistables(exe, ckpt.latest_model_dir)
logger.info("Checkpoint loaded. current_epoch={},"
"global_step={}".format(current_epoch, global_step))
logger.info("PaddleHub model checkpoint loaded. current_epoch={}, "
"global_step={}".format(ckpt.current_epoch,
ckpt.global_step))
return ckpt.current_epoch, ckpt.global_step
else:
current_epoch = 1
global_step = 0
latest_model_dir = None
logger.info("Checkpoint not found, start training from scratch...")
logger.info(
"PaddleHub model checkpoint not found, start training from scratch..."
)
exe.run(fluid.default_startup_program())
return current_epoch, global_step
......
......@@ -40,7 +40,6 @@ def _get_running_device_info(config):
def _do_memory_optimization(task, config):
if config.enable_memory_optim:
logger.info("Memory optimization start...")
task_var_name = task.metric_variable_names()
......@@ -56,7 +55,7 @@ def _do_memory_optimization(task, config):
lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=fluid.default_main_program(), batch_size=config.batch_size)
logger.info("Theoretical memory usage in training: %.3f - %.3f %s" %
logger.info("Theoretical memory usage in training: %.2f - %.2f %s" %
(lower_mem, upper_mem, unit)),
......@@ -102,6 +101,7 @@ def _finetune_model(task, data_reader, feed_list, config=None, do_eval=False):
eval_loss_scalar = logw.scalar(tag="loss[evaluate]")
eval_acc_scalar = logw.scalar(tag="accuracy[evaluate]")
# Finetune loop
for epoch in range(current_epoch, num_epoch + 1):
train_reader = data_reader.data_generator(
batch_size=batch_size, phase='train')
......@@ -134,9 +134,6 @@ def _finetune_model(task, data_reader, feed_list, config=None, do_eval=False):
num_trained_examples = acc_sum = loss_sum = 0
if global_step % config.save_ckpt_interval == 0:
model_saved_dir = os.path.join(config.checkpoint_dir,
"step_%d" % global_step)
fluid.io.save_persistables(exe, dirname=model_saved_dir)
# NOTE: current saved checkpoint machanism is not completed,
# it can't restore dataset training status
save_checkpoint(
......@@ -163,9 +160,6 @@ def _finetune_model(task, data_reader, feed_list, config=None, do_eval=False):
(model_saved_dir, best_eval_acc))
fluid.io.save_persistables(exe, dirname=model_saved_dir)
# update model and checkpoint
model_saved_dir = os.path.join(config.checkpoint_dir, "final_model")
fluid.io.save_persistables(exe, dirname=model_saved_dir)
# NOTE: current saved checkpoint machanism is not completed, it can't
# resotre dataset training status
save_checkpoint(
......@@ -188,6 +182,7 @@ def finetune(task, data_reader, feed_list, config=None):
def evaluate(task, data_reader, feed_list, phase="test", config=None):
logger.info("Evaluation on {} dataset start".format(phase))
inference_program = task.inference_program()
main_program = task.main_program()
loss = task.variable("loss")
......@@ -216,7 +211,8 @@ def evaluate(task, data_reader, feed_list, phase="test", config=None):
avg_loss = loss_sum / num_eval_examples
avg_acc = acc_sum / num_eval_examples
eval_speed = eval_step / eval_time_used
logger.info("[evaluation on %s set] loss=%.5f acc=%.5f [step/sec: %.2f]" %
(phase, avg_loss, avg_acc, eval_speed))
logger.info(
"[%s dataset evaluation result] loss=%.5f acc=%.5f [step/sec: %.2f]" %
(phase, avg_loss, avg_acc, eval_speed))
return avg_loss, avg_acc, eval_speed
......@@ -104,6 +104,7 @@ class Module(object):
self.module_info = None
self.processor = None
self.name = "temp"
# TODO(wuzewu): print more module loading info log
if url:
self._init_with_url(url=url)
elif module_dir:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .nlp_reader import BERTTokenizeReader
......@@ -188,7 +188,3 @@ def pad_batch_data(insts,
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
if __name__ == "__main__":
pass
......@@ -16,52 +16,50 @@ import os
import types
import csv
import numpy as np
import tokenization
#from paddle_hub import dataset
from paddle_hub.reader import tokenization
from .batching import prepare_batch_data
class DataProcessor(object):
class BERTTokenizeReader(object):
"""Base class for data converters for sequence classification data sets."""
def __init__(self,
data_dir,
dataset,
vocab_path,
max_seq_len,
do_lower_case=True,
in_tokens=False,
random_seed=None):
self.data_dir = data_dir
self.dataset = dataset
self.max_seq_len = max_seq_len
self.tokenizer = tokenization.FullTokenizer(
vocab_file=vocab_path, do_lower_case=do_lower_case)
self.vocab = self.tokenizer.vocab
self.in_tokens = in_tokens
np.random.seed(random_seed)
self.current_train_example = -1
self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
self.current_train_epoch = -1
def get_train_examples(self, data_dir):
def get_train_examples(self):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
return self.dataset.get_train_examples()
def get_dev_examples(self, data_dir):
def get_dev_examples(self):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
return self.dataset.get_dev_examples()
def get_val_examples(self, data_dir):
def get_val_examples(self):
"""Gets a collection of `InputExample`s for the val set."""
raise NotImplementedError()
return self.dataset.get_val_examples()
def get_test_examples(self, data_dir):
def get_test_examples(self):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError()
return self.dataset.get_test_examples()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
return self.dataset.get_labels()
def convert_example(self, index, example, labels, max_seq_len, tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`."""
......@@ -76,9 +74,10 @@ class DataProcessor(object):
Args:
feature: InputFeatures(object). A single set of features of data.
"""
input_pos = list(range(len(feature.input_ids)))
position_ids = list(range(len(feature.input_ids)))
return [
feature.input_ids, feature.segment_ids, input_pos, feature.label_id
feature.input_ids, feature.segment_ids, position_ids,
feature.label_id
]
def generate_batch_data(self,
......@@ -87,7 +86,7 @@ class DataProcessor(object):
voc_size=-1,
mask_id=-1,
return_input_mask=True,
return_max_len=True,
return_max_len=False,
return_num_token=False):
return prepare_batch_data(
batch_data,
......@@ -99,19 +98,9 @@ class DataProcessor(object):
sep_id=self.vocab["[SEP]"],
mask_id=-1,
return_input_mask=return_input_mask,
return_max_len=True,
return_max_len=return_max_len,
return_num_token=return_num_token)
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines
def get_num_examples(self, phase):
"""Get number of examples for train, dev or test."""
if phase not in ['train', 'val', 'dev', 'test']:
......@@ -120,13 +109,9 @@ class DataProcessor(object):
)
return self.num_examples[phase]
def get_train_progress(self):
"""Gets progress for training phase."""
return self.current_train_example, self.current_train_epoch
def data_generator(self, batch_size, phase='train', epoch=1, shuffle=True):
def data_generator(self, batch_size, phase='train', shuffle=True):
"""
Generate data for train, dev or test.
Generate data for train, dev/val or test.
Args:
batch_size: int. The batch size of generated data.
......@@ -135,59 +120,49 @@ class DataProcessor(object):
shuffle: bool. Whether to shuffle examples.
"""
if phase == 'train':
examples = self.get_train_examples(self.data_dir)
examples = self.get_train_examples()
self.num_examples['train'] = len(examples)
elif phase == 'val' or phase == 'dev':
examples = self.get_dev_examples(self.data_dir)
examples = self.get_dev_examples()
self.num_examples['dev'] = len(examples)
elif phase == 'test':
examples = self.get_test_examples(self.data_dir)
examples = self.get_test_examples()
self.num_examples['test'] = len(examples)
else:
raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].")
def instance_reader():
for epoch_index in range(epoch):
if shuffle:
np.random.shuffle(examples)
if phase == 'train':
self.current_train_epoch = epoch_index
for (index, example) in enumerate(examples):
if phase == 'train':
self.current_train_example = index + 1
feature = self.convert_example(index, example,
self.get_labels(),
self.max_seq_len,
self.tokenizer)
instance = self.generate_instance(feature)
yield instance
def batch_reader(reader, batch_size, in_tokens):
"""
convert a single instance to BERT input feature
"""
if shuffle:
np.random.shuffle(examples)
for (index, example) in enumerate(examples):
feature = self.convert_example(index, example,
self.get_labels(),
self.max_seq_len, self.tokenizer)
instance = self.generate_instance(feature)
yield instance
def batch_reader(reader, batch_size):
batch, total_token_num, max_len = [], 0, 0
for instance in reader():
token_ids, sent_ids, pos_ids, label = instance[:4]
max_len = max(max_len, len(token_ids))
if in_tokens:
to_append = (len(batch) + 1) * max_len <= batch_size
else:
to_append = len(batch) < batch_size
if to_append:
batch.append(instance)
total_token_num += len(token_ids)
else:
batch.append(instance)
total_token_num += len(token_ids)
if len(batch) == batch_size:
yield batch, total_token_num
batch, total_token_num, max_len = [
instance
], len(token_ids), len(token_ids)
batch, total_token_num, max_len = [], 0, 0
if len(batch) > 0:
yield batch, total_token_num
def wrapper():
for batch_data, total_token_num in batch_reader(
instance_reader, batch_size, self.in_tokens):
instance_reader, batch_size):
batch_data = self.generate_batch_data(
batch_data,
total_token_num,
......@@ -201,27 +176,6 @@ class DataProcessor(object):
return wrapper
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
......@@ -249,271 +203,6 @@ class InputFeatures(object):
self.label_id = label_id
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
def get_train_examples(self, data_dir):
"""See base class."""
self.language = "zh"
lines = self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % self.language))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "train-%d" % (i)
text_a = tokenization.convert_to_unicode(line[0])
text_b = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[2])
if label == tokenization.convert_to_unicode("contradictory"):
label = tokenization.convert_to_unicode("contradiction")
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
self.language = "zh"
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % (i)
language = tokenization.convert_to_unicode(line[0])
if language != tokenization.convert_to_unicode(self.language):
continue
text_a = tokenization.convert_to_unicode(line[6])
text_b = tokenization.convert_to_unicode(line[7])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
self.language = "zh"
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % (i)
language = tokenization.convert_to_unicode(line[0])
if language != tokenization.convert_to_unicode(self.language):
continue
text_a = tokenization.convert_to_unicode(line[6])
text_b = tokenization.convert_to_unicode(line[7])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(
line[0]))
text_a = tokenization.convert_to_unicode(line[8])
text_b = tokenization.convert_to_unicode(line[9])
if set_type == "test":
label = "contradiction"
else:
label = tokenization.convert_to_unicode(line[-1])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3])
text_b = tokenization.convert_to_unicode(line[4])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
# Only the test set has a header
if set_type == "test" and i == 0:
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1])
label = "0"
else:
text_a = tokenization.convert_to_unicode(line[3])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=None, label=label))
return examples
class ChnsenticorpProcessor(DataProcessor):
"""Processor for the Chnsenticorp data set."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=None, label=label))
return examples
class BERTClassifyReader(DataProcessor):
"""Processor for the Chnsenticorp data set."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(
guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def convert_single_example_to_unicode(guid, single_example):
text_a = tokenization.convert_to_unicode(single_example[0])
text_b = tokenization.convert_to_unicode(single_example[1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册