batching.py 3.7 KB
Newer Older
K
kgresearch 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
"""Mask, padding and batching."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np


def mask(input_tokens, input_mask_type, max_len, mask_id):
    """
    Add mask for batch_tokens, return out, mask_label, mask_pos;
    Note: mask_pos responding the batch_tokens after padded;
    """
    output_tokens = []
    mask_label = []
    mask_pos = []
    for sent_index, sent in enumerate(input_tokens):
        mask_type = input_mask_type[sent_index]
        if mask_type == "MASK_HEAD":
            token_index = 0
            mask_label.append(sent[token_index])
            mask_pos.append(sent_index * max_len + token_index)
            sent_out = sent[:]
            sent_out[token_index] = mask_id
            output_tokens.append(sent_out)
        elif mask_type == "MASK_TAIL":
            token_index = len(sent) - 1
            mask_label.append(sent[token_index])
            mask_pos.append(sent_index * max_len + token_index)
            sent_out = sent[:]
            sent_out[token_index] = mask_id
            output_tokens.append(sent_out)
        else:
            raise ValueError(
                "Unknown mask type, which should be in ['MASK_HEAD', 'MASK_TAIL']."
            )
    mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
    mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
    return output_tokens, mask_label, mask_pos


def pad_batch_data(insts,
                   max_len,
                   pad_idx=0,
                   return_pos=False,
                   return_input_mask=False):
    """
    Pad the instances to the max sequence length in batch, and generate the
    corresponding position data and input mask.
    """
    return_list = []

    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.

    inst_data = np.array([
        list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts
    ])
    return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])]

    # position data
    if return_pos:
        inst_pos = np.array([
            list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
            for inst in insts
        ])

        return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])]

    if return_input_mask:
        # This is used to avoid attention on paddings.
        input_mask_data = np.array([[1] * len(inst) + [0] *
                                    (max_len - len(inst)) for inst in insts])
        input_mask_data = np.expand_dims(input_mask_data, axis=-1)
        return_list += [input_mask_data.astype("float32")]

    return return_list if len(return_list) > 1 else return_list[0]


def prepare_batch_data(insts, max_len, pad_id=None, mask_id=None):
    """ masking, padding, turn list data into numpy arrays, for batch examples
    """
    batch_src_ids = [inst[0] for inst in insts]
    batch_mask_type = [inst[1] for inst in insts]

    # First step: do mask without padding
    if mask_id >= 0:
        out, mask_label, mask_pos = mask(
            input_tokens=batch_src_ids,
            input_mask_type=batch_mask_type,
            max_len=max_len,
            mask_id=mask_id)
    else:
        out = batch_src_ids

    # Second step: padding and turn into numpy arrays
    src_id, pos_id, input_mask = pad_batch_data(
        out,
        max_len=max_len,
        pad_idx=pad_id,
        return_pos=True,
        return_input_mask=True)

    if mask_id >= 0:
        return_list = [src_id, pos_id, input_mask, mask_label, mask_pos]
    else:
        return_list = [src_id, pos_id, input_mask]

    return return_list if len(return_list) > 1 else return_list[0]