finetune_batching.py 4.5 KB
Newer Older
T
tangjiji 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
#    Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");                                                      
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
        
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" prepare data format for finetuning tasks """

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from six.moves import xrange


def prepare_batch_data(batch_records, num_choice, pad_id, task_index, task_num):
    """
    prepare batch data for finetuning tasks
    """
    batch_input_ids = []
    batch_input_pos = []
    batch_seg_ids = []
    batch_input_masks = []
    num_sample = len(batch_records)
    batch_lens = [record["input_lens"] for record in batch_records]
    batch_labels = [record["target"] for record in batch_records]
    binary_labels = np.zeros([num_choice * num_sample, 1], dtype='float32')
    for i, l in enumerate(batch_labels):
        binary_labels[i * num_choice + l] = 1.0
    labels = np.array(batch_labels).astype("int64").reshape([-1, 1])
    image_features = [record["features"] for record in batch_records]
    image_boxes = [record["boxes"] for record in batch_records]
    batch_anno_ids = np.array([record["anno_id"] for record in batch_records]).astype("int64").reshape([-1, 1])
    max_len = max([max(lens) for lens in batch_lens])
    for i in range(len(batch_records)):
        batch_input_ids.append([inst + list([pad_id] * (max_len - len(inst))) \
            for inst in batch_records[i]["input_ids"]])
        batch_input_pos.append([inst + list([pad_id] * (max_len - len(inst))) \
            for inst in batch_records[i]["input_pos"]])
        batch_seg_ids.append([inst + list([pad_id] * (max_len - len(inst)))   \
            for inst in batch_records[i]["segment_ids"]])
        batch_input_masks.append([[1] * len(inst) + [0] * (max_len - len(inst))    \
            for inst in batch_records[i]["input_ids"]])

    image_embedding, image_mask = pad_feature_data(image_features, return_mask=True)
    image_loc = pad_feature_data(image_boxes)
    src_ids = np.array(batch_input_ids).astype("int64").reshape([num_choice * num_sample, max_len, 1])
    src_pos = np.array(batch_input_pos).astype("int64").reshape([num_choice * num_sample, max_len, 1])
    src_seg = np.array(batch_seg_ids).astype("int64").reshape([num_choice * num_sample, max_len, 1])
    src_masks = np.array(batch_input_masks).astype("float32").reshape([num_choice * num_sample, max_len, 1])
    src_task = np.zeros(src_ids.shape, dtype="int64")
    batch, seq_len, fea_len = image_embedding.shape
    image_embedding = np.tile(np.expand_dims(image_embedding, axis=1),    \
        (1, num_choice, 1, 1)).reshape([num_choice * batch, seq_len, fea_len])
    image_mask = np.tile(np.expand_dims(image_mask, axis=1),        \
        (1, num_choice, 1, 1)).reshape([num_choice * batch, seq_len, 1])
    image_loc = np.tile(np.expand_dims(image_loc, axis=1),     \
        (1, num_choice, 1, 1)).reshape([num_choice * batch, seq_len, 5])
    return_list = [src_ids, src_pos, src_seg, src_task, src_masks, \
        image_embedding, image_loc, image_mask, labels, batch_anno_ids]
    return_list.append(np.array([task_index]).astype('int64'))
    return_list.append(binary_labels)
    for i in xrange(task_num):
        if i == task_index:
            return_list.append(np.array([1.0]).astype("float32"))
        else:
            return_list.append(np.array([0.0]).astype("float32"))
    return return_list


def pad_feature_data(data, pad_value=0.0, dtype="float32", return_mask=False):
    """
    pad visual features with given pad value
    """
    max_lenth=max([len(item) for item in data])
    data_width = len(data[0][0])
    out_data = np.ones((len(data), max_lenth, data_width), dtype=dtype) * pad_value
    out_mask = np.zeros((len(data), max_lenth, 1), dtype=dtype)
    for i in range(len(data)):
        out_data[i, 0: len(data[i]), :] = data[i]
        if return_mask:
            out_mask[i, 0:len(data[i]):] = 1.0
    if return_mask:
        return out_data, out_mask
    else:
        return out_data

if __name__ == "__main__":
    pass