From f4063b2f69d4206aed813a966c618bfd5b9d4034 Mon Sep 17 00:00:00 2001 From: HansBug Date: Wed, 22 Sep 2021 19:37:55 +0800 Subject: [PATCH] dev, test, doc(hansbug): add Tensor.sign_ --- test/torch/test_tensor.py | 20 ++++++++++++++++++++ treetensor/torch/tensor.py | 9 +++++++++ 2 files changed, 29 insertions(+) diff --git a/test/torch/test_tensor.py b/test/torch/test_tensor.py index bd5dd0d6b..dbc3440dc 100644 --- a/test/torch/test_tensor.py +++ b/test/torch/test_tensor.py @@ -435,6 +435,26 @@ class TestTorchTensor: [0, -1]]}, })).all() + @choose_mark() + def test_sign_(self): + t1 = ttorch.tensor([12, 0, -3]) + t1r = t1.sign_() + assert t1r is t1 + assert isinstance(t1, torch.Tensor) + assert (t1 == ttorch.tensor([1, 0, -1])).all() + + t2 = ttorch.tensor({ + 'a': [12, 0, -3], + 'b': {'x': [[-3, 1], [0, -2]]}, + }) + t2r = t2.sign_() + assert t2r is t2 + assert (t2 == ttorch.tensor({ + 'a': [1, 0, -1], + 'b': {'x': [[-1, 1], + [0, -1]]}, + })).all() + @choose_mark() def test_round(self): t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]).round() diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index 3fe0d3182..aed77e6aa 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -363,6 +363,15 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.sign(*args, **kwargs) + @doc_from_base() + @return_self + @method_treelize() + def sign_(self, *args, **kwargs): + """ + In-place version of :meth:`Tensor.sign`. + """ + return self.sign_(*args, **kwargs) + @doc_from_base() @method_treelize() def sigmoid(self, *args, **kwargs): -- GitLab