reader_helper.py 10.7 KB
Newer Older
X
xixiaoyao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# -*- coding: UTF-8 -*-
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import random
import numpy as np
import paddle
from paddle import fluid
from paddle.fluid import layers


X
xixiaoyao 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37
def create_feed_batch_process_fn(net_inputs):

    def feed_batch_process_fn(data):
        temp = {}
        for q, var in net_inputs.items():
            if isinstance(var, str) or isinstance(var, unicode):
                temp[var] = data[q]
            else:
                temp[var.name] = data[q]
        return temp

    return feed_batch_process_fn

X
xixiaoyao 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

def create_multihead_feed_batch_process_fn(net_inputs):

    def feed_batch_process_fn(data, id=-1):
        # temps = {}
        # for i in range(len(net_inputs)):
        temp = {}
        inputs = net_inputs[id] if id != -1 else net_inputs
        
        for q, var in inputs.items():
            if isinstance(var, str) or isinstance(var, unicode):
                temp[var] = data[q]
            else:
                temp[var.name] = data[q]
            # temps[i] = temp
            
        return temp

    return feed_batch_process_fn


X
xixiaoyao 已提交
59
def _check_and_adapt_shape_dtype(rt_val, attr, message=""):
X
xixiaoyao 已提交
60 61 62 63 64 65 66
    if not isinstance(rt_val, np.ndarray):
        rt_val = np.array(rt_val)
        assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)."
        if rt_val.dtype == np.dtype('float64'):
            rt_val = rt_val.astype('float32')
    
    shape, dtype = attr
X
xixiaoyao 已提交
67 68
    assert rt_val.dtype == np.dtype(dtype), message+"yielded data type not consistent with attr settings. Expect: {}, receive: {}.".format(rt_val.dtype, np.dtype(dtype))
    assert len(shape) == rt_val.ndim, message+"yielded data rank(ndim) not consistent with attr settings. Expect: {}, receive: {}.".format(len(shape), rt_val.ndim)
X
xixiaoyao 已提交
69 70 71
    for rt, exp in zip(rt_val.shape, shape):
        if exp is None or exp < 0:
            continue
X
xixiaoyao 已提交
72
        assert rt == exp, "yielded data shape is not consistent with attr settings.Expected:{}Actual:{}".format(exp, rt)
X
xixiaoyao 已提交
73 74 75 76 77 78 79 80 81 82 83 84
    return rt_val
    

def _zero_batch(attrs):
    pos_attrs = []
    for shape, dtype in attrs:
        pos_shape = [size if size and size > 0 else 1 for size in shape]
        pos_attrs.append([pos_shape, dtype])

    return [np.zeros(shape=shape, dtype=dtype) for shape, dtype in pos_attrs]


X
xixiaoyao 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97
def _zero_batch_x(attrs, batch_size):
    pos_attrs = []
    for shape, dtype in attrs:
        pos_shape = [size for size in shape]
        if pos_shape[0] == -1:
            pos_shape[0] = batch_size
        if pos_shape[1] == -1:
            pos_shape[1] = 512 # max seq len
        pos_attrs.append([pos_shape, dtype])

    return [np.zeros(shape=shape, dtype=dtype) for shape, dtype in pos_attrs]


X
xixiaoyao 已提交
98 99 100 101 102 103 104 105 106 107
def create_net_inputs(input_attrs, async=False, iterator_fn=None, dev_count=1, n_prefetch=1):
    inputs = []
    ret = {}
    for name, shape, dtype in input_attrs:
        p = layers.data(name, shape=shape, dtype=dtype)
        ret[name] = p
        inputs.append(p)

    if async:
        assert iterator_fn is not None, "iterator_fn is needed for building async input layer."
X
xixiaoyao 已提交
108
        reader = fluid.io.PyReader(inputs, capacity=dev_count, iterable=False)
X
xixiaoyao 已提交
109 110 111 112 113 114
        reader.decorate_batch_generator(iterator_fn)
        reader.start()

    return ret


X
xixiaoyao 已提交
115
def create_iterator_fn(iterator, iterator_prefix, shape_and_dtypes, outname_to_pos, verbose=0, return_type='list'):
X
xixiaoyao 已提交
116

X
xixiaoyao 已提交
117 118 119
    pos_to_outname = {j:i for i,j in outname_to_pos.items()}
    
    def iterator_fn():
X
xixiaoyao 已提交
120 121
        v = verbose
        while True:
X
xixiaoyao 已提交
122 123
            # results = _zero_batch(shape_and_dtypes)
            results = [None] * len(outname_to_pos)
X
xixiaoyao 已提交
124 125

            outputs = next(iterator) # dict type
X
xixiaoyao 已提交
126
            prefix = iterator_prefix
X
xixiaoyao 已提交
127
            for outname, val in outputs.items():
X
xixiaoyao 已提交
128
                task_outname = prefix + '.' + outname
X
xixiaoyao 已提交
129 130 131

                if outname in outname_to_pos:
                    idx = outname_to_pos[outname]
X
xixiaoyao 已提交
132
                    val = _check_and_adapt_shape_dtype(val, shape_and_dtypes[idx])
X
xixiaoyao 已提交
133 134 135 136
                    results[idx] = val

                if task_outname in outname_to_pos:
                    idx = outname_to_pos[task_outname]
X
xixiaoyao 已提交
137
                    val = _check_and_adapt_shape_dtype(val, shape_and_dtypes[idx])
X
xixiaoyao 已提交
138
                    results[idx] = val
X
xixiaoyao 已提交
139 140 141 142 143 144
            if return_type == 'list':
                yield results
            elif return_type == 'dict':
                temp = {}
                for pos, i in enumerate(results):
                    temp[pos_to_outname[pos]] = i
X
xixiaoyao 已提交
145

X
xixiaoyao 已提交
146
                yield temp
X
xixiaoyao 已提交
147

X
xixiaoyao 已提交
148
    return iterator_fn
X
xixiaoyao 已提交
149 150


X
xixiaoyao 已提交
151
def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, outname_to_pos, dev_count=1, keep_one_task=True, verbose=0):
X
xixiaoyao 已提交
152 153 154
    """
        joint_shape_and_dtypes: 本质上是根据bb和parad的attr设定的,并且由reader中的attr自动填充-1(可变)维度得到,因此通过与iterator的校验可以完成runtime的batch正确性检查
    """
X
xixiaoyao 已提交
155

X
xixiaoyao 已提交
156 157 158 159 160 161 162 163 164 165 166 167
    task_ids = range(len(iterators))
    weights = [mr / float(sum(mrs)) for mr in mrs]
    if not keep_one_task:
        dev_count = 1

    results = _zero_batch(joint_shape_and_dtypes)
    outbuf = {}
    for id in task_ids:
        outputs = next(iterators[id]) # dict type
        outbuf[id] = outputs
        prefix = iterator_prefixes[id]
        for outname, val in outputs.items():
X
xixiaoyao 已提交
168
            task_outname = prefix + '.' + outname
X
xixiaoyao 已提交
169 170 171

            if outname in outname_to_pos:
                idx = outname_to_pos[outname]
X
xixiaoyao 已提交
172
                val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=outname+': ')
X
xixiaoyao 已提交
173 174 175 176
                results[idx] = val

            if task_outname in outname_to_pos:
                idx = outname_to_pos[task_outname]
X
xixiaoyao 已提交
177
                val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ')
X
xixiaoyao 已提交
178 179 180 181 182 183 184
                results[idx] = val

    fake_batch = results
    dev_count_bak = dev_count

    def iterator():
        v = verbose
X
xixiaoyao 已提交
185
        has_show_warn = False
X
xixiaoyao 已提交
186 187 188 189 190 191 192 193 194
        while True:
            id = np.random.choice(task_ids, p=weights)
            results = fake_batch
            if v > 0:
                print('----- debug joint iterator -----')
                print('sampled task id: '+str(id))
            task_id_tensor = np.array([[id]]).astype("int64")
            
            for i in range(dev_count):
X
xixiaoyao 已提交
195 196 197 198
                
                results[outname_to_pos['__task_id']] = task_id_tensor
                assert outname_to_pos['__task_id'] == 0

X
xixiaoyao 已提交
199 200 201 202 203 204
                if id in outbuf:
                    outputs = outbuf[id]
                    del outbuf[id]
                else:
                    outputs = next(iterators[id]) # dict type

X
xixiaoyao 已提交
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
                if 'token_ids' in outputs:
                    val1 = len(outputs['token_ids'])
                    val = _check_and_adapt_shape_dtype([val1], [[1], 'int64'])
                    results[outname_to_pos['batch_size']] = val

                    val2 = len(outputs['token_ids'][0])
                    val = _check_and_adapt_shape_dtype([val2], [[1], 'int64'])
                    results[outname_to_pos['seqlen']] = val

                    val = _check_and_adapt_shape_dtype([val1*val2], [[1], 'int64'])
                    results[outname_to_pos['batchsize_x_seqlen']] = val
                else:
                    if not has_show_warn:
                        print('WARNING: token_ids not found in current batch, failed to yield batch_size, seqlen and batchsize_x_seqlen. (This message would be shown only once.)')
                        has_show_warn = True
X
xixiaoyao 已提交
220

X
xixiaoyao 已提交
221 222 223 224
                prefix = iterator_prefixes[id]
                for outname, val in outputs.items():
                    if v > 0:
                        print('reader generate: '+outname)
X
xixiaoyao 已提交
225
                    task_outname = prefix + '.' + outname
X
xixiaoyao 已提交
226 227 228 229 230

                    if outname in outname_to_pos:
                        idx = outname_to_pos[outname]
                        if v > 0:
                            print(outname + ' is insert in idx ' + str(idx))
X
xixiaoyao 已提交
231
                        val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=outname+': ')
X
xixiaoyao 已提交
232 233 234 235 236 237
                        results[idx] = val

                    if task_outname in outname_to_pos:
                        idx = outname_to_pos[task_outname]
                        if v > 0:
                            print(task_outname + ' is insert in idx ' + str(idx))
X
xixiaoyao 已提交
238
                        val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ')
X
xixiaoyao 已提交
239 240 241 242 243 244 245 246 247 248 249 250 251 252
                        results[idx] = val

                if v > 0:
                    print('yielded batch len and shapes:')
                    print(len(results))
                    for i in results:
                        print(np.shape(i))
                    print('')
                    v -= 1
                yield results

    return iterator


X
xixiaoyao 已提交
253
def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batchsize=True, insert_seqlen=True, insert_batchsize_x_seqlen=True):
X
xixiaoyao 已提交
254 255 256 257 258 259 260
    """
    Args:
        task_attrs(list[dict]|dict): task input attributes, key=attr_name, val=[shape, dtype], support single task and nested tasks
    """
    if isinstance(task_attrs, dict):
        task_attrs = [task_attrs]

X
xixiaoyao 已提交
261 262 263
    ret = []
    names = []
    start = 0
X
xixiaoyao 已提交
264
    if insert_taskid:
X
xixiaoyao 已提交
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
        ret.append(([1,1], 'int64'))
        names.append('__task_id')
        start += 1
    
    if insert_batchsize:
        ret.append(([1], 'int64'))
        names.append('batch_size')
        start += 1

    if insert_seqlen:
        ret.append(([1], 'int64'))
        names.append('seqlen')
        start += 1

    if insert_batchsize_x_seqlen:
        ret.append(([1], 'int64'))
X
xixiaoyao 已提交
281
        names.append(u'batchsize_x_seqlen')
X
xixiaoyao 已提交
282
        start += 1
X
xixiaoyao 已提交
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
        
    names += sorted(backbone_attr.keys())
    ret.extend([backbone_attr[k] for k in names[start:]])
    name_to_position = {}
    # pos=0 is for task_id, thus we start from 1
    for pos, k in enumerate(names):
        name_to_position[k] = pos
    for task_attr in task_attrs:
        task_names = sorted(task_attr.keys())
        names.extend(task_names)
        ret.extend([task_attr[k] for k in task_names])
        for pos, k in enumerate(task_names, start=len(name_to_position)):
            name_to_position[k] = pos
    return names, ret, name_to_position