reader.py 7.7 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4
"""
Reader for deep attention matching network
"""

Y
Yibing Liu 已提交
5
import six
Y
Yibing Liu 已提交
6 7
import numpy as np

Y
Yibing Liu 已提交
8 9 10 11 12
try:
    import cPickle as pickle  #python 2
except ImportError as e:
    import pickle  #python 3

Y
Yibing Liu 已提交
13 14

def unison_shuffle(data, seed=None):
Y
Yibing Liu 已提交
15 16 17
    """
    Shuffle data
    """
Y
Yibing Liu 已提交
18 19 20
    if seed is not None:
        np.random.seed(seed)

Y
Yibing Liu 已提交
21 22 23
    y = np.array(data[six.b('y')])
    c = np.array(data[six.b('c')])
    r = np.array(data[six.b('r')])
Y
Yibing Liu 已提交
24 25 26

    assert len(y) == len(c) == len(r)
    p = np.random.permutation(len(y))
Y
Yibing Liu 已提交
27
    print(p)
Y
Yibing Liu 已提交
28
    shuffle_data = {six.b('y'): y[p], six.b('c'): c[p], six.b('r'): r[p]}
Y
Yibing Liu 已提交
29 30 31 32
    return shuffle_data


def split_c(c, split_id):
Y
Yibing Liu 已提交
33 34 35 36 37 38
    """
    Split
    c is a list, example context
    split_id is a integer, conf[_EOS_]
    return nested list
    """
Y
Yibing Liu 已提交
39 40 41 42 43 44 45 46 47 48 49 50
    turns = [[]]
    for _id in c:
        if _id != split_id:
            turns[-1].append(_id)
        else:
            turns.append([])
    if turns[-1] == [] and len(turns) > 1:
        turns.pop()
    return turns


def normalize_length(_list, length, cut_type='tail'):
Y
Yibing Liu 已提交
51
    """_list is a list or nested list, example turns/r/single turn c
Y
Yibing Liu 已提交
52 53
       cut_type is head or tail, if _list len > length is used
       return a list len=length and min(read_length, length)
Y
Yibing Liu 已提交
54
    """
Y
Yibing Liu 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    real_length = len(_list)
    if real_length == 0:
        return [0] * length, 0

    if real_length <= length:
        if not isinstance(_list[0], list):
            _list.extend([0] * (length - real_length))
        else:
            _list.extend([[]] * (length - real_length))
        return _list, real_length

    if cut_type == 'head':
        return _list[:length], length
    if cut_type == 'tail':
        return _list[-length:], length


def produce_one_sample(data,
                       index,
                       split_id,
                       max_turn_num,
                       max_turn_len,
                       turn_cut_type='tail',
                       term_cut_type='tail'):
Y
Yibing Liu 已提交
79
    """max_turn_num=10
Y
Yibing Liu 已提交
80 81
       max_turn_len=50
       return y, nor_turns_nor_c, nor_r, turn_len, term_len, r_len
Y
Yibing Liu 已提交
82
    """
Y
Yibing Liu 已提交
83 84 85
    c = data[six.b('c')][index]
    r = data[six.b('r')][index][:]
    y = data[six.b('y')][index]
Y
Yibing Liu 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109

    turns = split_c(c, split_id)
    #normalize turns_c length, nor_turns length is max_turn_num
    nor_turns, turn_len = normalize_length(turns, max_turn_num, turn_cut_type)

    nor_turns_nor_c = []
    term_len = []
    #nor_turn_nor_c length is max_turn_num, element is a list length is max_turn_len
    for c in nor_turns:
        #nor_c length is max_turn_len
        nor_c, nor_c_len = normalize_length(c, max_turn_len, term_cut_type)
        nor_turns_nor_c.append(nor_c)
        term_len.append(nor_c_len)

    nor_r, r_len = normalize_length(r, max_turn_len, term_cut_type)

    return y, nor_turns_nor_c, nor_r, turn_len, term_len, r_len


def build_one_batch(data,
                    batch_index,
                    conf,
                    turn_cut_type='tail',
                    term_cut_type='tail'):
Y
Yibing Liu 已提交
110 111 112
    """
    Build one batch
    """
Y
Yibing Liu 已提交
113 114 115 116 117 118 119 120 121
    _turns = []
    _tt_turns_len = []
    _every_turn_len = []

    _response = []
    _response_len = []

    _label = []

Y
Yibing Liu 已提交
122
    for i in six.moves.xrange(conf['batch_size']):
Y
Yibing Liu 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
        index = batch_index * conf['batch_size'] + i
        y, nor_turns_nor_c, nor_r, turn_len, term_len, r_len = produce_one_sample(
            data, index, conf['_EOS_'], conf['max_turn_num'],
            conf['max_turn_len'], turn_cut_type, term_cut_type)

        _label.append(y)
        _turns.append(nor_turns_nor_c)
        _response.append(nor_r)
        _every_turn_len.append(term_len)
        _tt_turns_len.append(turn_len)
        _response_len.append(r_len)

    return _turns, _tt_turns_len, _every_turn_len, _response, _response_len, _label


def build_one_batch_dict(data,
                         batch_index,
                         conf,
                         turn_cut_type='tail',
                         term_cut_type='tail'):
Y
Yibing Liu 已提交
143 144 145
    """
    Build one batch dict
    """
Y
Yibing Liu 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158 159
    _turns, _tt_turns_len, _every_turn_len, _response, _response_len, _label = build_one_batch(
        data, batch_index, conf, turn_cut_type, term_cut_type)
    ans = {
        'turns': _turns,
        'tt_turns_len': _tt_turns_len,
        'every_turn_len': _every_turn_len,
        'response': _response,
        'response_len': _response_len,
        'label': _label
    }
    return ans


def build_batches(data, conf, turn_cut_type='tail', term_cut_type='tail'):
Y
Yibing Liu 已提交
160 161 162
    """
    Build batches
    """
Y
Yibing Liu 已提交
163 164 165 166 167 168 169 170 171
    _turns_batches = []
    _tt_turns_len_batches = []
    _every_turn_len_batches = []

    _response_batches = []
    _response_len_batches = []

    _label_batches = []

Y
Yibing Liu 已提交
172 173
    batch_len = len(data[six.b('y')]) // conf['batch_size']
    for batch_index in six.moves.range(batch_len):
Y
Yibing Liu 已提交
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
        _turns, _tt_turns_len, _every_turn_len, _response, _response_len, _label = build_one_batch(
            data, batch_index, conf, turn_cut_type='tail', term_cut_type='tail')

        _turns_batches.append(_turns)
        _tt_turns_len_batches.append(_tt_turns_len)
        _every_turn_len_batches.append(_every_turn_len)

        _response_batches.append(_response)
        _response_len_batches.append(_response_len)

        _label_batches.append(_label)

    ans = {
        "turns": _turns_batches,
        "tt_turns_len": _tt_turns_len_batches,
        "every_turn_len": _every_turn_len_batches,
        "response": _response_batches,
        "response_len": _response_len_batches,
        "label": _label_batches
    }

    return ans


def make_one_batch_input(data_batches, index):
    """Split turns and return feeding data.

    Args:
        data_batches: All data batches
        index: The index for current batch

    Return:
        feeding dictionary
    """

209 210 211 212 213 214
    turns = np.array(data_batches["turns"][index]).astype('int64')
    tt_turns_len = np.array(data_batches["tt_turns_len"][index]).astype('int64')
    every_turn_len = np.array(data_batches["every_turn_len"][index]).astype(
        'int64')
    response = np.array(data_batches["response"][index]).astype('int64')
    response_len = np.array(data_batches["response_len"][index]).astype('int64')
Y
Yibing Liu 已提交
215 216 217 218 219

    batch_size = turns.shape[0]
    max_turn_num = turns.shape[1]
    max_turn_len = turns.shape[2]

Y
Yibing Liu 已提交
220 221 222 223
    turns_list = [turns[:, i, :] for i in six.moves.xrange(max_turn_num)]
    every_turn_len_list = [
        every_turn_len[:, i] for i in six.moves.xrange(max_turn_num)
    ]
Y
Yibing Liu 已提交
224

Y
Yibing Liu 已提交
225
    feed_list = []
Y
Yibing Liu 已提交
226
    for i, turn in enumerate(turns_list):
Y
Yibing Liu 已提交
227 228
        turn = np.expand_dims(turn, axis=-1)
        feed_list.append(turn)
Y
Yibing Liu 已提交
229 230

    for i, turn_len in enumerate(every_turn_len_list):
Y
Yibing Liu 已提交
231
        turn_mask = np.ones((batch_size, max_turn_len, 1)).astype("float32")
Y
Yibing Liu 已提交
232
        for row in six.moves.xrange(batch_size):
Y
Yibing Liu 已提交
233 234
            turn_mask[row, turn_len[row]:, 0] = 0
        feed_list.append(turn_mask)
Y
Yibing Liu 已提交
235

Y
Yibing Liu 已提交
236 237
    response = np.expand_dims(response, axis=-1)
    feed_list.append(response)
Y
Yibing Liu 已提交
238

Y
Yibing Liu 已提交
239
    response_mask = np.ones((batch_size, max_turn_len, 1)).astype("float32")
Y
Yibing Liu 已提交
240
    for row in six.moves.xrange(batch_size):
Y
Yibing Liu 已提交
241 242
        response_mask[row, response_len[row]:, 0] = 0
    feed_list.append(response_mask)
Y
Yibing Liu 已提交
243

Y
Yibing Liu 已提交
244
    label = np.array([data_batches["label"][index]]).reshape(
Y
Yibing Liu 已提交
245
        [-1, 1]).astype("float32")
Y
Yibing Liu 已提交
246
    feed_list.append(label)
Y
Yibing Liu 已提交
247

Y
Yibing Liu 已提交
248
    return feed_list
Y
Yibing Liu 已提交
249 250 251 252 253 254 255 256 257


if __name__ == '__main__':
    conf = {
        "batch_size": 256,
        "max_turn_num": 10,
        "max_turn_len": 50,
        "_EOS_": 28270,
    }
Y
Yibing Liu 已提交
258 259 260 261 262
    with open('../ubuntu/data/data_small.pkl', 'rb') as f:
        if six.PY2:
            train, val, test = pickle.load(f)
        else:
            train, val, test = pickle.load(f, encoding="bytes")
Y
Yibing Liu 已提交
263 264 265 266 267 268
    print('load data success')

    train_batches = build_batches(train, conf)
    val_batches = build_batches(val, conf)
    test_batches = build_batches(test, conf)
    print('build batches success')