reader_helper.py 10.2 KB
Newer Older
X
xixiaoyao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# -*- coding: UTF-8 -*-
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import random
import numpy as np
import paddle
from paddle import fluid
from paddle.fluid import layers


X
xixiaoyao 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37
def create_feed_batch_process_fn(net_inputs):

    def feed_batch_process_fn(data):
        temp = {}
        for q, var in net_inputs.items():
            if isinstance(var, str) or isinstance(var, unicode):
                temp[var] = data[q]
            else:
                temp[var.name] = data[q]
        return temp

    return feed_batch_process_fn

X
xixiaoyao 已提交
38
def _check_and_adapt_shape_dtype(rt_val, attr, message=""):
X
xixiaoyao 已提交
39 40 41 42 43 44 45
    if not isinstance(rt_val, np.ndarray):
        rt_val = np.array(rt_val)
        assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)."
        if rt_val.dtype == np.dtype('float64'):
            rt_val = rt_val.astype('float32')
    
    shape, dtype = attr
X
xixiaoyao 已提交
46 47
    assert rt_val.dtype == np.dtype(dtype), message+"yielded data type not consistent with attr settings. Expect: {}, receive: {}.".format(rt_val.dtype, np.dtype(dtype))
    assert len(shape) == rt_val.ndim, message+"yielded data rank(ndim) not consistent with attr settings. Expect: {}, receive: {}.".format(len(shape), rt_val.ndim)
X
xixiaoyao 已提交
48 49 50
    for rt, exp in zip(rt_val.shape, shape):
        if exp is None or exp < 0:
            continue
X
xixiaoyao 已提交
51
        assert rt == exp, "yielded data shape is not consistent with attr settings.Expected:{}Actual:{}".format(exp, rt)
X
xixiaoyao 已提交
52 53 54 55 56 57 58 59 60 61 62 63
    return rt_val
    

def _zero_batch(attrs):
    pos_attrs = []
    for shape, dtype in attrs:
        pos_shape = [size if size and size > 0 else 1 for size in shape]
        pos_attrs.append([pos_shape, dtype])

    return [np.zeros(shape=shape, dtype=dtype) for shape, dtype in pos_attrs]


X
xixiaoyao 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76
def _zero_batch_x(attrs, batch_size):
    pos_attrs = []
    for shape, dtype in attrs:
        pos_shape = [size for size in shape]
        if pos_shape[0] == -1:
            pos_shape[0] = batch_size
        if pos_shape[1] == -1:
            pos_shape[1] = 512 # max seq len
        pos_attrs.append([pos_shape, dtype])

    return [np.zeros(shape=shape, dtype=dtype) for shape, dtype in pos_attrs]


X
xixiaoyao 已提交
77 78 79 80 81 82 83 84 85 86
def create_net_inputs(input_attrs, async=False, iterator_fn=None, dev_count=1, n_prefetch=1):
    inputs = []
    ret = {}
    for name, shape, dtype in input_attrs:
        p = layers.data(name, shape=shape, dtype=dtype)
        ret[name] = p
        inputs.append(p)

    if async:
        assert iterator_fn is not None, "iterator_fn is needed for building async input layer."
X
xixiaoyao 已提交
87
        reader = fluid.io.PyReader(inputs, capacity=dev_count, iterable=False)
X
xixiaoyao 已提交
88 89 90 91 92 93
        reader.decorate_batch_generator(iterator_fn)
        reader.start()

    return ret


X
xixiaoyao 已提交
94
def create_iterator_fn(iterator, iterator_prefix, shape_and_dtypes, outname_to_pos, verbose=0, return_type='list'):
X
xixiaoyao 已提交
95

X
xixiaoyao 已提交
96 97 98
    pos_to_outname = {j:i for i,j in outname_to_pos.items()}
    
    def iterator_fn():
X
xixiaoyao 已提交
99 100
        v = verbose
        while True:
X
xixiaoyao 已提交
101 102
            # results = _zero_batch(shape_and_dtypes)
            results = [None] * len(outname_to_pos)
X
xixiaoyao 已提交
103 104

            outputs = next(iterator) # dict type
X
xixiaoyao 已提交
105
            prefix = iterator_prefix
X
xixiaoyao 已提交
106
            for outname, val in outputs.items():
X
xixiaoyao 已提交
107
                task_outname = prefix + '.' + outname
X
xixiaoyao 已提交
108 109 110

                if outname in outname_to_pos:
                    idx = outname_to_pos[outname]
X
xixiaoyao 已提交
111
                    val = _check_and_adapt_shape_dtype(val, shape_and_dtypes[idx])
X
xixiaoyao 已提交
112 113 114 115
                    results[idx] = val

                if task_outname in outname_to_pos:
                    idx = outname_to_pos[task_outname]
X
xixiaoyao 已提交
116
                    val = _check_and_adapt_shape_dtype(val, shape_and_dtypes[idx])
X
xixiaoyao 已提交
117
                    results[idx] = val
X
xixiaoyao 已提交
118 119 120 121 122 123
            if return_type == 'list':
                yield results
            elif return_type == 'dict':
                temp = {}
                for pos, i in enumerate(results):
                    temp[pos_to_outname[pos]] = i
X
xixiaoyao 已提交
124

X
xixiaoyao 已提交
125
                yield temp
X
xixiaoyao 已提交
126

X
xixiaoyao 已提交
127
    return iterator_fn
X
xixiaoyao 已提交
128 129


X
xixiaoyao 已提交
130
def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, outname_to_pos, dev_count=1, keep_one_task=True, verbose=0):
X
xixiaoyao 已提交
131 132 133
    """
        joint_shape_and_dtypes: 本质上是根据bb和parad的attr设定的,并且由reader中的attr自动填充-1(可变)维度得到,因此通过与iterator的校验可以完成runtime的batch正确性检查
    """
X
xixiaoyao 已提交
134

X
xixiaoyao 已提交
135 136 137 138 139 140 141 142 143 144 145 146
    task_ids = range(len(iterators))
    weights = [mr / float(sum(mrs)) for mr in mrs]
    if not keep_one_task:
        dev_count = 1

    results = _zero_batch(joint_shape_and_dtypes)
    outbuf = {}
    for id in task_ids:
        outputs = next(iterators[id]) # dict type
        outbuf[id] = outputs
        prefix = iterator_prefixes[id]
        for outname, val in outputs.items():
X
xixiaoyao 已提交
147
            task_outname = prefix + '.' + outname
X
xixiaoyao 已提交
148 149 150

            if outname in outname_to_pos:
                idx = outname_to_pos[outname]
X
xixiaoyao 已提交
151
                val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=outname+': ')
X
xixiaoyao 已提交
152 153 154 155
                results[idx] = val

            if task_outname in outname_to_pos:
                idx = outname_to_pos[task_outname]
X
xixiaoyao 已提交
156
                val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ')
X
xixiaoyao 已提交
157 158 159 160 161 162 163
                results[idx] = val

    fake_batch = results
    dev_count_bak = dev_count

    def iterator():
        v = verbose
X
xixiaoyao 已提交
164
        has_show_warn = False
X
xixiaoyao 已提交
165 166 167 168 169 170 171 172 173
        while True:
            id = np.random.choice(task_ids, p=weights)
            results = fake_batch
            if v > 0:
                print('----- debug joint iterator -----')
                print('sampled task id: '+str(id))
            task_id_tensor = np.array([[id]]).astype("int64")
            
            for i in range(dev_count):
X
xixiaoyao 已提交
174 175 176 177
                
                results[outname_to_pos['__task_id']] = task_id_tensor
                assert outname_to_pos['__task_id'] == 0

X
xixiaoyao 已提交
178 179 180 181 182 183
                if id in outbuf:
                    outputs = outbuf[id]
                    del outbuf[id]
                else:
                    outputs = next(iterators[id]) # dict type

X
xixiaoyao 已提交
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
                if 'token_ids' in outputs:
                    val1 = len(outputs['token_ids'])
                    val = _check_and_adapt_shape_dtype([val1], [[1], 'int64'])
                    results[outname_to_pos['batch_size']] = val

                    val2 = len(outputs['token_ids'][0])
                    val = _check_and_adapt_shape_dtype([val2], [[1], 'int64'])
                    results[outname_to_pos['seqlen']] = val

                    val = _check_and_adapt_shape_dtype([val1*val2], [[1], 'int64'])
                    results[outname_to_pos['batchsize_x_seqlen']] = val
                else:
                    if not has_show_warn:
                        print('WARNING: token_ids not found in current batch, failed to yield batch_size, seqlen and batchsize_x_seqlen. (This message would be shown only once.)')
                        has_show_warn = True
X
xixiaoyao 已提交
199

X
xixiaoyao 已提交
200 201 202 203
                prefix = iterator_prefixes[id]
                for outname, val in outputs.items():
                    if v > 0:
                        print('reader generate: '+outname)
X
xixiaoyao 已提交
204
                    task_outname = prefix + '.' + outname
X
xixiaoyao 已提交
205 206 207 208 209

                    if outname in outname_to_pos:
                        idx = outname_to_pos[outname]
                        if v > 0:
                            print(outname + ' is insert in idx ' + str(idx))
X
xixiaoyao 已提交
210
                        val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=outname+': ')
X
xixiaoyao 已提交
211 212 213 214 215 216
                        results[idx] = val

                    if task_outname in outname_to_pos:
                        idx = outname_to_pos[task_outname]
                        if v > 0:
                            print(task_outname + ' is insert in idx ' + str(idx))
X
xixiaoyao 已提交
217
                        val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[idx], message=task_outname+': ')
X
xixiaoyao 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231
                        results[idx] = val

                if v > 0:
                    print('yielded batch len and shapes:')
                    print(len(results))
                    for i in results:
                        print(np.shape(i))
                    print('')
                    v -= 1
                yield results

    return iterator


X
xixiaoyao 已提交
232
def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batchsize=True, insert_seqlen=True, insert_batchsize_x_seqlen=True):
X
xixiaoyao 已提交
233 234 235 236 237 238 239
    """
    Args:
        task_attrs(list[dict]|dict): task input attributes, key=attr_name, val=[shape, dtype], support single task and nested tasks
    """
    if isinstance(task_attrs, dict):
        task_attrs = [task_attrs]

X
xixiaoyao 已提交
240 241 242
    ret = []
    names = []
    start = 0
X
xixiaoyao 已提交
243
    if insert_taskid:
X
xixiaoyao 已提交
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
        ret.append(([1,1], 'int64'))
        names.append('__task_id')
        start += 1
    
    if insert_batchsize:
        ret.append(([1], 'int64'))
        names.append('batch_size')
        start += 1

    if insert_seqlen:
        ret.append(([1], 'int64'))
        names.append('seqlen')
        start += 1

    if insert_batchsize_x_seqlen:
        ret.append(([1], 'int64'))
X
xixiaoyao 已提交
260
        names.append(u'batchsize_x_seqlen')
X
xixiaoyao 已提交
261
        start += 1
X
xixiaoyao 已提交
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
        
    names += sorted(backbone_attr.keys())
    ret.extend([backbone_attr[k] for k in names[start:]])
    name_to_position = {}
    # pos=0 is for task_id, thus we start from 1
    for pos, k in enumerate(names):
        name_to_position[k] = pos
    for task_attr in task_attrs:
        task_names = sorted(task_attr.keys())
        names.extend(task_names)
        ret.extend([task_attr[k] for k in task_names])
        for pos, k in enumerate(task_names, start=len(name_to_position)):
            name_to_position[k] = pos
    return names, ret, name_to_position