__init__.py 3.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import print_function

import copy
import logging

from .transformer import MappedDataset, BatchedDataset
from .post_map import build_post_map
from .parallel_map import ParallelMappedDataset
from .operators import BaseOperator, registered_ops

__all__ = ['build_mapper', 'map', 'batch', 'batch_map']

logger = logging.getLogger(__name__)


def build_mapper(ops, context=None):
    """
    Build a mapper for operators in 'ops'

    Args:
        ops (list of operator.BaseOperator or list of op dict):
            configs for oprators, eg:
            [{'name': 'DecodeImage', 'params': {'to_rgb': True}}, {xxx}]
        context (dict): a context object for mapper

    Returns:
        a mapper function which accept one argument 'sample' and
        return the processed result
    """
    new_ops = []
    for _dict in ops:
        new_dict = {}
        for i, j in _dict.items():
            new_dict[i.lower()] = j
        new_ops.append(new_dict)
    ops = new_ops
    op_funcs = []
    op_repr = []
    for op in ops:
        if type(op) is dict and 'op' in op:
            op_func = getattr(BaseOperator, op['op'])
            params = copy.deepcopy(op)
            del params['op']
            o = op_func(**params)
        elif not isinstance(op, BaseOperator):
            op_func = getattr(BaseOperator, op['name'])
            params = {} if 'params' not in op else op['params']
            o = op_func(**params)
        else:
            assert isinstance(op, BaseOperator), \
                "invalid operator when build ops"
            o = op
        op_funcs.append(o)
        op_repr.append('{{}}'.format(str(o)))
    op_repr = '[{}]'.format(','.join(op_repr))

    def _mapper(sample):
        ctx = {} if context is None else copy.deepcopy(context)
        for f in op_funcs:
            try:
                out = f(sample, ctx)
                sample = out
            except Exception as e:
                logger.warn("fail to map op [{}] with error: {}".format(f, e))
        return out

    _mapper.ops = op_repr
    return _mapper


def map(ds, mapper, worker_args=None):
    """
    Apply 'mapper' to 'ds'

    Args:
        ds (instance of Dataset): dataset to be mapped
        mapper (function): action to be executed for every data sample
        worker_args (dict): configs for concurrent mapper
    Returns:
        a mapped dataset
    """

    if worker_args is not None:
        return ParallelMappedDataset(ds, mapper, worker_args)
    else:
        return MappedDataset(ds, mapper)


def batch(ds, batchsize, drop_last=False):
    """
    Batch data samples to batches
    Args:
        batchsize (int): number of samples for a batch
        drop_last (bool): drop last few samples if not enough for a batch

    Returns:
        a batched dataset
    """

    return BatchedDataset(ds, batchsize, drop_last=drop_last)


def batch_map(ds, config):
    """
    Post process the batches.

    Args:
        ds (instance of Dataset): dataset to be mapped
        mapper (function): action to be executed for every batch
    Returns:
        a batched dataset which is processed
    """

    mapper = build_post_map(**config)
    return MappedDataset(ds, mapper)


for nm in registered_ops:
    op = getattr(BaseOperator, nm)
    locals()[nm] = op

__all__ += registered_ops