mlm_batching.py 6.6 KB
Newer Older
X
xixiaoyao 已提交
1
# -*- coding: UTF-8 -*-
X
xixiaoyao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mask, padding and batching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np


def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
    """
    Add mask for batch_tokens, return out, mask_label, mask_pos;
    Note: mask_pos responding the batch_tokens after padded;
    """
    max_len = max([len(sent) for sent in batch_tokens])
    mask_label = []
    mask_pos = []
    prob_mask = np.random.rand(total_token_num)
    # Note: the first token is [CLS], so [low=1]
    replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
    pre_sent_len = 0
    prob_index = 0
    for sent_index, sent in enumerate(batch_tokens):
        mask_flag = False
        prob_index += pre_sent_len
        for token_index, token in enumerate(sent):
            prob = prob_mask[prob_index + token_index]
            if prob > 0.15:
                continue
            elif 0.03 < prob <= 0.15:
                # mask
                if token != SEP and token != CLS:
                    mask_label.append(sent[token_index])
                    sent[token_index] = MASK
                    mask_flag = True
                    mask_pos.append(sent_index * max_len + token_index)
            elif 0.015 < prob <= 0.03:
                # random replace
                if token != SEP and token != CLS:
                    mask_label.append(sent[token_index])
                    sent[token_index] = replace_ids[prob_index + token_index]
                    mask_flag = True
                    mask_pos.append(sent_index * max_len + token_index)
            else:
                # keep the original token
                if token != SEP and token != CLS:
                    mask_label.append(sent[token_index])
                    mask_pos.append(sent_index * max_len + token_index)
        pre_sent_len = len(sent)
        # ensure at least mask one word in a sentence
        while not mask_flag:
            token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
            if sent[token_index] != SEP and sent[token_index] != CLS:
                mask_label.append(sent[token_index])
                sent[token_index] = MASK
                mask_flag = True
                mask_pos.append(sent_index * max_len + token_index)
W
wangxiao 已提交
70 71
    mask_label = np.array(mask_label).astype("int64").reshape([-1])
    mask_pos = np.array(mask_pos).astype("int64").reshape([-1])
X
xixiaoyao 已提交
72 73 74 75 76 77 78 79 80 81 82
    return batch_tokens, mask_label, mask_pos


def prepare_batch_data(insts,
                       total_token_num,
                       max_len=None,
                       voc_size=0,
                       pad_id=None,
                       cls_id=None,
                       sep_id=None,
                       mask_id=None,
X
xixiaoyao 已提交
83
                       task_id=0,
X
xixiaoyao 已提交
84 85 86 87 88 89 90 91 92 93 94
                       return_input_mask=True,
                       return_max_len=True,
                       return_num_token=False):
    """
    1. generate Tensor of data
    2. generate Tensor of position
    3. generate self attention mask, [shape: batch_size *  max_len * max_len]
    """
    batch_src_ids = [inst[0] for inst in insts]
    batch_sent_ids = [inst[1] for inst in insts]
    batch_pos_ids = [inst[2] for inst in insts]
X
xixiaoyao 已提交
95

X
xixiaoyao 已提交
96
    # 这里是否应该反过来???否则在task layer里展开后的word embedding是padding后的,这时候word的index是跟没有padding时的index对不上的?
X
xixiaoyao 已提交
97
    # First step: do mask without padding
X
xixiaoyao 已提交
98 99 100 101 102 103 104
    out, mask_label, mask_pos = mask(
        batch_src_ids,
        total_token_num,
        vocab_size=voc_size,
        CLS=cls_id,
        SEP=sep_id,
        MASK=mask_id)
X
xixiaoyao 已提交
105 106 107 108 109
    # Second step: padding
    src_id, self_input_mask = pad_batch_data(
        out, 
        max_len=max_len,
        pad_idx=pad_id, return_input_mask=True)
X
xixiaoyao 已提交
110

X
xixiaoyao 已提交
111 112 113 114 115 116 117 118 119 120 121 122
    pos_id = pad_batch_data(
        batch_pos_ids,
        max_len=max_len,
        pad_idx=pad_id,
        return_pos=False,
        return_input_mask=False)
    sent_id = pad_batch_data(
        batch_sent_ids,
        max_len=max_len,
        pad_idx=pad_id,
        return_pos=False,
        return_input_mask=False)
X
xixiaoyao 已提交
123 124 125 126 127
    task_ids = np.ones_like(
        src_id, dtype="int64") * task_id
    return_list = [
        src_id, pos_id, sent_id, self_input_mask, task_ids, mask_label, mask_pos
    ]
X
xixiaoyao 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
    return return_list if len(return_list) > 1 else return_list[0]


def pad_batch_data(insts,
                   max_len=None,
                   pad_idx=0,
                   return_pos=False,
                   return_input_mask=False,
                   return_max_len=False,
                   return_num_token=False):
    """
    Pad the instances to the max sequence length in batch, and generate the
    corresponding position data and input mask.
    """
    return_list = []
    if max_len is None:
        max_len = max(len(inst) for inst in insts)
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array([
        list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts
    ])
W
wangxiao 已提交
150
    return_list += [inst_data.astype("int64").reshape([-1, max_len])]
X
xixiaoyao 已提交
151 152 153 154 155 156
    # position data
    if return_pos:
        inst_pos = np.array([
            list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
            for inst in insts
        ])
W
wangxiao 已提交
157
        return_list += [inst_pos.astype("int64").reshape([-1, max_len])]
X
xixiaoyao 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
    if return_input_mask:
        # This is used to avoid attention on paddings.
        input_mask_data = np.array([[1] * len(inst) + [0] *
                                    (max_len - len(inst)) for inst in insts])
        input_mask_data = np.expand_dims(input_mask_data, axis=-1)
        return_list += [input_mask_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
    if return_num_token:
        num_token = 0
        for inst in insts:
            num_token += len(inst)
        return_list += [num_token]
    return return_list if len(return_list) > 1 else return_list[0]


if __name__ == "__main__":
    pass