mlm_batching.py 7.4 KB
Newer Older
X
xixiaoyao 已提交
1
# -*- coding: UTF-8 -*-
X
xixiaoyao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mask, padding and batching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np


X
xixiaoyao 已提交
22
def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3, dev_count=1):
X
xixiaoyao 已提交
23 24 25 26 27
    """
    Add mask for batch_tokens, return out, mask_label, mask_pos;
    Note: mask_pos responding the batch_tokens after padded;
    """
    max_len = max([len(sent) for sent in batch_tokens])
X
xixiaoyao 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

    multidev_batch_tokens = []
    multidev_mask_label = []
    multidev_mask_pos = []

    big_batch_tokens = batch_tokens
    stride = len(batch_tokens) // dev_count
    if stride == 0:
        return None, None, None
    p = stride

    for i in range(dev_count):
        batch_tokens = big_batch_tokens[p-stride:p]
        p += stride
        mask_label = []
        mask_pos = []
        prob_mask = np.random.rand(total_token_num)
        # Note: the first token is [CLS], so [low=1]
        replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
        pre_sent_len = 0
        prob_index = 0
        for sent_index, sent in enumerate(batch_tokens):
            mask_flag = False
            prob_index += pre_sent_len
            for token_index, token in enumerate(sent):
                prob = prob_mask[prob_index + token_index]
                if prob > 0.15:
                    continue
                elif 0.03 < prob <= 0.15:
                    # mask
                    if token != SEP and token != CLS:
                        mask_label.append(sent[token_index])
                        sent[token_index] = MASK
                        mask_flag = True
                        mask_pos.append(sent_index * max_len + token_index)
                elif 0.015 < prob <= 0.03:
                    # random replace
                    if token != SEP and token != CLS:
                        mask_label.append(sent[token_index])
                        sent[token_index] = replace_ids[prob_index + token_index]
                        mask_flag = True
                        mask_pos.append(sent_index * max_len + token_index)
                else:
                    # keep the original token
                    if token != SEP and token != CLS:
                        mask_label.append(sent[token_index])
                        mask_pos.append(sent_index * max_len + token_index)
            pre_sent_len = len(sent)
            # ensure at least mask one word in a sentence
            while not mask_flag:
                token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
                if sent[token_index] != SEP and sent[token_index] != CLS:
X
xixiaoyao 已提交
80 81 82 83
                    mask_label.append(sent[token_index])
                    sent[token_index] = MASK
                    mask_flag = True
                    mask_pos.append(sent_index * max_len + token_index)
X
xixiaoyao 已提交
84 85 86 87 88 89 90 91
        mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
        mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])

        multidev_batch_tokens.extend(batch_tokens)
        multidev_mask_label.append(mask_label)
        multidev_mask_pos.append(mask_pos)
    
    return multidev_batch_tokens, multidev_mask_label, multidev_mask_pos
X
xixiaoyao 已提交
92 93 94 95 96 97 98 99 100 101


def prepare_batch_data(insts,
                       total_token_num,
                       max_len=None,
                       voc_size=0,
                       pad_id=None,
                       cls_id=None,
                       sep_id=None,
                       mask_id=None,
X
xixiaoyao 已提交
102
                       task_id=0,
X
xixiaoyao 已提交
103 104
                       return_input_mask=True,
                       return_max_len=True,
X
xixiaoyao 已提交
105 106
                       return_num_token=False, 
                       dev_count=1):
X
xixiaoyao 已提交
107 108 109 110 111 112 113 114
    """
    1. generate Tensor of data
    2. generate Tensor of position
    3. generate self attention mask, [shape: batch_size *  max_len * max_len]
    """
    batch_src_ids = [inst[0] for inst in insts]
    batch_sent_ids = [inst[1] for inst in insts]
    batch_pos_ids = [inst[2] for inst in insts]
X
xixiaoyao 已提交
115

X
xixiaoyao 已提交
116
    # 这里是否应该反过来???否则在task layer里展开后的word embedding是padding后的,这时候word的index是跟没有padding时的index对不上的?
X
xixiaoyao 已提交
117
    # First step: do mask without padding
X
xixiaoyao 已提交
118 119 120 121 122 123
    out, mask_label, mask_pos = mask(
        batch_src_ids,
        total_token_num,
        vocab_size=voc_size,
        CLS=cls_id,
        SEP=sep_id,
X
xixiaoyao 已提交
124 125
        MASK=mask_id,
        dev_count=dev_count)
X
xixiaoyao 已提交
126 127 128 129 130
    # Second step: padding
    src_id, self_input_mask = pad_batch_data(
        out, 
        max_len=max_len,
        pad_idx=pad_id, return_input_mask=True)
X
xixiaoyao 已提交
131

X
xixiaoyao 已提交
132 133 134 135 136 137 138 139 140 141 142 143
    pos_id = pad_batch_data(
        batch_pos_ids,
        max_len=max_len,
        pad_idx=pad_id,
        return_pos=False,
        return_input_mask=False)
    sent_id = pad_batch_data(
        batch_sent_ids,
        max_len=max_len,
        pad_idx=pad_id,
        return_pos=False,
        return_input_mask=False)
X
xixiaoyao 已提交
144 145 146 147 148
    task_ids = np.ones_like(
        src_id, dtype="int64") * task_id
    return_list = [
        src_id, pos_id, sent_id, self_input_mask, task_ids, mask_label, mask_pos
    ]
X
xixiaoyao 已提交
149
    return return_list
X
xixiaoyao 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170


def pad_batch_data(insts,
                   max_len=None,
                   pad_idx=0,
                   return_pos=False,
                   return_input_mask=False,
                   return_max_len=False,
                   return_num_token=False):
    """
    Pad the instances to the max sequence length in batch, and generate the
    corresponding position data and input mask.
    """
    return_list = []
    if max_len is None:
        max_len = max(len(inst) for inst in insts)
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array([
        list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts
    ])
W
wangxiao 已提交
171
    return_list += [inst_data.astype("int64").reshape([-1, max_len])]
X
xixiaoyao 已提交
172 173 174 175 176 177
    # position data
    if return_pos:
        inst_pos = np.array([
            list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
            for inst in insts
        ])
W
wangxiao 已提交
178
        return_list += [inst_pos.astype("int64").reshape([-1, max_len])]
X
xixiaoyao 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
    if return_input_mask:
        # This is used to avoid attention on paddings.
        input_mask_data = np.array([[1] * len(inst) + [0] *
                                    (max_len - len(inst)) for inst in insts])
        input_mask_data = np.expand_dims(input_mask_data, axis=-1)
        return_list += [input_mask_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
    if return_num_token:
        num_token = 0
        for inst in insts:
            num_token += len(inst)
        return_list += [num_token]
    return return_list if len(return_list) > 1 else return_list[0]


if __name__ == "__main__":
    pass