提交 a1957a0a 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): complete autograd part

上级 fbf5cb25
......@@ -7,5 +7,5 @@ Object
-----------------
.. autoclass:: Object
:members: __init__
:members: __init__, all, any
......@@ -14,3 +14,13 @@ class TestCommonObject:
assert Object({'a': 1, 'b': 2}) == typetrans(TreeValue({
'a': 1, 'b': 2
}), Object)
def test_all(self):
assert not Object({'a': False, 'b': {'x': False}}).all()
assert not Object({'a': True, 'b': {'x': False}}).all()
assert Object({'a': True, 'b': {'x': True}}).all()
def test_any(self):
assert not Object({'a': False, 'b': {'x': False}}).any()
assert Object({'a': True, 'b': {'x': False}}).any()
assert Object({'a': True, 'b': {'x': True}}).any()
from .test_autograd import TestTorchFuncsAutograd
from .test_comparison import TestTorchFuncsComparison
from .test_construct import TestTorchFuncsConstruct
from .test_math import TestTorchFuncsMath
......
import treetensor.torch as ttorch
from .base import choose_mark
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchFuncsAutograd:
@choose_mark()
def test_detach(self):
tt1 = ttorch.tensor({
'a': [2, 3, 4.0],
'b': {'x': [[5, 6], [7, 8.0]]}
}, requires_grad=True)
assert tt1.requires_grad.all()
tt1r = ttorch.detach(tt1)
assert tt1.requires_grad.all()
assert tt1r is not tt1
assert not tt1r.requires_grad.any()
@choose_mark()
def test_detach_(self):
tt1 = ttorch.tensor({
'a': [2, 3, 4.0],
'b': {'x': [[5, 6], [7, 8.0]]}
}, requires_grad=True)
assert tt1.requires_grad.all()
tt1r = ttorch.detach_(tt1)
assert tt1r is tt1
assert not tt1.requires_grad.any()
from .test_autograd import TestTorchTensorAutograd
from .test_clazz import TestTorchTensorClass
from .test_comparison import TestTorchTensorComparison
from .test_math import TestTorchTensorMath
......
import treetensor.torch as ttorch
from .base import choose_mark
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchTensorAutograd:
@choose_mark()
def test_requires_grad(self):
tt1 = ttorch.tensor({
'a': [2, 3, 4.0],
'b': {'x': [[5, 6], [7, 8.0]]}
}, requires_grad=True)
assert tt1.requires_grad.all()
tt1.a.requires_grad_(False)
assert not tt1.requires_grad.all()
assert tt1.requires_grad.any()
tt1.b.x.requires_grad_(False)
assert not tt1.requires_grad.all()
assert not tt1.requires_grad.any()
@choose_mark()
def test_requires_grad_(self):
tt1 = ttorch.tensor({
'a': [2, 3, 4.0],
'b': {'x': [[5, 6], [7, 8.0]]}
})
assert not tt1.requires_grad.any()
tt1.requires_grad_(True)
assert tt1.requires_grad.all()
tt1.a.requires_grad_(False)
assert not tt1.requires_grad.all()
assert tt1.requires_grad.any()
tt1.b.x.requires_grad_(False)
assert not tt1.requires_grad.all()
assert not tt1.requires_grad.any()
@choose_mark()
def test_grad(self):
tt1 = ttorch.tensor({
'a': [2, 3, 4.0],
'b': {'x': [[5, 6], [7, 8.0]]}
}, requires_grad=True)
mq = tt1.mean() ** 2
mq.backward()
assert ttorch.isclose(tt1.grad, ttorch.tensor({
'a': [1.4286, 1.4286, 1.4286],
'b': {'x': [[1.4286, 1.4286],
[1.4286, 1.4286]]},
}), atol=1e-4).all()
@choose_mark()
def test_detach(self):
tt1 = ttorch.tensor({
'a': [2, 3, 4.0],
'b': {'x': [[5, 6], [7, 8.0]]}
}, requires_grad=True)
assert tt1.requires_grad.all()
tt1r = tt1.detach()
assert tt1.requires_grad.all()
assert tt1r is not tt1
assert not tt1r.requires_grad.any()
@choose_mark()
def test_detach_(self):
tt1 = ttorch.tensor({
'a': [2, 3, 4.0],
'b': {'x': [[5, 6], [7, 8.0]]}
}, requires_grad=True)
assert tt1.requires_grad.all()
tt1r = tt1.detach_()
assert tt1r is tt1
assert not tt1.requires_grad.any()
import builtins
from treevalue import method_treelize
from .trees import BaseTreeStruct, clsmeta
from .wrappers import ireduce
__all__ = [
"Object",
......@@ -33,3 +38,41 @@ class Object(BaseTreeStruct, metaclass=clsmeta(_object, allow_dict=True)):
└── c --> 233
"""
super(BaseTreeStruct, self).__init__(data)
@ireduce(builtins.all, piter=list)
@method_treelize()
def all(self):
"""
The values in this tree is all true or not.
Examples::
>>> from treetensor.common import Object
>>> Object({'a': False, 'b': {'x': False}}).all()
False
>>> Object({'a': True, 'b': {'x': False}}).all()
False
>>> Object({'a': True, 'b': {'x': True}}).all()
True
"""
return not not self
@ireduce(builtins.any, piter=list)
@method_treelize()
def any(self):
"""
The values in this tree is not all False or yes.
Examples::
>>> from treetensor.common import Object
>>> Object({'a': False, 'b': {'x': False}}).any()
False
>>> Object({'a': True, 'b': {'x': False}}).any()
True
>>> Object({'a': True, 'b': {'x': True}}).any()
True
"""
return not not self
import sys
from .autograd import *
from .autograd import __all__ as _autograd_all
from .comparison import *
from .comparison import __all__ as _comparison_all
from .construct import *
......@@ -15,6 +17,7 @@ from .reduction import __all__ as _reduction_all
from ...utils import module_autoremove
__all__ = [
*_autograd_all,
*_comparison_all,
*_construct_all,
*_math_all,
......
import torch
from .base import doc_from_base, func_treelize
from ...common import return_self
__all__ = [
'detach', 'detach_'
]
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def detach(input):
"""
Detach tensor from calculation graph.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... })
>>> tt.requires_grad_(True)
>>> tt
<Tensor 0x7f5881338eb8>
├── a --> tensor([[ 2.5262, 0.7398, 0.7966],
│ [ 1.3164, 1.2248, -2.2494]], requires_grad=True)
└── b --> <Tensor 0x7f5881338e10>
└── x --> tensor([[ 0.3578, 0.4611, -0.6668, 0.5356],
[-1.4392, -1.2899, -0.0394, 0.8457],
[ 0.4492, -0.5188, -0.2375, -1.2649]], requires_grad=True)
>>> ttorch.detach(tt)
<Tensor 0x7f588133a588>
├── a --> tensor([[ 2.5262, 0.7398, 0.7966],
│ [ 1.3164, 1.2248, -2.2494]])
└── b --> <Tensor 0x7f588133a4e0>
└── x --> tensor([[ 0.3578, 0.4611, -0.6668, 0.5356],
[-1.4392, -1.2899, -0.0394, 0.8457],
[ 0.4492, -0.5188, -0.2375, -1.2649]])
"""
return torch.detach(input)
# noinspection PyShadowingBuiltins
@doc_from_base()
@return_self
@func_treelize()
def detach_(input):
"""
In-place version of :func:`treetensor.torch.detach`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... })
>>> tt.requires_grad_(True)
>>> tt
<Tensor 0x7f588133aba8>
├── a --> tensor([[-0.1631, -1.1573, 1.3109],
│ [ 2.7277, -0.0745, -1.2577]], requires_grad=True)
└── b --> <Tensor 0x7f588133ab00>
└── x --> tensor([[-0.5876, 0.9836, 1.9584, -0.1513],
[ 0.5369, -1.3986, 0.9361, 0.6765],
[ 0.6465, -0.2212, 1.5499, -1.2156]], requires_grad=True)
>>> ttorch.detach_(tt)
<Tensor 0x7f588133aba8>
├── a --> tensor([[-0.1631, -1.1573, 1.3109],
│ [ 2.7277, -0.0745, -1.2577]])
└── b --> <Tensor 0x7f588133ab00>
└── x --> tensor([[-0.5876, 0.9836, 1.9584, -0.1513],
[ 0.5369, -1.3986, 0.9361, 0.6765],
[ 0.6465, -0.2212, 1.5499, -1.2156]])
"""
return torch.detach_(input)
......@@ -117,10 +117,8 @@ def cat(tensors, *args, **kwargs):
# noinspection PyShadowingNames
@doc_from_base()
@post_process(lambda r: tuple(r))
@post_process(auto_tensor)
@func_treelize(return_type=TreeValue, rise=dict(template=[None]))
@post_process(lambda r: list(r))
@func_treelize(return_type=TreeValue, rise=True)
def split(tensor, split_size_or_sections, *args, **kwargs):
"""
Splits the tensor into chunks. Each chunk is a view of the original tensor.
......@@ -208,10 +206,8 @@ def split(tensor, split_size_or_sections, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from_base()
@post_process(lambda r: tuple(r))
@post_process(auto_tensor)
@func_treelize(return_type=TreeValue, rise=dict(template=[None]))
@post_process(lambda r: list(r))
@func_treelize(return_type=TreeValue, rise=True)
def chunk(input, chunks, *args, **kwargs):
"""
Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.
......
......@@ -174,6 +174,72 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.shape
@property
@method_treelize()
def grad(self):
"""
Return the grad data of the whole tree.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... })
>>> tt.requires_grad_(True)
>>> tt
<Tensor 0x7feec3bcce80>
├── a --> tensor([[-1.4375, 0.0988, 1.2198],
│ [-0.7627, -0.8797, -0.9299]], requires_grad=True)
└── b --> <Tensor 0x7feec3bccdd8>
└── x --> tensor([[ 0.2149, -0.5839, -0.6049, -0.9151],
[ 1.5381, -1.4386, 0.1831, 0.2018],
[-0.0725, -0.9062, -2.6212, 0.5929]], requires_grad=True)
>>> mq = tt.mean() ** 2
>>> mq.backward()
>>> tt.grad
<Tensor 0x7feec3c0fa90>
├── a --> tensor([[-0.0438, -0.0438, -0.0438],
│ [-0.0438, -0.0438, -0.0438]])
└── b --> <Tensor 0x7feec3c0f9e8>
└── x --> tensor([[-0.0438, -0.0438, -0.0438, -0.0438],
[-0.0438, -0.0438, -0.0438, -0.0438],
[-0.0438, -0.0438, -0.0438, -0.0438]])
"""
return self.grad
@property
@method_treelize(return_type=Object)
def requires_grad(self):
"""
Return the grad situation of current tree.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... })
>>> tt.requires_grad_(True)
>>> tt.requires_grad
<Object 0x7feec3c229e8>
├── a --> True
└── b --> <Object 0x7feec3c22940>
└── x --> True
>>> tt.a.requires_grad_(False)
>>> tt.requires_grad
<Object 0x7feec3c0fa58>
├── a --> False
└── b --> <Object 0x7feec3c0f5f8>
└── x --> True
"""
return self.requires_grad
@doc_from_base()
@return_self
@method_treelize()
......@@ -181,9 +247,44 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
Change if autograd should record operations on this tensor:
sets this tensor’s ``requires_grad`` attribute in-place. Returns this tensor.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... })
>>> tt.requires_grad_(True)
>>> tt
<Tensor 0x7feec3c22240>
├── a --> tensor([[ 1.4754, 1.1167, 1.5431],
│ [-0.5816, 0.4746, 0.8392]], requires_grad=True)
└── b --> <Tensor 0x7feec3c22128>
└── x --> tensor([[ 0.3361, 0.8194, 0.1297, -0.5547],
[ 0.2531, -0.0637, 0.9822, 2.1618],
[ 2.0140, -0.0929, 0.9304, 1.5430]], requires_grad=True)
"""
return self.requires_grad_(requires_grad)
@doc_from_base()
@method_treelize()
def detach(self):
"""
See :func:`treetensor.torch.detach`.
"""
return self.detach()
@doc_from_base()
@return_self
@method_treelize()
def detach_(self):
"""
In-place version of :meth:`Tensor.detach`.
"""
return self.detach_()
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.all)
@method_treelize(return_type=Object)
......@@ -715,8 +816,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
@doc_from_base()
@post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r))
@method_treelize(return_type=TreeValue, rise=dict(template=[None]))
@post_process(lambda r: list(r))
@method_treelize(return_type=TreeValue, rise=True)
def split(self, split_size, *args, **kwargs):
"""
See :func:`treetensor.torch.split`.
......@@ -725,8 +825,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
@doc_from_base()
@post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r))
@method_treelize(return_type=TreeValue, rise=dict(template=[None]))
@post_process(lambda r: list(r))
@method_treelize(return_type=TreeValue, rise=True)
def chunk(self, chunks, *args, **kwargs):
"""
See :func:`treetensor.torch.chunk`.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册