__init__.py 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import print_function

import copy
import logging
20
import traceback
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69

from .transformer import MappedDataset, BatchedDataset
from .post_map import build_post_map
from .parallel_map import ParallelMappedDataset
from .operators import BaseOperator, registered_ops

__all__ = ['build_mapper', 'map', 'batch', 'batch_map']

logger = logging.getLogger(__name__)


def build_mapper(ops, context=None):
    """
    Build a mapper for operators in 'ops'

    Args:
        ops (list of operator.BaseOperator or list of op dict):
            configs for oprators, eg:
            [{'name': 'DecodeImage', 'params': {'to_rgb': True}}, {xxx}]
        context (dict): a context object for mapper

    Returns:
        a mapper function which accept one argument 'sample' and
        return the processed result
    """
    new_ops = []
    for _dict in ops:
        new_dict = {}
        for i, j in _dict.items():
            new_dict[i.lower()] = j
        new_ops.append(new_dict)
    ops = new_ops
    op_funcs = []
    op_repr = []
    for op in ops:
        if type(op) is dict and 'op' in op:
            op_func = getattr(BaseOperator, op['op'])
            params = copy.deepcopy(op)
            del params['op']
            o = op_func(**params)
        elif not isinstance(op, BaseOperator):
            op_func = getattr(BaseOperator, op['name'])
            params = {} if 'params' not in op else op['params']
            o = op_func(**params)
        else:
            assert isinstance(op, BaseOperator), \
                "invalid operator when build ops"
            o = op
        op_funcs.append(o)
70
        op_repr.append('{{{}}}'.format(str(o)))
71 72 73 74 75 76 77 78 79
    op_repr = '[{}]'.format(','.join(op_repr))

    def _mapper(sample):
        ctx = {} if context is None else copy.deepcopy(context)
        for f in op_funcs:
            try:
                out = f(sample, ctx)
                sample = out
            except Exception as e:
80
                stack_info = traceback.format_exc()
81 82
                logger.warn("fail to map op [{}] with error: {} and stack:\n{}".
                            format(f, e, str(stack_info)))
83 84
                raise e

85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
        return out

    _mapper.ops = op_repr
    return _mapper


def map(ds, mapper, worker_args=None):
    """
    Apply 'mapper' to 'ds'

    Args:
        ds (instance of Dataset): dataset to be mapped
        mapper (function): action to be executed for every data sample
        worker_args (dict): configs for concurrent mapper
    Returns:
        a mapped dataset
    """

    if worker_args is not None:
        return ParallelMappedDataset(ds, mapper, worker_args)
    else:
        return MappedDataset(ds, mapper)


109
def batch(ds, batchsize, drop_last=False, drop_empty=True):
110 111 112 113 114 115 116 117 118 119
    """
    Batch data samples to batches
    Args:
        batchsize (int): number of samples for a batch
        drop_last (bool): drop last few samples if not enough for a batch

    Returns:
        a batched dataset
    """

120 121
    return BatchedDataset(
        ds, batchsize, drop_last=drop_last, drop_empty=drop_empty)
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143


def batch_map(ds, config):
    """
    Post process the batches.

    Args:
        ds (instance of Dataset): dataset to be mapped
        mapper (function): action to be executed for every batch
    Returns:
        a batched dataset which is processed
    """

    mapper = build_post_map(**config)
    return MappedDataset(ds, mapper)


for nm in registered_ops:
    op = getattr(BaseOperator, nm)
    locals()[nm] = op

__all__ += registered_ops