提交 42904f2e 编写于 作者: HansBug's avatar HansBug 😆

doc, test, dev(hansbug): complete code, documentation and unittest for clone, mm, matmul and dot

上级 f6e128a6
......@@ -4,7 +4,7 @@ import torch
import treetensor.torch as ttorch
# noinspection DuplicatedCode
# noinspection DuplicatedCode,PyUnresolvedReferences
@pytest.mark.unittest
class TestTorchFuncs:
def test_tensor(self):
......@@ -567,3 +567,80 @@ class TestTorchFuncs:
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})) == torch.tensor(11.0)
def test_clone(self):
t1 = ttorch.clone(torch.tensor([1.0, 2.0, 1.5]))
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([1.0, 2.0, 1.5])).all()
t2 = ttorch.clone(ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
}))
assert (t2 == ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})).all()
def test_dot(self):
t1 = ttorch.dot(torch.tensor([1, 2]), torch.tensor([2, 3]))
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 8
t2 = ttorch.dot(
ttorch.tensor({
'a': [1, 2, 3],
'b': {'x': [3, 4]},
}),
ttorch.tensor({
'a': [5, 6, 7],
'b': {'x': [1, 2]},
})
)
assert (t2 == ttorch.tensor({'a': 38, 'b': {'x': 11}})).all()
def test_matmul(self):
t1 = ttorch.matmul(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[5, 6], [7, 2]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([[19, 10], [43, 26]])).all()
t2 = ttorch.matmul(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [3, 4, 5, 6]},
}),
ttorch.tensor({
'a': [[5, 6], [7, 2]],
'b': {'x': [4, 3, 2, 1]},
}),
)
assert (t2 == ttorch.tensor({
'a': [[19, 10], [43, 26]],
'b': {'x': 40}
})).all()
def test_mm(self):
t1 = ttorch.mm(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[5, 6], [7, 2]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([[19, 10], [43, 26]])).all()
t2 = ttorch.mm(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [[3, 4, 5], [6, 7, 8]]},
}),
ttorch.tensor({
'a': [[5, 6], [7, 2]],
'b': {'x': [[6, 5], [4, 3], [2, 1]]},
}),
)
assert (t2 == ttorch.tensor({
'a': [[19, 10], [43, 26]],
'b': {'x': [[44, 32], [80, 59]]},
})).all()
......@@ -10,6 +10,7 @@ from treetensor.common import Object
_all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y)
# noinspection PyUnresolvedReferences
@pytest.mark.unittest
class TestTorchTensor:
_DEMO_1 = ttorch.Tensor({
......@@ -208,3 +209,75 @@ class TestTorchTensor:
'a': [True, False],
'b': {'x': [[True, True], [False, True]]}
})).all()
def test_clone(self):
t1 = ttorch.tensor([1.0, 2.0, 1.5]).clone()
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([1.0, 2.0, 1.5])).all()
t2 = ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
}).clone()
assert (t2 == ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})).all()
def test_dot(self):
t1 = torch.tensor([1, 2]).dot(torch.tensor([2, 3]))
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 8
t2 = ttorch.tensor({
'a': [1, 2, 3],
'b': {'x': [3, 4]},
}).dot(
ttorch.tensor({
'a': [5, 6, 7],
'b': {'x': [1, 2]},
})
)
assert (t2 == ttorch.tensor({'a': 38, 'b': {'x': 11}})).all()
def test_matmul(self):
t1 = torch.tensor([[1, 2], [3, 4]]).matmul(
torch.tensor([[5, 6], [7, 2]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([[19, 10], [43, 26]])).all()
t2 = ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [3, 4, 5, 6]},
}).matmul(
ttorch.tensor({
'a': [[5, 6], [7, 2]],
'b': {'x': [4, 3, 2, 1]},
}),
)
assert (t2 == ttorch.tensor({
'a': [[19, 10], [43, 26]],
'b': {'x': 40}
})).all()
def test_mm(self):
t1 = torch.tensor([[1, 2], [3, 4]]).mm(
torch.tensor([[5, 6], [7, 2]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([[19, 10], [43, 26]])).all()
t2 = ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [[3, 4, 5], [6, 7, 8]]},
}).mm(
ttorch.tensor({
'a': [[5, 6], [7, 2]],
'b': {'x': [[6, 5], [4, 3], [2, 1]]},
}),
)
assert (t2 == ttorch.tensor({
'a': [[19, 10], [43, 26]],
'b': {'x': [[44, 32], [80, 59]]},
})).all()
......@@ -20,7 +20,8 @@ __all__ = [
'all', 'any',
'min', 'max', 'sum',
'eq', 'ne', 'lt', 'le', 'gt', 'ge',
'equal', 'tensor',
'equal', 'tensor', 'clone',
'dot', 'matmul', 'mm',
]
func_treelize = post_process(post_process(args_mapping(
......@@ -816,3 +817,140 @@ def tensor(*args, **kwargs):
[False, True]])
"""
return torch.tensor(*args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from(torch.clone)
@func_treelize()
def clone(input, *args, **kwargs):
"""
In ``treetensor``, you can create a clone of the original tree with :func:`treetensor.torch.clone`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.clone(torch.tensor([[1, 2], [3, 4]]))
tensor([[1, 2],
[3, 4]])
>>> ttorch.clone(ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': {'x': [[5], [6], [7]]},
... }))
<Tensor 0x7f2a820ba5e0>
├── a --> tensor([[1, 2],
│ [3, 4]])
└── b --> <Tensor 0x7f2a820aaf70>
└── x --> tensor([[5],
[6],
[7]])
"""
return torch.clone(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from(torch.dot)
@func_treelize()
def dot(input, other, *args, **kwargs):
"""
In ``treetensor``, you can get the dot product of 2 tree tensors with :func:`treetensor.torch.dot`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.dot(torch.tensor([1, 2]), torch.tensor([2, 3]))
tensor(8)
>>> ttorch.dot(
... ttorch.tensor({
... 'a': [1, 2, 3],
... 'b': {'x': [3, 4]},
... }),
... ttorch.tensor({
... 'a': [5, 6, 7],
... 'b': {'x': [1, 2]},
... })
... )
<Tensor 0x7feac55bde50>
├── a --> tensor(38)
└── b --> <Tensor 0x7feac55c9250>
└── x --> tensor(11)
"""
return torch.dot(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from(torch.matmul)
@func_treelize()
def matmul(input, other, *args, **kwargs):
"""
In ``treetensor``, you can create a matrix product with :func:`treetensor.torch.matmul`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.matmul(
... torch.tensor([[1, 2], [3, 4]]),
... torch.tensor([[5, 6], [7, 2]]),
... )
tensor([[19, 10],
[43, 26]])
>>> ttorch.matmul(
... ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': {'x': [3, 4, 5, 6]},
... }),
... ttorch.tensor({
... 'a': [[5, 6], [7, 2]],
... 'b': {'x': [4, 3, 2, 1]},
... }),
... )
<Tensor 0x7f2e74883f40>
├── a --> tensor([[19, 10],
│ [43, 26]])
└── b --> <Tensor 0x7f2e74886430>
└── x --> tensor(40)
"""
return torch.matmul(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from(torch.mm)
@func_treelize()
def mm(input, mat2, *args, **kwargs):
"""
In ``treetensor``, you can create a matrix multiplication with :func:`treetensor.torch.mm`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.mm(
... torch.tensor([[1, 2], [3, 4]]),
... torch.tensor([[5, 6], [7, 2]]),
... )
tensor([[19, 10],
[43, 26]])
>>> ttorch.mm(
... ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': {'x': [[3, 4, 5], [6, 7, 8]]},
... }),
... ttorch.tensor({
... 'a': [[5, 6], [7, 2]],
... 'b': {'x': [[6, 5], [4, 3], [2, 1]]},
... }),
... )
<Tensor 0x7f2e7489f340>
├── a --> tensor([[19, 10],
│ [43, 26]])
└── b --> <Tensor 0x7f2e74896e50>
└── x --> tensor([[44, 32],
[80, 59]])
"""
return torch.mm(input, mat2, *args, **kwargs)
......@@ -261,3 +261,35 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
See :func:`treetensor.torch.ge`.
"""
return self >= other
@doc_from(torch.Tensor.clone)
@method_treelize()
def clone(self, *args, **kwargs):
"""
See :func:`treetensor.torch.clone`.
"""
return self.clone(*args, **kwargs)
@doc_from(torch.Tensor.dot)
@method_treelize()
def dot(self, other, *args, **kwargs):
"""
See :func:`treetensor.torch.dot`.
"""
return self.dot(other, *args, **kwargs)
@doc_from(torch.Tensor.mm)
@method_treelize()
def mm(self, mat2, *args, **kwargs):
"""
See :func:`treetensor.torch.mm`.
"""
return self.mm(mat2, *args, **kwargs)
@doc_from(torch.Tensor.matmul)
@method_treelize()
def matmul(self, tensor2, *args, **kwargs):
"""
See :func:`treetensor.torch.matmul`.
"""
return self.matmul(tensor2, *args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册