未验证 提交 2133f3dd 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

add API Tensor.T for reverse dim of Tensor (#35379)

上级 12155358
......@@ -148,6 +148,16 @@ def monkey_patch_math_varbase():
def _size_(var):
return np.prod(var.shape)
@property
def _T_(var):
if len(var.shape) == 1:
return var
perm = []
for i in range(len(var.shape)):
perm.insert(0, i)
out, _ = _C_ops.transpose2(var, 'axis', perm)
return out
def _scalar_add_(var, value):
return _scalar_elementwise_op_(var, 1.0, value)
......@@ -271,6 +281,7 @@ def monkey_patch_math_varbase():
('ndimension', lambda x: len(x.shape)),
('ndim', _ndim_),
('size', _size_),
('T', _T_),
('__add__',
_binary_creator_('__add__', 'elementwise_add', False, _scalar_add_)),
## a+b == b+a. Do not need to reverse explicitly
......
......@@ -88,17 +88,17 @@ def monkey_patch_varbase():
"""
# Note: getattr(self, attr, None) will call x.grad=x.gradient(), but gradient() only available in dygraph.
# It will fail. So, for propery in dygraph only, should not let it getattr(self, attr, None).
attr_not_need_keys = ['grad']
# It will fail. So, for propery that different between dynamic and static graph, should not getattr(self, attr, None).
attr_not_need_keys = ['grad', 'T']
if isinstance(self, ParamBase):
attr_kwargs = self.__dict__.copy()
else:
attr_names = []
for name in dir(self):
if name not in attr_not_need_keys and not (
inspect.ismethod(getattr(self, name)) or
name.startswith('_')):
attr_names.append(name)
if name not in attr_not_need_keys:
if not inspect.ismethod(getattr(
self, name)) and not name.startswith('_'):
attr_names.append(name)
attr_kwargs = {name: getattr(self, name) for name in attr_names}
attr_keys = ['block', 'shape', 'dtype', 'type', 'name', 'persistable']
......
......@@ -1503,6 +1503,55 @@ class Variable(object):
"""
return self.desc.type()
@property
def T(self):
"""
Permute current Variable with its dimensions reversed.
If `n` is the dimensions of `x` , `x.T` is equivalent to `x.transpose([n-1, n-2, ..., 0])`.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
x = paddle.ones(shape=[2, 3, 5])
x_T = x.T
exe = paddle.static.Executor()
x_T_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_T])[0]
print(x_T_np.shape)
# (5, 3, 2)
"""
if len(self.shape) == 1:
return self
perm = []
for i in range(len(self.shape)):
perm.insert(0, i)
out = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name + '.tmp'),
dtype=self.dtype,
type=self.type,
persistable=False,
stop_gradient=False)
input_shape = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name + '.tmp'),
dtype=self.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
self.block.append_op(
type='transpose2',
inputs={'X': [self]},
outputs={'Out': [out],
'XShape': [input_shape]},
attrs={'axis': perm})
return out
def clone(self):
"""
Returns a new static Variable, which is the clone of the original static
......
......@@ -335,6 +335,20 @@ class TestMathOpPatches(unittest.TestCase):
fetch_list=[z])
self.assertTrue(np.array_equal(out[0], out_np))
@prog_scope()
def test_T(self):
x_np = np.random.randint(-100, 100, [2, 8, 5, 3]).astype("int32")
out_np = x_np.T
x = paddle.static.data(name="x", shape=[2, 8, 5, 3], dtype="int32")
z = x.T
exe = fluid.Executor()
out = exe.run(fluid.default_main_program(),
feed={"x": x_np},
fetch_list=[z])
self.assertTrue(np.array_equal(out[0], out_np))
@prog_scope()
def test_ndim(self):
a = paddle.static.data(name="a", shape=[10, 1])
......
......@@ -527,6 +527,12 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
np.array_equal(
x.where(a, b).numpy(), paddle.where(x, a, b).numpy()))
x_np = np.random.randn(3, 6, 9, 7)
x = paddle.to_tensor(x_np)
x_T = x.T
self.assertTrue(x_T.shape, [7, 9, 6, 3])
self.assertTrue(np.array_equal(x_T.numpy(), x_np.T))
self.assertTrue(inspect.ismethod(a.dot))
self.assertTrue(inspect.ismethod(a.logsumexp))
self.assertTrue(inspect.ismethod(a.multiplex))
......
......@@ -367,8 +367,8 @@ tensor_method_func = [ #noqa
'real',
'imag',
'digamma',
'diagonal'
'trunc'
'diagonal',
'trunc',
'bitwise_and',
'bitwise_or',
'bitwise_xor',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册