提交 09103084 编写于 作者: M minqiyang

Polish compat.py and add unittest for it

上级 c3fdf3ae
...@@ -202,7 +202,6 @@ std::vector<std::string> OpDesc::AttrNames() const { ...@@ -202,7 +202,6 @@ std::vector<std::string> OpDesc::AttrNames() const {
} }
void OpDesc::SetAttr(const std::string &name, const Attribute &v) { void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
VLOG(11) << "SetAttr: " << Type() << ", " << name << ", " << v.which();
// NOTICE(minqiyang): pybind11 will take the empty list in python as // NOTICE(minqiyang): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type // the std::vector<int> type in C++; so we have to change the attr's type
// here if we meet this issue // here if we meet this issue
......
...@@ -15,18 +15,77 @@ ...@@ -15,18 +15,77 @@
import six import six
import math import math
__all__ = [
'to_literal_str',
'to_bytes',
'round',
'floor_division',
'get_exception_message',
]
# str and bytes related functions # str and bytes related functions
def to_literal_str(obj, encoding='utf-8'): def to_literal_str(obj, encoding='utf-8', inplace=False):
"""
All string in PaddlePaddle should be represented as a literal string.
This function will convert object to a literal string without any encoding.
Especially, if the object type is a list or set container, we will iterate
all items in the object and convert them to literal string.
In Python3:
Decode the bytes type object to str type with specific encoding
In Python2:
Decode the str type object to unicode type with specific encoding
Args:
obj(unicode|str|bytes|list|set) : The object to be decoded.
encoding(str) : The encoding format to decode a string
inplace(bool) : If we change the original object or we create a new one
Returns:
Decoded result of obj
"""
if obj is None:
return obj
if isinstance(obj, list): if isinstance(obj, list):
if inplace:
for i in six.moves.xrange(len(obj)):
obj[i] = _to_literal_str(obj[i], encoding)
return obj
else:
return [_to_literal_str(item, encoding) for item in obj] return [_to_literal_str(item, encoding) for item in obj]
elif isinstance(obj, set): elif isinstance(obj, set):
if inplace:
for item in obj:
obj.remove(item)
obj.add(_to_literal_str(item, encoding))
return obj
else:
return set([_to_literal_str(item, encoding) for item in obj]) return set([_to_literal_str(item, encoding) for item in obj])
else: else:
return _to_literal_str(obj, encoding) return _to_literal_str(obj, encoding)
def _to_literal_str(obj, encoding): def _to_literal_str(obj, encoding):
"""
In Python3:
Decode the bytes type object to str type with specific encoding
In Python2:
Decode the str type object to unicode type with specific encoding,
or we just return the unicode string of object
Args:
obj(unicode|str|bytes) : The object to be decoded.
encoding(str) : The encoding format
Returns:
decoded result of obj
"""
if obj is None:
return obj
if isinstance(obj, six.binary_type): if isinstance(obj, six.binary_type):
return obj.decode(encoding) return obj.decode(encoding)
elif isinstance(obj, six.text_type): elif isinstance(obj, six.text_type):
...@@ -35,16 +94,70 @@ def _to_literal_str(obj, encoding): ...@@ -35,16 +94,70 @@ def _to_literal_str(obj, encoding):
return six.u(obj) return six.u(obj)
def to_bytes(obj, encoding='utf-8'): def to_bytes(obj, encoding='utf-8', inplace=False):
"""
All string in PaddlePaddle should be represented as a literal string.
This function will convert object to a bytes with specific encoding.
Especially, if the object type is a list or set container, we will iterate
all items in the object and convert them to bytes.
In Python3:
Encode the str type object to bytes type with specific encoding
In Python2:
Encode the unicode type object to str type with specific encoding,
or we just return the 8-bit string of object
Args:
obj(unicode|str|bytes|list|set) : The object to be encoded.
encoding(str) : The encoding format to encode a string
inplace(bool) : If we change the original object or we create a new one
Returns:
Decoded result of obj
"""
if obj is None:
return obj
if isinstance(obj, list): if isinstance(obj, list):
if inplace:
for i in six.moves.xrange(len(obj)):
obj[i] = _to_bytes(obj[i], encoding)
return obj
else:
return [_to_bytes(item, encoding) for item in obj] return [_to_bytes(item, encoding) for item in obj]
elif isinstance(obj, set): elif isinstance(obj, set):
if inplace:
for item in obj:
obj.remove(item)
obj.add(_to_bytes(item, encoding))
return obj
else:
return set([_to_bytes(item, encoding) for item in obj]) return set([_to_bytes(item, encoding) for item in obj])
else: else:
return _to_bytes(obj, encoding) return _to_bytes(obj, encoding)
def _to_bytes(obj, encoding): def _to_bytes(obj, encoding):
"""
In Python3:
Encode the str type object to bytes type with specific encoding
In Python2:
Encode the unicode type object to str type with specific encoding,
or we just return the 8-bit string of object
Args:
obj(unicode|str|bytes) : The object to be encoded.
encoding(str) : The encoding format
Returns:
encoded result of obj
"""
if obj is None:
return obj
assert encoding is not None
if isinstance(obj, six.text_type): if isinstance(obj, six.text_type):
return obj.encode(encoding) return obj.encode(encoding)
elif isinstance(obj, six.binary_type): elif isinstance(obj, six.binary_type):
...@@ -64,15 +177,48 @@ def round(x, d=0): ...@@ -64,15 +177,48 @@ def round(x, d=0):
Returns: Returns:
round result of x round result of x
""" """
p = 10**d if six.PY3:
# The official walkaround of round in Python3 is incorrect
# we implement accroding this answer: https://www.techforgeek.info/round_python.html
if x > 0.0:
p = 10 ** d
return float(math.floor((x * p) + math.copysign(0.5, x))) / p return float(math.floor((x * p) + math.copysign(0.5, x))) / p
else:
p = 10 ** d
return float(math.ceil((x * p) + math.copysign(0.5, x))) / p
else:
import __builtin__
return __builtin__.round(x, d)
def floor_division(x, y): def floor_division(x, y):
"""
Compatible division which act the same behaviour in Python3 and Python2,
whose result will be a int value of floor(x / y) in Python3 and value of
(x / y) in Python2.
Args:
x(int|float) : The number to divide.
y(int|float) : The number to be divided
Returns:
division result of x // y
"""
return x // y return x // y
# exception related functions # exception related functions
def get_exception_message(exc): def get_exception_message(exc):
"""
Get the error message of a specific exception
Args:
exec(Exception) : The exception to get error message.
Returns:
the error message of exec
"""
assert exc is not None
if six.PY2: if six.PY2:
return exc.message return exc.message
else: else:
......
...@@ -20,6 +20,7 @@ from .layer_function_generator import autodoc, templatedoc ...@@ -20,6 +20,7 @@ from .layer_function_generator import autodoc, templatedoc
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from . import tensor from . import tensor
from . import nn from . import nn
from .. import compat as cpt
import math import math
import six import six
from functools import reduce from functools import reduce
...@@ -1104,7 +1105,8 @@ def multi_box_head(inputs, ...@@ -1104,7 +1105,8 @@ def multi_box_head(inputs,
mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1]) mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1])
new_shape = [ new_shape = [
mbox_loc.shape[0], mbox_loc.shape[0],
mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3] / 4, 4 mbox_loc.shape[1] * mbox_loc.shape[2] * cpt.floor_division(mbox_loc.shape[3], 4),
4
] ]
mbox_loc_flatten = nn.reshape(mbox_loc, shape=new_shape) mbox_loc_flatten = nn.reshape(mbox_loc, shape=new_shape)
mbox_locs.append(mbox_loc_flatten) mbox_locs.append(mbox_loc_flatten)
...@@ -1119,8 +1121,9 @@ def multi_box_head(inputs, ...@@ -1119,8 +1121,9 @@ def multi_box_head(inputs,
stride=stride) stride=stride)
conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1]) conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1])
new_shape = [ new_shape = [
conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] * conf_loc.shape[0],
conf_loc.shape[3] / num_classes, num_classes conf_loc.shape[1] * conf_loc.shape[2] * cpt.floor_division(conf_loc.shape[3], num_classes),
num_classes
] ]
conf_loc_flatten = nn.reshape(conf_loc, shape=new_shape) conf_loc_flatten = nn.reshape(conf_loc, shape=new_shape)
mbox_confs.append(conf_loc_flatten) mbox_confs.append(conf_loc_flatten)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid.compat as cpt
import six
class TestCompatible(unittest.TestCase):
def test_to_literal_str(self):
# Only support python2.x and python3.x now
self.assertTrue(six.PY2 | six.PY3)
if six.PY2:
# check None
self.assertIsNone(cpt.to_literal_str(None))
# check all string related types
self.assertTrue(isinstance(cpt.to_literal_str(str("")), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(str("123")), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(b""), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(b""), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(u""), unicode))
self.assertTrue(isinstance(cpt.to_literal_str(u""), unicode))
self.assertEqual(u"", cpt.to_literal_str(str("")))
self.assertEqual(u"123", cpt.to_literal_str(str("123")))
self.assertEqual(u"", cpt.to_literal_str(b""))
self.assertEqual(u"123", cpt.to_literal_str(b"123"))
self.assertEqual(u"", cpt.to_literal_str(u""))
self.assertEqual(u"123", cpt.to_literal_str(u"123"))
# check list types, not inplace
l = [""]
l2 = cpt.to_literal_str(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u""], l2)
l = ["", "123"]
l2 = cpt.to_literal_str(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u"", u"123"], l2)
l = ["", b'123', u"321"]
l2 = cpt.to_literal_str(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u"", u"123", u"321"], l2)
for i in l2:
self.assertTrue(isinstance(i, unicode))
# check list types, inplace
l = [""]
l2 = cpt.to_literal_str(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u""], l2)
l = ["", "123"]
l2 = cpt.to_literal_str(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u"", u"123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_literal_str(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([u"", u"123", u"321"], l2)
# check set types, not inplace
l = set("")
l2 = cpt.to_literal_str(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(u""), l2)
l = set([b"", b"123"])
l2 = cpt.to_literal_str(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([u"", u"123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_literal_str(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([u"", u"123", u"321"]), l2)
for i in l2:
self.assertTrue(isinstance(i, unicode))
# check set types, inplace
l = set("")
l2 = cpt.to_literal_str(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(u""), l2)
l = set([b"", b"123"])
l2 = cpt.to_literal_str(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([u"", u"123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_literal_str(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([u"", u"123", u"321"]), l2)
elif six.PY3:
self.assertIsNone(cpt.to_literal_str(None))
self.assertTrue(isinstance(cpt.to_literal_str(str("")), str))
self.assertTrue(isinstance(cpt.to_literal_str(str("123")), str))
self.assertTrue(isinstance(cpt.to_literal_str(b""), str))
self.assertTrue(isinstance(cpt.to_literal_str(b""), str))
self.assertTrue(isinstance(cpt.to_literal_str(u""), str))
self.assertTrue(isinstance(cpt.to_literal_str(u""), str))
self.assertEqual("", cpt.to_literal_str(str("")))
self.assertEqual("123", cpt.to_literal_str(str("123")))
self.assertEqual("", cpt.to_literal_str(b""))
self.assertEqual("123", cpt.to_literal_str(b"123"))
self.assertEqual("", cpt.to_literal_str(u""))
self.assertEqual("123", cpt.to_literal_str(u"123"))
# check list types, not inplace
l = [""]
l2 = cpt.to_literal_str(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([""], l2)
l = ["", "123"]
l2 = cpt.to_literal_str(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(["", "123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_literal_str(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual(["", "123", "321"], l2)
# check list types, inplace
l = [""]
l2 = cpt.to_literal_str(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([""], l2)
l = ["", b"123"]
l2 = cpt.to_literal_str(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(["", "123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_literal_str(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(["", "123", "321"], l2)
for i in l2:
self.assertTrue(isinstance(i, str))
# check set types, not inplace
l = set("")
l2 = cpt.to_literal_str(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(""), l2)
l = set([b"", b"123"])
l2 = cpt.to_literal_str(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual(set(["", "123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_literal_str(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual(set(["", "123", "321"]), l2)
# check set types, inplace
l = set("")
l2 = cpt.to_literal_str(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(""), l2)
l = set([b"", b"123"])
l2 = cpt.to_literal_str(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(["", "123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_literal_str(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(["", "123", "321"]), l2)
for i in l2:
self.assertTrue(isinstance(i, str))
def test_to_bytes(self):
# Only support python2.x and python3.x now
self.assertTrue(six.PY2 | six.PY3)
if six.PY2:
# check None
self.assertIsNone(cpt.to_bytes(None))
# check all string related types
self.assertTrue(isinstance(cpt.to_bytes(str("")), bytes))
self.assertTrue(isinstance(cpt.to_bytes(str("123")), bytes))
self.assertTrue(isinstance(cpt.to_bytes(b""), bytes))
self.assertTrue(isinstance(cpt.to_bytes(b""), bytes))
self.assertTrue(isinstance(cpt.to_bytes(u""), bytes))
self.assertTrue(isinstance(cpt.to_bytes(u""), bytes))
self.assertEqual(b"", cpt.to_bytes(str("")))
self.assertEqual(b"123", cpt.to_bytes(str("123")))
self.assertEqual(b"", cpt.to_bytes(b""))
self.assertEqual(b"123", cpt.to_bytes(b"123"))
self.assertEqual(b"", cpt.to_bytes(u""))
self.assertEqual(b"123", cpt.to_bytes(u"123"))
# check list types, not inplace
l = [""]
l2 = cpt.to_bytes(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b""], l2)
l = ["", "123"]
l2 = cpt.to_bytes(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b"", b"123"], l2)
l = ["", b'123', u"321"]
l2 = cpt.to_bytes(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b"", b"123", b"321"], l2)
for i in l2:
self.assertTrue(isinstance(i, bytes))
# check list types, inplace
l = [""]
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b""], l2)
l = ["", "123"]
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b"", b"123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b"", b"123", b"321"], l2)
# check set types, not inplace
l = set("")
l2 = cpt.to_bytes(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(b""), l2)
l = set([b"", b"123"])
l2 = cpt.to_bytes(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([b"", b"123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_bytes(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([b"", b"123", b"321"]), l2)
for i in l2:
self.assertTrue(isinstance(i, bytes))
# check set types, inplace
l = set("")
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(b""), l2)
l = set([b"", b"123"])
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([b"", b"123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([b"", b"123", b"321"]), l2)
elif six.PY3:
self.assertIsNone(cpt.to_bytes(None))
self.assertTrue(isinstance(cpt.to_bytes(str("")), bytes))
self.assertTrue(isinstance(cpt.to_bytes(str("123")), bytes))
self.assertTrue(isinstance(cpt.to_bytes(b""), bytes))
self.assertTrue(isinstance(cpt.to_bytes(b""), bytes))
self.assertTrue(isinstance(cpt.to_bytes(u""), bytes))
self.assertTrue(isinstance(cpt.to_bytes(u""), bytes))
self.assertEqual(b"", cpt.to_bytes(str("")))
self.assertEqual(b"123", cpt.to_bytes(str("123")))
self.assertEqual(b"", cpt.to_bytes(b""))
self.assertEqual(b"123", cpt.to_bytes(b"123"))
self.assertEqual(b"", cpt.to_bytes(u""))
self.assertEqual(b"123", cpt.to_bytes(u"123"))
# check list types, not inplace
l = [""]
l2 = cpt.to_bytes(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual([b""], l2)
l = ["", "123"]
l2 = cpt.to_bytes(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual([b"", b"123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_bytes(l)
self.assertTrue(isinstance(l2, list))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual([b"", b"123", b"321"], l2)
# check list types, inplace
l = [""]
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b""], l2)
l = ["", b"123"]
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b"", b"123"], l2)
l = ["", b"123", u"321"]
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, list))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual([b"", b"123", b"321"], l2)
for i in l2:
self.assertTrue(isinstance(i, bytes))
# check set types, not inplace
l = set([""])
l2 = cpt.to_bytes(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual(set([b""]), l2)
l = set([u"", u"123"])
l2 = cpt.to_bytes(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual(set([b"", b"123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_bytes(l, inplace=False)
self.assertTrue(isinstance(l2, set))
self.assertFalse(l is l2)
self.assertNotEqual(l, l2)
self.assertEqual(set([b"", b"123", b"321"]), l2)
# check set types, inplace
l = set("")
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set(b""), l2)
l = set([u"", u"123"])
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([b"", b"123"]), l2)
l = set(["", b"123", u"321"])
l2 = cpt.to_bytes(l, inplace=True)
self.assertTrue(isinstance(l2, set))
self.assertTrue(l is l2)
self.assertEqual(l, l2)
self.assertEqual(set([b"", b"123", b"321"]), l2)
for i in l2:
self.assertTrue(isinstance(i, bytes))
def test_round(self):
self.assertEqual(3.0, cpt.round(3.4))
self.assertEqual(4.0, cpt.round(3.5))
self.assertEqual(0.0, cpt.round(0.1))
self.assertEqual(-0.0, cpt.round(-0.1))
self.assertEqual(-3.0, cpt.round(-3.4))
self.assertEqual(-4.0, cpt.round(-3.5))
self.assertEqual(5.0, cpt.round(5))
self.assertRaises(TypeError, cpt.round, None)
def test_floor_division(self):
self.assertEqual(0.0, cpt.floor_division(3, 4))
self.assertEqual(1.0, cpt.floor_division(4, 3))
self.assertEqual(2.0, cpt.floor_division(6, 3))
self.assertEqual(-2.0, cpt.floor_division(-4, 3))
self.assertEqual(-2.0, cpt.floor_division(-6, 3))
self.assertRaises(ZeroDivisionError, cpt.floor_division, 3, 0)
self.assertRaises(TypeError, cpt.floor_division, None, None)
def test_get_exception_message(self):
exception_message = "test_message"
self.assertRaises(AssertionError, cpt.get_exception_message, None)
if six.PY2:
self.assertRaises(AttributeError, cpt.get_exception_message, exception_message)
try:
raise RuntimeError(exception_message)
except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e))
self.assertIsNotNone(e)
try:
raise Exception(exception_message)
except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e))
self.assertIsNotNone(e)
if six.PY3:
try:
raise RuntimeError(exception_message)
except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e))
self.assertIsNotNone(e)
try:
raise Exception(exception_message)
except Exception as e:
self.assertEqual(exception_message, cpt.get_exception_message(e))
self.assertIsNotNone(e)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册