未验证 提交 300be16c 编写于 作者: P pkpk 提交者: GitHub

test=develop (#4175)

上级 071dc299
...@@ -23,19 +23,21 @@ import numpy as np ...@@ -23,19 +23,21 @@ import numpy as np
from dgu import tokenization from dgu import tokenization
from dgu.batching import prepare_batch_data from dgu.batching import prepare_batch_data
reload(sys) if sys.version[0] == '2':
sys.setdefaultencoding('utf-8') reload(sys)
sys.setdefaultencoding('utf-8')
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for data converters for sequence classification data sets."""
def __init__(self, def __init__(self,
data_dir, data_dir,
vocab_path, vocab_path,
max_seq_len, max_seq_len,
do_lower_case, do_lower_case,
in_tokens, in_tokens,
task_name, task_name,
random_seed=None): random_seed=None):
self.data_dir = data_dir self.data_dir = data_dir
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
...@@ -92,7 +94,7 @@ class DataProcessor(object): ...@@ -92,7 +94,7 @@ class DataProcessor(object):
mask_id=-1, mask_id=-1,
return_input_mask=True, return_input_mask=True,
return_max_len=False, return_max_len=False,
return_num_token=False): return_num_token=False):
"""generate batch data""" """generate batch data"""
return prepare_batch_data( return prepare_batch_data(
self.task_name, self.task_name,
...@@ -114,7 +116,7 @@ class DataProcessor(object): ...@@ -114,7 +116,7 @@ class DataProcessor(object):
f = io.open(input_file, "r", encoding="utf8") f = io.open(input_file, "r", encoding="utf8")
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = [] lines = []
for line in reader: for line in reader:
lines.append(line) lines.append(line)
return lines return lines
...@@ -147,21 +149,21 @@ class DataProcessor(object): ...@@ -147,21 +149,21 @@ class DataProcessor(object):
raise ValueError( raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].") "Unknown phase, which should be in ['train', 'dev', 'test'].")
def instance_reader(): def instance_reader():
"""generate instance data""" """generate instance data"""
if shuffle: if shuffle:
np.random.shuffle(examples) np.random.shuffle(examples)
for (index, example) in enumerate(examples): for (index, example) in enumerate(examples):
feature = self.convert_example( feature = self.convert_example(index, example,
index, example, self.get_labels(),
self.get_labels(), self.max_seq_len, self.tokenizer) self.max_seq_len, self.tokenizer)
instance = self.generate_instance(feature) instance = self.generate_instance(feature)
yield instance yield instance
def batch_reader(reader, batch_size, in_tokens): def batch_reader(reader, batch_size, in_tokens):
"""read batch data""" """read batch data"""
batch, total_token_num, max_len = [], 0, 0 batch, total_token_num, max_len = [], 0, 0
for instance in reader(): for instance in reader():
token_ids, sent_ids, pos_ids, label = instance[:4] token_ids, sent_ids, pos_ids, label = instance[:4]
max_len = max(max_len, len(token_ids)) max_len = max(max_len, len(token_ids))
if in_tokens: if in_tokens:
...@@ -179,13 +181,13 @@ class DataProcessor(object): ...@@ -179,13 +181,13 @@ class DataProcessor(object):
if len(batch) > 0: if len(batch) > 0:
yield batch, total_token_num yield batch, total_token_num
def wrapper(): def wrapper():
"""yield batch data to network""" """yield batch data to network"""
for batch_data, total_token_num in batch_reader( for batch_data, total_token_num in batch_reader(
instance_reader, batch_size, self.in_tokens): instance_reader, batch_size, self.in_tokens):
if self.in_tokens: if self.in_tokens:
max_seq = -1 max_seq = -1
else: else:
max_seq = self.max_seq_len max_seq = self.max_seq_len
batch_data = self.generate_batch_data( batch_data = self.generate_batch_data(
batch_data, batch_data,
...@@ -199,7 +201,7 @@ class DataProcessor(object): ...@@ -199,7 +201,7 @@ class DataProcessor(object):
yield batch_data yield batch_data
return wrapper return wrapper
class InputExample(object): class InputExample(object):
"""A single training/test example for simple sequence classification.""" """A single training/test example for simple sequence classification."""
...@@ -250,19 +252,24 @@ class InputFeatures(object): ...@@ -250,19 +252,24 @@ class InputFeatures(object):
self.label_id = label_id self.label_id = label_id
class UDCProcessor(DataProcessor): class UDCProcessor(DataProcessor):
"""Processor for the UDC data set.""" """Processor for the UDC data set."""
def _create_examples(self, lines, set_type):
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
examples = [] examples = []
print("UDC dataset is too big, loading data spent a long time, please wait patiently..................") print(
for (i, line) in enumerate(lines): "UDC dataset is too big, loading data spent a long time, please wait patiently.................."
if len(line) < 3: )
for (i, line) in enumerate(lines):
if len(line) < 3:
print("data format error: %s" % "\t".join(line)) print("data format error: %s" % "\t".join(line))
print("data row contains at least three parts: label\tconv1\t.....\tresponse") print(
"data row contains at least three parts: label\tconv1\t.....\tresponse"
)
continue continue
guid = "%s-%d" % (set_type, i) guid = "%s-%d" % (set_type, i)
text_a = "\t".join(line[1: -1]) text_a = "\t".join(line[1:-1])
text_a = tokenization.convert_to_unicode(text_a) text_a = tokenization.convert_to_unicode(text_a)
text_a = text_a.split('\t') text_a = text_a.split('\t')
text_b = line[-1] text_b = line[-1]
...@@ -273,21 +280,21 @@ class UDCProcessor(DataProcessor): ...@@ -273,21 +280,21 @@ class UDCProcessor(DataProcessor):
guid=guid, text_a=text_a, text_b=text_b, label=label)) guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
examples = [] examples = []
lines = self._read_tsv(os.path.join(data_dir, "train.txt")) lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
examples = self._create_examples(lines, "train") examples = self._create_examples(lines, "train")
return examples return examples
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
examples = [] examples = []
lines = self._read_tsv(os.path.join(data_dir, "dev.txt")) lines = self._read_tsv(os.path.join(data_dir, "dev.txt"))
examples = self._create_examples(lines, "dev") examples = self._create_examples(lines, "dev")
return examples return examples
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
examples = [] examples = []
lines = self._read_tsv(os.path.join(data_dir, "test.txt")) lines = self._read_tsv(os.path.join(data_dir, "test.txt"))
...@@ -295,19 +302,20 @@ class UDCProcessor(DataProcessor): ...@@ -295,19 +302,20 @@ class UDCProcessor(DataProcessor):
return examples return examples
@staticmethod @staticmethod
def get_labels(): def get_labels():
"""See base class.""" """See base class."""
return ["0", "1"] return ["0", "1"]
class SWDAProcessor(DataProcessor): class SWDAProcessor(DataProcessor):
"""Processor for the SWDA data set.""" """Processor for the SWDA data set."""
def _create_examples(self, lines, set_type):
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
examples = create_multi_turn_examples(lines, set_type) examples = create_multi_turn_examples(lines, set_type)
return examples return examples
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
examples = [] examples = []
lines = self._read_tsv(os.path.join(data_dir, "train.txt")) lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
...@@ -329,21 +337,22 @@ class SWDAProcessor(DataProcessor): ...@@ -329,21 +337,22 @@ class SWDAProcessor(DataProcessor):
return examples return examples
@staticmethod @staticmethod
def get_labels(): def get_labels():
"""See base class.""" """See base class."""
labels = range(42) labels = range(42)
labels = [str(label) for label in labels] labels = [str(label) for label in labels]
return labels return labels
class MRDAProcessor(DataProcessor): class MRDAProcessor(DataProcessor):
"""Processor for the MRDA data set.""" """Processor for the MRDA data set."""
def _create_examples(self, lines, set_type):
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
examples = create_multi_turn_examples(lines, set_type) examples = create_multi_turn_examples(lines, set_type)
return examples return examples
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
examples = [] examples = []
lines = self._read_tsv(os.path.join(data_dir, "train.txt")) lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
...@@ -365,22 +374,25 @@ class MRDAProcessor(DataProcessor): ...@@ -365,22 +374,25 @@ class MRDAProcessor(DataProcessor):
return examples return examples
@staticmethod @staticmethod
def get_labels(): def get_labels():
"""See base class.""" """See base class."""
labels = range(42) labels = range(42)
labels = [str(label) for label in labels] labels = [str(label) for label in labels]
return labels return labels
class ATISSlotProcessor(DataProcessor): class ATISSlotProcessor(DataProcessor):
"""Processor for the ATIS Slot data set.""" """Processor for the ATIS Slot data set."""
def _create_examples(self, lines, set_type):
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if len(line) != 2: if len(line) != 2:
print("data format error: %s" % "\t".join(line)) print("data format error: %s" % "\t".join(line))
print("data row contains two parts: conversation_content \t label1 label2 label3") print(
"data row contains two parts: conversation_content \t label1 label2 label3"
)
continue continue
guid = "%s-%d" % (set_type, i) guid = "%s-%d" % (set_type, i)
text_a = line[0] text_a = line[0]
...@@ -392,7 +404,7 @@ class ATISSlotProcessor(DataProcessor): ...@@ -392,7 +404,7 @@ class ATISSlotProcessor(DataProcessor):
guid=guid, text_a=text_a, label=label_list)) guid=guid, text_a=text_a, label=label_list))
return examples return examples
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
examples = [] examples = []
lines = self._read_tsv(os.path.join(data_dir, "train.txt")) lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
...@@ -414,30 +426,30 @@ class ATISSlotProcessor(DataProcessor): ...@@ -414,30 +426,30 @@ class ATISSlotProcessor(DataProcessor):
return examples return examples
@staticmethod @staticmethod
def get_labels(): def get_labels():
"""See base class.""" """See base class."""
labels = range(130) labels = range(130)
labels = [str(label) for label in labels] labels = [str(label) for label in labels]
return labels return labels
class ATISIntentProcessor(DataProcessor): class ATISIntentProcessor(DataProcessor):
"""Processor for the ATIS intent data set.""" """Processor for the ATIS intent data set."""
def _create_examples(self, lines, set_type):
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if len(line) != 2: if len(line) != 2:
print("data format error: %s" % "\t".join(line)) print("data format error: %s" % "\t".join(line))
print("data row contains two parts: label \t conversation_content") print(
"data row contains two parts: label \t conversation_content")
continue continue
guid = "%s-%d" % (set_type, i) guid = "%s-%d" % (set_type, i)
text_a = line[1] text_a = line[1]
text_a = tokenization.convert_to_unicode(text_a) text_a = tokenization.convert_to_unicode(text_a)
label = tokenization.convert_to_unicode(line[0]) label = tokenization.convert_to_unicode(line[0])
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, label=label))
InputExample(
guid=guid, text_a=text_a, label=label))
return examples return examples
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
...@@ -469,53 +481,60 @@ class ATISIntentProcessor(DataProcessor): ...@@ -469,53 +481,60 @@ class ATISIntentProcessor(DataProcessor):
return labels return labels
class DSTC2Processor(DataProcessor): class DSTC2Processor(DataProcessor):
"""Processor for the DSTC2 data set.""" """Processor for the DSTC2 data set."""
def _create_turns(self, conv_example):
def _create_turns(self, conv_example):
"""create multi turn dataset""" """create multi turn dataset"""
samples = [] samples = []
max_turns = 20 max_turns = 20
for i in range(len(conv_example)): for i in range(len(conv_example)):
conv_turns = conv_example[max(i - max_turns, 0): i + 1] conv_turns = conv_example[max(i - max_turns, 0):i + 1]
conv_info = "\1".join([sample[0] for sample in conv_turns]) conv_info = "\1".join([sample[0] for sample in conv_turns])
samples.append((conv_info.split('\1'), conv_example[i][1])) samples.append((conv_info.split('\1'), conv_example[i][1]))
return samples return samples
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for multi-turn dialogue sets.""" """Creates examples for multi-turn dialogue sets."""
examples = [] examples = []
conv_id = -1 conv_id = -1
index = 0 index = 0
conv_example = [] conv_example = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if len(line) != 3: if len(line) != 3:
print("data format error: %s" % "\t".join(line)) print("data format error: %s" % "\t".join(line))
print("data row contains three parts: conversation_content \t question \1 answer \t state1 state2 state3......") print(
"data row contains three parts: conversation_content \t question \1 answer \t state1 state2 state3......"
)
continue continue
conv_no = line[0] conv_no = line[0]
text_a = line[1] text_a = line[1]
label_list = line[2].split() label_list = line[2].split()
if conv_no != conv_id and i != 0: if conv_no != conv_id and i != 0:
samples = self._create_turns(conv_example) samples = self._create_turns(conv_example)
for sample in samples: for sample in samples:
guid = "%s-%s" % (set_type, index) guid = "%s-%s" % (set_type, index)
index += 1 index += 1
history = sample[0] history = sample[0]
dst_label = sample[1] dst_label = sample[1]
examples.append(InputExample(guid=guid, text_a=history, label=dst_label)) examples.append(
InputExample(
guid=guid, text_a=history, label=dst_label))
conv_example = [] conv_example = []
conv_id = conv_no conv_id = conv_no
if i == 0: if i == 0:
conv_id = conv_no conv_id = conv_no
conv_example.append((text_a, label_list)) conv_example.append((text_a, label_list))
if conv_example: if conv_example:
samples = self._create_turns(conv_example) samples = self._create_turns(conv_example)
for sample in samples: for sample in samples:
guid = "%s-%s" % (set_type, index) guid = "%s-%s" % (set_type, index)
index += 1 index += 1
history = sample[0] history = sample[0]
dst_label = sample[1] dst_label = sample[1]
examples.append(InputExample(guid=guid, text_a=history, label=dst_label)) examples.append(
InputExample(
guid=guid, text_a=history, label=dst_label))
return examples return examples
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
...@@ -547,20 +566,22 @@ class DSTC2Processor(DataProcessor): ...@@ -547,20 +566,22 @@ class DSTC2Processor(DataProcessor):
return labels return labels
class MULTIWOZProcessor(DataProcessor): class MULTIWOZProcessor(DataProcessor):
"""Processor for the MULTIWOZ data set.""" """Processor for the MULTIWOZ data set."""
def _create_turns(self, conv_example):
def _create_turns(self, conv_example):
"""create multi turn dataset""" """create multi turn dataset"""
samples = [] samples = []
max_turns = 2 max_turns = 2
for i in range(len(conv_example)): for i in range(len(conv_example)):
prefix_turns = conv_example[max(i - max_turns, 0): i] prefix_turns = conv_example[max(i - max_turns, 0):i]
conv_info = "\1".join([turn[0] for turn in prefix_turns]) conv_info = "\1".join([turn[0] for turn in prefix_turns])
current_turns = conv_example[i][0] current_turns = conv_example[i][0]
samples.append((conv_info.split('\1'), current_turns.split('\1'), conv_example[i][1])) samples.append((conv_info.split('\1'), current_turns.split('\1'),
conv_example[i][1]))
return samples return samples
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for multi-turn dialogue sets.""" """Creates examples for multi-turn dialogue sets."""
examples = [] examples = []
conv_id = -1 conv_id = -1
...@@ -570,7 +591,7 @@ class MULTIWOZProcessor(DataProcessor): ...@@ -570,7 +591,7 @@ class MULTIWOZProcessor(DataProcessor):
conv_no = line[0] conv_no = line[0]
text_a = line[2] text_a = line[2]
label_list = line[1].split() label_list = line[1].split()
if conv_no != conv_id and i != 0: if conv_no != conv_id and i != 0:
samples = self._create_turns(conv_example) samples = self._create_turns(conv_example)
for sample in samples: for sample in samples:
guid = "%s-%s" % (set_type, index) guid = "%s-%s" % (set_type, index)
...@@ -578,13 +599,18 @@ class MULTIWOZProcessor(DataProcessor): ...@@ -578,13 +599,18 @@ class MULTIWOZProcessor(DataProcessor):
history = sample[0] history = sample[0]
current = sample[1] current = sample[1]
dst_label = sample[2] dst_label = sample[2]
examples.append(InputExample(guid=guid, text_a=history, text_b=current, label=dst_label)) examples.append(
InputExample(
guid=guid,
text_a=history,
text_b=current,
label=dst_label))
conv_example = [] conv_example = []
conv_id = conv_no conv_id = conv_no
if i == 0: if i == 0:
conv_id = conv_no conv_id = conv_no
conv_example.append((text_a, label_list)) conv_example.append((text_a, label_list))
if conv_example: if conv_example:
samples = self._create_turns(conv_example) samples = self._create_turns(conv_example)
for sample in samples: for sample in samples:
guid = "%s-%s" % (set_type, index) guid = "%s-%s" % (set_type, index)
...@@ -592,10 +618,15 @@ class MULTIWOZProcessor(DataProcessor): ...@@ -592,10 +618,15 @@ class MULTIWOZProcessor(DataProcessor):
history = sample[0] history = sample[0]
current = sample[1] current = sample[1]
dst_label = sample[2] dst_label = sample[2]
examples.append(InputExample(guid=guid, text_a=history, text_b=current, label=dst_label)) examples.append(
InputExample(
guid=guid,
text_a=history,
text_b=current,
label=dst_label))
return examples return examples
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
examples = [] examples = []
lines = self._read_tsv(os.path.join(data_dir, "train.txt")) lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
...@@ -624,34 +655,38 @@ class MULTIWOZProcessor(DataProcessor): ...@@ -624,34 +655,38 @@ class MULTIWOZProcessor(DataProcessor):
return labels return labels
def create_dialogue_examples(conv): def create_dialogue_examples(conv):
"""Creates dialogue sample""" """Creates dialogue sample"""
samples = [] samples = []
for i in range(len(conv)): for i in range(len(conv)):
cur_txt = "%s : %s" % (conv[i][2], conv[i][3]) cur_txt = "%s : %s" % (conv[i][2], conv[i][3])
pre_txt = ["%s : %s" % (c[2], c[3]) for c in conv[max(0, i - 5): i]] pre_txt = ["%s : %s" % (c[2], c[3]) for c in conv[max(0, i - 5):i]]
suf_txt = ["%s : %s" % (c[2], c[3]) for c in conv[i + 1: min(len(conv), i + 3)]] suf_txt = [
"%s : %s" % (c[2], c[3]) for c in conv[i + 1:min(len(conv), i + 3)]
]
sample = [conv[i][1], pre_txt, cur_txt, suf_txt] sample = [conv[i][1], pre_txt, cur_txt, suf_txt]
samples.append(sample) samples.append(sample)
return samples return samples
def create_multi_turn_examples(lines, set_type): def create_multi_turn_examples(lines, set_type):
"""Creates examples for multi-turn dialogue sets.""" """Creates examples for multi-turn dialogue sets."""
conv_id = -1 conv_id = -1
examples = [] examples = []
conv_example = [] conv_example = []
index = 0 index = 0
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if len(line) != 4: if len(line) != 4:
print("data format error: %s" % "\t".join(line)) print("data format error: %s" % "\t".join(line))
print("data row contains four parts: conversation_id \t label \t caller \t conversation_content") print(
"data row contains four parts: conversation_id \t label \t caller \t conversation_content"
)
continue continue
tokens = line tokens = line
conv_no = tokens[0] conv_no = tokens[0]
if conv_no != conv_id and i != 0: if conv_no != conv_id and i != 0:
samples = create_dialogue_examples(conv_example) samples = create_dialogue_examples(conv_example)
for sample in samples: for sample in samples:
guid = "%s-%s" % (set_type, index) guid = "%s-%s" % (set_type, index)
index += 1 index += 1
label = sample[0] label = sample[0]
...@@ -659,15 +694,20 @@ def create_multi_turn_examples(lines, set_type): ...@@ -659,15 +694,20 @@ def create_multi_turn_examples(lines, set_type):
text_b = sample[2] text_b = sample[2]
text_c = sample[3] text_c = sample[3]
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, text_c=text_c, label=label)) InputExample(
guid=guid,
text_a=text_a,
text_b=text_b,
text_c=text_c,
label=label))
conv_example = [] conv_example = []
conv_id = conv_no conv_id = conv_no
if i == 0: if i == 0:
conv_id = conv_no conv_id = conv_no
conv_example.append(tokens) conv_example.append(tokens)
if conv_example: if conv_example:
samples = create_dialogue_examples(conv_example) samples = create_dialogue_examples(conv_example)
for sample in samples: for sample in samples:
guid = "%s-%s" % (set_type, index) guid = "%s-%s" % (set_type, index)
index += 1 index += 1
label = sample[0] label = sample[0]
...@@ -675,62 +715,67 @@ def create_multi_turn_examples(lines, set_type): ...@@ -675,62 +715,67 @@ def create_multi_turn_examples(lines, set_type):
text_b = sample[2] text_b = sample[2]
text_c = sample[3] text_c = sample[3]
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, text_c=text_c, label=label)) InputExample(
guid=guid,
text_a=text_a,
text_b=text_b,
text_c=text_c,
label=label))
return examples return examples
def convert_tokens(tokens, sep_id, tokenizer): def convert_tokens(tokens, sep_id, tokenizer):
"""Converts tokens to ids""" """Converts tokens to ids"""
tokens_ids = [] tokens_ids = []
if not tokens: if not tokens:
return tokens_ids return tokens_ids
if isinstance(tokens, list): if isinstance(tokens, list):
for text in tokens: for text in tokens:
tok_text = tokenizer.tokenize(text) tok_text = tokenizer.tokenize(text)
ids = tokenizer.convert_tokens_to_ids(tok_text) ids = tokenizer.convert_tokens_to_ids(tok_text)
tokens_ids.extend(ids) tokens_ids.extend(ids)
tokens_ids.append(sep_id) tokens_ids.append(sep_id)
tokens_ids = tokens_ids[: -1] tokens_ids = tokens_ids[:-1]
else: else:
tok_text = tokenizer.tokenize(tokens) tok_text = tokenizer.tokenize(tokens)
tokens_ids = tokenizer.convert_tokens_to_ids(tok_text) tokens_ids = tokenizer.convert_tokens_to_ids(tok_text)
return tokens_ids return tokens_ids
def convert_single_example(ex_index, example, label_list, max_seq_length, def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer, task_name): tokenizer, task_name):
"""Converts a single DA `InputExample` into a single `InputFeatures`.""" """Converts a single DA `InputExample` into a single `InputFeatures`."""
label_map = {} label_map = {}
SEP = 102 SEP = 102
CLS = 101 CLS = 101
if task_name == 'udc': if task_name == 'udc':
INNER_SEP = 1 INNER_SEP = 1
limit_length = 60 limit_length = 60
elif task_name == 'swda': elif task_name == 'swda':
INNER_SEP = 1 INNER_SEP = 1
limit_length = 50 limit_length = 50
elif task_name == 'mrda': elif task_name == 'mrda':
INNER_SEP = 1 INNER_SEP = 1
limit_length = 50 limit_length = 50
elif task_name == 'atis_intent': elif task_name == 'atis_intent':
INNER_SEP = -1 INNER_SEP = -1
limit_length = -1 limit_length = -1
elif task_name == 'atis_slot': elif task_name == 'atis_slot':
INNER_SEP = -1 INNER_SEP = -1
limit_length = -1 limit_length = -1
elif task_name == 'dstc2': elif task_name == 'dstc2':
INNER_SEP = 1 INNER_SEP = 1
limit_length = -1 limit_length = -1
elif task_name == 'dstc2_asr': elif task_name == 'dstc2_asr':
INNER_SEP = 1 INNER_SEP = 1
limit_length = -1 limit_length = -1
elif task_name == 'multi-woz': elif task_name == 'multi-woz':
INNER_SEP = 1 INNER_SEP = 1
limit_length = 200 limit_length = 200
for (i, label) in enumerate(label_list): for (i, label) in enumerate(label_list):
label_map[label] = i label_map[label] = i
tokens_a = example.text_a tokens_a = example.text_a
tokens_b = example.text_b tokens_b = example.text_b
tokens_c = example.text_c tokens_c = example.text_c
...@@ -739,30 +784,36 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -739,30 +784,36 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
tokens_b_ids = convert_tokens(tokens_b, INNER_SEP, tokenizer) tokens_b_ids = convert_tokens(tokens_b, INNER_SEP, tokenizer)
tokens_c_ids = convert_tokens(tokens_c, INNER_SEP, tokenizer) tokens_c_ids = convert_tokens(tokens_c, INNER_SEP, tokenizer)
if tokens_b_ids: if tokens_b_ids:
tokens_b_ids = tokens_b_ids[:min(limit_length, len(tokens_b_ids))] tokens_b_ids = tokens_b_ids[:min(limit_length, len(tokens_b_ids))]
else: else:
if len(tokens_a_ids) > max_seq_length - 2: if len(tokens_a_ids) > max_seq_length - 2:
tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length + 2:] tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length + 2:]
if not tokens_c_ids: if not tokens_c_ids:
if len(tokens_a_ids) > max_seq_length - len(tokens_b_ids) - 3: if len(tokens_a_ids) > max_seq_length - len(tokens_b_ids) - 3:
tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length + len(tokens_b_ids) + 3:] tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length +
else: len(tokens_b_ids) + 3:]
if len(tokens_a_ids) + len(tokens_b_ids) + len(tokens_c_ids) > max_seq_length - 4: else:
if len(tokens_a_ids) + len(tokens_b_ids) + len(
tokens_c_ids) > max_seq_length - 4:
left_num = max_seq_length - len(tokens_b_ids) - 4 left_num = max_seq_length - len(tokens_b_ids) - 4
if len(tokens_a_ids) > len(tokens_c_ids): if len(tokens_a_ids) > len(tokens_c_ids):
suffix_num = int(left_num / 2) suffix_num = int(left_num / 2)
tokens_c_ids = tokens_c_ids[: min(len(tokens_c_ids), suffix_num)] tokens_c_ids = tokens_c_ids[:min(len(tokens_c_ids), suffix_num)]
prefix_num = left_num - len(tokens_c_ids) prefix_num = left_num - len(tokens_c_ids)
tokens_a_ids = tokens_a_ids[max(0, len(tokens_a_ids) - prefix_num):] tokens_a_ids = tokens_a_ids[max(
else: 0, len(tokens_a_ids) - prefix_num):]
if not tokens_a_ids: else:
tokens_c_ids = tokens_c_ids[max(0, len(tokens_c_ids) - left_num):] if not tokens_a_ids:
else: tokens_c_ids = tokens_c_ids[max(
0, len(tokens_c_ids) - left_num):]
else:
prefix_num = int(left_num / 2) prefix_num = int(left_num / 2)
tokens_a_ids = tokens_a_ids[max(0, len(tokens_a_ids) - prefix_num):] tokens_a_ids = tokens_a_ids[max(
0, len(tokens_a_ids) - prefix_num):]
suffix_num = left_num - len(tokens_a_ids) suffix_num = left_num - len(tokens_a_ids)
tokens_c_ids = tokens_c_ids[: min(len(tokens_c_ids), suffix_num)] tokens_c_ids = tokens_c_ids[:min(
len(tokens_c_ids), suffix_num)]
input_ids = [] input_ids = []
segment_ids = [] segment_ids = []
...@@ -772,31 +823,31 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -772,31 +823,31 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
segment_ids.extend([0] * len(tokens_a_ids)) segment_ids.extend([0] * len(tokens_a_ids))
input_ids.append(SEP) input_ids.append(SEP)
segment_ids.append(0) segment_ids.append(0)
if tokens_b_ids: if tokens_b_ids:
input_ids.extend(tokens_b_ids) input_ids.extend(tokens_b_ids)
segment_ids.extend([1] * len(tokens_b_ids)) segment_ids.extend([1] * len(tokens_b_ids))
input_ids.append(SEP) input_ids.append(SEP)
segment_ids.append(1) segment_ids.append(1)
if tokens_c_ids: if tokens_c_ids:
input_ids.extend(tokens_c_ids) input_ids.extend(tokens_c_ids)
segment_ids.extend([0] * len(tokens_c_ids)) segment_ids.extend([0] * len(tokens_c_ids))
input_ids.append(SEP) input_ids.append(SEP)
segment_ids.append(0) segment_ids.append(0)
input_mask = [1] * len(input_ids) input_mask = [1] * len(input_ids)
if task_name == 'atis_slot': if task_name == 'atis_slot':
label_id = [0] + [label_map[l] for l in example.label] + [0] label_id = [0] + [label_map[l] for l in example.label] + [0]
elif task_name in ['dstc2', 'dstc2_asr', 'multi-woz']: elif task_name in ['dstc2', 'dstc2_asr', 'multi-woz']:
label_id_enty = [label_map[l] for l in example.label] label_id_enty = [label_map[l] for l in example.label]
label_id = [] label_id = []
for i in range(len(label_map)): for i in range(len(label_map)):
if i in label_id_enty: if i in label_id_enty:
label_id.append(1) label_id.append(1)
else: else:
label_id.append(0) label_id.append(0)
else: else:
label_id = label_map[example.label] label_id = label_map[example.label]
if ex_index < 5: if ex_index < 5:
print("*** Example ***") print("*** Example ***")
print("guid: %s" % (example.guid)) print("guid: %s" % (example.guid))
...@@ -809,7 +860,5 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -809,7 +860,5 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
input_mask=input_mask, input_mask=input_mask,
segment_ids=segment_ids, segment_ids=segment_ids,
label_id=label_id) label_id=label_id)
return feature
return feature
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册