未验证 提交 dc24f38a 编写于 作者: Z zhongpu 提交者: GitHub

support math operator for variable (#23063)

* support math operator for variable, test=develop

* polish code, test=develop

* polish code, test=develop
上级 baec0a07
...@@ -19,6 +19,9 @@ from ..framework import Variable, convert_np_dtype_to_dtype_ ...@@ -19,6 +19,9 @@ from ..framework import Variable, convert_np_dtype_to_dtype_
from ..layers.layer_function_generator import OpProtoHolder from ..layers.layer_function_generator import OpProtoHolder
from . import to_variable, no_grad from . import to_variable, no_grad
import numpy as np
import six
_supported_int_dtype_ = [ _supported_int_dtype_ = [
core.VarDesc.VarType.UINT8, core.VarDesc.VarType.UINT8,
core.VarDesc.VarType.INT8, core.VarDesc.VarType.INT8,
...@@ -116,6 +119,43 @@ def monkey_patch_math_varbase(): ...@@ -116,6 +119,43 @@ def monkey_patch_math_varbase():
def _neg_(var): def _neg_(var):
return _scalar_elementwise_op_(var, -1.0, 0.0) return _scalar_elementwise_op_(var, -1.0, 0.0)
def _float_(var):
numel = np.prod(var.shape)
assert numel == 1, "only one element variable can be converted to float."
tensor = var.value().get_tensor()
assert tensor._is_initialized(), "variable's tensor is not initialized"
return float(var.numpy().flatten()[0])
def _long_(var):
numel = np.prod(var.shape)
assert numel == 1, "only one element variable can be converted to long."
tensor = var.value().get_tensor()
assert tensor._is_initialized(), "variable's tensor is not initialized"
if six.PY2:
return long(var.numpy().flatten()[0])
else:
return int(var.numpy().flatten()[0])
def _int_(var):
numel = np.prod(var.shape)
assert numel == 1, "only one element variable can be converted to int."
tensor = var.value().get_tensor()
assert tensor._is_initialized(), "variable's tensor is not initialized"
return int(var.numpy().flatten()[0])
def _len_(var):
return var.shape[0]
def _index_(var):
numel = np.prod(var.shape)
assert numel == 1, "only one element variable can be converted to python index."
tensor = var.value().get_tensor()
assert tensor._is_initialized(), "variable's tensor is not initialized"
if six.PY2:
return long(var.numpy().flatten()[0])
else:
return int(var.numpy().flatten()[0])
def _scalar_elementwise_add_(var, value): def _scalar_elementwise_add_(var, value):
return _scalar_elementwise_op_(var, 1.0, value) return _scalar_elementwise_op_(var, 1.0, value)
...@@ -220,4 +260,9 @@ def monkey_patch_math_varbase(): ...@@ -220,4 +260,9 @@ def monkey_patch_math_varbase():
# b = -a # b = -a
core.VarBase.__neg__ = _neg_ core.VarBase.__neg__ = _neg_
core.VarBase.__float__ = _float_
core.VarBase.__long__ = _long_
core.VarBase.__int__ = _int_
core.VarBase.__len__ = _len_
core.VarBase.__index__ = _index_
core.VarBase.astype = astype core.VarBase.astype = astype
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
from decorator_helper import prog_scope from decorator_helper import prog_scope
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
import six
class TestMathOpPatchesVarBase(unittest.TestCase): class TestMathOpPatchesVarBase(unittest.TestCase):
...@@ -208,6 +209,34 @@ class TestMathOpPatchesVarBase(unittest.TestCase): ...@@ -208,6 +209,34 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
res = -a res = -a
self.assertTrue(np.array_equal(res.numpy(), -a_np)) self.assertTrue(np.array_equal(res.numpy(), -a_np))
def test_float_int_long(self):
with fluid.dygraph.guard():
a = fluid.dygraph.to_variable(np.array([100.1]))
self.assertTrue(float(a) == 100.1)
self.assertTrue(int(a) == 100)
if six.PY2:
self.assertTrue(long(a) == 100)
else:
self.assertTrue(int(a) == 100)
def test_len(self):
a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
with fluid.dygraph.guard():
a = fluid.dygraph.to_variable(a_np)
self.assertTrue(len(a) == 10)
def test_index(self):
with fluid.dygraph.guard():
var1 = fluid.dygraph.to_variable(np.array([2]))
i_tmp = 0
for i in range(var1):
self.assertTrue(i == i_tmp)
i_tmp = i_tmp + 1
list1 = [1, 2, 3, 4, 5]
self.assertTrue(list1[var1] == 3)
str1 = "just test"
self.assertTrue(str1[var1] == 's')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册