compat.py 6.4 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import six
M
minqiyang 已提交
16
import math
M
minqiyang 已提交
17

18 19 20 21 22 23 24
__all__ = [
    'to_literal_str',
    'to_bytes',
    'round',
    'floor_division',
    'get_exception_message',
]
M
minqiyang 已提交
25 26

#  str and bytes related functions
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
def to_literal_str(obj, encoding='utf-8', inplace=False):
    """
      All string in PaddlePaddle should be represented as a literal string.
    This function will convert object to a literal string without any encoding.
    Especially, if the object type is a list or set container, we will iterate
    all items in the object and convert them to literal string.

    In Python3:
        Decode the bytes type object to str type with specific encoding

    In Python2:
        Decode the str type object to unicode type with specific encoding

    Args:
        obj(unicode|str|bytes|list|set) : The object to be decoded.
        encoding(str) : The encoding format to decode a string
        inplace(bool) : If we change the original object or we create a new one

    Returns:
        Decoded result of obj
    """
    if obj is None:
        return obj

M
minqiyang 已提交
51
    if isinstance(obj, list):
52 53 54 55 56 57
        if inplace:
            for i in six.moves.xrange(len(obj)):
                obj[i] = _to_literal_str(obj[i], encoding)
            return obj
        else:
            return [_to_literal_str(item, encoding) for item in obj]
M
minqiyang 已提交
58
    elif isinstance(obj, set):
59 60 61 62 63 64 65
        if inplace:
            for item in obj:
                obj.remove(item)
                obj.add(_to_literal_str(item, encoding))
            return obj
        else:
            return set([_to_literal_str(item, encoding) for item in obj])
M
minqiyang 已提交
66
    else:
M
minqiyang 已提交
67
        return _to_literal_str(obj, encoding)
M
minqiyang 已提交
68 69


M
minqiyang 已提交
70
def _to_literal_str(obj, encoding):
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    """
    In Python3:
        Decode the bytes type object to str type with specific encoding

    In Python2:
        Decode the str type object to unicode type with specific encoding,
        or we just return the unicode string of object

    Args:
        obj(unicode|str|bytes) : The object to be decoded.
        encoding(str) : The encoding format

    Returns:
        decoded result of obj
    """
    if obj is None:
        return obj

M
minqiyang 已提交
89
    if isinstance(obj, six.binary_type):
M
minqiyang 已提交
90
        return obj.decode(encoding)
M
minqiyang 已提交
91 92 93 94 95 96
    elif isinstance(obj, six.text_type):
        return obj
    else:
        return six.u(obj)


97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
def to_bytes(obj, encoding='utf-8', inplace=False):
    """
      All string in PaddlePaddle should be represented as a literal string.
    This function will convert object to a bytes with specific encoding.
    Especially, if the object type is a list or set container, we will iterate
    all items in the object and convert them to bytes.

    In Python3:
        Encode the str type object to bytes type with specific encoding

    In Python2:
        Encode the unicode type object to str type with specific encoding,
        or we just return the 8-bit string of object

    Args:
        obj(unicode|str|bytes|list|set) : The object to be encoded.
        encoding(str) : The encoding format to encode a string
        inplace(bool) : If we change the original object or we create a new one

    Returns:
        Decoded result of obj
    """
    if obj is None:
        return obj

M
minqiyang 已提交
122
    if isinstance(obj, list):
123 124 125 126 127 128
        if inplace:
            for i in six.moves.xrange(len(obj)):
                obj[i] = _to_bytes(obj[i], encoding)
            return obj
        else:
            return [_to_bytes(item, encoding) for item in obj]
M
minqiyang 已提交
129
    elif isinstance(obj, set):
130 131 132 133 134 135 136
        if inplace:
            for item in obj:
                obj.remove(item)
                obj.add(_to_bytes(item, encoding))
            return obj
        else:
            return set([_to_bytes(item, encoding) for item in obj])
M
minqiyang 已提交
137
    else:
M
minqiyang 已提交
138
        return _to_bytes(obj, encoding)
M
minqiyang 已提交
139 140


M
minqiyang 已提交
141
def _to_bytes(obj, encoding):
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
    """
    In Python3:
        Encode the str type object to bytes type with specific encoding

    In Python2:
        Encode the unicode type object to str type with specific encoding,
        or we just return the 8-bit string of object

    Args:
        obj(unicode|str|bytes) : The object to be encoded.
        encoding(str) : The encoding format

    Returns:
        encoded result of obj
    """
    if obj is None:
        return obj

    assert encoding is not None
M
minqiyang 已提交
161
    if isinstance(obj, six.text_type):
M
minqiyang 已提交
162
        return obj.encode(encoding)
M
minqiyang 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    elif isinstance(obj, six.binary_type):
        return obj
    else:
        return six.b(obj)


# math related functions
def round(x, d=0):
    """
    Compatible round which act the same behaviour in Python3.

    Args:
        x(float) : The number to round halfway.

    Returns:
        round result of x
    """
180 181 182 183 184 185
    if six.PY3:
        # The official walkaround of round in Python3 is incorrect
        # we implement accroding this answer: https://www.techforgeek.info/round_python.html
        if x > 0.0:
            p = 10 ** d
            return float(math.floor((x * p) + math.copysign(0.5, x))) / p
M
minqiyang 已提交
186
        elif x < 0.0:
187 188
            p = 10 ** d
            return float(math.ceil((x * p) + math.copysign(0.5, x))) / p
M
minqiyang 已提交
189 190
        else:
            return math.copysign(0.0, x)
191 192 193
    else:
        import __builtin__
        return __builtin__.round(x, d)
M
minqiyang 已提交
194 195 196


def floor_division(x, y):
197 198 199 200 201 202 203 204 205 206 207 208
    """
    Compatible division which act the same behaviour in Python3 and Python2,
    whose result will be a int value of floor(x / y) in Python3 and value of
    (x / y) in Python2.

    Args:
        x(int|float) : The number to divide.
        y(int|float) : The number to be divided

    Returns:
        division result of x // y
    """
M
minqiyang 已提交
209
    return x // y
M
minqiyang 已提交
210 211 212

# exception related functions
def get_exception_message(exc):
213 214 215 216 217 218 219 220 221 222 223
    """
    Get the error message of a specific exception

    Args:
        exec(Exception) : The exception to get error message.

    Returns:
        the error message of exec
    """
    assert exc is not None

M
minqiyang 已提交
224 225 226 227 228
    if six.PY2:
        return exc.message
    else:
        return str(exc)