未验证 提交 300be16c 编写于 作者: P pkpk 提交者: GitHub

test=develop (#4175)

上级 071dc299
......@@ -23,19 +23,21 @@ import numpy as np
from dgu import tokenization
from dgu.batching import prepare_batch_data
reload(sys)
sys.setdefaultencoding('utf-8')
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding('utf-8')
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def __init__(self,
data_dir,
vocab_path,
max_seq_len,
do_lower_case,
def __init__(self,
data_dir,
vocab_path,
max_seq_len,
do_lower_case,
in_tokens,
task_name,
task_name,
random_seed=None):
self.data_dir = data_dir
self.max_seq_len = max_seq_len
......@@ -92,7 +94,7 @@ class DataProcessor(object):
mask_id=-1,
return_input_mask=True,
return_max_len=False,
return_num_token=False):
return_num_token=False):
"""generate batch data"""
return prepare_batch_data(
self.task_name,
......@@ -114,7 +116,7 @@ class DataProcessor(object):
f = io.open(input_file, "r", encoding="utf8")
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
for line in reader:
lines.append(line)
return lines
......@@ -147,21 +149,21 @@ class DataProcessor(object):
raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].")
def instance_reader():
def instance_reader():
"""generate instance data"""
if shuffle:
np.random.shuffle(examples)
for (index, example) in enumerate(examples):
feature = self.convert_example(
index, example,
self.get_labels(), self.max_seq_len, self.tokenizer)
for (index, example) in enumerate(examples):
feature = self.convert_example(index, example,
self.get_labels(),
self.max_seq_len, self.tokenizer)
instance = self.generate_instance(feature)
yield instance
def batch_reader(reader, batch_size, in_tokens):
def batch_reader(reader, batch_size, in_tokens):
"""read batch data"""
batch, total_token_num, max_len = [], 0, 0
for instance in reader():
for instance in reader():
token_ids, sent_ids, pos_ids, label = instance[:4]
max_len = max(max_len, len(token_ids))
if in_tokens:
......@@ -179,13 +181,13 @@ class DataProcessor(object):
if len(batch) > 0:
yield batch, total_token_num
def wrapper():
def wrapper():
"""yield batch data to network"""
for batch_data, total_token_num in batch_reader(
instance_reader, batch_size, self.in_tokens):
if self.in_tokens:
instance_reader, batch_size, self.in_tokens):
if self.in_tokens:
max_seq = -1
else:
else:
max_seq = self.max_seq_len
batch_data = self.generate_batch_data(
batch_data,
......@@ -199,7 +201,7 @@ class DataProcessor(object):
yield batch_data
return wrapper
class InputExample(object):
"""A single training/test example for simple sequence classification."""
......@@ -250,19 +252,24 @@ class InputFeatures(object):
self.label_id = label_id
class UDCProcessor(DataProcessor):
class UDCProcessor(DataProcessor):
"""Processor for the UDC data set."""
def _create_examples(self, lines, set_type):
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
print("UDC dataset is too big, loading data spent a long time, please wait patiently..................")
for (i, line) in enumerate(lines):
if len(line) < 3:
print(
"UDC dataset is too big, loading data spent a long time, please wait patiently.................."
)
for (i, line) in enumerate(lines):
if len(line) < 3:
print("data format error: %s" % "\t".join(line))
print("data row contains at least three parts: label\tconv1\t.....\tresponse")
print(
"data row contains at least three parts: label\tconv1\t.....\tresponse"
)
continue
guid = "%s-%d" % (set_type, i)
text_a = "\t".join(line[1: -1])
text_a = "\t".join(line[1:-1])
text_a = tokenization.convert_to_unicode(text_a)
text_a = text_a.split('\t')
text_b = line[-1]
......@@ -273,21 +280,21 @@ class UDCProcessor(DataProcessor):
guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_train_examples(self, data_dir):
def get_train_examples(self, data_dir):
"""See base class."""
examples = []
lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
examples = self._create_examples(lines, "train")
return examples
def get_dev_examples(self, data_dir):
def get_dev_examples(self, data_dir):
"""See base class."""
examples = []
lines = self._read_tsv(os.path.join(data_dir, "dev.txt"))
examples = self._create_examples(lines, "dev")
return examples
def get_test_examples(self, data_dir):
def get_test_examples(self, data_dir):
"""See base class."""
examples = []
lines = self._read_tsv(os.path.join(data_dir, "test.txt"))
......@@ -295,19 +302,20 @@ class UDCProcessor(DataProcessor):
return examples
@staticmethod
def get_labels():
def get_labels():
"""See base class."""
return ["0", "1"]
class SWDAProcessor(DataProcessor):
class SWDAProcessor(DataProcessor):
"""Processor for the SWDA data set."""
def _create_examples(self, lines, set_type):
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = create_multi_turn_examples(lines, set_type)
return examples
def get_train_examples(self, data_dir):
def get_train_examples(self, data_dir):
"""See base class."""
examples = []
lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
......@@ -329,21 +337,22 @@ class SWDAProcessor(DataProcessor):
return examples
@staticmethod
def get_labels():
def get_labels():
"""See base class."""
labels = range(42)
labels = [str(label) for label in labels]
return labels
class MRDAProcessor(DataProcessor):
class MRDAProcessor(DataProcessor):
"""Processor for the MRDA data set."""
def _create_examples(self, lines, set_type):
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = create_multi_turn_examples(lines, set_type)
return examples
def get_train_examples(self, data_dir):
def get_train_examples(self, data_dir):
"""See base class."""
examples = []
lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
......@@ -365,22 +374,25 @@ class MRDAProcessor(DataProcessor):
return examples
@staticmethod
def get_labels():
def get_labels():
"""See base class."""
labels = range(42)
labels = [str(label) for label in labels]
return labels
class ATISSlotProcessor(DataProcessor):
class ATISSlotProcessor(DataProcessor):
"""Processor for the ATIS Slot data set."""
def _create_examples(self, lines, set_type):
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if len(line) != 2:
for (i, line) in enumerate(lines):
if len(line) != 2:
print("data format error: %s" % "\t".join(line))
print("data row contains two parts: conversation_content \t label1 label2 label3")
print(
"data row contains two parts: conversation_content \t label1 label2 label3"
)
continue
guid = "%s-%d" % (set_type, i)
text_a = line[0]
......@@ -392,7 +404,7 @@ class ATISSlotProcessor(DataProcessor):
guid=guid, text_a=text_a, label=label_list))
return examples
def get_train_examples(self, data_dir):
def get_train_examples(self, data_dir):
"""See base class."""
examples = []
lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
......@@ -414,30 +426,30 @@ class ATISSlotProcessor(DataProcessor):
return examples
@staticmethod
def get_labels():
def get_labels():
"""See base class."""
labels = range(130)
labels = [str(label) for label in labels]
return labels
class ATISIntentProcessor(DataProcessor):
class ATISIntentProcessor(DataProcessor):
"""Processor for the ATIS intent data set."""
def _create_examples(self, lines, set_type):
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if len(line) != 2:
for (i, line) in enumerate(lines):
if len(line) != 2:
print("data format error: %s" % "\t".join(line))
print("data row contains two parts: label \t conversation_content")
print(
"data row contains two parts: label \t conversation_content")
continue
guid = "%s-%d" % (set_type, i)
text_a = line[1]
text_a = tokenization.convert_to_unicode(text_a)
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(
guid=guid, text_a=text_a, label=label))
examples.append(InputExample(guid=guid, text_a=text_a, label=label))
return examples
def get_train_examples(self, data_dir):
......@@ -469,53 +481,60 @@ class ATISIntentProcessor(DataProcessor):
return labels
class DSTC2Processor(DataProcessor):
class DSTC2Processor(DataProcessor):
"""Processor for the DSTC2 data set."""
def _create_turns(self, conv_example):
def _create_turns(self, conv_example):
"""create multi turn dataset"""
samples = []
max_turns = 20
for i in range(len(conv_example)):
conv_turns = conv_example[max(i - max_turns, 0): i + 1]
for i in range(len(conv_example)):
conv_turns = conv_example[max(i - max_turns, 0):i + 1]
conv_info = "\1".join([sample[0] for sample in conv_turns])
samples.append((conv_info.split('\1'), conv_example[i][1]))
return samples
def _create_examples(self, lines, set_type):
def _create_examples(self, lines, set_type):
"""Creates examples for multi-turn dialogue sets."""
examples = []
conv_id = -1
index = 0
conv_example = []
for (i, line) in enumerate(lines):
if len(line) != 3:
for (i, line) in enumerate(lines):
if len(line) != 3:
print("data format error: %s" % "\t".join(line))
print("data row contains three parts: conversation_content \t question \1 answer \t state1 state2 state3......")
print(
"data row contains three parts: conversation_content \t question \1 answer \t state1 state2 state3......"
)
continue
conv_no = line[0]
text_a = line[1]
label_list = line[2].split()
if conv_no != conv_id and i != 0:
if conv_no != conv_id and i != 0:
samples = self._create_turns(conv_example)
for sample in samples:
for sample in samples:
guid = "%s-%s" % (set_type, index)
index += 1
history = sample[0]
dst_label = sample[1]
examples.append(InputExample(guid=guid, text_a=history, label=dst_label))
examples.append(
InputExample(
guid=guid, text_a=history, label=dst_label))
conv_example = []
conv_id = conv_no
if i == 0:
conv_id = conv_no
conv_example.append((text_a, label_list))
if conv_example:
if conv_example:
samples = self._create_turns(conv_example)
for sample in samples:
guid = "%s-%s" % (set_type, index)
index += 1
history = sample[0]
dst_label = sample[1]
examples.append(InputExample(guid=guid, text_a=history, label=dst_label))
examples.append(
InputExample(
guid=guid, text_a=history, label=dst_label))
return examples
def get_train_examples(self, data_dir):
......@@ -547,20 +566,22 @@ class DSTC2Processor(DataProcessor):
return labels
class MULTIWOZProcessor(DataProcessor):
class MULTIWOZProcessor(DataProcessor):
"""Processor for the MULTIWOZ data set."""
def _create_turns(self, conv_example):
def _create_turns(self, conv_example):
"""create multi turn dataset"""
samples = []
max_turns = 2
for i in range(len(conv_example)):
prefix_turns = conv_example[max(i - max_turns, 0): i]
prefix_turns = conv_example[max(i - max_turns, 0):i]
conv_info = "\1".join([turn[0] for turn in prefix_turns])
current_turns = conv_example[i][0]
samples.append((conv_info.split('\1'), current_turns.split('\1'), conv_example[i][1]))
samples.append((conv_info.split('\1'), current_turns.split('\1'),
conv_example[i][1]))
return samples
def _create_examples(self, lines, set_type):
def _create_examples(self, lines, set_type):
"""Creates examples for multi-turn dialogue sets."""
examples = []
conv_id = -1
......@@ -570,7 +591,7 @@ class MULTIWOZProcessor(DataProcessor):
conv_no = line[0]
text_a = line[2]
label_list = line[1].split()
if conv_no != conv_id and i != 0:
if conv_no != conv_id and i != 0:
samples = self._create_turns(conv_example)
for sample in samples:
guid = "%s-%s" % (set_type, index)
......@@ -578,13 +599,18 @@ class MULTIWOZProcessor(DataProcessor):
history = sample[0]
current = sample[1]
dst_label = sample[2]
examples.append(InputExample(guid=guid, text_a=history, text_b=current, label=dst_label))
examples.append(
InputExample(
guid=guid,
text_a=history,
text_b=current,
label=dst_label))
conv_example = []
conv_id = conv_no
if i == 0:
if i == 0:
conv_id = conv_no
conv_example.append((text_a, label_list))
if conv_example:
if conv_example:
samples = self._create_turns(conv_example)
for sample in samples:
guid = "%s-%s" % (set_type, index)
......@@ -592,10 +618,15 @@ class MULTIWOZProcessor(DataProcessor):
history = sample[0]
current = sample[1]
dst_label = sample[2]
examples.append(InputExample(guid=guid, text_a=history, text_b=current, label=dst_label))
examples.append(
InputExample(
guid=guid,
text_a=history,
text_b=current,
label=dst_label))
return examples
def get_train_examples(self, data_dir):
def get_train_examples(self, data_dir):
"""See base class."""
examples = []
lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
......@@ -624,34 +655,38 @@ class MULTIWOZProcessor(DataProcessor):
return labels
def create_dialogue_examples(conv):
def create_dialogue_examples(conv):
"""Creates dialogue sample"""
samples = []
for i in range(len(conv)):
for i in range(len(conv)):
cur_txt = "%s : %s" % (conv[i][2], conv[i][3])
pre_txt = ["%s : %s" % (c[2], c[3]) for c in conv[max(0, i - 5): i]]
suf_txt = ["%s : %s" % (c[2], c[3]) for c in conv[i + 1: min(len(conv), i + 3)]]
pre_txt = ["%s : %s" % (c[2], c[3]) for c in conv[max(0, i - 5):i]]
suf_txt = [
"%s : %s" % (c[2], c[3]) for c in conv[i + 1:min(len(conv), i + 3)]
]
sample = [conv[i][1], pre_txt, cur_txt, suf_txt]
samples.append(sample)
return samples
def create_multi_turn_examples(lines, set_type):
def create_multi_turn_examples(lines, set_type):
"""Creates examples for multi-turn dialogue sets."""
conv_id = -1
examples = []
conv_example = []
index = 0
for (i, line) in enumerate(lines):
if len(line) != 4:
for (i, line) in enumerate(lines):
if len(line) != 4:
print("data format error: %s" % "\t".join(line))
print("data row contains four parts: conversation_id \t label \t caller \t conversation_content")
print(
"data row contains four parts: conversation_id \t label \t caller \t conversation_content"
)
continue
tokens = line
conv_no = tokens[0]
if conv_no != conv_id and i != 0:
if conv_no != conv_id and i != 0:
samples = create_dialogue_examples(conv_example)
for sample in samples:
for sample in samples:
guid = "%s-%s" % (set_type, index)
index += 1
label = sample[0]
......@@ -659,15 +694,20 @@ def create_multi_turn_examples(lines, set_type):
text_b = sample[2]
text_c = sample[3]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, text_c=text_c, label=label))
InputExample(
guid=guid,
text_a=text_a,
text_b=text_b,
text_c=text_c,
label=label))
conv_example = []
conv_id = conv_no
if i == 0:
if i == 0:
conv_id = conv_no
conv_example.append(tokens)
if conv_example:
if conv_example:
samples = create_dialogue_examples(conv_example)
for sample in samples:
for sample in samples:
guid = "%s-%s" % (set_type, index)
index += 1
label = sample[0]
......@@ -675,62 +715,67 @@ def create_multi_turn_examples(lines, set_type):
text_b = sample[2]
text_c = sample[3]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, text_c=text_c, label=label))
InputExample(
guid=guid,
text_a=text_a,
text_b=text_b,
text_c=text_c,
label=label))
return examples
def convert_tokens(tokens, sep_id, tokenizer):
def convert_tokens(tokens, sep_id, tokenizer):
"""Converts tokens to ids"""
tokens_ids = []
if not tokens:
if not tokens:
return tokens_ids
if isinstance(tokens, list):
for text in tokens:
if isinstance(tokens, list):
for text in tokens:
tok_text = tokenizer.tokenize(text)
ids = tokenizer.convert_tokens_to_ids(tok_text)
tokens_ids.extend(ids)
tokens_ids.append(sep_id)
tokens_ids = tokens_ids[: -1]
else:
tokens_ids = tokens_ids[:-1]
else:
tok_text = tokenizer.tokenize(tokens)
tokens_ids = tokenizer.convert_tokens_to_ids(tok_text)
return tokens_ids
def convert_single_example(ex_index, example, label_list, max_seq_length,
def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer, task_name):
"""Converts a single DA `InputExample` into a single `InputFeatures`."""
label_map = {}
SEP = 102
SEP = 102
CLS = 101
if task_name == 'udc':
if task_name == 'udc':
INNER_SEP = 1
limit_length = 60
elif task_name == 'swda':
elif task_name == 'swda':
INNER_SEP = 1
limit_length = 50
elif task_name == 'mrda':
elif task_name == 'mrda':
INNER_SEP = 1
limit_length = 50
elif task_name == 'atis_intent':
elif task_name == 'atis_intent':
INNER_SEP = -1
limit_length = -1
elif task_name == 'atis_slot':
elif task_name == 'atis_slot':
INNER_SEP = -1
limit_length = -1
elif task_name == 'dstc2':
elif task_name == 'dstc2':
INNER_SEP = 1
limit_length = -1
elif task_name == 'dstc2_asr':
elif task_name == 'dstc2_asr':
INNER_SEP = 1
limit_length = -1
elif task_name == 'multi-woz':
elif task_name == 'multi-woz':
INNER_SEP = 1
limit_length = 200
for (i, label) in enumerate(label_list):
for (i, label) in enumerate(label_list):
label_map[label] = i
tokens_a = example.text_a
tokens_b = example.text_b
tokens_c = example.text_c
......@@ -739,30 +784,36 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
tokens_b_ids = convert_tokens(tokens_b, INNER_SEP, tokenizer)
tokens_c_ids = convert_tokens(tokens_c, INNER_SEP, tokenizer)
if tokens_b_ids:
if tokens_b_ids:
tokens_b_ids = tokens_b_ids[:min(limit_length, len(tokens_b_ids))]
else:
else:
if len(tokens_a_ids) > max_seq_length - 2:
tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length + 2:]
if not tokens_c_ids:
if len(tokens_a_ids) > max_seq_length - len(tokens_b_ids) - 3:
tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length + len(tokens_b_ids) + 3:]
else:
if len(tokens_a_ids) + len(tokens_b_ids) + len(tokens_c_ids) > max_seq_length - 4:
if not tokens_c_ids:
if len(tokens_a_ids) > max_seq_length - len(tokens_b_ids) - 3:
tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length +
len(tokens_b_ids) + 3:]
else:
if len(tokens_a_ids) + len(tokens_b_ids) + len(
tokens_c_ids) > max_seq_length - 4:
left_num = max_seq_length - len(tokens_b_ids) - 4
if len(tokens_a_ids) > len(tokens_c_ids):
if len(tokens_a_ids) > len(tokens_c_ids):
suffix_num = int(left_num / 2)
tokens_c_ids = tokens_c_ids[: min(len(tokens_c_ids), suffix_num)]
tokens_c_ids = tokens_c_ids[:min(len(tokens_c_ids), suffix_num)]
prefix_num = left_num - len(tokens_c_ids)
tokens_a_ids = tokens_a_ids[max(0, len(tokens_a_ids) - prefix_num):]
else:
if not tokens_a_ids:
tokens_c_ids = tokens_c_ids[max(0, len(tokens_c_ids) - left_num):]
else:
tokens_a_ids = tokens_a_ids[max(
0, len(tokens_a_ids) - prefix_num):]
else:
if not tokens_a_ids:
tokens_c_ids = tokens_c_ids[max(
0, len(tokens_c_ids) - left_num):]
else:
prefix_num = int(left_num / 2)
tokens_a_ids = tokens_a_ids[max(0, len(tokens_a_ids) - prefix_num):]
tokens_a_ids = tokens_a_ids[max(
0, len(tokens_a_ids) - prefix_num):]
suffix_num = left_num - len(tokens_a_ids)
tokens_c_ids = tokens_c_ids[: min(len(tokens_c_ids), suffix_num)]
tokens_c_ids = tokens_c_ids[:min(
len(tokens_c_ids), suffix_num)]
input_ids = []
segment_ids = []
......@@ -772,31 +823,31 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
segment_ids.extend([0] * len(tokens_a_ids))
input_ids.append(SEP)
segment_ids.append(0)
if tokens_b_ids:
if tokens_b_ids:
input_ids.extend(tokens_b_ids)
segment_ids.extend([1] * len(tokens_b_ids))
input_ids.append(SEP)
segment_ids.append(1)
if tokens_c_ids:
if tokens_c_ids:
input_ids.extend(tokens_c_ids)
segment_ids.extend([0] * len(tokens_c_ids))
input_ids.append(SEP)
segment_ids.append(0)
input_mask = [1] * len(input_ids)
if task_name == 'atis_slot':
if task_name == 'atis_slot':
label_id = [0] + [label_map[l] for l in example.label] + [0]
elif task_name in ['dstc2', 'dstc2_asr', 'multi-woz']:
elif task_name in ['dstc2', 'dstc2_asr', 'multi-woz']:
label_id_enty = [label_map[l] for l in example.label]
label_id = []
for i in range(len(label_map)):
if i in label_id_enty:
for i in range(len(label_map)):
if i in label_id_enty:
label_id.append(1)
else:
else:
label_id.append(0)
else:
else:
label_id = label_map[example.label]
if ex_index < 5:
print("*** Example ***")
print("guid: %s" % (example.guid))
......@@ -809,7 +860,5 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id)
return feature
return feature
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册