提交 6988aea9 编写于 作者: M minqiyang

Fix long type in Python3

上级 fae5c1f5
...@@ -16,6 +16,7 @@ import six ...@@ -16,6 +16,7 @@ import six
import math import math
__all__ = [ __all__ = [
'long_type',
'to_literal_str', 'to_literal_str',
'to_bytes', 'to_bytes',
'round', 'round',
...@@ -23,6 +24,13 @@ __all__ = [ ...@@ -23,6 +24,13 @@ __all__ = [
'get_exception_message', 'get_exception_message',
] ]
if six.PY2:
int_type = int
long_type = long
else:
int_type = int
long_type = int
# str and bytes related functions # str and bytes related functions
def to_literal_str(obj, encoding='utf-8', inplace=False): def to_literal_str(obj, encoding='utf-8', inplace=False):
......
...@@ -18,6 +18,14 @@ import six ...@@ -18,6 +18,14 @@ import six
class TestCompatible(unittest.TestCase): class TestCompatible(unittest.TestCase):
def test_type(self):
if six.PY2:
self.assertEqual(cpt.int_type, int)
self.assertEqual(cpt.long_type, long)
else:
self.assertEqual(cpt.int_type, int)
self.assertEqual(cpt.long_type, int)
def test_to_literal_str(self): def test_to_literal_str(self):
# Only support python2.x and python3.x now # Only support python2.x and python3.x now
self.assertTrue(six.PY2 | six.PY3) self.assertTrue(six.PY2 | six.PY3)
......
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
import paddle.fluid.compat as cpt
class TestLookupTableOp(OpTest): class TestLookupTableOp(OpTest):
...@@ -71,7 +72,7 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds): ...@@ -71,7 +72,7 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
flatten_idx = ids.flatten() flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0] padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': long(padding_idx)} self.attrs = {'padding_idx': cpt.long_type(padding_idx)}
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册