From a1957a0ad311fd90dd7b420d14c6f94078ee307f Mon Sep 17 00:00:00 2001 From: HansBug Date: Tue, 28 Sep 2021 22:10:15 +0800 Subject: [PATCH] dev(hansbug): complete autograd part --- docs/source/api_doc/common/object.rst | 2 +- test/common/test_object.py | 10 +++ test/torch/funcs/__init__.py | 1 + test/torch/funcs/test_autograd.py | 31 ++++++++ test/torch/tensor/__init__.py | 1 + test/torch/tensor/test_autograd.py | 80 +++++++++++++++++++ treetensor/common/object.py | 43 +++++++++++ treetensor/torch/funcs/__init__.py | 3 + treetensor/torch/funcs/autograd.py | 83 ++++++++++++++++++++ treetensor/torch/funcs/operation.py | 8 +- treetensor/torch/tensor.py | 107 +++++++++++++++++++++++++- 11 files changed, 358 insertions(+), 11 deletions(-) create mode 100644 test/torch/funcs/test_autograd.py create mode 100644 test/torch/tensor/test_autograd.py create mode 100644 treetensor/torch/funcs/autograd.py diff --git a/docs/source/api_doc/common/object.rst b/docs/source/api_doc/common/object.rst index 1f3e0a8a2..7c6d3a738 100644 --- a/docs/source/api_doc/common/object.rst +++ b/docs/source/api_doc/common/object.rst @@ -7,5 +7,5 @@ Object ----------------- .. autoclass:: Object - :members: __init__ + :members: __init__, all, any diff --git a/test/common/test_object.py b/test/common/test_object.py index c35d6aeb8..9a8489e91 100644 --- a/test/common/test_object.py +++ b/test/common/test_object.py @@ -14,3 +14,13 @@ class TestCommonObject: assert Object({'a': 1, 'b': 2}) == typetrans(TreeValue({ 'a': 1, 'b': 2 }), Object) + + def test_all(self): + assert not Object({'a': False, 'b': {'x': False}}).all() + assert not Object({'a': True, 'b': {'x': False}}).all() + assert Object({'a': True, 'b': {'x': True}}).all() + + def test_any(self): + assert not Object({'a': False, 'b': {'x': False}}).any() + assert Object({'a': True, 'b': {'x': False}}).any() + assert Object({'a': True, 'b': {'x': True}}).any() diff --git a/test/torch/funcs/__init__.py b/test/torch/funcs/__init__.py index 7dd0fdff5..8c7d59bf9 100644 --- a/test/torch/funcs/__init__.py +++ b/test/torch/funcs/__init__.py @@ -1,3 +1,4 @@ +from .test_autograd import TestTorchFuncsAutograd from .test_comparison import TestTorchFuncsComparison from .test_construct import TestTorchFuncsConstruct from .test_math import TestTorchFuncsMath diff --git a/test/torch/funcs/test_autograd.py b/test/torch/funcs/test_autograd.py new file mode 100644 index 000000000..f56fa6093 --- /dev/null +++ b/test/torch/funcs/test_autograd.py @@ -0,0 +1,31 @@ +import treetensor.torch as ttorch +from .base import choose_mark + + +# noinspection DuplicatedCode,PyUnresolvedReferences +class TestTorchFuncsAutograd: + + @choose_mark() + def test_detach(self): + tt1 = ttorch.tensor({ + 'a': [2, 3, 4.0], + 'b': {'x': [[5, 6], [7, 8.0]]} + }, requires_grad=True) + assert tt1.requires_grad.all() + + tt1r = ttorch.detach(tt1) + assert tt1.requires_grad.all() + assert tt1r is not tt1 + assert not tt1r.requires_grad.any() + + @choose_mark() + def test_detach_(self): + tt1 = ttorch.tensor({ + 'a': [2, 3, 4.0], + 'b': {'x': [[5, 6], [7, 8.0]]} + }, requires_grad=True) + assert tt1.requires_grad.all() + + tt1r = ttorch.detach_(tt1) + assert tt1r is tt1 + assert not tt1.requires_grad.any() diff --git a/test/torch/tensor/__init__.py b/test/torch/tensor/__init__.py index 3ed12ec2a..7f603a2bc 100644 --- a/test/torch/tensor/__init__.py +++ b/test/torch/tensor/__init__.py @@ -1,3 +1,4 @@ +from .test_autograd import TestTorchTensorAutograd from .test_clazz import TestTorchTensorClass from .test_comparison import TestTorchTensorComparison from .test_math import TestTorchTensorMath diff --git a/test/torch/tensor/test_autograd.py b/test/torch/tensor/test_autograd.py new file mode 100644 index 000000000..73257d101 --- /dev/null +++ b/test/torch/tensor/test_autograd.py @@ -0,0 +1,80 @@ +import treetensor.torch as ttorch +from .base import choose_mark + + +# noinspection DuplicatedCode,PyUnresolvedReferences +class TestTorchTensorAutograd: + @choose_mark() + def test_requires_grad(self): + tt1 = ttorch.tensor({ + 'a': [2, 3, 4.0], + 'b': {'x': [[5, 6], [7, 8.0]]} + }, requires_grad=True) + assert tt1.requires_grad.all() + + tt1.a.requires_grad_(False) + assert not tt1.requires_grad.all() + assert tt1.requires_grad.any() + + tt1.b.x.requires_grad_(False) + assert not tt1.requires_grad.all() + assert not tt1.requires_grad.any() + + @choose_mark() + def test_requires_grad_(self): + tt1 = ttorch.tensor({ + 'a': [2, 3, 4.0], + 'b': {'x': [[5, 6], [7, 8.0]]} + }) + assert not tt1.requires_grad.any() + + tt1.requires_grad_(True) + assert tt1.requires_grad.all() + + tt1.a.requires_grad_(False) + assert not tt1.requires_grad.all() + assert tt1.requires_grad.any() + + tt1.b.x.requires_grad_(False) + assert not tt1.requires_grad.all() + assert not tt1.requires_grad.any() + + @choose_mark() + def test_grad(self): + tt1 = ttorch.tensor({ + 'a': [2, 3, 4.0], + 'b': {'x': [[5, 6], [7, 8.0]]} + }, requires_grad=True) + + mq = tt1.mean() ** 2 + mq.backward() + assert ttorch.isclose(tt1.grad, ttorch.tensor({ + 'a': [1.4286, 1.4286, 1.4286], + 'b': {'x': [[1.4286, 1.4286], + [1.4286, 1.4286]]}, + }), atol=1e-4).all() + + @choose_mark() + def test_detach(self): + tt1 = ttorch.tensor({ + 'a': [2, 3, 4.0], + 'b': {'x': [[5, 6], [7, 8.0]]} + }, requires_grad=True) + assert tt1.requires_grad.all() + + tt1r = tt1.detach() + assert tt1.requires_grad.all() + assert tt1r is not tt1 + assert not tt1r.requires_grad.any() + + @choose_mark() + def test_detach_(self): + tt1 = ttorch.tensor({ + 'a': [2, 3, 4.0], + 'b': {'x': [[5, 6], [7, 8.0]]} + }, requires_grad=True) + assert tt1.requires_grad.all() + + tt1r = tt1.detach_() + assert tt1r is tt1 + assert not tt1.requires_grad.any() diff --git a/treetensor/common/object.py b/treetensor/common/object.py index f16271f79..a493a7ef2 100644 --- a/treetensor/common/object.py +++ b/treetensor/common/object.py @@ -1,4 +1,9 @@ +import builtins + +from treevalue import method_treelize + from .trees import BaseTreeStruct, clsmeta +from .wrappers import ireduce __all__ = [ "Object", @@ -33,3 +38,41 @@ class Object(BaseTreeStruct, metaclass=clsmeta(_object, allow_dict=True)): └── c --> 233 """ super(BaseTreeStruct, self).__init__(data) + + @ireduce(builtins.all, piter=list) + @method_treelize() + def all(self): + """ + The values in this tree is all true or not. + + Examples:: + + >>> from treetensor.common import Object + >>> Object({'a': False, 'b': {'x': False}}).all() + False + >>> Object({'a': True, 'b': {'x': False}}).all() + False + >>> Object({'a': True, 'b': {'x': True}}).all() + True + + """ + return not not self + + @ireduce(builtins.any, piter=list) + @method_treelize() + def any(self): + """ + The values in this tree is not all False or yes. + + Examples:: + + >>> from treetensor.common import Object + >>> Object({'a': False, 'b': {'x': False}}).any() + False + >>> Object({'a': True, 'b': {'x': False}}).any() + True + >>> Object({'a': True, 'b': {'x': True}}).any() + True + + """ + return not not self diff --git a/treetensor/torch/funcs/__init__.py b/treetensor/torch/funcs/__init__.py index d2a53802a..98b029bf8 100644 --- a/treetensor/torch/funcs/__init__.py +++ b/treetensor/torch/funcs/__init__.py @@ -1,5 +1,7 @@ import sys +from .autograd import * +from .autograd import __all__ as _autograd_all from .comparison import * from .comparison import __all__ as _comparison_all from .construct import * @@ -15,6 +17,7 @@ from .reduction import __all__ as _reduction_all from ...utils import module_autoremove __all__ = [ + *_autograd_all, *_comparison_all, *_construct_all, *_math_all, diff --git a/treetensor/torch/funcs/autograd.py b/treetensor/torch/funcs/autograd.py new file mode 100644 index 000000000..e84314f67 --- /dev/null +++ b/treetensor/torch/funcs/autograd.py @@ -0,0 +1,83 @@ +import torch + +from .base import doc_from_base, func_treelize +from ...common import return_self + +__all__ = [ + 'detach', 'detach_' +] + + +# noinspection PyShadowingBuiltins +@doc_from_base() +@func_treelize() +def detach(input): + """ + Detach tensor from calculation graph. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> tt = ttorch.randn({ + ... 'a': (2, 3), + ... 'b': {'x': (3, 4)}, + ... }) + >>> tt.requires_grad_(True) + >>> tt + + ├── a --> tensor([[ 2.5262, 0.7398, 0.7966], + │ [ 1.3164, 1.2248, -2.2494]], requires_grad=True) + └── b --> + └── x --> tensor([[ 0.3578, 0.4611, -0.6668, 0.5356], + [-1.4392, -1.2899, -0.0394, 0.8457], + [ 0.4492, -0.5188, -0.2375, -1.2649]], requires_grad=True) + + >>> ttorch.detach(tt) + + ├── a --> tensor([[ 2.5262, 0.7398, 0.7966], + │ [ 1.3164, 1.2248, -2.2494]]) + └── b --> + └── x --> tensor([[ 0.3578, 0.4611, -0.6668, 0.5356], + [-1.4392, -1.2899, -0.0394, 0.8457], + [ 0.4492, -0.5188, -0.2375, -1.2649]]) + """ + return torch.detach(input) + + +# noinspection PyShadowingBuiltins +@doc_from_base() +@return_self +@func_treelize() +def detach_(input): + """ + In-place version of :func:`treetensor.torch.detach`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> tt = ttorch.randn({ + ... 'a': (2, 3), + ... 'b': {'x': (3, 4)}, + ... }) + >>> tt.requires_grad_(True) + >>> tt + + ├── a --> tensor([[-0.1631, -1.1573, 1.3109], + │ [ 2.7277, -0.0745, -1.2577]], requires_grad=True) + └── b --> + └── x --> tensor([[-0.5876, 0.9836, 1.9584, -0.1513], + [ 0.5369, -1.3986, 0.9361, 0.6765], + [ 0.6465, -0.2212, 1.5499, -1.2156]], requires_grad=True) + + >>> ttorch.detach_(tt) + + ├── a --> tensor([[-0.1631, -1.1573, 1.3109], + │ [ 2.7277, -0.0745, -1.2577]]) + └── b --> + └── x --> tensor([[-0.5876, 0.9836, 1.9584, -0.1513], + [ 0.5369, -1.3986, 0.9361, 0.6765], + [ 0.6465, -0.2212, 1.5499, -1.2156]]) + """ + return torch.detach_(input) diff --git a/treetensor/torch/funcs/operation.py b/treetensor/torch/funcs/operation.py index 6aef1fd41..456471030 100644 --- a/treetensor/torch/funcs/operation.py +++ b/treetensor/torch/funcs/operation.py @@ -117,10 +117,8 @@ def cat(tensors, *args, **kwargs): # noinspection PyShadowingNames @doc_from_base() -@post_process(lambda r: tuple(r)) @post_process(auto_tensor) -@func_treelize(return_type=TreeValue, rise=dict(template=[None])) -@post_process(lambda r: list(r)) +@func_treelize(return_type=TreeValue, rise=True) def split(tensor, split_size_or_sections, *args, **kwargs): """ Splits the tensor into chunks. Each chunk is a view of the original tensor. @@ -208,10 +206,8 @@ def split(tensor, split_size_or_sections, *args, **kwargs): # noinspection PyShadowingBuiltins @doc_from_base() -@post_process(lambda r: tuple(r)) @post_process(auto_tensor) -@func_treelize(return_type=TreeValue, rise=dict(template=[None])) -@post_process(lambda r: list(r)) +@func_treelize(return_type=TreeValue, rise=True) def chunk(input, chunks, *args, **kwargs): """ Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor. diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index a222a370f..0f6d18a68 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -174,6 +174,72 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.shape + @property + @method_treelize() + def grad(self): + """ + Return the grad data of the whole tree. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> tt = ttorch.randn({ + ... 'a': (2, 3), + ... 'b': {'x': (3, 4)}, + ... }) + >>> tt.requires_grad_(True) + >>> tt + + ├── a --> tensor([[-1.4375, 0.0988, 1.2198], + │ [-0.7627, -0.8797, -0.9299]], requires_grad=True) + └── b --> + └── x --> tensor([[ 0.2149, -0.5839, -0.6049, -0.9151], + [ 1.5381, -1.4386, 0.1831, 0.2018], + [-0.0725, -0.9062, -2.6212, 0.5929]], requires_grad=True) + >>> mq = tt.mean() ** 2 + >>> mq.backward() + >>> tt.grad + + ├── a --> tensor([[-0.0438, -0.0438, -0.0438], + │ [-0.0438, -0.0438, -0.0438]]) + └── b --> + └── x --> tensor([[-0.0438, -0.0438, -0.0438, -0.0438], + [-0.0438, -0.0438, -0.0438, -0.0438], + [-0.0438, -0.0438, -0.0438, -0.0438]]) + """ + return self.grad + + @property + @method_treelize(return_type=Object) + def requires_grad(self): + """ + Return the grad situation of current tree. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> tt = ttorch.randn({ + ... 'a': (2, 3), + ... 'b': {'x': (3, 4)}, + ... }) + >>> tt.requires_grad_(True) + >>> tt.requires_grad + + ├── a --> True + └── b --> + └── x --> True + + >>> tt.a.requires_grad_(False) + >>> tt.requires_grad + + ├── a --> False + └── b --> + └── x --> True + """ + return self.requires_grad + @doc_from_base() @return_self @method_treelize() @@ -181,9 +247,44 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ Change if autograd should record operations on this tensor: sets this tensor’s ``requires_grad`` attribute in-place. Returns this tensor. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> tt = ttorch.randn({ + ... 'a': (2, 3), + ... 'b': {'x': (3, 4)}, + ... }) + >>> tt.requires_grad_(True) + >>> tt + + ├── a --> tensor([[ 1.4754, 1.1167, 1.5431], + │ [-0.5816, 0.4746, 0.8392]], requires_grad=True) + └── b --> + └── x --> tensor([[ 0.3361, 0.8194, 0.1297, -0.5547], + [ 0.2531, -0.0637, 0.9822, 2.1618], + [ 2.0140, -0.0929, 0.9304, 1.5430]], requires_grad=True) """ return self.requires_grad_(requires_grad) + @doc_from_base() + @method_treelize() + def detach(self): + """ + See :func:`treetensor.torch.detach`. + """ + return self.detach() + + @doc_from_base() + @return_self + @method_treelize() + def detach_(self): + """ + In-place version of :meth:`Tensor.detach`. + """ + return self.detach_() + # noinspection PyShadowingBuiltins,PyUnusedLocal @post_reduce(torch.all) @method_treelize(return_type=Object) @@ -715,8 +816,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): @doc_from_base() @post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r)) - @method_treelize(return_type=TreeValue, rise=dict(template=[None])) - @post_process(lambda r: list(r)) + @method_treelize(return_type=TreeValue, rise=True) def split(self, split_size, *args, **kwargs): """ See :func:`treetensor.torch.split`. @@ -725,8 +825,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): @doc_from_base() @post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r)) - @method_treelize(return_type=TreeValue, rise=dict(template=[None])) - @post_process(lambda r: list(r)) + @method_treelize(return_type=TreeValue, rise=True) def chunk(self, chunks, *args, **kwargs): """ See :func:`treetensor.torch.chunk`. -- GitLab