未验证 提交 300be16c 编写于 作者: P pkpk 提交者: GitHub

test=develop (#4175)

上级 071dc299
......@@ -23,8 +23,10 @@ import numpy as np
from dgu import tokenization
from dgu.batching import prepare_batch_data
reload(sys)
sys.setdefaultencoding('utf-8')
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding('utf-8')
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
......@@ -152,9 +154,9 @@ class DataProcessor(object):
if shuffle:
np.random.shuffle(examples)
for (index, example) in enumerate(examples):
feature = self.convert_example(
index, example,
self.get_labels(), self.max_seq_len, self.tokenizer)
feature = self.convert_example(index, example,
self.get_labels(),
self.max_seq_len, self.tokenizer)
instance = self.generate_instance(feature)
yield instance
......@@ -252,17 +254,22 @@ class InputFeatures(object):
class UDCProcessor(DataProcessor):
"""Processor for the UDC data set."""
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
print("UDC dataset is too big, loading data spent a long time, please wait patiently..................")
print(
"UDC dataset is too big, loading data spent a long time, please wait patiently.................."
)
for (i, line) in enumerate(lines):
if len(line) < 3:
print("data format error: %s" % "\t".join(line))
print("data row contains at least three parts: label\tconv1\t.....\tresponse")
print(
"data row contains at least three parts: label\tconv1\t.....\tresponse"
)
continue
guid = "%s-%d" % (set_type, i)
text_a = "\t".join(line[1: -1])
text_a = "\t".join(line[1:-1])
text_a = tokenization.convert_to_unicode(text_a)
text_a = text_a.split('\t')
text_b = line[-1]
......@@ -302,6 +309,7 @@ class UDCProcessor(DataProcessor):
class SWDAProcessor(DataProcessor):
"""Processor for the SWDA data set."""
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = create_multi_turn_examples(lines, set_type)
......@@ -338,6 +346,7 @@ class SWDAProcessor(DataProcessor):
class MRDAProcessor(DataProcessor):
"""Processor for the MRDA data set."""
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = create_multi_turn_examples(lines, set_type)
......@@ -374,13 +383,16 @@ class MRDAProcessor(DataProcessor):
class ATISSlotProcessor(DataProcessor):
"""Processor for the ATIS Slot data set."""
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if len(line) != 2:
print("data format error: %s" % "\t".join(line))
print("data row contains two parts: conversation_content \t label1 label2 label3")
print(
"data row contains two parts: conversation_content \t label1 label2 label3"
)
continue
guid = "%s-%d" % (set_type, i)
text_a = line[0]
......@@ -423,21 +435,21 @@ class ATISSlotProcessor(DataProcessor):
class ATISIntentProcessor(DataProcessor):
"""Processor for the ATIS intent data set."""
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if len(line) != 2:
print("data format error: %s" % "\t".join(line))
print("data row contains two parts: label \t conversation_content")
print(
"data row contains two parts: label \t conversation_content")
continue
guid = "%s-%d" % (set_type, i)
text_a = line[1]
text_a = tokenization.convert_to_unicode(text_a)
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(
guid=guid, text_a=text_a, label=label))
examples.append(InputExample(guid=guid, text_a=text_a, label=label))
return examples
def get_train_examples(self, data_dir):
......@@ -471,12 +483,13 @@ class ATISIntentProcessor(DataProcessor):
class DSTC2Processor(DataProcessor):
"""Processor for the DSTC2 data set."""
def _create_turns(self, conv_example):
"""create multi turn dataset"""
samples = []
max_turns = 20
for i in range(len(conv_example)):
conv_turns = conv_example[max(i - max_turns, 0): i + 1]
conv_turns = conv_example[max(i - max_turns, 0):i + 1]
conv_info = "\1".join([sample[0] for sample in conv_turns])
samples.append((conv_info.split('\1'), conv_example[i][1]))
return samples
......@@ -490,7 +503,9 @@ class DSTC2Processor(DataProcessor):
for (i, line) in enumerate(lines):
if len(line) != 3:
print("data format error: %s" % "\t".join(line))
print("data row contains three parts: conversation_content \t question \1 answer \t state1 state2 state3......")
print(
"data row contains three parts: conversation_content \t question \1 answer \t state1 state2 state3......"
)
continue
conv_no = line[0]
text_a = line[1]
......@@ -502,7 +517,9 @@ class DSTC2Processor(DataProcessor):
index += 1
history = sample[0]
dst_label = sample[1]
examples.append(InputExample(guid=guid, text_a=history, label=dst_label))
examples.append(
InputExample(
guid=guid, text_a=history, label=dst_label))
conv_example = []
conv_id = conv_no
if i == 0:
......@@ -515,7 +532,9 @@ class DSTC2Processor(DataProcessor):
index += 1
history = sample[0]
dst_label = sample[1]
examples.append(InputExample(guid=guid, text_a=history, label=dst_label))
examples.append(
InputExample(
guid=guid, text_a=history, label=dst_label))
return examples
def get_train_examples(self, data_dir):
......@@ -549,15 +568,17 @@ class DSTC2Processor(DataProcessor):
class MULTIWOZProcessor(DataProcessor):
"""Processor for the MULTIWOZ data set."""
def _create_turns(self, conv_example):
"""create multi turn dataset"""
samples = []
max_turns = 2
for i in range(len(conv_example)):
prefix_turns = conv_example[max(i - max_turns, 0): i]
prefix_turns = conv_example[max(i - max_turns, 0):i]
conv_info = "\1".join([turn[0] for turn in prefix_turns])
current_turns = conv_example[i][0]
samples.append((conv_info.split('\1'), current_turns.split('\1'), conv_example[i][1]))
samples.append((conv_info.split('\1'), current_turns.split('\1'),
conv_example[i][1]))
return samples
def _create_examples(self, lines, set_type):
......@@ -578,7 +599,12 @@ class MULTIWOZProcessor(DataProcessor):
history = sample[0]
current = sample[1]
dst_label = sample[2]
examples.append(InputExample(guid=guid, text_a=history, text_b=current, label=dst_label))
examples.append(
InputExample(
guid=guid,
text_a=history,
text_b=current,
label=dst_label))
conv_example = []
conv_id = conv_no
if i == 0:
......@@ -592,7 +618,12 @@ class MULTIWOZProcessor(DataProcessor):
history = sample[0]
current = sample[1]
dst_label = sample[2]
examples.append(InputExample(guid=guid, text_a=history, text_b=current, label=dst_label))
examples.append(
InputExample(
guid=guid,
text_a=history,
text_b=current,
label=dst_label))
return examples
def get_train_examples(self, data_dir):
......@@ -629,8 +660,10 @@ def create_dialogue_examples(conv):
samples = []
for i in range(len(conv)):
cur_txt = "%s : %s" % (conv[i][2], conv[i][3])
pre_txt = ["%s : %s" % (c[2], c[3]) for c in conv[max(0, i - 5): i]]
suf_txt = ["%s : %s" % (c[2], c[3]) for c in conv[i + 1: min(len(conv), i + 3)]]
pre_txt = ["%s : %s" % (c[2], c[3]) for c in conv[max(0, i - 5):i]]
suf_txt = [
"%s : %s" % (c[2], c[3]) for c in conv[i + 1:min(len(conv), i + 3)]
]
sample = [conv[i][1], pre_txt, cur_txt, suf_txt]
samples.append(sample)
return samples
......@@ -645,7 +678,9 @@ def create_multi_turn_examples(lines, set_type):
for (i, line) in enumerate(lines):
if len(line) != 4:
print("data format error: %s" % "\t".join(line))
print("data row contains four parts: conversation_id \t label \t caller \t conversation_content")
print(
"data row contains four parts: conversation_id \t label \t caller \t conversation_content"
)
continue
tokens = line
conv_no = tokens[0]
......@@ -659,7 +694,12 @@ def create_multi_turn_examples(lines, set_type):
text_b = sample[2]
text_c = sample[3]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, text_c=text_c, label=label))
InputExample(
guid=guid,
text_a=text_a,
text_b=text_b,
text_c=text_c,
label=label))
conv_example = []
conv_id = conv_no
if i == 0:
......@@ -675,7 +715,12 @@ def create_multi_turn_examples(lines, set_type):
text_b = sample[2]
text_c = sample[3]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, text_c=text_c, label=label))
InputExample(
guid=guid,
text_a=text_a,
text_b=text_b,
text_c=text_c,
label=label))
return examples
......@@ -690,7 +735,7 @@ def convert_tokens(tokens, sep_id, tokenizer):
ids = tokenizer.convert_tokens_to_ids(tok_text)
tokens_ids.extend(ids)
tokens_ids.append(sep_id)
tokens_ids = tokens_ids[: -1]
tokens_ids = tokens_ids[:-1]
else:
tok_text = tokenizer.tokenize(tokens)
tokens_ids = tokenizer.convert_tokens_to_ids(tok_text)
......@@ -746,23 +791,29 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length + 2:]
if not tokens_c_ids:
if len(tokens_a_ids) > max_seq_length - len(tokens_b_ids) - 3:
tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length + len(tokens_b_ids) + 3:]
tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length +
len(tokens_b_ids) + 3:]
else:
if len(tokens_a_ids) + len(tokens_b_ids) + len(tokens_c_ids) > max_seq_length - 4:
if len(tokens_a_ids) + len(tokens_b_ids) + len(
tokens_c_ids) > max_seq_length - 4:
left_num = max_seq_length - len(tokens_b_ids) - 4
if len(tokens_a_ids) > len(tokens_c_ids):
suffix_num = int(left_num / 2)
tokens_c_ids = tokens_c_ids[: min(len(tokens_c_ids), suffix_num)]
tokens_c_ids = tokens_c_ids[:min(len(tokens_c_ids), suffix_num)]
prefix_num = left_num - len(tokens_c_ids)
tokens_a_ids = tokens_a_ids[max(0, len(tokens_a_ids) - prefix_num):]
tokens_a_ids = tokens_a_ids[max(
0, len(tokens_a_ids) - prefix_num):]
else:
if not tokens_a_ids:
tokens_c_ids = tokens_c_ids[max(0, len(tokens_c_ids) - left_num):]
tokens_c_ids = tokens_c_ids[max(
0, len(tokens_c_ids) - left_num):]
else:
prefix_num = int(left_num / 2)
tokens_a_ids = tokens_a_ids[max(0, len(tokens_a_ids) - prefix_num):]
tokens_a_ids = tokens_a_ids[max(
0, len(tokens_a_ids) - prefix_num):]
suffix_num = left_num - len(tokens_a_ids)
tokens_c_ids = tokens_c_ids[: min(len(tokens_c_ids), suffix_num)]
tokens_c_ids = tokens_c_ids[:min(
len(tokens_c_ids), suffix_num)]
input_ids = []
segment_ids = []
......@@ -811,5 +862,3 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
label_id=label_id)
return feature
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册