compat.py 6.4 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import six
M
minqiyang 已提交
16
import math
M
minqiyang 已提交
17

18
__all__ = []
M
minqiyang 已提交
19

M
minqiyang 已提交
20

M
minqiyang 已提交
21
#  str and bytes related functions
M
minqiyang 已提交
22
def to_text(obj, encoding='utf-8', inplace=False):
23
    """
C
Chen Long 已提交
24
    All string in PaddlePaddle should be represented as a literal string.
25

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
    This function will convert object to a literal string without any encoding.
    Especially, if the object type is a list or set container, we will iterate
    all items in the object and convert them to literal string.

    In Python3:
        Decode the bytes type object to str type with specific encoding

    In Python2:
        Decode the str type object to unicode type with specific encoding

    Args:
        obj(unicode|str|bytes|list|set) : The object to be decoded.
        encoding(str) : The encoding format to decode a string
        inplace(bool) : If we change the original object or we create a new one

    Returns:
        Decoded result of obj
43

C
Chen Long 已提交
44 45 46 47 48 49 50 51 52 53
    Examples:

        .. code-block:: python

            import paddle

            data = "paddlepaddle"
            data = paddle.compat.to_text(data)
            # paddlepaddle

54 55 56 57
    """
    if obj is None:
        return obj

M
minqiyang 已提交
58
    if isinstance(obj, list):
59 60
        if inplace:
            for i in six.moves.xrange(len(obj)):
M
minqiyang 已提交
61
                obj[i] = _to_text(obj[i], encoding)
62 63
            return obj
        else:
M
minqiyang 已提交
64
            return [_to_text(item, encoding) for item in obj]
M
minqiyang 已提交
65
    elif isinstance(obj, set):
66 67 68
        if inplace:
            for item in obj:
                obj.remove(item)
M
minqiyang 已提交
69
                obj.add(_to_text(item, encoding))
70 71
            return obj
        else:
M
minqiyang 已提交
72
            return set([_to_text(item, encoding) for item in obj])
73 74 75 76 77 78 79 80 81 82 83 84
    elif isinstance(obj, dict):
        if inplace:
            new_obj = {}
            for key, value in six.iteritems(obj):
                new_obj[_to_text(key, encoding)] = _to_text(value, encoding)
            obj.update(new_obj)
            return obj
        else:
            new_obj = {}
            for key, value in six.iteritems(obj):
                new_obj[_to_text(key, encoding)] = _to_text(value, encoding)
            return new_obj
M
minqiyang 已提交
85
    else:
M
minqiyang 已提交
86
        return _to_text(obj, encoding)
M
minqiyang 已提交
87 88


M
minqiyang 已提交
89
def _to_text(obj, encoding):
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    """
    In Python3:
        Decode the bytes type object to str type with specific encoding

    In Python2:
        Decode the str type object to unicode type with specific encoding,
        or we just return the unicode string of object

    Args:
        obj(unicode|str|bytes) : The object to be decoded.
        encoding(str) : The encoding format

    Returns:
        decoded result of obj
    """
    if obj is None:
        return obj

M
minqiyang 已提交
108
    if isinstance(obj, six.binary_type):
M
minqiyang 已提交
109
        return obj.decode(encoding)
M
minqiyang 已提交
110 111
    elif isinstance(obj, six.text_type):
        return obj
112 113
    elif isinstance(obj, (bool, float)):
        return obj
M
minqiyang 已提交
114 115 116 117
    else:
        return six.u(obj)


118 119
def to_bytes(obj, encoding='utf-8', inplace=False):
    """
C
Chen Long 已提交
120
    All string in PaddlePaddle should be represented as a literal string.
121

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
    This function will convert object to a bytes with specific encoding.
    Especially, if the object type is a list or set container, we will iterate
    all items in the object and convert them to bytes.

    In Python3:
        Encode the str type object to bytes type with specific encoding

    In Python2:
        Encode the unicode type object to str type with specific encoding,
        or we just return the 8-bit string of object

    Args:
        obj(unicode|str|bytes|list|set) : The object to be encoded.
        encoding(str) : The encoding format to encode a string
        inplace(bool) : If we change the original object or we create a new one

    Returns:
        Decoded result of obj
140

C
Chen Long 已提交
141 142 143 144 145 146 147 148 149 150
    Examples:

        .. code-block:: python

            import paddle

            data = "paddlepaddle"
            data = paddle.compat.to_bytes(data)
            # b'paddlepaddle'

151 152 153 154
    """
    if obj is None:
        return obj

M
minqiyang 已提交
155
    if isinstance(obj, list):
156 157 158 159 160 161
        if inplace:
            for i in six.moves.xrange(len(obj)):
                obj[i] = _to_bytes(obj[i], encoding)
            return obj
        else:
            return [_to_bytes(item, encoding) for item in obj]
M
minqiyang 已提交
162
    elif isinstance(obj, set):
163 164 165 166 167 168 169
        if inplace:
            for item in obj:
                obj.remove(item)
                obj.add(_to_bytes(item, encoding))
            return obj
        else:
            return set([_to_bytes(item, encoding) for item in obj])
M
minqiyang 已提交
170
    else:
M
minqiyang 已提交
171
        return _to_bytes(obj, encoding)
M
minqiyang 已提交
172 173


M
minqiyang 已提交
174
def _to_bytes(obj, encoding):
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
    """
    In Python3:
        Encode the str type object to bytes type with specific encoding

    In Python2:
        Encode the unicode type object to str type with specific encoding,
        or we just return the 8-bit string of object

    Args:
        obj(unicode|str|bytes) : The object to be encoded.
        encoding(str) : The encoding format

    Returns:
        encoded result of obj
    """
    if obj is None:
        return obj

    assert encoding is not None
M
minqiyang 已提交
194
    if isinstance(obj, six.text_type):
M
minqiyang 已提交
195
        return obj.encode(encoding)
M
minqiyang 已提交
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
    elif isinstance(obj, six.binary_type):
        return obj
    else:
        return six.b(obj)


# math related functions
def round(x, d=0):
    """
    Compatible round which act the same behaviour in Python3.

    Args:
        x(float) : The number to round halfway.

    Returns:
        round result of x
    """
213 214
    if six.PY3:
        # The official walkaround of round in Python3 is incorrect
T
tianshuo78520a 已提交
215
        # we implement according this answer: https://www.techforgeek.info/round_python.html
216
        if x > 0.0:
M
minqiyang 已提交
217
            p = 10**d
218
            return float(math.floor((x * p) + math.copysign(0.5, x))) / p
M
minqiyang 已提交
219
        elif x < 0.0:
M
minqiyang 已提交
220
            p = 10**d
221
            return float(math.ceil((x * p) + math.copysign(0.5, x))) / p
M
minqiyang 已提交
222 223
        else:
            return math.copysign(0.0, x)
224 225 226
    else:
        import __builtin__
        return __builtin__.round(x, d)