diff --git a/python/paddle/fluid/compat.py b/python/paddle/fluid/compat.py index 6d3e7794de47a08415aeb232f63ecc7c84c2c9ab..62826c7ce9d751ef2b2317bf16afa8d5cfe26407 100644 --- a/python/paddle/fluid/compat.py +++ b/python/paddle/fluid/compat.py @@ -16,6 +16,7 @@ import six import math __all__ = [ + 'long_type', 'to_literal_str', 'to_bytes', 'round', @@ -23,6 +24,13 @@ __all__ = [ 'get_exception_message', ] +if six.PY2: + int_type = int + long_type = long +else: + int_type = int + long_type = int + # str and bytes related functions def to_literal_str(obj, encoding='utf-8', inplace=False): diff --git a/python/paddle/fluid/tests/unittests/test_compat.py b/python/paddle/fluid/tests/unittests/test_compat.py index 00216d33e1a2daa99cdcc490a78c973b439c26ab..525789ddb6d2f398607715def9f0ec6cd69e5993 100644 --- a/python/paddle/fluid/tests/unittests/test_compat.py +++ b/python/paddle/fluid/tests/unittests/test_compat.py @@ -18,6 +18,14 @@ import six class TestCompatible(unittest.TestCase): + def test_type(self): + if six.PY2: + self.assertEqual(cpt.int_type, int) + self.assertEqual(cpt.long_type, long) + else: + self.assertEqual(cpt.int_type, int) + self.assertEqual(cpt.long_type, int) + def test_to_literal_str(self): # Only support python2.x and python3.x now self.assertTrue(six.PY2 | six.PY3) diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index ac25f432dffd544d4b336983ec868f2431a5b91a..a325422c316a55f4aa57bd0770b8ed50f4803f71 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -17,6 +17,7 @@ import numpy as np from op_test import OpTest import paddle.fluid.core as core from paddle.fluid.op import Operator +import paddle.fluid.compat as cpt class TestLookupTableOp(OpTest): @@ -71,7 +72,7 @@ class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds): flatten_idx = ids.flatten() padding_idx = np.random.choice(flatten_idx, 1)[0] self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) - self.attrs = {'padding_idx': long(padding_idx)} + self.attrs = {'padding_idx': cpt.long_type(padding_idx)} self.check_output() def test_check_grad(self):