提交 6988aea9 编写于 作者: M minqiyang

Fix long type in Python3

上级 fae5c1f5
......@@ -16,6 +16,7 @@ import six
import math
__all__ = [
'long_type',
'to_literal_str',
'to_bytes',
'round',
......@@ -23,6 +24,13 @@ __all__ = [
'get_exception_message',
]
if six.PY2:
int_type = int
long_type = long
else:
int_type = int
long_type = int
# str and bytes related functions
def to_literal_str(obj, encoding='utf-8', inplace=False):
......
......@@ -18,6 +18,14 @@ import six
class TestCompatible(unittest.TestCase):
def test_type(self):
if six.PY2:
self.assertEqual(cpt.int_type, int)
self.assertEqual(cpt.long_type, long)
else:
self.assertEqual(cpt.int_type, int)
self.assertEqual(cpt.long_type, int)
def test_to_literal_str(self):
# Only support python2.x and python3.x now
self.assertTrue(six.PY2 | six.PY3)
......
......@@ -17,6 +17,7 @@ import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid.compat as cpt
class TestLookupTableOp(OpTest):
......@@ -71,7 +72,7 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': long(padding_idx)}
self.attrs = {'padding_idx': cpt.long_type(padding_idx)}
self.check_output()
def test_check_grad(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册