flat.py 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import numbers
import numpy as np

try:
    from collections.abc import Sequence, Mapping
except:
    from collections import Sequence, Mapping

FIELD_PREFIX = "_paddle_field_"


def _flatten_batch(batch):
    """
    For lod_blocking_queue only receive tensor array, flatten batch
    data, extract numpy.array data out as a list of numpy.array to
    send to lod_blocking_queue, and save the batch data structure
    such as fields in other types (str, int, etc) or key-value map
    of dictionaries
    """

    def _flatten(batch, flat_batch, structure, field_idx):
        if isinstance(batch, Sequence):
            for field in batch:
K
Kaipeng Deng 已提交
39
                if isinstance(field, (np.ndarray, paddle.Tensor)):
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
                    structure.append('{}{}'.format(FIELD_PREFIX, field_idx))
                    flat_batch.append(field)
                    field_idx += 1
                elif isinstance(field, (str, bytes, numbers.Number)):
                    structure.append(field)
                elif isinstance(field, Sequence):
                    field_struct, field_idx = _flatten(field, flat_batch, [],
                                                       field_idx)
                    structure.append(field_struct)
                elif isinstance(field, Mapping):
                    field_struct, field_idx = _flatten(field, flat_batch, {},
                                                       field_idx)
                    structure.append(field_struct)
                else:
                    structure.append(field)
        elif isinstance(batch, Mapping):
            for k, field in batch.items():
K
Kaipeng Deng 已提交
57
                if isinstance(field, (np.ndarray, paddle.Tensor)):
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
                    structure[k] = '{}{}'.format(FIELD_PREFIX, field_idx)
                    flat_batch.append(field)
                    field_idx += 1
                elif isinstance(field, (str, bytes, numbers.Number)):
                    structure[k] = field
                elif isinstance(field, Sequence):
                    field_struct, field_idx = _flatten(field, flat_batch, [],
                                                       field_idx)
                    structure[k] = field_struct
                elif isinstance(field, Mapping):
                    field_struct, field_idx = _flatten(field, flat_batch, {},
                                                       field_idx)
                    structure[k] = field_struct
                else:
                    structure[k] = field
        else:
            raise TypeError("wrong flat data type: {}".format(type(batch)))

        return structure, field_idx

    # sample only contains single fields
    if not isinstance(batch, Sequence):
        flat_batch = []
        structure, _ = _flatten([batch], flat_batch, [], 0)
        return flat_batch, structure[0]
    flat_batch = []
    structure, _ = _flatten(batch, flat_batch, [], 0)
    return flat_batch, structure


def _restore_batch(flat_batch, structure):
    """
    After reading list of Tensor data from lod_blocking_queue outputs,
    use this function to restore the batch data structrue, replace
    :attr:`_paddle_field_x` with data from flat_batch
    """

    def _restore(structure, field_idx):
        if isinstance(structure, Sequence):
            for i, field in enumerate(structure):
                if isinstance(field, str) and field.startswith(FIELD_PREFIX):
                    cur_field_idx = int(field.replace(FIELD_PREFIX, ''))
                    field_idx = max(field_idx, cur_field_idx)
                    assert flat_batch[cur_field_idx] is not None, \
                                "flat_batch[{}] parsed repeatly"
                    structure[i] = flat_batch[cur_field_idx]
                    flat_batch[cur_field_idx] = None
                elif isinstance(field, (str, bytes, numbers.Number)):
                    continue
                elif isinstance(field, (Sequence, Mapping)):
                    field_idx = _restore(structure[i], field_idx)
        elif isinstance(structure, Mapping):
            for k, field in structure.items():
                if isinstance(field, str) and field.startswith(FIELD_PREFIX):
                    cur_field_idx = int(field.replace(FIELD_PREFIX, ''))
                    field_idx = max(field_idx, cur_field_idx)
                    assert flat_batch[cur_field_idx] is not None, \
                                "flat_batch[{}] parsed repeatly"
                    structure[k] = flat_batch[cur_field_idx]
                    flat_batch[cur_field_idx] = None
                elif isinstance(field, (str, bytes, numbers.Number)):
                    continue
                elif isinstance(field, (Sequence, Mapping)):
                    field_idx = _restore(structure[k], field_idx)
        else:
J
Jiangxinz 已提交
123
            raise TypeError("wrong flat data type: {}".format(type(structure)))
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142

        return field_idx

    assert isinstance(flat_batch, Sequence), \
            "flat_batch is not a list or tuple"

    # no np.array in dataset, no output tensor from blocking queue
    # simply return structure
    if len(flat_batch) == 0:
        return structure

    # sample only contains single fields
    if isinstance(structure, (str, bytes)):
        assert structure == '{}{}'.format(FIELD_PREFIX, 0), \
                "invalid structure: {}".format(structure)
        return flat_batch[0]
    field_idx = _restore(structure, 0)
    assert field_idx + 1 == len(flat_batch), "Tensor parse incomplete"
    return structure