From 2133f3dd78d990aab756c1f9e1964353eb66549d Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Wed, 8 Sep 2021 16:34:45 +0800 Subject: [PATCH] add API Tensor.T for reverse dim of Tensor (#35379) --- python/paddle/fluid/dygraph/math_op_patch.py | 11 +++++ .../fluid/dygraph/varbase_patch_methods.py | 12 ++--- python/paddle/fluid/framework.py | 49 +++++++++++++++++++ .../tests/unittests/test_math_op_patch.py | 14 ++++++ .../unittests/test_math_op_patch_var_base.py | 6 +++ python/paddle/tensor/__init__.py | 4 +- 6 files changed, 88 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index 9fbf176c22c..6b57544329e 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -148,6 +148,16 @@ def monkey_patch_math_varbase(): def _size_(var): return np.prod(var.shape) + @property + def _T_(var): + if len(var.shape) == 1: + return var + perm = [] + for i in range(len(var.shape)): + perm.insert(0, i) + out, _ = _C_ops.transpose2(var, 'axis', perm) + return out + def _scalar_add_(var, value): return _scalar_elementwise_op_(var, 1.0, value) @@ -271,6 +281,7 @@ def monkey_patch_math_varbase(): ('ndimension', lambda x: len(x.shape)), ('ndim', _ndim_), ('size', _size_), + ('T', _T_), ('__add__', _binary_creator_('__add__', 'elementwise_add', False, _scalar_add_)), ## a+b == b+a. Do not need to reverse explicitly diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 83e7d0ae1e0..e39a86e961d 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -88,17 +88,17 @@ def monkey_patch_varbase(): """ # Note: getattr(self, attr, None) will call x.grad=x.gradient(), but gradient() only available in dygraph. - # It will fail. So, for propery in dygraph only, should not let it getattr(self, attr, None). - attr_not_need_keys = ['grad'] + # It will fail. So, for propery that different between dynamic and static graph, should not getattr(self, attr, None). + attr_not_need_keys = ['grad', 'T'] if isinstance(self, ParamBase): attr_kwargs = self.__dict__.copy() else: attr_names = [] for name in dir(self): - if name not in attr_not_need_keys and not ( - inspect.ismethod(getattr(self, name)) or - name.startswith('_')): - attr_names.append(name) + if name not in attr_not_need_keys: + if not inspect.ismethod(getattr( + self, name)) and not name.startswith('_'): + attr_names.append(name) attr_kwargs = {name: getattr(self, name) for name in attr_names} attr_keys = ['block', 'shape', 'dtype', 'type', 'name', 'persistable'] diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 13477fd3422..6c95c9fad56 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1503,6 +1503,55 @@ class Variable(object): """ return self.desc.type() + @property + def T(self): + """ + Permute current Variable with its dimensions reversed. + + If `n` is the dimensions of `x` , `x.T` is equivalent to `x.transpose([n-1, n-2, ..., 0])`. + + Examples: + + .. code-block:: python + + import paddle + paddle.enable_static() + + x = paddle.ones(shape=[2, 3, 5]) + x_T = x.T + + exe = paddle.static.Executor() + x_T_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_T])[0] + print(x_T_np.shape) + # (5, 3, 2) + """ + if len(self.shape) == 1: + return self + perm = [] + for i in range(len(self.shape)): + perm.insert(0, i) + + out = self.block.create_var( + name=unique_name.generate_with_ignorable_key(self.name + '.tmp'), + dtype=self.dtype, + type=self.type, + persistable=False, + stop_gradient=False) + input_shape = self.block.create_var( + name=unique_name.generate_with_ignorable_key(self.name + '.tmp'), + dtype=self.dtype, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + + self.block.append_op( + type='transpose2', + inputs={'X': [self]}, + outputs={'Out': [out], + 'XShape': [input_shape]}, + attrs={'axis': perm}) + return out + def clone(self): """ Returns a new static Variable, which is the clone of the original static diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch.py b/python/paddle/fluid/tests/unittests/test_math_op_patch.py index cef5adbc5d3..258543631f9 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch.py @@ -335,6 +335,20 @@ class TestMathOpPatches(unittest.TestCase): fetch_list=[z]) self.assertTrue(np.array_equal(out[0], out_np)) + @prog_scope() + def test_T(self): + x_np = np.random.randint(-100, 100, [2, 8, 5, 3]).astype("int32") + out_np = x_np.T + + x = paddle.static.data(name="x", shape=[2, 8, 5, 3], dtype="int32") + z = x.T + + exe = fluid.Executor() + out = exe.run(fluid.default_main_program(), + feed={"x": x_np}, + fetch_list=[z]) + self.assertTrue(np.array_equal(out[0], out_np)) + @prog_scope() def test_ndim(self): a = paddle.static.data(name="a", shape=[10, 1]) diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py index 0afc9ee6253..3f611a31921 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py @@ -527,6 +527,12 @@ class TestMathOpPatchesVarBase(unittest.TestCase): np.array_equal( x.where(a, b).numpy(), paddle.where(x, a, b).numpy())) + x_np = np.random.randn(3, 6, 9, 7) + x = paddle.to_tensor(x_np) + x_T = x.T + self.assertTrue(x_T.shape, [7, 9, 6, 3]) + self.assertTrue(np.array_equal(x_T.numpy(), x_np.T)) + self.assertTrue(inspect.ismethod(a.dot)) self.assertTrue(inspect.ismethod(a.logsumexp)) self.assertTrue(inspect.ismethod(a.multiplex)) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 375375c8604..a67b015f8ff 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -367,8 +367,8 @@ tensor_method_func = [ #noqa 'real', 'imag', 'digamma', - 'diagonal' - 'trunc' + 'diagonal', + 'trunc', 'bitwise_and', 'bitwise_or', 'bitwise_xor', -- GitLab