reader.py 29.8 KB
Newer Older
0
0YuanZhang0 已提交
1
# -*- coding: utf-8 -*-
0
0YuanZhang0 已提交
2
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Y
Yibing Liu 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15 16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""data reader"""
import os
0
0YuanZhang0 已提交
17
import io
Y
Yibing Liu 已提交
18
import csv
0
0YuanZhang0 已提交
19
import sys
0
0YuanZhang0 已提交
20
import types
Y
Yibing Liu 已提交
21
import numpy as np
0
0YuanZhang0 已提交
22

0
0YuanZhang0 已提交
23 24
from dgu import tokenization
from dgu.batching import prepare_batch_data
Y
Yibing Liu 已提交
25

P
pkpk 已提交
26 27 28 29
if sys.version[0] == '2':
    reload(sys)
    sys.setdefaultencoding('utf-8')

Y
Yibing Liu 已提交
30 31 32 33

class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

P
pkpk 已提交
34 35 36 37 38
    def __init__(self,
                 data_dir,
                 vocab_path,
                 max_seq_len,
                 do_lower_case,
Y
Yibing Liu 已提交
39
                 in_tokens,
P
pkpk 已提交
40
                 task_name,
Y
Yibing Liu 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
                 random_seed=None):
        self.data_dir = data_dir
        self.max_seq_len = max_seq_len
        self.tokenizer = tokenization.FullTokenizer(
            vocab_file=vocab_path, do_lower_case=do_lower_case)
        self.vocab = self.tokenizer.vocab
        self.in_tokens = in_tokens

        np.random.seed(random_seed)

        self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
        self.task_name = task_name

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_test_examples(self, data_dir):
        """Gets a collection of `InputExample`s for prediction."""
        raise NotImplementedError()

0
0YuanZhang0 已提交
66 67
    @staticmethod
    def get_labels():
Y
Yibing Liu 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    def convert_example(self, index, example, labels, max_seq_len, tokenizer):
        """Converts a single `InputExample` into a single `InputFeatures`."""
        feature = convert_single_example(index, example, labels, max_seq_len,
                                         tokenizer, self.task_name)
        return feature

    def generate_instance(self, feature):
        """
        generate instance with given feature

        Args:
            feature: InputFeatures(object). A single set of features of data.
        """
        input_pos = list(range(len(feature.input_ids)))
        return [
            feature.input_ids, feature.segment_ids, input_pos, feature.label_id
        ]

    def generate_batch_data(self,
                            batch_data,
                            max_len,
                            total_token_num,
                            voc_size=-1,
                            mask_id=-1,
                            return_input_mask=True,
                            return_max_len=False,
P
pkpk 已提交
97
                            return_num_token=False):
Y
Yibing Liu 已提交
98 99
        """generate batch data"""
        return prepare_batch_data(
0
0YuanZhang0 已提交
100
            self.task_name,
Y
Yibing Liu 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
            batch_data,
            max_len,
            total_token_num,
            voc_size=-1,
            pad_id=self.vocab["[PAD]"],
            cls_id=self.vocab["[CLS]"],
            sep_id=self.vocab["[SEP]"],
            mask_id=-1,
            return_input_mask=True,
            return_max_len=False,
            return_num_token=False)

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
0
0YuanZhang0 已提交
116 117 118
        f = io.open(input_file, "r", encoding="utf8")
        reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
        lines = []
P
pkpk 已提交
119
        for line in reader:
0
0YuanZhang0 已提交
120 121
            lines.append(line)
        return lines
Y
Yibing Liu 已提交
122 123 124 125 126 127 128 129

    def get_num_examples(self, phase):
        """Get number of examples for train, dev or test."""
        if phase not in ['train', 'dev', 'test']:
            raise ValueError(
                "Unknown phase, which should be in ['train', 'dev', 'test'].")
        return self.num_examples[phase]

0
0YuanZhang0 已提交
130
    def data_generator(self, batch_size, phase='train', shuffle=False):
Y
Yibing Liu 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
        """
        Generate data for train, dev or test.
    
        Args:
          batch_size: int. The batch size of generated data.
          phase: string. The phase for which to generate data.
          shuffle: bool. Whether to shuffle examples.
        """
        if phase == 'train':
            examples = self.get_train_examples(self.data_dir)
            self.num_examples['train'] = len(examples)
        elif phase == 'dev':
            examples = self.get_dev_examples(self.data_dir)
            self.num_examples['dev'] = len(examples)
        elif phase == 'test':
            examples = self.get_test_examples(self.data_dir)
            self.num_examples['test'] = len(examples)
        else:
            raise ValueError(
                "Unknown phase, which should be in ['train', 'dev', 'test'].")

P
pkpk 已提交
152
        def instance_reader():
Y
Yibing Liu 已提交
153
            """generate instance data"""
0
0YuanZhang0 已提交
154 155
            if shuffle:
                np.random.shuffle(examples)
P
pkpk 已提交
156 157 158 159
            for (index, example) in enumerate(examples):
                feature = self.convert_example(index, example,
                                               self.get_labels(),
                                               self.max_seq_len, self.tokenizer)
0
0YuanZhang0 已提交
160 161
                instance = self.generate_instance(feature)
                yield instance
Y
Yibing Liu 已提交
162

P
pkpk 已提交
163
        def batch_reader(reader, batch_size, in_tokens):
Y
Yibing Liu 已提交
164 165
            """read batch data"""
            batch, total_token_num, max_len = [], 0, 0
P
pkpk 已提交
166
            for instance in reader():
Y
Yibing Liu 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
                token_ids, sent_ids, pos_ids, label = instance[:4]
                max_len = max(max_len, len(token_ids))
                if in_tokens:
                    to_append = (len(batch) + 1) * max_len <= batch_size
                else:
                    to_append = len(batch) < batch_size
                if to_append:
                    batch.append(instance)
                    total_token_num += len(token_ids)
                else:
                    yield batch, total_token_num
                    batch, total_token_num, max_len = [instance], len(
                        token_ids), len(token_ids)

            if len(batch) > 0:
                yield batch, total_token_num

P
pkpk 已提交
184
        def wrapper():
Y
Yibing Liu 已提交
185 186
            """yield batch data to network"""
            for batch_data, total_token_num in batch_reader(
P
pkpk 已提交
187 188
                    instance_reader, batch_size, self.in_tokens):
                if self.in_tokens:
Y
Yibing Liu 已提交
189
                    max_seq = -1
P
pkpk 已提交
190
                else:
Y
Yibing Liu 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203
                    max_seq = self.max_seq_len
                batch_data = self.generate_batch_data(
                    batch_data,
                    max_seq,
                    total_token_num,
                    voc_size=-1,
                    mask_id=-1,
                    return_input_mask=True,
                    return_max_len=False,
                    return_num_token=False)
                yield batch_data

        return wrapper
P
pkpk 已提交
204

Y
Yibing Liu 已提交
205 206 207 208 209 210 211

class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, text_c=None, label=None):
        """Constructs a InputExample.

0
0YuanZhang0 已提交
212 213 214
        Args:
        guid: Unique id for the example.
        text_a: string. The untokenized text of the first sequence. For single
Y
Yibing Liu 已提交
215
        sequence tasks, only this sequence must be specified.
0
0YuanZhang0 已提交
216
        text_b: (Optional) string. The untokenized text of the second sequence.
Y
Yibing Liu 已提交
217
        Only must be specified for sequence pair tasks.
0
0YuanZhang0 已提交
218
        label: (Optional) string. The label of the example. This should be
Y
Yibing Liu 已提交
219
        specified for train and dev examples, but not for test examples.
0
0YuanZhang0 已提交
220
        """
Y
Yibing Liu 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.text_c = text_c
        self.label = label


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id


P
pkpk 已提交
255
class UDCProcessor(DataProcessor):
Y
Yibing Liu 已提交
256
    """Processor for the UDC data set."""
P
pkpk 已提交
257 258

    def _create_examples(self, lines, set_type):
Y
Yibing Liu 已提交
259 260
        """Creates examples for the training and dev sets."""
        examples = []
P
pkpk 已提交
261 262 263 264 265
        print(
            "UDC dataset is too big, loading data spent a long time, please wait patiently.................."
        )
        for (i, line) in enumerate(lines):
            if len(line) < 3:
0
0YuanZhang0 已提交
266
                print("data format error: %s" % "\t".join(line))
P
pkpk 已提交
267 268 269
                print(
                    "data row contains at least three parts: label\tconv1\t.....\tresponse"
                )
0
0YuanZhang0 已提交
270
                continue
Y
Yibing Liu 已提交
271
            guid = "%s-%d" % (set_type, i)
P
pkpk 已提交
272
            text_a = "\t".join(line[1:-1])
Y
Yibing Liu 已提交
273 274 275 276 277 278 279 280 281 282
            text_a = tokenization.convert_to_unicode(text_a)
            text_a = text_a.split('\t')
            text_b = line[-1]
            text_b = tokenization.convert_to_unicode(text_b)
            label = tokenization.convert_to_unicode(line[0])
            examples.append(
                InputExample(
                    guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples

P
pkpk 已提交
283
    def get_train_examples(self, data_dir):
Y
Yibing Liu 已提交
284 285 286 287 288 289
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
        examples = self._create_examples(lines, "train")
        return examples

P
pkpk 已提交
290
    def get_dev_examples(self, data_dir):
Y
Yibing Liu 已提交
291 292 293 294 295 296
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "dev.txt"))
        examples = self._create_examples(lines, "dev")
        return examples

P
pkpk 已提交
297
    def get_test_examples(self, data_dir):
Y
Yibing Liu 已提交
298 299 300 301 302 303
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "test.txt"))
        examples = self._create_examples(lines, "test")
        return examples

0
0YuanZhang0 已提交
304
    @staticmethod
P
pkpk 已提交
305
    def get_labels():
Y
Yibing Liu 已提交
306 307 308 309
        """See base class."""
        return ["0", "1"]


P
pkpk 已提交
310
class SWDAProcessor(DataProcessor):
Y
Yibing Liu 已提交
311
    """Processor for the SWDA data set."""
P
pkpk 已提交
312 313

    def _create_examples(self, lines, set_type):
Y
Yibing Liu 已提交
314 315 316
        """Creates examples for the training and dev sets."""
        examples = create_multi_turn_examples(lines, set_type)
        return examples
P
pkpk 已提交
317 318

    def get_train_examples(self, data_dir):
Y
Yibing Liu 已提交
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
        examples = self._create_examples(lines, "train")
        return examples

    def get_dev_examples(self, data_dir):
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "dev.txt"))
        examples = self._create_examples(lines, "dev")
        return examples

    def get_test_examples(self, data_dir):
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "test.txt"))
        examples = self._create_examples(lines, "test")
        return examples

0
0YuanZhang0 已提交
339
    @staticmethod
P
pkpk 已提交
340
    def get_labels():
Y
Yibing Liu 已提交
341 342 343 344 345 346
        """See base class."""
        labels = range(42)
        labels = [str(label) for label in labels]
        return labels


P
pkpk 已提交
347
class MRDAProcessor(DataProcessor):
Y
Yibing Liu 已提交
348
    """Processor for the MRDA data set."""
P
pkpk 已提交
349 350

    def _create_examples(self, lines, set_type):
Y
Yibing Liu 已提交
351 352 353
        """Creates examples for the training and dev sets."""
        examples = create_multi_turn_examples(lines, set_type)
        return examples
P
pkpk 已提交
354 355

    def get_train_examples(self, data_dir):
Y
Yibing Liu 已提交
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
        examples = self._create_examples(lines, "train")
        return examples

    def get_dev_examples(self, data_dir):
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "dev.txt"))
        examples = self._create_examples(lines, "dev")
        return examples

    def get_test_examples(self, data_dir):
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "test.txt"))
        examples = self._create_examples(lines, "test")
        return examples

0
0YuanZhang0 已提交
376
    @staticmethod
P
pkpk 已提交
377
    def get_labels():
Y
Yibing Liu 已提交
378 379 380 381 382 383
        """See base class."""
        labels = range(42)
        labels = [str(label) for label in labels]
        return labels


P
pkpk 已提交
384
class ATISSlotProcessor(DataProcessor):
Y
Yibing Liu 已提交
385
    """Processor for the ATIS Slot data set."""
P
pkpk 已提交
386 387

    def _create_examples(self, lines, set_type):
Y
Yibing Liu 已提交
388 389
        """Creates examples for the training and dev sets."""
        examples = []
P
pkpk 已提交
390 391
        for (i, line) in enumerate(lines):
            if len(line) != 2:
0
0YuanZhang0 已提交
392
                print("data format error: %s" % "\t".join(line))
P
pkpk 已提交
393 394 395
                print(
                    "data row contains two parts: conversation_content \t label1 label2 label3"
                )
0
0YuanZhang0 已提交
396
                continue
Y
Yibing Liu 已提交
397 398 399 400 401 402 403 404 405 406
            guid = "%s-%d" % (set_type, i)
            text_a = line[0]
            label = line[1]
            text_a = tokenization.convert_to_unicode(text_a)
            label_list = label.split()
            examples.append(
                InputExample(
                    guid=guid, text_a=text_a, label=label_list))
        return examples

P
pkpk 已提交
407
    def get_train_examples(self, data_dir):
Y
Yibing Liu 已提交
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
        examples = self._create_examples(lines, "train")
        return examples

    def get_dev_examples(self, data_dir):
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "dev.txt"))
        examples = self._create_examples(lines, "dev")
        return examples

    def get_test_examples(self, data_dir):
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "test.txt"))
        examples = self._create_examples(lines, "test")
        return examples

0
0YuanZhang0 已提交
428
    @staticmethod
P
pkpk 已提交
429
    def get_labels():
Y
Yibing Liu 已提交
430 431 432 433 434 435
        """See base class."""
        labels = range(130)
        labels = [str(label) for label in labels]
        return labels


P
pkpk 已提交
436
class ATISIntentProcessor(DataProcessor):
Y
Yibing Liu 已提交
437
    """Processor for the ATIS intent data set."""
P
pkpk 已提交
438 439

    def _create_examples(self, lines, set_type):
Y
Yibing Liu 已提交
440 441
        """Creates examples for the training and dev sets."""
        examples = []
P
pkpk 已提交
442 443
        for (i, line) in enumerate(lines):
            if len(line) != 2:
0
0YuanZhang0 已提交
444
                print("data format error: %s" % "\t".join(line))
P
pkpk 已提交
445 446
                print(
                    "data row contains two parts: label \t conversation_content")
0
0YuanZhang0 已提交
447
                continue
Y
Yibing Liu 已提交
448 449 450 451
            guid = "%s-%d" % (set_type, i)
            text_a = line[1]
            text_a = tokenization.convert_to_unicode(text_a)
            label = tokenization.convert_to_unicode(line[0])
P
pkpk 已提交
452
            examples.append(InputExample(guid=guid, text_a=text_a, label=label))
Y
Yibing Liu 已提交
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
        return examples

    def get_train_examples(self, data_dir):
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
        examples = self._create_examples(lines, "train")
        return examples

    def get_dev_examples(self, data_dir):
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "dev.txt"))
        examples = self._create_examples(lines, "dev")
        return examples

    def get_test_examples(self, data_dir):
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "test.txt"))
        examples = self._create_examples(lines, "test")
        return examples

0
0YuanZhang0 已提交
476 477
    @staticmethod
    def get_labels():
Y
Yibing Liu 已提交
478 479 480 481 482 483
        """See base class."""
        labels = range(26)
        labels = [str(label) for label in labels]
        return labels


P
pkpk 已提交
484
class DSTC2Processor(DataProcessor):
Y
Yibing Liu 已提交
485
    """Processor for the DSTC2 data set."""
P
pkpk 已提交
486 487

    def _create_turns(self, conv_example):
Y
Yibing Liu 已提交
488 489 490
        """create multi turn dataset"""
        samples = []
        max_turns = 20
P
pkpk 已提交
491 492
        for i in range(len(conv_example)):
            conv_turns = conv_example[max(i - max_turns, 0):i + 1]
Y
Yibing Liu 已提交
493 494 495 496
            conv_info = "\1".join([sample[0] for sample in conv_turns])
            samples.append((conv_info.split('\1'), conv_example[i][1]))
        return samples

P
pkpk 已提交
497
    def _create_examples(self, lines, set_type):
Y
Yibing Liu 已提交
498 499 500 501 502
        """Creates examples for multi-turn dialogue sets."""
        examples = []
        conv_id = -1
        index = 0
        conv_example = []
P
pkpk 已提交
503 504
        for (i, line) in enumerate(lines):
            if len(line) != 3:
0
0YuanZhang0 已提交
505
                print("data format error: %s" % "\t".join(line))
P
pkpk 已提交
506 507 508
                print(
                    "data row contains three parts: conversation_content \t question \1 answer \t state1 state2 state3......"
                )
0
0YuanZhang0 已提交
509
                continue
Y
Yibing Liu 已提交
510 511 512
            conv_no = line[0]
            text_a = line[1]
            label_list = line[2].split()
P
pkpk 已提交
513
            if conv_no != conv_id and i != 0:
Y
Yibing Liu 已提交
514
                samples = self._create_turns(conv_example)
P
pkpk 已提交
515
                for sample in samples:
Y
Yibing Liu 已提交
516 517 518 519
                    guid = "%s-%s" % (set_type, index)
                    index += 1
                    history = sample[0]
                    dst_label = sample[1]
P
pkpk 已提交
520 521 522
                    examples.append(
                        InputExample(
                            guid=guid, text_a=history, label=dst_label))
Y
Yibing Liu 已提交
523 524 525 526 527
                conv_example = []
                conv_id = conv_no
            if i == 0:
                conv_id = conv_no
            conv_example.append((text_a, label_list))
P
pkpk 已提交
528
        if conv_example:
Y
Yibing Liu 已提交
529 530 531 532 533 534
            samples = self._create_turns(conv_example)
            for sample in samples:
                guid = "%s-%s" % (set_type, index)
                index += 1
                history = sample[0]
                dst_label = sample[1]
P
pkpk 已提交
535 536 537
                examples.append(
                    InputExample(
                        guid=guid, text_a=history, label=dst_label))
Y
Yibing Liu 已提交
538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560
        return examples

    def get_train_examples(self, data_dir):
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
        examples = self._create_examples(lines, "train")
        return examples

    def get_dev_examples(self, data_dir):
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "dev.txt"))
        examples = self._create_examples(lines, "dev")
        return examples

    def get_test_examples(self, data_dir):
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "test.txt"))
        examples = self._create_examples(lines, "test")
        return examples

0
0YuanZhang0 已提交
561 562
    @staticmethod
    def get_labels():
Y
Yibing Liu 已提交
563 564 565 566 567 568
        """See base class."""
        labels = range(217)
        labels = [str(label) for label in labels]
        return labels


P
pkpk 已提交
569
class MULTIWOZProcessor(DataProcessor):
Y
Yibing Liu 已提交
570
    """Processor for the MULTIWOZ data set."""
P
pkpk 已提交
571 572

    def _create_turns(self, conv_example):
Y
Yibing Liu 已提交
573 574 575 576
        """create multi turn dataset"""
        samples = []
        max_turns = 2
        for i in range(len(conv_example)):
P
pkpk 已提交
577
            prefix_turns = conv_example[max(i - max_turns, 0):i]
Y
Yibing Liu 已提交
578 579
            conv_info = "\1".join([turn[0] for turn in prefix_turns])
            current_turns = conv_example[i][0]
P
pkpk 已提交
580 581
            samples.append((conv_info.split('\1'), current_turns.split('\1'),
                            conv_example[i][1]))
Y
Yibing Liu 已提交
582 583
        return samples

P
pkpk 已提交
584
    def _create_examples(self, lines, set_type):
Y
Yibing Liu 已提交
585 586 587 588 589 590 591 592 593
        """Creates examples for multi-turn dialogue sets."""
        examples = []
        conv_id = -1
        index = 0
        conv_example = []
        for (i, line) in enumerate(lines):
            conv_no = line[0]
            text_a = line[2]
            label_list = line[1].split()
P
pkpk 已提交
594
            if conv_no != conv_id and i != 0:
Y
Yibing Liu 已提交
595 596 597 598 599 600 601
                samples = self._create_turns(conv_example)
                for sample in samples:
                    guid = "%s-%s" % (set_type, index)
                    index += 1
                    history = sample[0]
                    current = sample[1]
                    dst_label = sample[2]
P
pkpk 已提交
602 603 604 605 606 607
                    examples.append(
                        InputExample(
                            guid=guid,
                            text_a=history,
                            text_b=current,
                            label=dst_label))
Y
Yibing Liu 已提交
608 609
                conv_example = []
                conv_id = conv_no
P
pkpk 已提交
610
            if i == 0:
Y
Yibing Liu 已提交
611 612
                conv_id = conv_no
            conv_example.append((text_a, label_list))
P
pkpk 已提交
613
        if conv_example:
Y
Yibing Liu 已提交
614 615 616 617 618 619 620
            samples = self._create_turns(conv_example)
            for sample in samples:
                guid = "%s-%s" % (set_type, index)
                index += 1
                history = sample[0]
                current = sample[1]
                dst_label = sample[2]
P
pkpk 已提交
621 622 623 624 625 626
                examples.append(
                    InputExample(
                        guid=guid,
                        text_a=history,
                        text_b=current,
                        label=dst_label))
Y
Yibing Liu 已提交
627 628
        return examples

P
pkpk 已提交
629
    def get_train_examples(self, data_dir):
Y
Yibing Liu 已提交
630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "train.txt"))
        examples = self._create_examples(lines, "train")
        return examples

    def get_dev_examples(self, data_dir):
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "dev.txt"))
        examples = self._create_examples(lines, "dev")
        return examples

    def get_test_examples(self, data_dir):
        """See base class."""
        examples = []
        lines = self._read_tsv(os.path.join(data_dir, "test.txt"))
        examples = self._create_examples(lines, "test")
        return examples

0
0YuanZhang0 已提交
650 651
    @staticmethod
    def get_labels():
Y
Yibing Liu 已提交
652 653 654 655 656 657
        """See base class."""
        labels = range(722)
        labels = [str(label) for label in labels]
        return labels


P
pkpk 已提交
658
def create_dialogue_examples(conv):
Y
Yibing Liu 已提交
659 660
    """Creates dialogue sample"""
    samples = []
P
pkpk 已提交
661
    for i in range(len(conv)):
Y
Yibing Liu 已提交
662
        cur_txt = "%s : %s" % (conv[i][2], conv[i][3])
P
pkpk 已提交
663 664 665 666
        pre_txt = ["%s : %s" % (c[2], c[3]) for c in conv[max(0, i - 5):i]]
        suf_txt = [
            "%s : %s" % (c[2], c[3]) for c in conv[i + 1:min(len(conv), i + 3)]
        ]
Y
Yibing Liu 已提交
667 668 669 670 671
        sample = [conv[i][1], pre_txt, cur_txt, suf_txt]
        samples.append(sample)
    return samples


P
pkpk 已提交
672
def create_multi_turn_examples(lines, set_type):
Y
Yibing Liu 已提交
673 674 675 676 677
    """Creates examples for multi-turn dialogue sets."""
    conv_id = -1
    examples = []
    conv_example = []
    index = 0
P
pkpk 已提交
678 679
    for (i, line) in enumerate(lines):
        if len(line) != 4:
0
0YuanZhang0 已提交
680
            print("data format error: %s" % "\t".join(line))
P
pkpk 已提交
681 682 683
            print(
                "data row contains four parts: conversation_id \t label \t caller \t conversation_content"
            )
0
0YuanZhang0 已提交
684
            continue
Y
Yibing Liu 已提交
685 686
        tokens = line
        conv_no = tokens[0]
P
pkpk 已提交
687
        if conv_no != conv_id and i != 0:
Y
Yibing Liu 已提交
688
            samples = create_dialogue_examples(conv_example)
P
pkpk 已提交
689
            for sample in samples:
Y
Yibing Liu 已提交
690 691 692 693 694 695 696
                guid = "%s-%s" % (set_type, index)
                index += 1
                label = sample[0]
                text_a = sample[1]
                text_b = sample[2]
                text_c = sample[3]
                examples.append(
P
pkpk 已提交
697 698 699 700 701 702
                    InputExample(
                        guid=guid,
                        text_a=text_a,
                        text_b=text_b,
                        text_c=text_c,
                        label=label))
Y
Yibing Liu 已提交
703 704
            conv_example = []
            conv_id = conv_no
P
pkpk 已提交
705
        if i == 0:
Y
Yibing Liu 已提交
706 707
            conv_id = conv_no
        conv_example.append(tokens)
P
pkpk 已提交
708
    if conv_example:
Y
Yibing Liu 已提交
709
        samples = create_dialogue_examples(conv_example)
P
pkpk 已提交
710
        for sample in samples:
Y
Yibing Liu 已提交
711 712 713 714 715 716 717
            guid = "%s-%s" % (set_type, index)
            index += 1
            label = sample[0]
            text_a = sample[1]
            text_b = sample[2]
            text_c = sample[3]
            examples.append(
P
pkpk 已提交
718 719 720 721 722 723
                InputExample(
                    guid=guid,
                    text_a=text_a,
                    text_b=text_b,
                    text_c=text_c,
                    label=label))
Y
Yibing Liu 已提交
724 725 726
    return examples


P
pkpk 已提交
727
def convert_tokens(tokens, sep_id, tokenizer):
Y
Yibing Liu 已提交
728 729
    """Converts tokens to ids"""
    tokens_ids = []
P
pkpk 已提交
730
    if not tokens:
Y
Yibing Liu 已提交
731
        return tokens_ids
P
pkpk 已提交
732 733
    if isinstance(tokens, list):
        for text in tokens:
Y
Yibing Liu 已提交
734 735 736
            tok_text = tokenizer.tokenize(text)
            ids = tokenizer.convert_tokens_to_ids(tok_text)
            tokens_ids.extend(ids)
0
0YuanZhang0 已提交
737
            tokens_ids.append(sep_id)
P
pkpk 已提交
738 739
        tokens_ids = tokens_ids[:-1]
    else:
Y
Yibing Liu 已提交
740 741 742 743 744
        tok_text = tokenizer.tokenize(tokens)
        tokens_ids = tokenizer.convert_tokens_to_ids(tok_text)
    return tokens_ids


P
pkpk 已提交
745
def convert_single_example(ex_index, example, label_list, max_seq_length,
Y
Yibing Liu 已提交
746 747 748
                           tokenizer, task_name):
    """Converts a single DA `InputExample` into a single `InputFeatures`."""
    label_map = {}
P
pkpk 已提交
749
    SEP = 102
Y
Yibing Liu 已提交
750 751
    CLS = 101

P
pkpk 已提交
752
    if task_name == 'udc':
Y
Yibing Liu 已提交
753 754
        INNER_SEP = 1
        limit_length = 60
P
pkpk 已提交
755
    elif task_name == 'swda':
Y
Yibing Liu 已提交
756 757
        INNER_SEP = 1
        limit_length = 50
P
pkpk 已提交
758
    elif task_name == 'mrda':
Y
Yibing Liu 已提交
759 760
        INNER_SEP = 1
        limit_length = 50
P
pkpk 已提交
761
    elif task_name == 'atis_intent':
Y
Yibing Liu 已提交
762 763
        INNER_SEP = -1
        limit_length = -1
P
pkpk 已提交
764
    elif task_name == 'atis_slot':
Y
Yibing Liu 已提交
765 766
        INNER_SEP = -1
        limit_length = -1
P
pkpk 已提交
767
    elif task_name == 'dstc2':
Y
Yibing Liu 已提交
768 769
        INNER_SEP = 1
        limit_length = -1
P
pkpk 已提交
770
    elif task_name == 'dstc2_asr':
Y
Yibing Liu 已提交
771 772
        INNER_SEP = 1
        limit_length = -1
P
pkpk 已提交
773
    elif task_name == 'multi-woz':
Y
Yibing Liu 已提交
774 775
        INNER_SEP = 1
        limit_length = 200
P
pkpk 已提交
776
    for (i, label) in enumerate(label_list):
Y
Yibing Liu 已提交
777
        label_map[label] = i
P
pkpk 已提交
778

Y
Yibing Liu 已提交
779 780 781 782 783 784 785 786
    tokens_a = example.text_a
    tokens_b = example.text_b
    tokens_c = example.text_c

    tokens_a_ids = convert_tokens(tokens_a, INNER_SEP, tokenizer)
    tokens_b_ids = convert_tokens(tokens_b, INNER_SEP, tokenizer)
    tokens_c_ids = convert_tokens(tokens_c, INNER_SEP, tokenizer)

P
pkpk 已提交
787
    if tokens_b_ids:
Y
Yibing Liu 已提交
788
        tokens_b_ids = tokens_b_ids[:min(limit_length, len(tokens_b_ids))]
P
pkpk 已提交
789
    else:
0
0YuanZhang0 已提交
790 791
        if len(tokens_a_ids) > max_seq_length - 2:
            tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length + 2:]
P
pkpk 已提交
792 793 794 795 796 797 798
    if not tokens_c_ids:
        if len(tokens_a_ids) > max_seq_length - len(tokens_b_ids) - 3:
            tokens_a_ids = tokens_a_ids[len(tokens_a_ids) - max_seq_length +
                                        len(tokens_b_ids) + 3:]
    else:
        if len(tokens_a_ids) + len(tokens_b_ids) + len(
                tokens_c_ids) > max_seq_length - 4:
Y
Yibing Liu 已提交
799
            left_num = max_seq_length - len(tokens_b_ids) - 4
P
pkpk 已提交
800
            if len(tokens_a_ids) > len(tokens_c_ids):
0
0YuanZhang0 已提交
801
                suffix_num = int(left_num / 2)
P
pkpk 已提交
802
                tokens_c_ids = tokens_c_ids[:min(len(tokens_c_ids), suffix_num)]
0
0YuanZhang0 已提交
803
                prefix_num = left_num - len(tokens_c_ids)
P
pkpk 已提交
804 805 806 807 808 809 810
                tokens_a_ids = tokens_a_ids[max(
                    0, len(tokens_a_ids) - prefix_num):]
            else:
                if not tokens_a_ids:
                    tokens_c_ids = tokens_c_ids[max(
                        0, len(tokens_c_ids) - left_num):]
                else:
Y
Yibing Liu 已提交
811
                    prefix_num = int(left_num / 2)
P
pkpk 已提交
812 813
                    tokens_a_ids = tokens_a_ids[max(
                        0, len(tokens_a_ids) - prefix_num):]
Y
Yibing Liu 已提交
814
                    suffix_num = left_num - len(tokens_a_ids)
P
pkpk 已提交
815 816
                    tokens_c_ids = tokens_c_ids[:min(
                        len(tokens_c_ids), suffix_num)]
Y
Yibing Liu 已提交
817 818 819 820 821 822 823 824 825

    input_ids = []
    segment_ids = []
    input_ids.append(CLS)
    segment_ids.append(0)
    input_ids.extend(tokens_a_ids)
    segment_ids.extend([0] * len(tokens_a_ids))
    input_ids.append(SEP)
    segment_ids.append(0)
P
pkpk 已提交
826
    if tokens_b_ids:
Y
Yibing Liu 已提交
827 828 829 830
        input_ids.extend(tokens_b_ids)
        segment_ids.extend([1] * len(tokens_b_ids))
        input_ids.append(SEP)
        segment_ids.append(1)
P
pkpk 已提交
831
    if tokens_c_ids:
Y
Yibing Liu 已提交
832 833 834 835 836 837
        input_ids.extend(tokens_c_ids)
        segment_ids.extend([0] * len(tokens_c_ids))
        input_ids.append(SEP)
        segment_ids.append(0)

    input_mask = [1] * len(input_ids)
P
pkpk 已提交
838
    if task_name == 'atis_slot':
Y
Yibing Liu 已提交
839
        label_id = [0] + [label_map[l] for l in example.label] + [0]
P
pkpk 已提交
840
    elif task_name in ['dstc2', 'dstc2_asr', 'multi-woz']:
Y
Yibing Liu 已提交
841 842
        label_id_enty = [label_map[l] for l in example.label]
        label_id = []
P
pkpk 已提交
843 844
        for i in range(len(label_map)):
            if i in label_id_enty:
Y
Yibing Liu 已提交
845
                label_id.append(1)
P
pkpk 已提交
846
            else:
Y
Yibing Liu 已提交
847
                label_id.append(0)
P
pkpk 已提交
848
    else:
Y
Yibing Liu 已提交
849
        label_id = label_map[example.label]
P
pkpk 已提交
850

Y
Yibing Liu 已提交
851 852 853 854 855 856 857 858 859 860 861 862 863
    if ex_index < 5:
        print("*** Example ***")
        print("guid: %s" % (example.guid))
        print("input_ids: %s" % " ".join([str(x) for x in input_ids]))
        print("input_mask: %s" % " ".join([str(x) for x in input_mask]))
        print("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
        print("label: %s (id = %s)" % (example.label, label_id))
    feature = InputFeatures(
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        label_id=label_id)

P
pkpk 已提交
864
    return feature