未验证 提交 ce7d5263 编写于 作者: Z Zhou Wei 提交者: GitHub

[2.0API]Bind method for tensor and Variable (#26416)

* binding tensor method

* binding tensor method

* binding tensor method

* Binding methods for class Tensor and Variable
上级 e662d1e0
......@@ -31,6 +31,10 @@ import paddle.reader
import paddle.dataset
import paddle.batch
batch = batch.batch
from .fluid import monkey_patch_variable
from .fluid.dygraph import monkey_patch_math_varbase
monkey_patch_variable()
monkey_patch_math_varbase()
import paddle.framework
from .framework import VarBase as Tensor
from .framework import ComplexVariable as ComplexTensor
......
......@@ -59,6 +59,8 @@ from .rnn import *
from . import amp
from .amp import *
from .math_op_patch import monkey_patch_math_varbase
__all__ = []
__all__ += layers.__all__
__all__ += base.__all__
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
from .. import core
from ..framework import Variable, convert_np_dtype_to_dtype_, _varbase_creator
from ..layers.layer_function_generator import OpProtoHolder
from ..layers import common_methods
from . import to_variable, no_grad
import numpy as np
......@@ -30,6 +31,8 @@ _supported_int_dtype_ = [
core.VarDesc.VarType.INT64,
]
_already_patch_varbase = False
def monkey_patch_math_varbase():
"""
......@@ -140,22 +143,27 @@ def monkey_patch_math_varbase():
else:
return int(var.numpy().flatten()[0])
def _scalar_elementwise_add_(var, value):
@property
def _ndim_(var):
return len(var.shape)
def _scalar_add_(var, value):
return _scalar_elementwise_op_(var, 1.0, value)
def _scalar_elementwise_sub_(var, value):
def _scalar_sub_(var, value):
return _scalar_elementwise_op_(var, 1.0, -value)
def _scalar_elementwise_rsub_(var, value):
def _scalar_rsub_(var, value):
return _scalar_elementwise_op_(var, -1.0, value)
def _scalar_elementwise_mul_(var, value):
def _scalar_mul_(var, value):
return _scalar_elementwise_op_(var, value, 0.0)
def _scalar_elementwise_div_(var, value):
def _scalar_div_(var, value):
return _scalar_elementwise_op_(var, 1.0 / value, 0.0)
def _elemwise_method_creator_(method_name,
# for binary operator such as elementwise, compare
def _binary_creator_(method_name,
op_type,
reverse=False,
scalar_method=None):
......@@ -200,60 +208,119 @@ def monkey_patch_math_varbase():
__impl__.__doc__ = """
{0}
Args:
self(Variable): left hand variable
other_var(Variable|float|int): right hand variable
self(Tensor): left hand Tensor
other_var(Tensor|float|int): right hand Tensor
Returns:
Variable
Tensor
""".format(comment)
__impl__.__name__ = method_name
return __impl__
# inject methods
for method_name, op_type, reverse, scalar_method in (
("__add__", "elementwise_add", False, _scalar_elementwise_add_),
# a+b == b+a. Do not need to reverse explicitly
("__radd__", "elementwise_add", False, _scalar_elementwise_add_),
("__sub__", "elementwise_sub", False, _scalar_elementwise_sub_),
("__rsub__", "elementwise_sub", True, _scalar_elementwise_rsub_),
("__mul__", "elementwise_mul", False, _scalar_elementwise_mul_),
# a*b == b*a. Do not need to reverse explicitly
("__rmul__", "elementwise_mul", False, _scalar_elementwise_mul_),
("__div__", "elementwise_div", False, _scalar_elementwise_div_),
("__truediv__", "elementwise_div", False, _scalar_elementwise_div_),
("__rdiv__", "elementwise_div", True, None),
("__rtruediv__", "elementwise_div", True, None),
("__pow__", "elementwise_pow", False, None),
("__rpow__", "elementwise_pow", True, None),
("__floordiv__", "elementwise_floordiv", False, None),
("__mod__", "elementwise_mod", False, None),
# for logical compare
("__eq__", "equal", False, None),
("__ne__", "not_equal", False, None),
("__lt__", "less_than", False, None),
("__le__", "less_equal", False, None),
("__gt__", "greater_than", False, None),
("__ge__", "greater_equal", False, None)):
setattr(core.VarBase, method_name,
_elemwise_method_creator_(method_name, op_type, reverse,
scalar_method))
# b = -a
core.VarBase.__neg__ = _neg_
core.VarBase.__float__ = _float_
core.VarBase.__long__ = _long_
core.VarBase.__int__ = _int_
core.VarBase.__len__ = _len_
core.VarBase.__index__ = _index_
core.VarBase.astype = astype
"""
When code is written like this
y = np.pi * var
ndarray.__mul__(self, var) is called, var will be traced as an array(by using __len__, __getitem__), which is not right.
when var.__array_ufunc__ is set to None, var.__rmul__(self, np) will be called.
# Todo(zhouwei): implement dygraph template to adapt to any function, receive('op_type', 'arg_template')
# Such as _method_creator_('addmm', 'x, y, alpha=1.0, beta=1.0, name=None'). It can reduce call time.
def _method_creator_(op_type, arg_template=None):
def __impl__(self):
op = getattr(core.ops, op_type)
return op(self)
The details can be seen bellow:
https://docs.scipy.org/doc/numpy-1.13.0/neps/ufunc-overrides.html#behavior-in-combination-with-python-s-binary-operations
"""
core.VarBase.__array_ufunc__ = None
__impl__.__doc__ = """
See paddle.{}""".format(op_type)
__impl__.__name__ = op_type
return __impl__
varbase_methods = [
# Type1: From custom fun or lambda
## b=-a
('__neg__', _neg_),
('__float__', _float_),
('__long__', _long_),
('__int__', _int_),
('__len__', _len_),
('__index__', _index_),
('astype', astype),
('dim', lambda x: len(x.shape)),
('ndimension', lambda x: len(x.shape)),
('ndim', _ndim_),
('size', lambda x: x.shape),
# Type2: From Template that create core.ops automatically. It's recommended.
('__add__',
_binary_creator_('__add__', 'elementwise_add', False, _scalar_add_)),
## a+b == b+a. Do not need to reverse explicitly
('__radd__',
_binary_creator_('__radd__', 'elementwise_add', False, _scalar_add_)),
('__sub__', _binary_creator_('__sub__', 'elementwise_sub', False,
_scalar_sub_)),
('__rsub__', _binary_creator_('__rsub__', 'elementwise_sub', True,
_scalar_rsub_)),
('__mul__', _binary_creator_('__mul__', 'elementwise_mul', False,
_scalar_mul_)),
## a*b == b*a. Do not need to reverse explicitly
('__rmul__',
_binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)),
('__div__', _binary_creator_('__div__', 'elementwise_div', False,
_scalar_div_)),
('__truediv__', _binary_creator_('__truediv__', 'elementwise_div',
False, _scalar_div_)),
('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True,
None)),
('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True,
None)),
('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False,
None)),
('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True,
None)),
('__floordiv__', _binary_creator_('__floordiv__',
'elementwise_floordiv', False, None)),
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False,
None)),
## for logical compare
('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),
('__lt__', _binary_creator_('__lt__', 'less_than', False, None)),
('__le__', _binary_creator_('__le__', 'less_equal', False, None)),
('__gt__', _binary_creator_('__gt__', 'greater_than', False, None)),
('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None)),
('__array_ufunc__', None),
('sigmoid', _method_creator_('sigmoid', 'name=None')),
('logsigmoid', _method_creator_('logsigmoid', 'name=None')),
('exp', _method_creator_('exp', 'name=None')),
('tanh', _method_creator_('tanh', 'name=None')),
('atan', _method_creator_('atan', 'name=None')),
('tanh_shrink', _method_creator_('tanh_shrink', 'name=None')),
('sqrt', _method_creator_('sqrt', 'name=None')),
('rsqrt', _method_creator_('rsqrt', 'name=None')),
('abs', _method_creator_('abs', 'name=None')),
('ceil', _method_creator_('ceil', 'name=None')),
('floor', _method_creator_('floor', 'name=None')),
('cos', _method_creator_('cos', 'name=None')),
('acos', _method_creator_('acos', 'name=None')),
('asin', _method_creator_('asin', 'name=None')),
('sin', _method_creator_('sin', 'name=None')),
('sinh', _method_creator_('sinh', 'name=None')),
('cosh', _method_creator_('cosh', 'name=None')),
('round', _method_creator_('round', 'name=None')),
('reciprocal', _method_creator_('reciprocal', 'name=None')),
('square', _method_creator_('square', 'name=None')),
('softplus', _method_creator_('softplus', 'name=None')),
('softsign', _method_creator_('softsign', 'name=None')),
# Type3: Form module 'paddle.tensor' defaultly.
# It's not a goodway, because it will increase call time.
]
global _already_patch_varbase
if not _already_patch_varbase:
for method in varbase_methods:
method_name = method[0]
method_impl = method[1]
setattr(core.VarBase, method_name, method_impl)
else:
import paddle.tensor
for method_name in common_methods:
if hasattr(core.VarBase, method_name): continue
method_impl = getattr(paddle.tensor, method_name, None)
if method_impl: setattr(core.VarBase, method_name, method_impl)
_already_patch_varbase = True
......@@ -54,6 +54,31 @@ EXPRESSION_MAP = {
"__ge__": "A >= B"
}
# method for Tensor from paddle.tensor
# edit it when paddle.tensor has new method about Tensor operation
common_methods = [
'exp', 'tanh', 'atan', 'sqrt', 'rsqrt', 'abs', 'ceil', 'floor', 'cos',
'acos', 'asin', 'sin', 'sinh', 'cosh', 'round', 'reciprocal', 'square',
'rank', 'matmul', 'dot', 'norm', 'transpose', 'dist', 't', 'cross',
'cholesky', 'bmm', 'histogram', 'equal', 'greater_equal', 'greater_than',
'is_empty', 'isfinite', 'less_equal', 'less_than', 'logical_and',
'logical_not', 'logical_or', 'logical_xor', 'not_equal', 'reduce_all',
'reduce_any', 'allclose', 'equal_all', 'cast', 'expand', 'expand_as',
'tile', 'flatten', 'gather', 'gather_nd', 'reshape', 'reverse', 'scatter',
'scatter_nd_add', 'scatter_nd', 'shard_index', 'slice', 'split', 'squeeze',
'strided_slice', 'unique', 'unique_with_counts', 'unsqueeze', 'flip',
'unbind', 'roll', 'cumsum', 'increment', 'log', 'pow', 'reciprocal',
'round', 'rsqrt', 'scale', 'sign', 'stanh', 'sum', 'reduce_prod', 'max',
'min', 'mm', 'div', 'multiply', 'add', 'logsumexp', 'log1p', 'erf',
'addcmul', 'addmm', 'clamp', 'trace', 'kron', 'argmax', 'argmin', 'argsort',
'has_inf', 'has_nan', 'topk', 'index_select', 'nonzero', 'sort',
'index_sample', 'mean', 'std', 'var', 'elementwise_add', 'elementwise_div',
'elementwise_floordiv', 'elementwise_mod', 'elementwise_pow',
'elementwise_sub'
]
_already_patch_variable = False
def monkey_patch_variable():
def unique_tmp_name():
......@@ -179,7 +204,7 @@ def monkey_patch_variable():
"out_dtype": out.dtype})
return out
def _scalar_elementwise_op_(var, scale, bias):
def _scalar_op_(var, scale, bias):
block = current_block(var)
out = create_new_tmp_var(block, var.dtype)
block.append_op(
......@@ -191,24 +216,24 @@ def monkey_patch_variable():
return out
def _neg_(var):
return _scalar_elementwise_op_(var, -1.0, 0.0)
return _scalar_op_(var, -1.0, 0.0)
def _scalar_elementwise_add_(var, value):
return _scalar_elementwise_op_(var, 1.0, value)
def _scalar_add_(var, value):
return _scalar_op_(var, 1.0, value)
def _scalar_elementwise_sub_(var, value):
return _scalar_elementwise_op_(var, 1.0, -value)
def _scalar_sub_(var, value):
return _scalar_op_(var, 1.0, -value)
def _scalar_elementwise_rsub_(var, value):
return _scalar_elementwise_op_(var, -1.0, value)
def _scalar_rsub_(var, value):
return _scalar_op_(var, -1.0, value)
def _scalar_elementwise_mul_(var, value):
return _scalar_elementwise_op_(var, value, 0.0)
def _scalar_mul_(var, value):
return _scalar_op_(var, value, 0.0)
def _scalar_elementwise_div_(var, value):
return _scalar_elementwise_op_(var, 1.0 / value, 0.0)
def _scalar_div_(var, value):
return _scalar_op_(var, 1.0 / value, 0.0)
def _elemwise_method_creator_(method_name,
def _binary_creator_(method_name,
op_type,
reverse=False,
scalar_method=None):
......@@ -296,35 +321,60 @@ def monkey_patch_variable():
__impl__.__name__ = method_name
return __impl__
# inject methods
for method_name, op_type, reverse, scalar_method in (
("__add__", "elementwise_add", False, _scalar_elementwise_add_),
variable_methods = [
# b=-a
('__neg__', _neg_),
('astype', astype),
('__add__', _binary_creator_('__add__', 'elementwise_add', False,
_scalar_add_)),
# a+b == b+a. Do not need to reverse explicitly
("__radd__", "elementwise_add", False, _scalar_elementwise_add_),
("__sub__", "elementwise_sub", False, _scalar_elementwise_sub_),
("__rsub__", "elementwise_sub", True, _scalar_elementwise_rsub_),
("__mul__", "elementwise_mul", False, _scalar_elementwise_mul_),
('__radd__',
_binary_creator_('__radd__', 'elementwise_add', False, _scalar_add_)),
('__sub__', _binary_creator_('__sub__', 'elementwise_sub', False,
_scalar_sub_)),
('__rsub__', _binary_creator_('__rsub__', 'elementwise_sub', True,
_scalar_rsub_)),
('__mul__', _binary_creator_('__mul__', 'elementwise_mul', False,
_scalar_mul_)),
# a*b == b*a. Do not need to reverse explicitly
("__rmul__", "elementwise_mul", False, _scalar_elementwise_mul_),
("__div__", "elementwise_div", False, _scalar_elementwise_div_),
("__truediv__", "elementwise_div", False, _scalar_elementwise_div_),
("__rdiv__", "elementwise_div", True, None),
("__rtruediv__", "elementwise_div", True, None),
("__pow__", "elementwise_pow", False, None),
("__rpow__", "elementwise_pow", True, None),
("__floordiv__", "elementwise_floordiv", False, None),
("__mod__", "elementwise_mod", False, None),
('__rmul__',
_binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)),
('__div__', _binary_creator_('__div__', 'elementwise_div', False,
_scalar_div_)),
('__truediv__', _binary_creator_('__truediv__', 'elementwise_div',
False, _scalar_div_)),
('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True,
None)),
('__rtruediv__', _binary_creator_('__rtruediv__', 'elementwise_div',
True, None)),
('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False,
None)),
('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True,
None)),
('__floordiv__', _binary_creator_('__floordiv__',
'elementwise_floordiv', False, None)),
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False,
None)),
# for logical compare
("__eq__", "equal", False, None),
("__ne__", "not_equal", False, None),
("__lt__", "less_than", False, None),
("__le__", "less_equal", False, None),
("__gt__", "greater_than", False, None),
("__ge__", "greater_equal", False, None)):
setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse,
scalar_method))
# b = -a
Variable.__neg__ = _neg_
Variable.astype = astype
('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),
('__lt__', _binary_creator_('__lt__', 'less_than', False, None)),
('__le__', _binary_creator_('__le__', 'less_equal', False, None)),
('__gt__', _binary_creator_('__gt__', 'greater_than', False, None)),
('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None))
]
global _already_patch_variable
if not _already_patch_variable:
for method in variable_methods:
method_name = method[0]
method_impl = method[1]
setattr(Variable, method_name, method_impl)
else:
import paddle.tensor
for method_name in common_methods:
if hasattr(Variable, method_name): continue
method_impl = getattr(paddle.tensor, method_name, None)
if method_impl: setattr(Variable, method_name, method_impl)
_already_patch_variable = True
......@@ -15,6 +15,7 @@
from __future__ import print_function
import unittest
import paddle
import paddle.fluid as fluid
import numpy as np
import six
......@@ -284,6 +285,223 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
self.assertEqual((a != b).dtype, fluid.core.VarDesc.VarType.BOOL)
self.assertTrue(np.array_equal((a != b).numpy(), a_np != b_np))
def test_tensor_patch_method(self):
paddle.disable_static()
x_np = np.random.uniform(-1, 1, [2, 3]).astype(self.dtype)
y_np = np.random.uniform(-1, 1, [2, 3]).astype(self.dtype)
z_np = np.random.uniform(-1, 1, [6, 9]).astype(self.dtype)
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
z = paddle.to_tensor(z_np)
a = paddle.to_tensor([[1, 1], [2, 2], [3, 3]])
b = paddle.to_tensor([[1, 1], [2, 2], [3, 3]])
# 1. Unary operation for Tensor
self.assertEqual(x.dim(), 2)
self.assertEqual(x.ndimension(), 2)
self.assertEqual(x.ndim, 2)
self.assertEqual(x.size(), [2, 3])
self.assertTrue(
np.array_equal(x.sigmoid().numpy(), fluid.layers.sigmoid(x).numpy(
)))
self.assertTrue(
np.array_equal(x.logsigmoid().numpy(),
fluid.layers.logsigmoid(x).numpy()))
self.assertTrue(np.array_equal(x.exp().numpy(), paddle.exp(x).numpy()))
self.assertTrue(
np.array_equal(x.tanh().numpy(), paddle.tanh(x).numpy()))
self.assertTrue(
np.array_equal(x.atan().numpy(), paddle.atan(x).numpy()))
self.assertTrue(
np.array_equal(x.tanh_shrink().numpy(),
fluid.layers.tanh_shrink(x).numpy()))
self.assertTrue(np.array_equal(x.abs().numpy(), paddle.abs(x).numpy()))
m = x.abs()
self.assertTrue(
np.array_equal(m.sqrt().numpy(), paddle.sqrt(m).numpy()))
self.assertTrue(
np.array_equal(m.rsqrt().numpy(), paddle.rsqrt(m).numpy()))
self.assertTrue(
np.array_equal(x.ceil().numpy(), paddle.ceil(x).numpy()))
self.assertTrue(
np.array_equal(x.floor().numpy(), paddle.floor(x).numpy()))
self.assertTrue(np.array_equal(x.cos().numpy(), paddle.cos(x).numpy()))
self.assertTrue(
np.array_equal(x.acos().numpy(), paddle.acos(x).numpy()))
self.assertTrue(
np.array_equal(x.asin().numpy(), paddle.asin(x).numpy()))
self.assertTrue(np.array_equal(x.sin().numpy(), paddle.sin(x).numpy()))
self.assertTrue(
np.array_equal(x.sinh().numpy(), paddle.sinh(x).numpy()))
self.assertTrue(
np.array_equal(x.cosh().numpy(), paddle.cosh(x).numpy()))
self.assertTrue(
np.array_equal(x.round().numpy(), paddle.round(x).numpy()))
self.assertTrue(
np.array_equal(x.reciprocal().numpy(), paddle.reciprocal(x).numpy(
)))
self.assertTrue(
np.array_equal(x.square().numpy(), paddle.square(x).numpy()))
self.assertTrue(
np.array_equal(x.softplus().numpy(),
fluid.layers.softplus(x).numpy()))
self.assertTrue(
np.array_equal(x.softsign().numpy(),
fluid.layers.softsign(x).numpy()))
self.assertTrue(
np.array_equal(x.rank().numpy(), paddle.rank(x).numpy()))
self.assertTrue(
np.array_equal(x[0].t().numpy(), paddle.t(x[0]).numpy()))
m = paddle.to_tensor(np.random.uniform(1, 2, [3, 3]), 'float32')
m = m.matmul(m.t())
self.assertTrue(
np.array_equal(m.cholesky().numpy(), paddle.cholesky(m).numpy()))
self.assertTrue(
np.array_equal(x.is_empty().numpy(), paddle.is_empty(x).numpy()))
self.assertTrue(
np.array_equal(x.isfinite().numpy(), paddle.isfinite(x).numpy()))
self.assertTrue(
np.array_equal(
x.cast('int32').numpy(), paddle.cast(x, 'int32').numpy()))
self.assertTrue(
np.array_equal(
x.expand([3, 2, 3]).numpy(),
paddle.expand(x, [3, 2, 3]).numpy()))
self.assertTrue(
np.array_equal(
x.tile([2, 2]).numpy(), paddle.tile(x, [2, 2]).numpy()))
self.assertTrue(
np.array_equal(x.flatten().numpy(), paddle.flatten(x).numpy()))
index = paddle.to_tensor([0, 1])
self.assertTrue(
np.array_equal(
x.gather(index).numpy(), paddle.gather(x, index).numpy()))
index = paddle.to_tensor([[0, 1], [1, 2]])
self.assertTrue(
np.array_equal(
x.gather_nd(index).numpy(), paddle.gather_nd(x, index).numpy()))
self.assertTrue(
np.array_equal(
x.reverse([0, 1]).numpy(), paddle.reverse(x, [0, 1]).numpy()))
self.assertTrue(
np.array_equal(
a.reshape([3, 2]).numpy(), paddle.reshape(a, [3, 2]).numpy()))
self.assertTrue(
np.array_equal(
x.slice([0, 1], [0, 0], [1, 2]).numpy(),
paddle.slice(x, [0, 1], [0, 0], [1, 2]).numpy()))
self.assertTrue(
np.array_equal(
x.split(2)[0].numpy(), paddle.split(x, 2)[0].numpy()))
m = paddle.to_tensor(
np.random.uniform(-1, 1, [1, 6, 1, 1]).astype(self.dtype))
self.assertTrue(
np.array_equal(
m.squeeze([]).numpy(), paddle.squeeze(m, []).numpy()))
self.assertTrue(
np.array_equal(
m.squeeze([1, 2]).numpy(), paddle.squeeze(m, [1, 2]).numpy()))
m = paddle.to_tensor([2, 3, 3, 1, 5, 3], 'float32')
self.assertTrue(
np.array_equal(m.unique()[0].numpy(), paddle.unique(m)[0].numpy()))
self.assertTrue(
np.array_equal(m.unique_with_counts()[2],
paddle.unique_with_counts(m)[2]))
self.assertTrue(np.array_equal(x.flip([0]), paddle.flip(x, [0])))
self.assertTrue(np.array_equal(x.unbind(0), paddle.unbind(x, 0)))
self.assertTrue(np.array_equal(x.roll(1), paddle.roll(x, 1)))
self.assertTrue(np.array_equal(x.cumsum(1), paddle.cumsum(x, 1)))
m = paddle.to_tensor(1)
self.assertTrue(np.array_equal(m.increment(), paddle.increment(m)))
m = x.abs()
self.assertTrue(np.array_equal(m.log(), paddle.log(m)))
self.assertTrue(np.array_equal(x.pow(2), paddle.pow(x, 2)))
self.assertTrue(np.array_equal(x.reciprocal(), paddle.reciprocal(x)))
# 2. Binary operation
self.assertTrue(
np.array_equal(
x.matmul(y, True, False).numpy(),
paddle.matmul(x, y, True, False).numpy()))
self.assertTrue(
np.array_equal(
x.norm(
p='fro', axis=[0, 1]).numpy(),
paddle.norm(
x, p='fro', axis=[0, 1]).numpy()))
self.assertTrue(
np.array_equal(x.dist(y).numpy(), paddle.dist(x, y).numpy()))
self.assertTrue(
np.array_equal(x.cross(y).numpy(), paddle.cross(x, y).numpy()))
m = x.expand([2, 2, 3])
n = y.expand([2, 2, 3]).transpose([0, 2, 1])
self.assertTrue(
np.array_equal(m.bmm(n).numpy(), paddle.bmm(m, n).numpy()))
self.assertTrue(
np.array_equal(
x.histogram(5, -1, 1).numpy(),
paddle.histogram(x, 5, -1, 1).numpy()))
self.assertTrue(
np.array_equal(x.equal(y).numpy(), paddle.equal(x, y).numpy()))
self.assertTrue(
np.array_equal(
x.greater_equal(y).numpy(), paddle.greater_equal(x, y).numpy()))
self.assertTrue(
np.array_equal(
x.greater_than(y).numpy(), paddle.greater_than(x, y).numpy()))
self.assertTrue(
np.array_equal(
x.less_equal(y).numpy(), paddle.less_equal(x, y).numpy()))
self.assertTrue(
np.array_equal(
x.less_than(y).numpy(), paddle.less_than(x, y).numpy()))
self.assertTrue(
np.array_equal(
x.not_equal(y).numpy(), paddle.not_equal(x, y).numpy()))
self.assertTrue(
np.array_equal(
x.equal_all(y).numpy(), paddle.equal_all(x, y).numpy()))
self.assertTrue(
np.array_equal(
x.allclose(y).numpy(), paddle.allclose(x, y).numpy()))
m = x.expand([2, 2, 3])
self.assertTrue(
np.array_equal(
x.expand_as(m).numpy(), paddle.expand_as(x, m).numpy()))
index = paddle.to_tensor([2, 1, 0])
self.assertTrue(
np.array_equal(
a.scatter(index, b).numpy(),
paddle.scatter(a, index, b).numpy()))
# 3. Bool tensor operation
x = paddle.to_tensor([[True, False], [True, False]])
y = paddle.to_tensor([[False, False], [False, True]])
self.assertTrue(
np.array_equal(x.reduce_all().numpy(), paddle.reduce_all(x).numpy(
)))
self.assertTrue(
np.array_equal(x.reduce_any().numpy(), paddle.reduce_any(x).numpy(
)))
self.assertTrue(
np.array_equal(
x.logical_and(y).numpy(), paddle.logical_and(x, y).numpy()))
self.assertTrue(
np.array_equal(
x.logical_not(y).numpy(), paddle.logical_not(x, y).numpy()))
self.assertTrue(
np.array_equal(
x.logical_or(y).numpy(), paddle.logical_or(x, y).numpy()))
self.assertTrue(
np.array_equal(
x.logical_xor(y).numpy(), paddle.logical_xor(x, y).numpy()))
self.assertTrue(
np.array_equal(
x.logical_and(y).numpy(), paddle.logical_and(x, y).numpy()))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册