yaml_helpers.py 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import inspect

import yaml
19
from .schema import SharedConfig
20 21 22 23

__all__ = ['serializable', 'Callable']


24 25 26 27 28 29 30 31 32
def represent_dictionary_order(self, dict_data):
    return self.represent_mapping('tag:yaml.org,2002:map', dict_data.items())


def setup_orderdict():
    from collections import OrderedDict
    yaml.add_representer(OrderedDict, represent_dictionary_order)


33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
def _make_python_constructor(cls):
    def python_constructor(loader, node):
        if isinstance(node, yaml.SequenceNode):
            args = loader.construct_sequence(node, deep=True)
            return cls(*args)
        else:
            kwargs = loader.construct_mapping(node, deep=True)
            try:
                return cls(**kwargs)
            except Exception as ex:
                print("Error when construct {} instance from yaml config".
                      format(cls.__name__))
                raise ex

    return python_constructor


def _make_python_representer(cls):
    # python 2 compatibility
    if hasattr(inspect, 'getfullargspec'):
        argspec = inspect.getfullargspec(cls)
    else:
55
        argspec = inspect.getfullargspec(cls.__init__)
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    argnames = [arg for arg in argspec.args if arg != 'self']

    def python_representer(dumper, obj):
        if argnames:
            data = {name: getattr(obj, name) for name in argnames}
        else:
            data = obj.__dict__
        if '_id' in data:
            del data['_id']
        return dumper.represent_mapping(u'!{}'.format(cls.__name__), data)

    return python_representer


def serializable(cls):
    """
72 73
    Add loader and dumper for given class, which must be
    "trivially serializable"
74 75 76 77 78 79 80 81 82 83 84 85

    Args:
        cls: class to be serialized

    Returns: cls
    """
    yaml.add_constructor(u'!{}'.format(cls.__name__),
                         _make_python_constructor(cls))
    yaml.add_representer(cls, _make_python_representer(cls))
    return cls


86 87 88 89
yaml.add_representer(SharedConfig,
                     lambda d, o: d.represent_data(o.default_value))


90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
@serializable
class Callable(object):
    """
    Helper to be used in Yaml for creating arbitrary class objects

    Args:
        full_type (str): the full module path to target function
    """

    def __init__(self, full_type, args=[], kwargs={}):
        super(Callable, self).__init__()
        self.full_type = full_type
        self.args = args
        self.kwargs = kwargs

    def __call__(self):
        if '.' in self.full_type:
            idx = self.full_type.rfind('.')
            module = importlib.import_module(self.full_type[:idx])
            func_name = self.full_type[idx + 1:]
        else:
            try:
                module = importlib.import_module('builtins')
            except Exception:
                module = importlib.import_module('__builtin__')
            func_name = self.full_type

        func = getattr(module, func_name)
        return func(*self.args, **self.kwargs)