diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index fc00a816169d8109f9a2dc3c99bcddca59fd0cf3..0424781cf3e4e2c08b7978c77e7f3b63765d8760 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -19,6 +19,9 @@ from ..framework import Variable, convert_np_dtype_to_dtype_ from ..layers.layer_function_generator import OpProtoHolder from . import to_variable, no_grad +import numpy as np +import six + _supported_int_dtype_ = [ core.VarDesc.VarType.UINT8, core.VarDesc.VarType.INT8, @@ -116,6 +119,43 @@ def monkey_patch_math_varbase(): def _neg_(var): return _scalar_elementwise_op_(var, -1.0, 0.0) + def _float_(var): + numel = np.prod(var.shape) + assert numel == 1, "only one element variable can be converted to float." + tensor = var.value().get_tensor() + assert tensor._is_initialized(), "variable's tensor is not initialized" + return float(var.numpy().flatten()[0]) + + def _long_(var): + numel = np.prod(var.shape) + assert numel == 1, "only one element variable can be converted to long." + tensor = var.value().get_tensor() + assert tensor._is_initialized(), "variable's tensor is not initialized" + if six.PY2: + return long(var.numpy().flatten()[0]) + else: + return int(var.numpy().flatten()[0]) + + def _int_(var): + numel = np.prod(var.shape) + assert numel == 1, "only one element variable can be converted to int." + tensor = var.value().get_tensor() + assert tensor._is_initialized(), "variable's tensor is not initialized" + return int(var.numpy().flatten()[0]) + + def _len_(var): + return var.shape[0] + + def _index_(var): + numel = np.prod(var.shape) + assert numel == 1, "only one element variable can be converted to python index." + tensor = var.value().get_tensor() + assert tensor._is_initialized(), "variable's tensor is not initialized" + if six.PY2: + return long(var.numpy().flatten()[0]) + else: + return int(var.numpy().flatten()[0]) + def _scalar_elementwise_add_(var, value): return _scalar_elementwise_op_(var, 1.0, value) @@ -220,4 +260,9 @@ def monkey_patch_math_varbase(): # b = -a core.VarBase.__neg__ = _neg_ + core.VarBase.__float__ = _float_ + core.VarBase.__long__ = _long_ + core.VarBase.__int__ = _int_ + core.VarBase.__len__ = _len_ + core.VarBase.__index__ = _index_ core.VarBase.astype = astype diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py index 2f454b174bf8dbd0772e77bb12d9b81eaef0c696..4a967d979646d23be4ccd7b75fc09adda76b80a5 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py @@ -18,6 +18,7 @@ import unittest from decorator_helper import prog_scope import paddle.fluid as fluid import numpy as np +import six class TestMathOpPatchesVarBase(unittest.TestCase): @@ -208,6 +209,34 @@ class TestMathOpPatchesVarBase(unittest.TestCase): res = -a self.assertTrue(np.array_equal(res.numpy(), -a_np)) + def test_float_int_long(self): + with fluid.dygraph.guard(): + a = fluid.dygraph.to_variable(np.array([100.1])) + self.assertTrue(float(a) == 100.1) + self.assertTrue(int(a) == 100) + if six.PY2: + self.assertTrue(long(a) == 100) + else: + self.assertTrue(int(a) == 100) + + def test_len(self): + a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + with fluid.dygraph.guard(): + a = fluid.dygraph.to_variable(a_np) + self.assertTrue(len(a) == 10) + + def test_index(self): + with fluid.dygraph.guard(): + var1 = fluid.dygraph.to_variable(np.array([2])) + i_tmp = 0 + for i in range(var1): + self.assertTrue(i == i_tmp) + i_tmp = i_tmp + 1 + list1 = [1, 2, 3, 4, 5] + self.assertTrue(list1[var1] == 3) + str1 = "just test" + self.assertTrue(str1[var1] == 's') + if __name__ == '__main__': unittest.main()