未验证 提交 300be16c 编写于 作者: P pkpk 提交者: GitHub

test=develop (#4175)

上级 071dc299
...@@ -23,8 +23,10 @@ import numpy as np ...@@ -23,8 +23,10 @@ import numpy as np
from dgu import tokenization from dgu import tokenization
from dgu.batching import prepare_batch_data from dgu.batching import prepare_batch_data
reload(sys) if sys.version[0] == '2':
sys.setdefaultencoding('utf-8') reload(sys)
sys.setdefaultencoding('utf-8')
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for data converters for sequence classification data sets."""
...@@ -152,9 +154,9 @@ class DataProcessor(object): ...@@ -152,9 +154,9 @@ class DataProcessor(object):
if shuffle: if shuffle:
np.random.shuffle(examples) np.random.shuffle(examples)
for (index, example) in enumerate(examples): for (index, example) in enumerate(examples):
feature = self.convert_example( feature = self.convert_example(index, example,
index, example, self.get_labels(),
self.get_labels(), self.max_seq_len, self.tokenizer) self.max_seq_len, self.tokenizer)
instance = self.generate_instance(feature) instance = self.generate_instance(feature)
yield instance yield instance
...@@ -252,17 +254,22 @@ class InputFeatures(object): ...@@ -252,17 +254,22 @@ class InputFeatures(object):
class UDCProcessor(DataProcessor): class UDCProcessor(DataProcessor):
"""Processor for the UDC data set.""" """Processor for the UDC data set."""
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
examples = [] examples = []
print("UDC dataset is too big, loading data spent a long time, please wait patiently..................") print(
"UDC dataset is too big, loading data spent a long time, please wait patiently.................."
)
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if len(line) < 3: if len(line) < 3:
print("data format error: %s" % "\t".join(line)) print("data format error: %s" % "\t".join(line))
print("data row contains at least three parts: label\tconv1\t.....\tresponse") print(
"data row contains at least three parts: label\tconv1\t.....\tresponse"
)
continue continue
guid = "%s-%d" % (set_type, i) guid = "%s-%d" % (set_type, i)
text_a = "\t".join(line[1: -1]) text_a = "\t".join(line[1:-1])
text_a = tokenization.convert_to_unicode(text_a) text_a = tokenization.convert_to_unicode(text_a)
text_a = text_a.split('\t') text_a = text_a.split('\t')
text_b = line[-1] text_b = line[-1]
...@@ -302,6 +309,7 @@ class UDCProcessor(DataProcessor): ...@@ -302,6 +309,7 @@ class UDCProcessor(DataProcessor):
class SWDAProcessor(DataProcessor): class SWDAProcessor(DataProcessor):
"""Processor for the SWDA data set.""" """Processor for the SWDA data set."""
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
examples = create_multi_turn_examples(lines, set_type) examples = create_multi_turn_examples(lines, set_type)
...@@ -338,6 +346,7 @@ class SWDAProcessor(DataProcessor): ...@@ -338,6 +346,7 @@ class SWDAProcessor(DataProcessor):
class MRDAProcessor(DataProcessor): class MRDAProcessor(DataProcessor):
"""Processor for the MRDA data set.""" """Processor for the MRDA data set."""
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
examples = create_multi_turn_examples(lines, set_type) examples = create_multi_turn_examples(lines, set_type)
...@@ -374,13 +383,16 @@ class MRDAProcessor(DataProcessor): ...@@ -374,13 +383,16 @@ class MRDAProcessor(DataProcessor):
class ATISSlotProcessor(DataProcessor): class ATISSlotProcessor(DataProcessor):
"""Processor for the ATIS Slot data set.""" """Processor for the ATIS Slot data set."""
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if len(line) != 2: if len(line) != 2:
print("data format error: %s" % "\t".join(line)) print("data format error: %s" % "\t".join(line))
print("data row contains two parts: conversation_content \t label1 label2 label3") print(
"data row contains two parts: conversation_content \t label1 label2 label3"
)
continue continue
guid = "%s-%d" % (set_type, i) guid = "%s-%d" % (set_type, i)
text_a = line[0] text_a = line[0]
...@@ -423,21 +435,21 @@ class ATISSlotProcessor(DataProcessor): ...@@ -423,21 +435,21 @@ class ATISSlotProcessor(DataProcessor):
class ATISIntentProcessor(DataProcessor): class ATISIntentProcessor(DataProcessor):
"""Processor for the ATIS intent data set.""" """Processor for the ATIS intent data set."""
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if len(line) != 2: if len(line) != 2:
print("data format error: %s" % "\t".join(line)) print("data format error: %s" % "\t".join(line))
print("data row contains two parts: label \t conversation_content") print(
"data row contains two parts: label \t conversation_content")
continue continue
guid = "%s-%d" % (set_type, i) guid = "%s-%d" % (set_type, i)
text_a = line[1] text_a = line[1]
text_a = tokenization.convert_to_unicode(text_a) text_a = tokenization.convert_to_unicode(text_a)
label = tokenization.convert_to_unicode(line[0]) label = tokenization.convert_to_unicode(line[0])
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, label=label))
InputExample(
guid=guid, text_a=text_a, label=label))
return examples return examples
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
...@@ -471,12 +483,13 @@ class ATISIntentProcessor(DataProcessor): ...@@ -471,12 +483,13 @@ class ATISIntentProcessor(DataProcessor):
class DSTC2Processor(DataProcessor): class DSTC2Processor(DataProcessor):
"""Processor for the DSTC2 data set.""" """Processor for the DSTC2 data set."""
def _create_turns(self, conv_example): def _create_turns(self, conv_example):
"""create multi turn dataset""" """create multi turn dataset"""
samples = [] samples = []
max_turns = 20 max_turns = 20
for i in range(len(conv_example)): for i in range(len(conv_example)):
conv_turns = conv_example[max(i - max_turns, 0): i + 1] conv_turns = conv_example[max(i - max_turns, 0):i + 1]
conv_info = "\1".join([sample[0] for sample in conv_turns]) conv_info = "\1".join([sample[0] for sample in conv_turns])
samples.append((conv_info.split('\1'), conv_example[i][1])) samples.append((conv_info.split('\1'), conv_example[i][1]))
return samples return samples
...@@ -490,7 +503,9 @@ class DSTC2Processor(DataProcessor): ...@@ -490,7 +503,9 @@ class DSTC2Processor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if len(line) != 3: if len(line) != 3:
print("data format error: %s" % "\t".join(line)) print("data format error: %s" % "\t".join(line))
print("data row contains three parts: conversation_content \t question \1 answer \t state1 state2 state3......") print(
"data row contains three parts: conversation_content \t question \1 answer \t state1 state2 state3......"
)
continue continue
conv_no = line[0] conv_no = line[0]
text_a = line[1] text_a = line[1]
...@@ -502,7 +517,9 @@ class DSTC2Processor(DataProcessor): ...@@ -502,7 +517,9 @@ class DSTC2Processor(DataProcessor):
index += 1 index += 1
history = sample[0] history = sample[0]
dst_label = sample[1] dst_label = sample[1]
examples.append(InputExample(guid=guid, text_a=history, label=dst_label)) examples.append(
InputExample(
guid=guid, text_a=history, label=dst_label))
conv_example = [] conv_example = []
conv_id = conv_no conv_id = conv_no
if i == 0: if i == 0:
...@@ -515,7 +532,9 @@ class DSTC2Processor(DataProcessor): ...@@ -515,7 +532,9 @@ class DSTC2Processor(DataProcessor):
index += 1 index += 1
history = sample[0] history = sample[0]
dst_label = sample[1] dst_label = sample[1]
examples.append(InputExample(guid=guid, text_a=history, label=dst_label)) examples.append(
InputExample(
guid=guid, text_a=history, label=dst_label))
return examples return examples
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
...@@ -549,15 +568,17 @@ class DSTC2Processor(DataProcessor): ...@@ -549,15 +568,17 @@ class DSTC2Processor(DataProcessor):
class MULTIWOZProcessor(DataProcessor): class MULTIWOZProcessor(DataProcessor):
"""Processor for the MULTIWOZ data set.""" """Processor for the MULTIWOZ data set."""
def _create_turns(self, conv_example): def _create_turns(self, conv_example):
"""create multi turn dataset""" """create multi turn dataset"""
samples = [] samples = []
max_turns = 2 max_turns = 2
for i in range(len(conv_example)): for i in range(len(conv_example)):
prefix_turns = conv_example[max(i - max_turns, 0): i] prefix_turns = conv_example[max(i - max_turns, 0):i]
conv_info = "\1".join([turn[0] for turn in prefix_turns]) conv_info = "\1".join([turn[0] for turn in prefix_turns])
current_turns = conv_example[i][0] current_turns = conv_example[i][0]
samples.append((conv_info.split('\1'), current_turns.split('\1'), conv_example[i][1])) samples.append((conv_info.split('\1'), current_turns.split('\1'),
conv_example[i][1]))
return samples return samples
def _create_examples(self, lines, set_type): def _create_examples(self, lines, set_type):
...@@ -578,7 +599,12 @@ class MULTIWOZProcessor(DataProcessor): ...@@ -578,7 +599,12 @@ class MULTIWOZProcessor(DataProcessor):
history = sample[0] history = sample[0]
current = sample[1] current = sample[1]
dst_label = sample[2] dst_label = sample[2]
examples.append(InputExample(guid=guid, text_a=history, text_b=current, label=dst_label)) examples.append(
InputExample(
guid=guid,
text_a=history,
text_b=current,
label=dst_label))
conv_example = [] conv_example = []
conv_id = conv_no conv_id = conv_no
if i == 0: if i == 0:
...@@ -592,7 +618,12 @@ class MULTIWOZProcessor(DataProcessor): ...@@ -592,7 +618,12 @@ class MULTIWOZProcessor(DataProcessor):
history = sample[0] history = sample[0]
current = sample[1] current = sample[1]
dst_label = sample[2] dst_label = sample[2]
examples.append(InputExample(guid=guid, text_a=history, text_b=current, label=dst_label)) examples.append(
InputExample(
guid=guid,
text_a=history,
text_b=current,
label=dst_label))
return examples return examples
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
...@@ -629,8 +660,10 @@ def create_dialogue_examples(conv): ...@@ -629,8 +660,10 @@ def create_dialogue_examples(conv):
samples = [] samples = []
for i in range(len(conv)): for i in range(len(conv)):
cur_txt = "%s : %s" % (conv[i][2], conv[i][3]) cur_txt = "%s : %s" % (conv[i][2], conv[i][3])
pre_txt = ["%s : %s" % (c[2], c[3]) for c in conv[max(0, i - 5): i]] pre_txt = ["%s : %s" % (c[2], c[3]) for c in conv[max(0, i - 5):i]]
suf_txt = ["%s : %s" % (c[2], c[3]) for c in conv[i + 1: min(len(conv), i + 3)]] suf_txt = [
"%s : %s" % (c[2], c[3]) for c in conv[i + 1:min(len(conv), i + 3)]
]
sample = [conv[i][1], pre_txt, cur_txt, suf_txt] sample = [conv[i][1], pre_txt, cur_txt, suf_txt]
samples.append(sample) samples.append(sample)
return samples return samples
...@@ -645,7 +678,9 @@ def create_multi_turn_examples(lines, set_type): ...@@ -645,7 +678,9 @@ def create_multi_turn_examples(lines, set_type):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if len(line) != 4: if len(line) != 4:
print("data format error: %s" % "\t".join(line)) print("data format error: %s" % "\t".join(line))
print("data row contains four parts: conversation_id \t label \t caller \t conversation_content") print(
"data row contains four parts: conversation_id \t label \t caller \t conversation_content"
)
continue continue
tokens = line tokens = line
conv_no = tokens[0] conv_no = tokens[0]
...@@ -659,7 +694,12 @@ def create_multi_turn_examples(lines, set_type): ...@@ -659,7 +694,12 @@ def create_multi_turn_examples(lines, set_type):
text_b = sample[2] text_b = sample[2]
text_c = sample[3] text_c = sample[3]
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, text_c=text_c, label=label)) InputExample(
guid=guid,
text_a=text_a,
text_b=text_b,
text_c=text_c,
label=label))
conv_example = [] conv_example = []
conv_id = conv_no conv_id = conv_no
if i == 0: if i == 0:
...@@ -675,7 +715,12 @@ def create_multi_turn_examples(lines, set_type): ...@@ -675,7 +715,12 @@ def create_multi_turn_examples(lines, set_type):
text_b = sample[2] text_b = sample[2]
text_c = sample[3] text_c = sample[3]
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, text_c=text_c, label=label)) InputExample(
guid=guid,
text_a=text_a,
text_b=text_b,
text_c=text_c,
label=label))
return examples return examples
...@@ -690,7 +735,7 @@ def convert_tokens(tokens, sep_id, tokenizer): ...@@ -690,7 +735,7 @@ def convert_tokens(tokens, sep_id, tokenizer):
ids = tokenizer.convert_tokens_to_ids(tok_text) ids = tokenizer.convert_tokens_to_ids(tok_text)
tokens_ids.extend(ids) tokens_ids.extend(ids)
tokens_ids.append(sep_id) tokens_ids.append(sep_id)
tokens_ids = tokens_ids[: -1] tokens_ids = tokens_ids[:-1]
else: else:
tok_text = tokenizer.tokenize(tokens) tok_text = tokenizer.tokenize(tokens)
tokens_ids = tokenizer.convert_tokens_to_ids(tok_text) tokens_ids = tokenizer.convert_tokens_to_ids(tok_text)
...@@ -746,23 +791,29 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -746,23 +791,29 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length + 2:] tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length + 2:]
if not tokens_c_ids: if not tokens_c_ids:
if len(tokens_a_ids) > max_seq_length - len(tokens_b_ids) - 3: if len(tokens_a_ids) > max_seq_length - len(tokens_b_ids) - 3:
tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length + len(tokens_b_ids) + 3:] tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length +
len(tokens_b_ids) + 3:]
else: else:
if len(tokens_a_ids) + len(tokens_b_ids) + len(tokens_c_ids) > max_seq_length - 4: if len(tokens_a_ids) + len(tokens_b_ids) + len(
tokens_c_ids) > max_seq_length - 4:
left_num = max_seq_length - len(tokens_b_ids) - 4 left_num = max_seq_length - len(tokens_b_ids) - 4
if len(tokens_a_ids) > len(tokens_c_ids): if len(tokens_a_ids) > len(tokens_c_ids):
suffix_num = int(left_num / 2) suffix_num = int(left_num / 2)
tokens_c_ids = tokens_c_ids[: min(len(tokens_c_ids), suffix_num)] tokens_c_ids = tokens_c_ids[:min(len(tokens_c_ids), suffix_num)]
prefix_num = left_num - len(tokens_c_ids) prefix_num = left_num - len(tokens_c_ids)
tokens_a_ids = tokens_a_ids[max(0, len(tokens_a_ids) - prefix_num):] tokens_a_ids = tokens_a_ids[max(
0, len(tokens_a_ids) - prefix_num):]
else: else:
if not tokens_a_ids: if not tokens_a_ids:
tokens_c_ids = tokens_c_ids[max(0, len(tokens_c_ids) - left_num):] tokens_c_ids = tokens_c_ids[max(
0, len(tokens_c_ids) - left_num):]
else: else:
prefix_num = int(left_num / 2) prefix_num = int(left_num / 2)
tokens_a_ids = tokens_a_ids[max(0, len(tokens_a_ids) - prefix_num):] tokens_a_ids = tokens_a_ids[max(
0, len(tokens_a_ids) - prefix_num):]
suffix_num = left_num - len(tokens_a_ids) suffix_num = left_num - len(tokens_a_ids)
tokens_c_ids = tokens_c_ids[: min(len(tokens_c_ids), suffix_num)] tokens_c_ids = tokens_c_ids[:min(
len(tokens_c_ids), suffix_num)]
input_ids = [] input_ids = []
segment_ids = [] segment_ids = []
...@@ -811,5 +862,3 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -811,5 +862,3 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
label_id=label_id) label_id=label_id)
return feature return feature
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册