diff --git a/docs/source/api_doc/common/object.rst b/docs/source/api_doc/common/object.rst index 1f3e0a8a25b315cf5e5d06d3565d6c23868c0a7d..7c6d3a738e212c04ed51cdfc4dcaba2547e3de9d 100644 --- a/docs/source/api_doc/common/object.rst +++ b/docs/source/api_doc/common/object.rst @@ -7,5 +7,5 @@ Object ----------------- .. autoclass:: Object - :members: __init__ + :members: __init__, all, any diff --git a/test/common/test_object.py b/test/common/test_object.py index c35d6aeb80aaac2f6f72febaaa36f0d6e3a89b60..9a8489e917c96b7c605bb8e41b767512f966652a 100644 --- a/test/common/test_object.py +++ b/test/common/test_object.py @@ -14,3 +14,13 @@ class TestCommonObject: assert Object({'a': 1, 'b': 2}) == typetrans(TreeValue({ 'a': 1, 'b': 2 }), Object) + + def test_all(self): + assert not Object({'a': False, 'b': {'x': False}}).all() + assert not Object({'a': True, 'b': {'x': False}}).all() + assert Object({'a': True, 'b': {'x': True}}).all() + + def test_any(self): + assert not Object({'a': False, 'b': {'x': False}}).any() + assert Object({'a': True, 'b': {'x': False}}).any() + assert Object({'a': True, 'b': {'x': True}}).any() diff --git a/test/torch/funcs/__init__.py b/test/torch/funcs/__init__.py index 7dd0fdff5eb2e9607fe5f296929e0e5cae664ea2..8c7d59bf93f86ffd9243b22973858781efe5c291 100644 --- a/test/torch/funcs/__init__.py +++ b/test/torch/funcs/__init__.py @@ -1,3 +1,4 @@ +from .test_autograd import TestTorchFuncsAutograd from .test_comparison import TestTorchFuncsComparison from .test_construct import TestTorchFuncsConstruct from .test_math import TestTorchFuncsMath diff --git a/test/torch/funcs/test_autograd.py b/test/torch/funcs/test_autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..f56fa60936f81b7a67965342028bdd77627c8af8 --- /dev/null +++ b/test/torch/funcs/test_autograd.py @@ -0,0 +1,31 @@ +import treetensor.torch as ttorch +from .base import choose_mark + + +# noinspection DuplicatedCode,PyUnresolvedReferences +class TestTorchFuncsAutograd: + + @choose_mark() + def test_detach(self): + tt1 = ttorch.tensor({ + 'a': [2, 3, 4.0], + 'b': {'x': [[5, 6], [7, 8.0]]} + }, requires_grad=True) + assert tt1.requires_grad.all() + + tt1r = ttorch.detach(tt1) + assert tt1.requires_grad.all() + assert tt1r is not tt1 + assert not tt1r.requires_grad.any() + + @choose_mark() + def test_detach_(self): + tt1 = ttorch.tensor({ + 'a': [2, 3, 4.0], + 'b': {'x': [[5, 6], [7, 8.0]]} + }, requires_grad=True) + assert tt1.requires_grad.all() + + tt1r = ttorch.detach_(tt1) + assert tt1r is tt1 + assert not tt1.requires_grad.any() diff --git a/test/torch/tensor/__init__.py b/test/torch/tensor/__init__.py index 3ed12ec2ac5e5308df69372b634f084bd0f175df..7f603a2bc065636060d489c3f24f785958cd60f3 100644 --- a/test/torch/tensor/__init__.py +++ b/test/torch/tensor/__init__.py @@ -1,3 +1,4 @@ +from .test_autograd import TestTorchTensorAutograd from .test_clazz import TestTorchTensorClass from .test_comparison import TestTorchTensorComparison from .test_math import TestTorchTensorMath diff --git a/test/torch/tensor/test_autograd.py b/test/torch/tensor/test_autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..73257d1011d6d53033f03676446b736a36e05e64 --- /dev/null +++ b/test/torch/tensor/test_autograd.py @@ -0,0 +1,80 @@ +import treetensor.torch as ttorch +from .base import choose_mark + + +# noinspection DuplicatedCode,PyUnresolvedReferences +class TestTorchTensorAutograd: + @choose_mark() + def test_requires_grad(self): + tt1 = ttorch.tensor({ + 'a': [2, 3, 4.0], + 'b': {'x': [[5, 6], [7, 8.0]]} + }, requires_grad=True) + assert tt1.requires_grad.all() + + tt1.a.requires_grad_(False) + assert not tt1.requires_grad.all() + assert tt1.requires_grad.any() + + tt1.b.x.requires_grad_(False) + assert not tt1.requires_grad.all() + assert not tt1.requires_grad.any() + + @choose_mark() + def test_requires_grad_(self): + tt1 = ttorch.tensor({ + 'a': [2, 3, 4.0], + 'b': {'x': [[5, 6], [7, 8.0]]} + }) + assert not tt1.requires_grad.any() + + tt1.requires_grad_(True) + assert tt1.requires_grad.all() + + tt1.a.requires_grad_(False) + assert not tt1.requires_grad.all() + assert tt1.requires_grad.any() + + tt1.b.x.requires_grad_(False) + assert not tt1.requires_grad.all() + assert not tt1.requires_grad.any() + + @choose_mark() + def test_grad(self): + tt1 = ttorch.tensor({ + 'a': [2, 3, 4.0], + 'b': {'x': [[5, 6], [7, 8.0]]} + }, requires_grad=True) + + mq = tt1.mean() ** 2 + mq.backward() + assert ttorch.isclose(tt1.grad, ttorch.tensor({ + 'a': [1.4286, 1.4286, 1.4286], + 'b': {'x': [[1.4286, 1.4286], + [1.4286, 1.4286]]}, + }), atol=1e-4).all() + + @choose_mark() + def test_detach(self): + tt1 = ttorch.tensor({ + 'a': [2, 3, 4.0], + 'b': {'x': [[5, 6], [7, 8.0]]} + }, requires_grad=True) + assert tt1.requires_grad.all() + + tt1r = tt1.detach() + assert tt1.requires_grad.all() + assert tt1r is not tt1 + assert not tt1r.requires_grad.any() + + @choose_mark() + def test_detach_(self): + tt1 = ttorch.tensor({ + 'a': [2, 3, 4.0], + 'b': {'x': [[5, 6], [7, 8.0]]} + }, requires_grad=True) + assert tt1.requires_grad.all() + + tt1r = tt1.detach_() + assert tt1r is tt1 + assert not tt1.requires_grad.any() diff --git a/treetensor/common/object.py b/treetensor/common/object.py index f16271f79bbf52ed1a3648cba30b7f9208f3b3e2..a493a7ef243ae90c33a808a5e968558e96933a5a 100644 --- a/treetensor/common/object.py +++ b/treetensor/common/object.py @@ -1,4 +1,9 @@ +import builtins + +from treevalue import method_treelize + from .trees import BaseTreeStruct, clsmeta +from .wrappers import ireduce __all__ = [ "Object", @@ -33,3 +38,41 @@ class Object(BaseTreeStruct, metaclass=clsmeta(_object, allow_dict=True)): └── c --> 233 """ super(BaseTreeStruct, self).__init__(data) + + @ireduce(builtins.all, piter=list) + @method_treelize() + def all(self): + """ + The values in this tree is all true or not. + + Examples:: + + >>> from treetensor.common import Object + >>> Object({'a': False, 'b': {'x': False}}).all() + False + >>> Object({'a': True, 'b': {'x': False}}).all() + False + >>> Object({'a': True, 'b': {'x': True}}).all() + True + + """ + return not not self + + @ireduce(builtins.any, piter=list) + @method_treelize() + def any(self): + """ + The values in this tree is not all False or yes. + + Examples:: + + >>> from treetensor.common import Object + >>> Object({'a': False, 'b': {'x': False}}).any() + False + >>> Object({'a': True, 'b': {'x': False}}).any() + True + >>> Object({'a': True, 'b': {'x': True}}).any() + True + + """ + return not not self diff --git a/treetensor/torch/funcs/__init__.py b/treetensor/torch/funcs/__init__.py index d2a53802a3f42c756e42c98978fa3904ae8d67ec..98b029bf8912dc4817e48b24762cae5c0b188509 100644 --- a/treetensor/torch/funcs/__init__.py +++ b/treetensor/torch/funcs/__init__.py @@ -1,5 +1,7 @@ import sys +from .autograd import * +from .autograd import __all__ as _autograd_all from .comparison import * from .comparison import __all__ as _comparison_all from .construct import * @@ -15,6 +17,7 @@ from .reduction import __all__ as _reduction_all from ...utils import module_autoremove __all__ = [ + *_autograd_all, *_comparison_all, *_construct_all, *_math_all, diff --git a/treetensor/torch/funcs/autograd.py b/treetensor/torch/funcs/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..e84314f67c1a11c7ebc7de5ceb1d65223ef7f08c --- /dev/null +++ b/treetensor/torch/funcs/autograd.py @@ -0,0 +1,83 @@ +import torch + +from .base import doc_from_base, func_treelize +from ...common import return_self + +__all__ = [ + 'detach', 'detach_' +] + + +# noinspection PyShadowingBuiltins +@doc_from_base() +@func_treelize() +def detach(input): + """ + Detach tensor from calculation graph. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> tt = ttorch.randn({ + ... 'a': (2, 3), + ... 'b': {'x': (3, 4)}, + ... }) + >>> tt.requires_grad_(True) + >>> tt + + ├── a --> tensor([[ 2.5262, 0.7398, 0.7966], + │ [ 1.3164, 1.2248, -2.2494]], requires_grad=True) + └── b --> + └── x --> tensor([[ 0.3578, 0.4611, -0.6668, 0.5356], + [-1.4392, -1.2899, -0.0394, 0.8457], + [ 0.4492, -0.5188, -0.2375, -1.2649]], requires_grad=True) + + >>> ttorch.detach(tt) + + ├── a --> tensor([[ 2.5262, 0.7398, 0.7966], + │ [ 1.3164, 1.2248, -2.2494]]) + └── b --> + └── x --> tensor([[ 0.3578, 0.4611, -0.6668, 0.5356], + [-1.4392, -1.2899, -0.0394, 0.8457], + [ 0.4492, -0.5188, -0.2375, -1.2649]]) + """ + return torch.detach(input) + + +# noinspection PyShadowingBuiltins +@doc_from_base() +@return_self +@func_treelize() +def detach_(input): + """ + In-place version of :func:`treetensor.torch.detach`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> tt = ttorch.randn({ + ... 'a': (2, 3), + ... 'b': {'x': (3, 4)}, + ... }) + >>> tt.requires_grad_(True) + >>> tt + + ├── a --> tensor([[-0.1631, -1.1573, 1.3109], + │ [ 2.7277, -0.0745, -1.2577]], requires_grad=True) + └── b --> + └── x --> tensor([[-0.5876, 0.9836, 1.9584, -0.1513], + [ 0.5369, -1.3986, 0.9361, 0.6765], + [ 0.6465, -0.2212, 1.5499, -1.2156]], requires_grad=True) + + >>> ttorch.detach_(tt) + + ├── a --> tensor([[-0.1631, -1.1573, 1.3109], + │ [ 2.7277, -0.0745, -1.2577]]) + └── b --> + └── x --> tensor([[-0.5876, 0.9836, 1.9584, -0.1513], + [ 0.5369, -1.3986, 0.9361, 0.6765], + [ 0.6465, -0.2212, 1.5499, -1.2156]]) + """ + return torch.detach_(input) diff --git a/treetensor/torch/funcs/operation.py b/treetensor/torch/funcs/operation.py index 6aef1fd41c724608a816c49712f1575cc56a7b36..456471030d55dce86532d3a319e626a220c3c395 100644 --- a/treetensor/torch/funcs/operation.py +++ b/treetensor/torch/funcs/operation.py @@ -117,10 +117,8 @@ def cat(tensors, *args, **kwargs): # noinspection PyShadowingNames @doc_from_base() -@post_process(lambda r: tuple(r)) @post_process(auto_tensor) -@func_treelize(return_type=TreeValue, rise=dict(template=[None])) -@post_process(lambda r: list(r)) +@func_treelize(return_type=TreeValue, rise=True) def split(tensor, split_size_or_sections, *args, **kwargs): """ Splits the tensor into chunks. Each chunk is a view of the original tensor. @@ -208,10 +206,8 @@ def split(tensor, split_size_or_sections, *args, **kwargs): # noinspection PyShadowingBuiltins @doc_from_base() -@post_process(lambda r: tuple(r)) @post_process(auto_tensor) -@func_treelize(return_type=TreeValue, rise=dict(template=[None])) -@post_process(lambda r: list(r)) +@func_treelize(return_type=TreeValue, rise=True) def chunk(input, chunks, *args, **kwargs): """ Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor. diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index a222a370fdf54f3e9486f7b85bb46f76f500f538..0f6d18a68da0f81d59ac74f1a6a6cd9e432c5e9a 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -174,6 +174,72 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.shape + @property + @method_treelize() + def grad(self): + """ + Return the grad data of the whole tree. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> tt = ttorch.randn({ + ... 'a': (2, 3), + ... 'b': {'x': (3, 4)}, + ... }) + >>> tt.requires_grad_(True) + >>> tt + + ├── a --> tensor([[-1.4375, 0.0988, 1.2198], + │ [-0.7627, -0.8797, -0.9299]], requires_grad=True) + └── b --> + └── x --> tensor([[ 0.2149, -0.5839, -0.6049, -0.9151], + [ 1.5381, -1.4386, 0.1831, 0.2018], + [-0.0725, -0.9062, -2.6212, 0.5929]], requires_grad=True) + >>> mq = tt.mean() ** 2 + >>> mq.backward() + >>> tt.grad + + ├── a --> tensor([[-0.0438, -0.0438, -0.0438], + │ [-0.0438, -0.0438, -0.0438]]) + └── b --> + └── x --> tensor([[-0.0438, -0.0438, -0.0438, -0.0438], + [-0.0438, -0.0438, -0.0438, -0.0438], + [-0.0438, -0.0438, -0.0438, -0.0438]]) + """ + return self.grad + + @property + @method_treelize(return_type=Object) + def requires_grad(self): + """ + Return the grad situation of current tree. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> tt = ttorch.randn({ + ... 'a': (2, 3), + ... 'b': {'x': (3, 4)}, + ... }) + >>> tt.requires_grad_(True) + >>> tt.requires_grad + + ├── a --> True + └── b --> + └── x --> True + + >>> tt.a.requires_grad_(False) + >>> tt.requires_grad + + ├── a --> False + └── b --> + └── x --> True + """ + return self.requires_grad + @doc_from_base() @return_self @method_treelize() @@ -181,9 +247,44 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ Change if autograd should record operations on this tensor: sets this tensor’s ``requires_grad`` attribute in-place. Returns this tensor. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> tt = ttorch.randn({ + ... 'a': (2, 3), + ... 'b': {'x': (3, 4)}, + ... }) + >>> tt.requires_grad_(True) + >>> tt + + ├── a --> tensor([[ 1.4754, 1.1167, 1.5431], + │ [-0.5816, 0.4746, 0.8392]], requires_grad=True) + └── b --> + └── x --> tensor([[ 0.3361, 0.8194, 0.1297, -0.5547], + [ 0.2531, -0.0637, 0.9822, 2.1618], + [ 2.0140, -0.0929, 0.9304, 1.5430]], requires_grad=True) """ return self.requires_grad_(requires_grad) + @doc_from_base() + @method_treelize() + def detach(self): + """ + See :func:`treetensor.torch.detach`. + """ + return self.detach() + + @doc_from_base() + @return_self + @method_treelize() + def detach_(self): + """ + In-place version of :meth:`Tensor.detach`. + """ + return self.detach_() + # noinspection PyShadowingBuiltins,PyUnusedLocal @post_reduce(torch.all) @method_treelize(return_type=Object) @@ -715,8 +816,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): @doc_from_base() @post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r)) - @method_treelize(return_type=TreeValue, rise=dict(template=[None])) - @post_process(lambda r: list(r)) + @method_treelize(return_type=TreeValue, rise=True) def split(self, split_size, *args, **kwargs): """ See :func:`treetensor.torch.split`. @@ -725,8 +825,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): @doc_from_base() @post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r)) - @method_treelize(return_type=TreeValue, rise=dict(template=[None])) - @post_process(lambda r: list(r)) + @method_treelize(return_type=TreeValue, rise=True) def chunk(self, chunks, *args, **kwargs): """ See :func:`treetensor.torch.chunk`.