# -*- coding: UTF-8 -*- # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Mask, padding and batching.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from six.moves import xrange def mask(batch_tokens, seg_labels, mask_word_tags, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3): """ Add mask for batch_tokens, return out, mask_label, mask_pos; Note: mask_pos responding the batch_tokens after padded; """ max_len = max([len(sent) for sent in batch_tokens]) mask_label = [] mask_pos = [] prob_mask = np.random.rand(total_token_num) # Note: the first token is [CLS], so [low=1] replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num) pre_sent_len = 0 prob_index = 0 for sent_index, sent in enumerate(batch_tokens): mask_flag = False mask_word = mask_word_tags[sent_index] prob_index += pre_sent_len if mask_word: beg = 0 for token_index, token in enumerate(sent): seg_label = seg_labels[sent_index][token_index] if seg_label == 1: continue if beg == 0: if seg_label != -1: beg = token_index continue prob = prob_mask[prob_index + beg] if prob > 0.15: pass else: for index in xrange(beg, token_index): prob = prob_mask[prob_index + index] base_prob = 1.0 if index == beg: base_prob = 0.15 if base_prob * 0.2 < prob <= base_prob: mask_label.append(sent[index]) sent[index] = MASK mask_flag = True mask_pos.append(sent_index * max_len + index) elif base_prob * 0.1 < prob <= base_prob * 0.2: mask_label.append(sent[index]) sent[index] = replace_ids[prob_index + index] mask_flag = True mask_pos.append(sent_index * max_len + index) else: mask_label.append(sent[index]) mask_pos.append(sent_index * max_len + index) if seg_label == -1: beg = 0 else: beg = token_index else: for token_index, token in enumerate(sent): prob = prob_mask[prob_index + token_index] if prob > 0.15: continue elif 0.03 < prob <= 0.15: # mask if token != SEP and token != CLS: mask_label.append(sent[token_index]) sent[token_index] = MASK mask_flag = True mask_pos.append(sent_index * max_len + token_index) elif 0.015 < prob <= 0.03: # random replace if token != SEP and token != CLS: mask_label.append(sent[token_index]) sent[token_index] = replace_ids[prob_index + token_index] mask_flag = True mask_pos.append(sent_index * max_len + token_index) else: # keep the original token if token != SEP and token != CLS: mask_label.append(sent[token_index]) mask_pos.append(sent_index * max_len + token_index) pre_sent_len = len(sent) mask_label = np.array(mask_label).astype("int64").reshape([-1]) mask_pos = np.array(mask_pos).astype("int64").reshape([-1]) return batch_tokens, mask_label, mask_pos def pad_batch_data(insts, pad_idx=0, return_pos=False, return_input_mask=False, return_max_len=False, return_num_token=False, return_seq_lens=False): """ Pad the instances to the max sequence length in batch, and generate the corresponding position data and attention bias. """ return_list = [] max_len = max(len(inst) for inst in insts) # Any token included in dict can be used to pad, since the paddings' loss # will be masked out by weights and make no effect on parameter gradients. inst_data = np.array( [inst + list([pad_idx] * (max_len - len(inst))) for inst in insts]) return_list += [inst_data.astype("int64").reshape([-1, max_len])] # position data if return_pos: inst_pos = np.array([ list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst)) for inst in insts ]) return_list += [inst_pos.astype("int64").reshape([-1, max_len])] if return_input_mask: # This is used to avoid attention on paddings. input_mask_data = np.array([[1] * len(inst) + [0] * (max_len - len(inst)) for inst in insts]) input_mask_data = np.expand_dims(input_mask_data, axis=-1) return_list += [input_mask_data.astype("float32")] if return_max_len: return_list += [max_len] if return_num_token: num_token = 0 for inst in insts: num_token += len(inst) return_list += [num_token] if return_seq_lens: seq_lens = np.array([len(inst) for inst in insts]) return_list += [seq_lens.astype("int64").reshape([-1])] return return_list if len(return_list) > 1 else return_list[0] if __name__ == "__main__": pass