提交 596fa782 编写于 作者: HansBug's avatar HansBug 😆

refactor(hansbug): Refactor the huge source files in treetensor and test

上级 8217663f
from .test_funcs import TestTorchFuncs
from .funcs import *
from .tensor import *
from .test_size import TestTorchSize
from .test_tensor import TestTorchTensor
from .test_comparison import TestTorchFuncsComparison
from .test_construct import TestTorchFuncsConstruct
from .test_math import TestTorchFuncsMath
from .test_matrix import TestTorchFuncsMatrix
from .test_operation import TestTorchFuncsOperation
from .test_reduction import TestTorchFuncsReduction
import treetensor.torch as ttorch
from treetensor.utils import replaceable_partial
from ...tests import choose_mark_with_existence_check
choose_mark = replaceable_partial(choose_mark_with_existence_check, base=ttorch)
import math
import torch
import treetensor.torch as ttorch
from .base import choose_mark
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchFuncsComparison:
@choose_mark()
def test_equal(self):
p1 = ttorch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3]))
assert isinstance(p1, bool)
assert p1
p2 = ttorch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 4]))
assert isinstance(p2, bool)
assert not p2
p3 = ttorch.equal({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}, ({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}))
assert isinstance(p3, bool)
assert p3
p4 = ttorch.equal({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}, ({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 5]),
}))
assert isinstance(p4, bool)
assert not p4
@choose_mark()
def test_eq(self):
assert ttorch.eq(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3])).all()
assert not ttorch.eq(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 2])).all()
assert ttorch.eq(torch.tensor([1, 1, 1]), 1).all()
assert not ttorch.eq(torch.tensor([1, 1, 2]), 1).all()
assert ttorch.eq({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}, ({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
})).all()
assert not ttorch.eq({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}, ({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 5]),
})).all()
@choose_mark()
def test_ne(self):
assert (ttorch.ne(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[False, True],
[True, False]])).all()
assert (ttorch.ne(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[False, True], [True, False]],
'b': [True, True, False],
})).all()
@choose_mark()
def test_lt(self):
assert (ttorch.lt(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[False, False],
[True, False]])).all()
assert (ttorch.lt(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[False, False], [True, False]],
'b': [True, False, False],
})).all()
@choose_mark()
def test_le(self):
assert (ttorch.le(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[True, False],
[True, True]])).all()
assert (ttorch.le(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[True, False], [True, True]],
'b': [True, False, True],
})).all()
@choose_mark()
def test_gt(self):
assert (ttorch.gt(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[False, True],
[False, False]])).all()
assert (ttorch.gt(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[False, True], [False, False]],
'b': [False, True, False],
})).all()
@choose_mark()
def test_ge(self):
assert (ttorch.ge(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[True, True],
[False, True]])).all()
assert (ttorch.ge(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[True, True], [False, True]],
'b': [False, True, True],
})).all()
@choose_mark()
def test_isfinite(self):
t1 = ttorch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([True, False, True, False, False])).all()
t2 = ttorch.isfinite(ttorch.tensor({
'a': [1, float('inf'), 2, float('-inf'), float('nan')],
'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]}
}))
assert (t2 == ttorch.tensor({
'a': [True, False, True, False, False],
'b': {'x': [[True, False, True], [False, True, False]]},
}))
@choose_mark()
def test_isinf(self):
t1 = ttorch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([False, True, False, True, False])).all()
t2 = ttorch.isinf(ttorch.tensor({
'a': [1, float('inf'), 2, float('-inf'), float('nan')],
'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]}
}))
assert (t2 == ttorch.tensor({
'a': [False, True, False, True, False],
'b': {'x': [[False, True, False], [True, False, False]]},
}))
@choose_mark()
def test_isnan(self):
t1 = ttorch.isnan(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([False, False, False, False, True])).all()
t2 = ttorch.isnan(ttorch.tensor({
'a': [1, float('inf'), 2, float('-inf'), float('nan')],
'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]}
}))
assert (t2 == ttorch.tensor({
'a': [False, False, False, False, True],
'b': {'x': [[False, False, False], [False, False, True]]},
})).all()
@choose_mark()
def test_isclose(self):
t1 = ttorch.isclose(
ttorch.tensor((1., 2, 3)),
ttorch.tensor((1 + 1e-10, 3, 4))
)
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([True, False, False])).all()
t2 = ttorch.isclose(
ttorch.tensor({
'a': [1., 2, 3],
'b': {'x': [[float('inf'), 4, 1e20],
[-math.inf, 2.2943, 9483.32]]},
}),
ttorch.tensor({
'a': [1 + 1e-10, 3, 4],
'b': {'x': [[math.inf, 6, 1e20 + 1],
[-float('inf'), 2.294300000001, 9484.32]]},
}),
)
assert (t2 == ttorch.tensor({
'a': [True, False, False],
'b': {'x': [[True, False, True],
[True, True, False]]},
})).all()
import torch
import treetensor.torch as ttorch
from .base import choose_mark
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchFuncsConstruct:
@choose_mark()
def test_tensor(self):
t1 = ttorch.tensor(True)
assert isinstance(t1, torch.Tensor)
assert t1
t2 = ttorch.tensor([[1, 2, 3], [4, 5, 6]])
assert isinstance(t2, torch.Tensor)
assert (t2 == torch.tensor([[1, 2, 3], [4, 5, 6]])).all()
t3 = ttorch.tensor({
'a': [1, 2],
'b': [[3, 4], [5, 6.2]],
'x': {
'c': True,
'd': [False, True],
}
})
assert isinstance(t3, ttorch.Tensor)
assert (t3 == ttorch.Tensor({
'a': torch.tensor([1, 2]),
'b': torch.tensor([[3, 4], [5, 6.2]]),
'x': {
'c': torch.tensor(True),
'd': torch.tensor([False, True]),
}
})).all()
@choose_mark()
def test_clone(self):
t1 = ttorch.clone(torch.tensor([1.0, 2.0, 1.5]))
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([1.0, 2.0, 1.5])).all()
t2 = ttorch.clone(ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
}))
assert (t2 == ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})).all()
@choose_mark()
def test_zeros(self):
assert ttorch.all(ttorch.zeros(2, 3) == torch.zeros(2, 3))
assert ttorch.all(ttorch.zeros({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}) == ttorch.Tensor({
'a': torch.zeros(2, 3),
'b': torch.zeros(5, 6),
'x': {
'c': torch.zeros(2, 3, 4),
}
}))
@choose_mark()
def test_zeros_like(self):
assert ttorch.all(
ttorch.zeros_like(torch.tensor([[1, 2, 3], [4, 5, 6]])) ==
torch.tensor([[0, 0, 0], [0, 0, 0]]),
)
assert ttorch.all(
ttorch.zeros_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}) == ttorch.Tensor({
'a': torch.tensor([[0, 0, 0], [0, 0, 0]]),
'b': torch.tensor([0, 0, 0, 0]),
'x': {
'c': torch.tensor([0, 0, 0]),
'd': torch.tensor([[[0, 0]]]),
}
})
)
@choose_mark()
def test_ones(self):
assert ttorch.all(ttorch.ones(2, 3) == torch.ones(2, 3))
assert ttorch.all(ttorch.ones({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}) == ttorch.Tensor({
'a': torch.ones(2, 3),
'b': torch.ones(5, 6),
'x': {
'c': torch.ones(2, 3, 4),
}
}))
@choose_mark()
def test_ones_like(self):
assert ttorch.all(
ttorch.ones_like(torch.tensor([[1, 2, 3], [4, 5, 6]])) ==
torch.tensor([[1, 1, 1], [1, 1, 1]])
)
assert ttorch.all(
ttorch.ones_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}) == ttorch.Tensor({
'a': torch.tensor([[1, 1, 1], [1, 1, 1]]),
'b': torch.tensor([1, 1, 1, 1]),
'x': {
'c': torch.tensor([1, 1, 1]),
'd': torch.tensor([[[1, 1]]]),
}
})
)
@choose_mark()
def test_randn(self):
_target = ttorch.randn(200, 300)
assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02
assert 0.98 <= _target.view(60000).std().tolist() <= 1.02
assert _target.shape == torch.Size([200, 300])
_target = ttorch.randn({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
})
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
'c': torch.Size([2, 3, 4]),
}
})
@choose_mark()
def test_randn_like(self):
_target = ttorch.randn_like(torch.ones(200, 300))
assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02
assert 0.98 <= _target.view(60000).std().tolist() <= 1.02
assert _target.shape == torch.Size([200, 300])
_target = ttorch.randn_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32),
'b': torch.tensor([1, 2, 3, 4], dtype=torch.float32),
'x': {
'c': torch.tensor([5, 6, 7], dtype=torch.float32),
'd': torch.tensor([[[8, 9]]], dtype=torch.float32),
}
})
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
'c': torch.Size([3]),
'd': torch.Size([1, 1, 2]),
}
})
@choose_mark()
def test_randint(self):
_target = ttorch.randint(-10, 10, {
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
})
assert ttorch.all(_target < 10)
assert ttorch.all(-10 <= _target)
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
'c': torch.Size([2, 3, 4]),
}
})
_target = ttorch.randint(10, {
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
})
assert ttorch.all(_target < 10)
assert ttorch.all(0 <= _target)
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
'c': torch.Size([2, 3, 4]),
}
})
@choose_mark()
def test_randint_like(self):
_target = ttorch.randint_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}, -10, 10)
assert ttorch.all(_target < 10)
assert ttorch.all(-10 <= _target)
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
'c': torch.Size([3]),
'd': torch.Size([1, 1, 2]),
}
})
_target = ttorch.randint_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}, 10)
assert ttorch.all(_target < 10)
assert ttorch.all(0 <= _target)
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
'c': torch.Size([3]),
'd': torch.Size([1, 1, 2]),
}
})
@choose_mark()
def test_full(self):
_target = ttorch.full({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}, 233)
assert ttorch.all(_target == 233)
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
'c': torch.Size([2, 3, 4]),
}
})
@choose_mark()
def test_full_like(self):
_target = ttorch.full_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}, 233)
assert ttorch.all(_target == 233)
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
'c': torch.Size([3]),
'd': torch.Size([1, 1, 2]),
}
})
@choose_mark()
def test_empty(self):
_target = ttorch.empty({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
})
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
'c': torch.Size([2, 3, 4]),
}
})
@choose_mark()
def test_empty_like(self):
_target = ttorch.empty_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
})
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
'c': torch.Size([3]),
'd': torch.Size([1, 1, 2]),
}
})
import torch
import treetensor.torch as ttorch
from .base import choose_mark
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchFuncsMatrix:
@choose_mark()
def test_dot(self):
t1 = ttorch.dot(torch.tensor([1, 2]), torch.tensor([2, 3]))
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 8
t2 = ttorch.dot(
ttorch.tensor({
'a': [1, 2, 3],
'b': {'x': [3, 4]},
}),
ttorch.tensor({
'a': [5, 6, 7],
'b': {'x': [1, 2]},
})
)
assert (t2 == ttorch.tensor({'a': 38, 'b': {'x': 11}})).all()
@choose_mark()
def test_matmul(self):
t1 = ttorch.matmul(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[5, 6], [7, 2]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([[19, 10], [43, 26]])).all()
t2 = ttorch.matmul(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [3, 4, 5, 6]},
}),
ttorch.tensor({
'a': [[5, 6], [7, 2]],
'b': {'x': [4, 3, 2, 1]},
}),
)
assert (t2 == ttorch.tensor({
'a': [[19, 10], [43, 26]],
'b': {'x': 40}
})).all()
@choose_mark()
def test_mm(self):
t1 = ttorch.mm(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[5, 6], [7, 2]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([[19, 10], [43, 26]])).all()
t2 = ttorch.mm(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [[3, 4, 5], [6, 7, 8]]},
}),
ttorch.tensor({
'a': [[5, 6], [7, 2]],
'b': {'x': [[6, 5], [4, 3], [2, 1]]},
}),
)
assert (t2 == ttorch.tensor({
'a': [[19, 10], [43, 26]],
'b': {'x': [[44, 32], [80, 59]]},
})).all()
import torch
import treetensor.torch as ttorch
from .base import choose_mark
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchFuncsOperation:
@choose_mark()
def test_cat(self):
t1 = torch.tensor([[21, 29, 17],
[16, 11, 16]])
t2 = torch.tensor([[46, 46, 46],
[30, 47, 36]])
t3 = torch.tensor([[51, 65, 65],
[54, 67, 57]])
assert (ttorch.cat((t1, t2, t3)) == ttorch.tensor([[21, 29, 17],
[16, 11, 16],
[46, 46, 46],
[30, 47, 36],
[51, 65, 65],
[54, 67, 57]])).all()
tt1 = ttorch.Tensor({
'a': t1,
'b': {'x': t2, 'y': t3},
})
tt2 = ttorch.Tensor({
'a': t2,
'b': {'x': t3, 'y': t1},
})
tt3 = ttorch.Tensor({
'a': t3,
'b': {'x': t1, 'y': t2},
})
assert (ttorch.cat((tt1, tt2, tt3)) == ttorch.tensor({
'a': [[21, 29, 17],
[16, 11, 16],
[46, 46, 46],
[30, 47, 36],
[51, 65, 65],
[54, 67, 57]],
'b': {
'x': [[46, 46, 46],
[30, 47, 36],
[51, 65, 65],
[54, 67, 57],
[21, 29, 17],
[16, 11, 16]],
'y': [[51, 65, 65],
[54, 67, 57],
[21, 29, 17],
[16, 11, 16],
[46, 46, 46],
[30, 47, 36]],
}})).all()
assert (ttorch.cat((tt1, tt2, tt3), dim=1) == ttorch.tensor({
'a': [[21, 29, 17, 46, 46, 46, 51, 65, 65],
[16, 11, 16, 30, 47, 36, 54, 67, 57]],
'b': {
'x': [[46, 46, 46, 51, 65, 65, 21, 29, 17],
[30, 47, 36, 54, 67, 57, 16, 11, 16]],
'y': [[51, 65, 65, 21, 29, 17, 46, 46, 46],
[54, 67, 57, 16, 11, 16, 30, 47, 36]],
}})).all()
@choose_mark()
def test_split(self):
t1 = torch.tensor([[59, 82],
[86, 42],
[71, 84],
[61, 58],
[82, 37],
[14, 31]])
t1_a, t1_b, t1_c = ttorch.split(t1, (1, 2, 3))
assert (t1_a == torch.tensor([[59, 82]])).all()
assert (t1_b == torch.tensor([[86, 42],
[71, 84]])).all()
assert (t1_c == torch.tensor([[61, 58],
[82, 37],
[14, 31]])).all()
tt1 = ttorch.tensor({
'a': [[1, 65],
[68, 31],
[76, 73],
[74, 76],
[90, 0],
[95, 89]],
'b': {'x': [[[11, 20, 74],
[17, 85, 44]],
[[67, 37, 89],
[76, 28, 0]],
[[56, 12, 7],
[17, 63, 32]],
[[81, 75, 19],
[89, 21, 55]],
[[71, 53, 0],
[66, 82, 57]],
[[73, 81, 11],
[58, 54, 78]]]},
})
tt1_a, tt1_b, tt1_c = ttorch.split(tt1, (1, 2, 3))
assert (tt1_a == ttorch.tensor({
'a': [[1, 65]],
'b': [[[11, 20, 74],
[17, 85, 44]]]
})).all()
assert (tt1_b == ttorch.tensor({
'a': [[68, 31],
[76, 73]],
'b': [[[67, 37, 89],
[76, 28, 0]],
[[56, 12, 7],
[17, 63, 32]]]
})).all()
assert (tt1_c == ttorch.tensor({
'a': [[74, 76],
[90, 0],
[95, 89]],
'b': [[[81, 75, 19],
[89, 21, 55]],
[[71, 53, 0],
[66, 82, 57]],
[[73, 81, 11],
[58, 54, 78]]]
})).all()
@choose_mark()
def test_stack(self):
t1 = torch.tensor([[17, 15, 27],
[12, 17, 29]])
t2 = torch.tensor([[45, 41, 47],
[37, 37, 36]])
t3 = torch.tensor([[60, 50, 55],
[69, 54, 58]])
assert (ttorch.stack((t1, t2, t3)) == torch.tensor([[[17, 15, 27],
[12, 17, 29]],
[[45, 41, 47],
[37, 37, 36]],
[[60, 50, 55],
[69, 54, 58]]])).all()
tt1 = ttorch.tensor({
'a': [[25, 22, 29],
[19, 21, 27]],
'b': {'x': [[20, 17, 28, 10],
[28, 16, 27, 27],
[18, 21, 17, 12]]},
})
tt2 = ttorch.tensor({
'a': [[40, 44, 41],
[39, 44, 40]],
'b': {'x': [[44, 42, 38, 44],
[30, 44, 42, 31],
[36, 30, 33, 31]]}
})
assert (ttorch.stack((tt1, tt2)) == ttorch.tensor({
'a': [[[25, 22, 29],
[19, 21, 27]],
[[40, 44, 41],
[39, 44, 40]]],
'b': {'x': [[[20, 17, 28, 10],
[28, 16, 27, 27],
[18, 21, 17, 12]],
[[44, 42, 38, 44],
[30, 44, 42, 31],
[36, 30, 33, 31]]]},
})).all()
assert (ttorch.stack((tt1, tt2), dim=1) == ttorch.tensor({
'a': [[[25, 22, 29],
[40, 44, 41]],
[[19, 21, 27],
[39, 44, 40]]],
'b': {'x': [[[20, 17, 28, 10],
[44, 42, 38, 44]],
[[28, 16, 27, 27],
[30, 44, 42, 31]],
[[18, 21, 17, 12],
[36, 30, 33, 31]]]},
})).all()
@choose_mark()
def test_reshape(self):
t1 = ttorch.reshape(torch.tensor([[1, 2], [3, 4]]), (-1,))
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([1, 2, 3, 4])).all()
t2 = ttorch.reshape(ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [[2], [3], [5], [7], [11], [13]]},
}), (-1,))
assert (t2 == ttorch.tensor({
'a': [1, 2, 3, 4],
'b': {'x': [2, 3, 5, 7, 11, 13]},
})).all()
@choose_mark()
def test_squeeze(self):
t1 = torch.randint(100, (2, 1, 2, 1, 2))
assert t1.shape == torch.Size([2, 1, 2, 1, 2])
assert ttorch.squeeze(t1).shape == torch.Size([2, 2, 2])
t2 = ttorch.randint(100, {
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
assert t2.shape == ttorch.Size({
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
assert ttorch.squeeze(t2).shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
@choose_mark()
def test_unsqueeze(self):
t1 = torch.randint(100, (100,))
assert t1.shape == torch.Size([100])
assert ttorch.unsqueeze(t1, 0).shape == torch.Size([1, 100])
tt1 = ttorch.randint(100, {
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
assert tt1.shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
assert ttorch.unsqueeze(tt1, 1).shape == ttorch.Size({
'a': (2, 1, 2, 2),
'b': {'x': (2, 1, 3)},
})
@choose_mark()
def test_where(self):
t1 = ttorch.where(
torch.tensor([[True, False], [False, True]]),
torch.tensor([[2, 8], [16, 4]]),
torch.tensor([[3, 11], [5, 7]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([[2, 11],
[5, 4]])).all()
t2 = ttorch.tensor({
'a': [[27, 90, 80],
[12, 59, 5]],
'b': {'x': [[[71, 52, 92, 79],
[48, 4, 13, 96]],
[[72, 89, 44, 62],
[32, 4, 29, 76]],
[[6, 3, 93, 89],
[44, 89, 85, 90]]]},
})
assert (ttorch.where(t2 % 2 == 1, t2,
ttorch.zeros({'a': (2, 3), 'b': {'x': (3, 2, 4)}}, dtype=torch.long)) ==
ttorch.tensor({
'a': [[27, 0, 0],
[0, 59, 5]],
'b': {'x': [[[71, 0, 0, 79],
[0, 0, 13, 0]],
[[0, 89, 0, 0],
[0, 0, 29, 0]],
[[0, 3, 93, 89],
[0, 89, 85, 0]]]},
})).all()
import torch
import treetensor.torch as ttorch
from .base import choose_mark
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchFuncsReduction:
@choose_mark()
def test_all(self):
r1 = ttorch.all(torch.tensor([True, True, True]))
assert torch.is_tensor(r1)
assert r1 == torch.tensor(True)
assert r1
r2 = ttorch.all(torch.tensor([True, True, False]))
assert torch.is_tensor(r2)
assert r2 == torch.tensor(False)
assert not r2
r3 = ttorch.all(torch.tensor([False, False, False]))
assert torch.is_tensor(r3)
assert r3 == torch.tensor(False)
assert not r3
r4 = ttorch.all({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]),
}).all()
assert torch.is_tensor(r4)
assert r4 == torch.tensor(True)
assert r4
r5 = ttorch.all({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
}).all()
assert torch.is_tensor(r5)
assert r5 == torch.tensor(False)
assert not r5
r6 = ttorch.all({
'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]),
}).all()
assert torch.is_tensor(r6)
assert r6 == torch.tensor(False)
assert not r6
@choose_mark()
def test_any(self):
r1 = ttorch.any(torch.tensor([True, True, True]))
assert torch.is_tensor(r1)
assert r1 == torch.tensor(True)
assert r1
r2 = ttorch.any(torch.tensor([True, True, False]))
assert torch.is_tensor(r2)
assert r2 == torch.tensor(True)
assert r2
r3 = ttorch.any(torch.tensor([False, False, False]))
assert torch.is_tensor(r3)
assert r3 == torch.tensor(False)
assert not r3
r4 = ttorch.any({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]),
}).all()
assert torch.is_tensor(r4)
assert r4 == torch.tensor(True)
assert r4
r5 = ttorch.any({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
}).all()
assert torch.is_tensor(r5)
assert r5 == torch.tensor(True)
assert r5
r6 = ttorch.any({
'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]),
}).all()
assert torch.is_tensor(r6)
assert r6 == torch.tensor(False)
assert not r6
@choose_mark()
def test_min(self):
t1 = ttorch.min(torch.tensor([1.0, 2.0, 1.5]))
assert isinstance(t1, torch.Tensor)
assert t1 == torch.tensor(1.0)
assert ttorch.min(ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})) == ttorch.tensor({
'a': 1.0,
'b': {'x': 0.9},
})
@choose_mark()
def test_max(self):
t1 = ttorch.max(torch.tensor([1.0, 2.0, 1.5]))
assert isinstance(t1, torch.Tensor)
assert t1 == torch.tensor(2.0)
assert ttorch.max(ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})) == ttorch.tensor({
'a': 2.0,
'b': {'x': 2.5, }
})
@choose_mark()
def test_sum(self):
assert ttorch.sum(torch.tensor([1.0, 2.0, 1.5])) == torch.tensor(4.5)
assert ttorch.sum(ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})) == torch.tensor(11.0)
from .test_clazz import TestTorchTensorClass
from .test_comparison import TestTorchTensorComparison
from .test_math import TestTorchTensorMath
from .test_matrix import TestTorchTensorMatrix
from .test_operation import TestTorchTensorOperation
from .test_reduction import TestTorchTensorReduction
import treetensor.torch as ttorch
from treetensor.utils import replaceable_partial
from ...tests import choose_mark_with_existence_check
choose_mark = replaceable_partial(choose_mark_with_existence_check, base=ttorch.Tensor)
import numpy as np
import torch
from treevalue import typetrans, TreeValue, func_treelize
import treetensor.numpy as tnp
import treetensor.torch as ttorch
from treetensor.common import Object
from .base import choose_mark
_all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y)
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchTensorClass:
_DEMO_1 = ttorch.Tensor({
'a': [[1, 2, 3], [4, 5, 6]],
'b': [[1, 2], [5, 6]],
'x': {
'c': [3, 5, 6, 7],
'd': [[[1, 2], [8, 9]]],
}
})
_DEMO_2 = ttorch.Tensor({
'a': [[1, 2, 3], [4, 5, 6]],
'b': [[1, 2], [5, 60]],
'x': {
'c': [3, 5, 6, 7],
'd': [[[1, 2], [8, 9]]],
}
})
@choose_mark()
def test___init__(self):
assert (ttorch.Tensor([1, 2, 3]) == torch.tensor([1, 2, 3])).all()
assert (ttorch.Tensor([1, 2, 3], dtype=torch.float32) == torch.FloatTensor([1, 2, 3])).all()
assert (self._DEMO_1 == typetrans(TreeValue({
'a': ttorch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': ttorch.tensor([[1, 2], [5, 6]]),
'x': {
'c': ttorch.tensor([3, 5, 6, 7]),
'd': ttorch.tensor([[[1, 2], [8, 9]]]),
}
}), ttorch.Tensor)).all()
@choose_mark()
def test_numel(self):
assert self._DEMO_1.numel() == 18
@choose_mark()
def test_numpy(self):
assert tnp.all(self._DEMO_1.numpy() == tnp.ndarray({
'a': np.array([[1, 2, 3], [4, 5, 6]]),
'b': np.array([[1, 2], [5, 6]]),
'x': {
'c': np.array([3, 5, 6, 7]),
'd': np.array([[[1, 2], [8, 9]]]),
}
}))
@choose_mark()
def test_cpu(self):
assert ttorch.all(self._DEMO_1.cpu() == self._DEMO_1)
assert _all_is(self._DEMO_1.cpu(), self._DEMO_1).reduce(lambda **kws: all(kws.values()))
@choose_mark()
def test_to(self):
assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.Tensor({
'a': torch.FloatTensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.FloatTensor([[1, 2], [5, 6]]),
'x': {
'c': torch.FloatTensor([3, 5, 6, 7]),
'd': torch.FloatTensor([[[1, 2], [8, 9]]]),
}
}))
@choose_mark()
def test_tolist(self):
assert self._DEMO_1.tolist() == Object({
'a': [[1, 2, 3], [4, 5, 6]],
'b': [[1, 2], [5, 6]],
'x': {
'c': [3, 5, 6, 7],
'd': [[[1, 2], [8, 9]]],
}
})
@choose_mark()
def test___eq__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) == ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [True, False],
'b': {'x': [[False, True], [False, False]]}
})).all()
@choose_mark()
def test___ne__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) != ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [False, True],
'b': {'x': [[True, False], [True, True]]}
})).all()
@choose_mark()
def test___lt__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) < ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [False, True],
'b': {'x': [[False, False], [True, False]]}
})).all()
@choose_mark()
def test___le__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) <= ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [True, True],
'b': {'x': [[False, True], [True, False]]}
})).all()
@choose_mark()
def test___gt__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) > ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [False, False],
'b': {'x': [[True, False], [False, True]]}
})).all()
@choose_mark()
def test___ge__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) >= ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [True, False],
'b': {'x': [[True, True], [False, True]]}
})).all()
@choose_mark()
def test_clone(self):
t1 = ttorch.tensor([1.0, 2.0, 1.5]).clone()
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([1.0, 2.0, 1.5])).all()
t2 = ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
}).clone()
assert (t2 == ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})).all()
import math
import torch
import treetensor.torch as ttorch
from .base import choose_mark
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchTensorComparison:
@choose_mark()
def test_isfinite(self):
t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isfinite()
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([True, False, True, False, False])).all()
t2 = ttorch.tensor({
'a': [1, float('inf'), 2, float('-inf'), float('nan')],
'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]}
}).isfinite()
assert (t2 == ttorch.tensor({
'a': [True, False, True, False, False],
'b': {'x': [[True, False, True], [False, True, False]]},
}))
@choose_mark()
def test_isinf(self):
t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isinf()
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([False, True, False, True, False])).all()
t2 = ttorch.tensor({
'a': [1, float('inf'), 2, float('-inf'), float('nan')],
'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]}
}).isinf()
assert (t2 == ttorch.tensor({
'a': [False, True, False, True, False],
'b': {'x': [[False, True, False], [True, False, False]]},
}))
@choose_mark()
def test_isnan(self):
t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isnan()
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([False, False, False, False, True])).all()
t2 = ttorch.tensor({
'a': [1, float('inf'), 2, float('-inf'), float('nan')],
'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]}
}).isnan()
assert (t2 == ttorch.tensor({
'a': [False, False, False, False, True],
'b': {'x': [[False, False, False], [False, False, True]]},
})).all()
@choose_mark()
def test_isclose(self):
t1 = ttorch.tensor((1., 2, 3)).isclose(ttorch.tensor((1 + 1e-10, 3, 4)))
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([True, False, False])).all()
t2 = ttorch.tensor({
'a': [1., 2, 3],
'b': {'x': [[float('inf'), 4, 1e20],
[-math.inf, 2.2943, 9483.32]]},
}).isclose(ttorch.tensor({
'a': [1 + 1e-10, 3, 4],
'b': {'x': [[math.inf, 6, 1e20 + 1],
[-float('inf'), 2.294300000001, 9484.32]]},
}))
assert (t2 == ttorch.tensor({
'a': [True, False, False],
'b': {'x': [[True, False, True],
[True, True, False]]},
})).all()
import math
import numpy as np
import torch
from treevalue import func_treelize, typetrans, TreeValue
import treetensor.numpy as tnp
import treetensor.torch as ttorch
from treetensor.common import Object
from treetensor.utils import replaceable_partial
from ..tests import choose_mark_with_existence_check
_all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y)
choose_mark = replaceable_partial(choose_mark_with_existence_check, base=ttorch.Tensor)
# noinspection PyUnresolvedReferences,DuplicatedCode
class TestTorchTensor:
_DEMO_1 = ttorch.Tensor({
'a': [[1, 2, 3], [4, 5, 6]],
'b': [[1, 2], [5, 6]],
'x': {
'c': [3, 5, 6, 7],
'd': [[[1, 2], [8, 9]]],
}
})
_DEMO_2 = ttorch.Tensor({
'a': [[1, 2, 3], [4, 5, 6]],
'b': [[1, 2], [5, 60]],
'x': {
'c': [3, 5, 6, 7],
'd': [[[1, 2], [8, 9]]],
}
})
from .base import choose_mark
@choose_mark()
def test___init__(self):
assert (ttorch.Tensor([1, 2, 3]) == torch.tensor([1, 2, 3])).all()
assert (ttorch.Tensor([1, 2, 3], dtype=torch.float32) == torch.FloatTensor([1, 2, 3])).all()
assert (self._DEMO_1 == typetrans(TreeValue({
'a': ttorch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': ttorch.tensor([[1, 2], [5, 6]]),
'x': {
'c': ttorch.tensor([3, 5, 6, 7]),
'd': ttorch.tensor([[[1, 2], [8, 9]]]),
}
}), ttorch.Tensor)).all()
@choose_mark()
def test_numel(self):
assert self._DEMO_1.numel() == 18
@choose_mark()
def test_numpy(self):
assert tnp.all(self._DEMO_1.numpy() == tnp.ndarray({
'a': np.array([[1, 2, 3], [4, 5, 6]]),
'b': np.array([[1, 2], [5, 6]]),
'x': {
'c': np.array([3, 5, 6, 7]),
'd': np.array([[[1, 2], [8, 9]]]),
}
}))
@choose_mark()
def test_cpu(self):
assert ttorch.all(self._DEMO_1.cpu() == self._DEMO_1)
assert _all_is(self._DEMO_1.cpu(), self._DEMO_1).reduce(lambda **kws: all(kws.values()))
@choose_mark()
def test_to(self):
assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.Tensor({
'a': torch.FloatTensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.FloatTensor([[1, 2], [5, 6]]),
'x': {
'c': torch.FloatTensor([3, 5, 6, 7]),
'd': torch.FloatTensor([[[1, 2], [8, 9]]]),
}
}))
@choose_mark()
def test_all(self):
t1 = ttorch.Tensor({
'a': [True, True],
'b': {'x': [[True, True, ], [True, True, ]]}
}).all()
assert isinstance(t1, torch.Tensor)
assert t1.dtype == torch.bool
assert t1
t2 = ttorch.Tensor({
'a': [True, False],
'b': {'x': [[True, True, ], [True, True, ]]}
}).all()
assert isinstance(t2, torch.Tensor)
assert t2.dtype == torch.bool
assert not t2
@choose_mark()
def test_tolist(self):
assert self._DEMO_1.tolist() == Object({
'a': [[1, 2, 3], [4, 5, 6]],
'b': [[1, 2], [5, 6]],
'x': {
'c': [3, 5, 6, 7],
'd': [[[1, 2], [8, 9]]],
}
})
@choose_mark()
def test_any(self):
t1 = ttorch.Tensor({
'a': [True, False],
'b': {'x': [[False, False, ], [False, False, ]]}
}).any()
assert isinstance(t1, torch.Tensor)
assert t1.dtype == torch.bool
assert t1
t2 = ttorch.Tensor({
'a': [False, False],
'b': {'x': [[False, False, ], [False, False, ]]}
}).any()
assert isinstance(t2, torch.Tensor)
assert t2.dtype == torch.bool
assert not t2
@choose_mark()
def test_max(self):
t1 = ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}).max()
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 3
@choose_mark()
def test_min(self):
t1 = ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}).min()
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == -1
@choose_mark()
def test_sum(self):
t1 = ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}).sum()
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 7
@choose_mark()
def test___eq__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) == ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [True, False],
'b': {'x': [[False, True], [False, False]]}
})).all()
@choose_mark()
def test___ne__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) != ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [False, True],
'b': {'x': [[True, False], [True, True]]}
})).all()
@choose_mark()
def test___lt__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) < ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [False, True],
'b': {'x': [[False, False], [True, False]]}
})).all()
@choose_mark()
def test___le__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) <= ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [True, True],
'b': {'x': [[False, True], [True, False]]}
})).all()
@choose_mark()
def test___gt__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) > ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [False, False],
'b': {'x': [[True, False], [False, True]]}
})).all()
@choose_mark()
def test___ge__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) >= ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [True, False],
'b': {'x': [[True, True], [False, True]]}
})).all()
@choose_mark()
def test_clone(self):
t1 = ttorch.tensor([1.0, 2.0, 1.5]).clone()
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([1.0, 2.0, 1.5])).all()
t2 = ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
}).clone()
assert (t2 == ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})).all()
@choose_mark()
def test_dot(self):
t1 = torch.tensor([1, 2]).dot(torch.tensor([2, 3]))
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 8
t2 = ttorch.tensor({
'a': [1, 2, 3],
'b': {'x': [3, 4]},
}).dot(
ttorch.tensor({
'a': [5, 6, 7],
'b': {'x': [1, 2]},
})
)
assert (t2 == ttorch.tensor({'a': 38, 'b': {'x': 11}})).all()
@choose_mark()
def test_matmul(self):
t1 = torch.tensor([[1, 2], [3, 4]]).matmul(
torch.tensor([[5, 6], [7, 2]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([[19, 10], [43, 26]])).all()
t2 = ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [3, 4, 5, 6]},
}).matmul(
ttorch.tensor({
'a': [[5, 6], [7, 2]],
'b': {'x': [4, 3, 2, 1]},
}),
)
assert (t2 == ttorch.tensor({
'a': [[19, 10], [43, 26]],
'b': {'x': 40}
})).all()
@choose_mark()
def test_mm(self):
t1 = torch.tensor([[1, 2], [3, 4]]).mm(
torch.tensor([[5, 6], [7, 2]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([[19, 10], [43, 26]])).all()
t2 = ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [[3, 4, 5], [6, 7, 8]]},
}).mm(
ttorch.tensor({
'a': [[5, 6], [7, 2]],
'b': {'x': [[6, 5], [4, 3], [2, 1]]},
}),
)
assert (t2 == ttorch.tensor({
'a': [[19, 10], [43, 26]],
'b': {'x': [[44, 32], [80, 59]]},
})).all()
@choose_mark()
def test_isfinite(self):
t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isfinite()
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([True, False, True, False, False])).all()
t2 = ttorch.tensor({
'a': [1, float('inf'), 2, float('-inf'), float('nan')],
'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]}
}).isfinite()
assert (t2 == ttorch.tensor({
'a': [True, False, True, False, False],
'b': {'x': [[True, False, True], [False, True, False]]},
}))
@choose_mark()
def test_isinf(self):
t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isinf()
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([False, True, False, True, False])).all()
t2 = ttorch.tensor({
'a': [1, float('inf'), 2, float('-inf'), float('nan')],
'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]}
}).isinf()
assert (t2 == ttorch.tensor({
'a': [False, True, False, True, False],
'b': {'x': [[False, True, False], [True, False, False]]},
}))
@choose_mark()
def test_isnan(self):
t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isnan()
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([False, False, False, False, True])).all()
t2 = ttorch.tensor({
'a': [1, float('inf'), 2, float('-inf'), float('nan')],
'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]}
}).isnan()
assert (t2 == ttorch.tensor({
'a': [False, False, False, False, True],
'b': {'x': [[False, False, False], [False, False, True]]},
})).all()
@choose_mark()
def test_isclose(self):
t1 = ttorch.tensor((1., 2, 3)).isclose(ttorch.tensor((1 + 1e-10, 3, 4)))
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([True, False, False])).all()
t2 = ttorch.tensor({
'a': [1., 2, 3],
'b': {'x': [[float('inf'), 4, 1e20],
[-math.inf, 2.2943, 9483.32]]},
}).isclose(ttorch.tensor({
'a': [1 + 1e-10, 3, 4],
'b': {'x': [[math.inf, 6, 1e20 + 1],
[-float('inf'), 2.294300000001, 9484.32]]},
}))
assert (t2 == ttorch.tensor({
'a': [True, False, False],
'b': {'x': [[True, False, True],
[True, True, False]]},
})).all()
# noinspection PyUnresolvedReferences
class TestTorchTensorMath:
@choose_mark()
def test_abs(self):
t1 = ttorch.tensor([12, 0, -3]).abs()
......@@ -1253,210 +889,3 @@ class TestTorchTensor:
'b': {'x': [[math.nan, 0.0792, -0.6021],
[1.2041, 0.5740, math.nan]]},
}), rtol=1e-4, atol=1e-4, equal_nan=True).all()
@choose_mark()
def test_split(self):
t1 = torch.tensor([[59, 82],
[86, 42],
[71, 84],
[61, 58],
[82, 37],
[14, 31]])
t1_a, t1_b, t1_c = t1.split((1, 2, 3))
assert (t1_a == torch.tensor([[59, 82]])).all()
assert (t1_b == torch.tensor([[86, 42],
[71, 84]])).all()
assert (t1_c == torch.tensor([[61, 58],
[82, 37],
[14, 31]])).all()
tt1 = ttorch.tensor({
'a': [[1, 65],
[68, 31],
[76, 73],
[74, 76],
[90, 0],
[95, 89]],
'b': {'x': [[[11, 20, 74],
[17, 85, 44]],
[[67, 37, 89],
[76, 28, 0]],
[[56, 12, 7],
[17, 63, 32]],
[[81, 75, 19],
[89, 21, 55]],
[[71, 53, 0],
[66, 82, 57]],
[[73, 81, 11],
[58, 54, 78]]]},
})
tt1_a, tt1_b, tt1_c = tt1.split((1, 2, 3))
assert (tt1_a == ttorch.tensor({
'a': [[1, 65]],
'b': [[[11, 20, 74],
[17, 85, 44]]]
})).all()
assert (tt1_b == ttorch.tensor({
'a': [[68, 31],
[76, 73]],
'b': [[[67, 37, 89],
[76, 28, 0]],
[[56, 12, 7],
[17, 63, 32]]]
})).all()
assert (tt1_c == ttorch.tensor({
'a': [[74, 76],
[90, 0],
[95, 89]],
'b': [[[81, 75, 19],
[89, 21, 55]],
[[71, 53, 0],
[66, 82, 57]],
[[73, 81, 11],
[58, 54, 78]]]
})).all()
@choose_mark()
def test_reshape(self):
t1 = torch.tensor([[1, 2], [3, 4]]).reshape((-1,))
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([1, 2, 3, 4])).all()
t2 = ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [[2], [3], [5], [7], [11], [13]]},
}).reshape((-1,))
assert (t2 == ttorch.tensor({
'a': [1, 2, 3, 4],
'b': {'x': [2, 3, 5, 7, 11, 13]},
})).all()
@choose_mark()
def test_squeeze(self):
t1 = torch.randint(100, (2, 1, 2, 1, 2))
assert t1.shape == torch.Size([2, 1, 2, 1, 2])
assert t1.squeeze().shape == torch.Size([2, 2, 2])
t2 = ttorch.randint(100, {
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
assert t2.shape == ttorch.Size({
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
assert t2.squeeze().shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
@choose_mark()
def test_squeeze_(self):
t1 = torch.randint(100, (2, 1, 2, 1, 2))
assert t1.shape == torch.Size([2, 1, 2, 1, 2])
t1r = t1.squeeze_()
assert t1r is t1
assert t1.shape == torch.Size([2, 2, 2])
t2 = ttorch.randint(100, {
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
assert t2.shape == ttorch.Size({
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
t2r = t2.squeeze_()
assert t2r is t2
assert t2.shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
@choose_mark()
def test_unsqueeze(self):
t1 = torch.randint(100, (100,))
assert t1.shape == torch.Size([100])
assert t1.unsqueeze(0).shape == torch.Size([1, 100])
tt1 = ttorch.randint(100, {
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
assert tt1.shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
assert tt1.unsqueeze(1).shape == ttorch.Size({
'a': (2, 1, 2, 2),
'b': {'x': (2, 1, 3)},
})
@choose_mark()
def test_unsqueeze_(self):
t1 = torch.randint(100, (100,))
assert t1.shape == torch.Size([100])
t1r = t1.unsqueeze_(0)
assert t1r is t1
assert t1.shape == torch.Size([1, 100])
tt1 = ttorch.randint(100, {
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
assert tt1.shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
tt1r = tt1.unsqueeze_(1)
assert tt1r is tt1
assert tt1.shape == ttorch.Size({
'a': (2, 1, 2, 2),
'b': {'x': (2, 1, 3)},
})
@choose_mark()
def test_where(self):
t1 = torch.tensor([[2, 8], [16, 4]]).where(
torch.tensor([[True, False], [False, True]]),
torch.tensor([[3, 11], [5, 7]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([[2, 11],
[5, 4]])).all()
t2 = ttorch.tensor({
'a': [[27, 90, 80],
[12, 59, 5]],
'b': {'x': [[[71, 52, 92, 79],
[48, 4, 13, 96]],
[[72, 89, 44, 62],
[32, 4, 29, 76]],
[[6, 3, 93, 89],
[44, 89, 85, 90]]]},
})
assert (t2.where(t2 % 2 == 1,
ttorch.zeros({'a': (2, 3), 'b': {'x': (3, 2, 4)}}, dtype=torch.long)) ==
ttorch.tensor({
'a': [[27, 0, 0],
[0, 59, 5]],
'b': {'x': [[[71, 0, 0, 79],
[0, 0, 13, 0]],
[[0, 89, 0, 0],
[0, 0, 29, 0]],
[[0, 3, 93, 89],
[0, 89, 85, 0]]]},
})).all()
import torch
import treetensor.torch as ttorch
from .base import choose_mark
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchTensorMatrix:
@choose_mark()
def test_dot(self):
t1 = torch.tensor([1, 2]).dot(torch.tensor([2, 3]))
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 8
t2 = ttorch.tensor({
'a': [1, 2, 3],
'b': {'x': [3, 4]},
}).dot(
ttorch.tensor({
'a': [5, 6, 7],
'b': {'x': [1, 2]},
})
)
assert (t2 == ttorch.tensor({'a': 38, 'b': {'x': 11}})).all()
@choose_mark()
def test_matmul(self):
t1 = torch.tensor([[1, 2], [3, 4]]).matmul(
torch.tensor([[5, 6], [7, 2]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([[19, 10], [43, 26]])).all()
t2 = ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [3, 4, 5, 6]},
}).matmul(
ttorch.tensor({
'a': [[5, 6], [7, 2]],
'b': {'x': [4, 3, 2, 1]},
}),
)
assert (t2 == ttorch.tensor({
'a': [[19, 10], [43, 26]],
'b': {'x': 40}
})).all()
@choose_mark()
def test_mm(self):
t1 = torch.tensor([[1, 2], [3, 4]]).mm(
torch.tensor([[5, 6], [7, 2]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == torch.tensor([[19, 10], [43, 26]])).all()
t2 = ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [[3, 4, 5], [6, 7, 8]]},
}).mm(
ttorch.tensor({
'a': [[5, 6], [7, 2]],
'b': {'x': [[6, 5], [4, 3], [2, 1]]},
}),
)
assert (t2 == ttorch.tensor({
'a': [[19, 10], [43, 26]],
'b': {'x': [[44, 32], [80, 59]]},
})).all()
import torch
import treetensor.torch as ttorch
from .base import choose_mark
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchTensorOperation:
@choose_mark()
def test_split(self):
t1 = torch.tensor([[59, 82],
[86, 42],
[71, 84],
[61, 58],
[82, 37],
[14, 31]])
t1_a, t1_b, t1_c = t1.split((1, 2, 3))
assert (t1_a == torch.tensor([[59, 82]])).all()
assert (t1_b == torch.tensor([[86, 42],
[71, 84]])).all()
assert (t1_c == torch.tensor([[61, 58],
[82, 37],
[14, 31]])).all()
tt1 = ttorch.tensor({
'a': [[1, 65],
[68, 31],
[76, 73],
[74, 76],
[90, 0],
[95, 89]],
'b': {'x': [[[11, 20, 74],
[17, 85, 44]],
[[67, 37, 89],
[76, 28, 0]],
[[56, 12, 7],
[17, 63, 32]],
[[81, 75, 19],
[89, 21, 55]],
[[71, 53, 0],
[66, 82, 57]],
[[73, 81, 11],
[58, 54, 78]]]},
})
tt1_a, tt1_b, tt1_c = tt1.split((1, 2, 3))
assert (tt1_a == ttorch.tensor({
'a': [[1, 65]],
'b': [[[11, 20, 74],
[17, 85, 44]]]
})).all()
assert (tt1_b == ttorch.tensor({
'a': [[68, 31],
[76, 73]],
'b': [[[67, 37, 89],
[76, 28, 0]],
[[56, 12, 7],
[17, 63, 32]]]
})).all()
assert (tt1_c == ttorch.tensor({
'a': [[74, 76],
[90, 0],
[95, 89]],
'b': [[[81, 75, 19],
[89, 21, 55]],
[[71, 53, 0],
[66, 82, 57]],
[[73, 81, 11],
[58, 54, 78]]]
})).all()
@choose_mark()
def test_reshape(self):
t1 = torch.tensor([[1, 2], [3, 4]]).reshape((-1,))
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([1, 2, 3, 4])).all()
t2 = ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': {'x': [[2], [3], [5], [7], [11], [13]]},
}).reshape((-1,))
assert (t2 == ttorch.tensor({
'a': [1, 2, 3, 4],
'b': {'x': [2, 3, 5, 7, 11, 13]},
})).all()
@choose_mark()
def test_squeeze(self):
t1 = torch.randint(100, (2, 1, 2, 1, 2))
assert t1.shape == torch.Size([2, 1, 2, 1, 2])
assert t1.squeeze().shape == torch.Size([2, 2, 2])
t2 = ttorch.randint(100, {
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
assert t2.shape == ttorch.Size({
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
assert t2.squeeze().shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
@choose_mark()
def test_squeeze_(self):
t1 = torch.randint(100, (2, 1, 2, 1, 2))
assert t1.shape == torch.Size([2, 1, 2, 1, 2])
t1r = t1.squeeze_()
assert t1r is t1
assert t1.shape == torch.Size([2, 2, 2])
t2 = ttorch.randint(100, {
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
assert t2.shape == ttorch.Size({
'a': (2, 1, 2, 1, 2),
'b': {'x': (2, 1, 1, 3)},
})
t2r = t2.squeeze_()
assert t2r is t2
assert t2.shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
@choose_mark()
def test_unsqueeze(self):
t1 = torch.randint(100, (100,))
assert t1.shape == torch.Size([100])
assert t1.unsqueeze(0).shape == torch.Size([1, 100])
tt1 = ttorch.randint(100, {
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
assert tt1.shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
assert tt1.unsqueeze(1).shape == ttorch.Size({
'a': (2, 1, 2, 2),
'b': {'x': (2, 1, 3)},
})
@choose_mark()
def test_unsqueeze_(self):
t1 = torch.randint(100, (100,))
assert t1.shape == torch.Size([100])
t1r = t1.unsqueeze_(0)
assert t1r is t1
assert t1.shape == torch.Size([1, 100])
tt1 = ttorch.randint(100, {
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
assert tt1.shape == ttorch.Size({
'a': (2, 2, 2),
'b': {'x': (2, 3)},
})
tt1r = tt1.unsqueeze_(1)
assert tt1r is tt1
assert tt1.shape == ttorch.Size({
'a': (2, 1, 2, 2),
'b': {'x': (2, 1, 3)},
})
@choose_mark()
def test_where(self):
t1 = torch.tensor([[2, 8], [16, 4]]).where(
torch.tensor([[True, False], [False, True]]),
torch.tensor([[3, 11], [5, 7]]),
)
assert isinstance(t1, torch.Tensor)
assert (t1 == ttorch.tensor([[2, 11],
[5, 4]])).all()
t2 = ttorch.tensor({
'a': [[27, 90, 80],
[12, 59, 5]],
'b': {'x': [[[71, 52, 92, 79],
[48, 4, 13, 96]],
[[72, 89, 44, 62],
[32, 4, 29, 76]],
[[6, 3, 93, 89],
[44, 89, 85, 90]]]},
})
assert (t2.where(t2 % 2 == 1,
ttorch.zeros({'a': (2, 3), 'b': {'x': (3, 2, 4)}}, dtype=torch.long)) ==
ttorch.tensor({
'a': [[27, 0, 0],
[0, 59, 5]],
'b': {'x': [[[71, 0, 0, 79],
[0, 0, 13, 0]],
[[0, 89, 0, 0],
[0, 0, 29, 0]],
[[0, 3, 93, 89],
[0, 89, 85, 0]]]},
})).all()
import torch
import treetensor.torch as ttorch
from .base import choose_mark
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchTensorReduction:
@choose_mark()
def test_all(self):
t1 = ttorch.Tensor({
'a': [True, True],
'b': {'x': [[True, True, ], [True, True, ]]}
}).all()
assert isinstance(t1, torch.Tensor)
assert t1.dtype == torch.bool
assert t1
t2 = ttorch.Tensor({
'a': [True, False],
'b': {'x': [[True, True, ], [True, True, ]]}
}).all()
assert isinstance(t2, torch.Tensor)
assert t2.dtype == torch.bool
assert not t2
@choose_mark()
def test_any(self):
t1 = ttorch.Tensor({
'a': [True, False],
'b': {'x': [[False, False, ], [False, False, ]]}
}).any()
assert isinstance(t1, torch.Tensor)
assert t1.dtype == torch.bool
assert t1
t2 = ttorch.Tensor({
'a': [False, False],
'b': {'x': [[False, False, ], [False, False, ]]}
}).any()
assert isinstance(t2, torch.Tensor)
assert t2.dtype == torch.bool
assert not t2
@choose_mark()
def test_max(self):
t1 = ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}).max()
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 3
@choose_mark()
def test_min(self):
t1 = ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}).min()
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == -1
@choose_mark()
def test_sum(self):
t1 = ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}).sum()
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 7
import sys
from .comparison import *
from .comparison import __all__ as _comparison_all
from .construct import *
from .construct import __all__ as _construct_all
from .math import *
from .math import __all__ as _math_all
from .matrix import *
from .matrix import __all__ as _matrix_all
from .operation import *
from .operation import __all__ as _operation_all
from .reduction import *
from .reduction import __all__ as _reduction_all
from ...utils import module_autoremove
__all__ = [
*_comparison_all,
*_construct_all,
*_math_all,
*_matrix_all,
*_operation_all,
*_reduction_all,
]
_current_module = sys.modules[__name__]
_current_module = module_autoremove(_current_module)
sys.modules[__name__] = _current_module
import torch
from treevalue import TreeValue
from treevalue import func_treelize as original_func_treelize
from treevalue.tree.common import BaseTree
from treevalue.utils import post_process
from ..base import _auto_torch
from ..tensor import Tensor
from ...utils import doc_from_base as original_doc_from_base
from ...utils import replaceable_partial, args_mapping
func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))(
replaceable_partial(original_func_treelize, return_type=Tensor)
)
doc_from_base = replaceable_partial(original_doc_from_base, base=torch)
auto_tensor = replaceable_partial(_auto_torch, cls=Tensor)
import builtins
import torch
from .base import doc_from_base, func_treelize
from ...common import ireduce
__all__ = [
'equal',
'isfinite', 'isinf', 'isnan', 'isclose',
'eq', 'ne', 'lt', 'le', 'gt', 'ge',
]
# noinspection PyShadowingBuiltins
@doc_from_base()
@ireduce(builtins.all)
@func_treelize()
def equal(input, other):
"""
In ``treetensor``, you can get the equality of the two tree tensors.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.equal(
... torch.tensor([1, 2, 3]),
... torch.tensor([1, 2, 3]),
... ) # the same as torch.equal
True
>>> ttorch.equal(
... ttorch.tensor({
... 'a': torch.tensor([1, 2, 3]),
... 'b': torch.tensor([[4, 5], [6, 7]]),
... }),
... ttorch.tensor({
... 'a': torch.tensor([1, 2, 3]),
... 'b': torch.tensor([[4, 5], [6, 7]]),
... }),
... )
True
"""
return torch.equal(input, other)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def isfinite(input):
"""
In ``treetensor``, you can get a tree of new tensors with boolean elements
representing if each element is `finite` or not.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))
tensor([ True, False, True, False, False])
>>> ttorch.isfinite(ttorch.tensor({
... 'a': [1, float('inf'), 2, float('-inf'), float('nan')],
... 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]}
... }))
<Tensor 0x7fb782a15970>
├── a --> tensor([ True, False, True, False, False])
└── b --> <Tensor 0x7fb782a1e040>
└── x --> tensor([[ True, False, True],
[False, True, False]])
"""
return torch.isfinite(input)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def isinf(input):
"""
In ``treetensor``, you can test if each element of ``input``
is infinite (positive or negative infinity) or not.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))
tensor([False, True, False, True, False])
>>> ttorch.isinf(ttorch.tensor({
... 'a': [1, float('inf'), 2, float('-inf'), float('nan')],
... 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]}
... }))
<Tensor 0x7fb782a29b80>
├── a --> tensor([False, True, False, True, False])
└── b --> <Tensor 0x7fb782a2d1f0>
└── x --> tensor([[False, True, False],
[ True, False, False]])
"""
return torch.isinf(input)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def isnan(input):
"""
In ``treetensor``, you get a tree of new tensors with boolean elements representing
if each element of ``input`` is NaN or not
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.isnan(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))
tensor([False, False, False, False, True])
>>> ttorch.isnan(ttorch.tensor({
... 'a': [1, float('inf'), 2, float('-inf'), float('nan')],
... 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]}
... }))
<Tensor 0x7fb782a2d0a0>
├── a --> tensor([False, False, False, False, True])
└── b --> <Tensor 0x7fb782a29d90>
└── x --> tensor([[False, False, False],
[False, False, True]])
"""
return torch.isnan(input)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def isclose(input, other, *args, **kwargs):
"""
Returns a new tensor with boolean elements representing
if each element of ``input`` is “close” to the corresponding element of ``other``.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> import math
>>> ttorch.isclose(
... ttorch.tensor((1., 2, 3)),
... ttorch.tensor((1 + 1e-10, 3, 4))
... )
tensor([ True, False, False])
>>> ttorch.isclose(
... ttorch.tensor({
... 'a': [1., 2, 3],
... 'b': {'x': [[float('inf'), 4, 1e20],
... [-math.inf, 2.2943, 9483.32]]},
... }),
... ttorch.tensor({
... 'a': [1 + 1e-10, 3, 4],
... 'b': {'x': [[math.inf, 6, 1e20+1],
... [-float('inf'), 2.294300000001, 9484.32]]},
... }),
... )
<Tensor 0x7f5b3219f370>
├── a --> tensor([ True, False, False])
└── b --> <Tensor 0x7f5b3219f550>
└── x --> tensor([[ True, False, True],
[ True, True, False]])
"""
return torch.isclose(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def eq(input, other, *args, **kwargs):
"""
In ``treetensor``, you can get the equality of the two tree tensors with :func:`eq`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.eq(
... torch.tensor([[1, 2], [3, 4]]),
... torch.tensor([[1, 1], [4, 4]]),
... )
tensor([[ True, False],
[False, True]])
>>> ttorch.eq(
... ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': [1.0, 1.5, 2.0],
... }),
... ttorch.tensor({
... 'a': [[1, 1], [4, 4]],
... 'b': [1.3, 1.2, 2.0],
... }),
... )
<Tensor 0x7ff363bbce10>
├── a --> tensor([[ True, False],
│ [False, True]])
└── b --> tensor([False, False, True])
"""
return torch.eq(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def ne(input, other, *args, **kwargs):
"""
In ``treetensor``, you can get the non-equality of the two tree tensors with :func:`ne`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.ne(
... torch.tensor([[1, 2], [3, 4]]),
... torch.tensor([[1, 1], [4, 4]]),
... )
tensor([[False, True],
[ True, False]])
>>> ttorch.ne(
... ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': [1.0, 1.5, 2.0],
... }),
... ttorch.tensor({
... 'a': [[1, 1], [4, 4]],
... 'b': [1.3, 1.2, 2.0],
... }),
... )
<Tensor 0x7ff363bb6cf8>
├── a --> tensor([[False, True],
│ [ True, False]])
└── b --> tensor([ True, True, False])
"""
return torch.ne(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def lt(input, other, *args, **kwargs):
"""
In ``treetensor``, you can get less-than situation of the two tree tensors with :func:`lt`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.lt(
... torch.tensor([[1, 2], [3, 4]]),
... torch.tensor([[1, 1], [4, 4]]),
... )
tensor([[False, False],
[ True, False]])
>>> ttorch.lt(
... ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': [1.0, 1.5, 2.0],
... }),
... ttorch.tensor({
... 'a': [[1, 1], [4, 4]],
... 'b': [1.3, 1.2, 2.0],
... }),
... )
<Tensor 0x7ff363bc67f0>
├── a --> tensor([[False, False],
│ [ True, False]])
└── b --> tensor([ True, False, False])
"""
return torch.lt(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def le(input, other, *args, **kwargs):
"""
In ``treetensor``, you can get less-than-or-equal situation of the two tree tensors with :func:`le`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.le(
... torch.tensor([[1, 2], [3, 4]]),
... torch.tensor([[1, 1], [4, 4]]),
... )
tensor([[ True, False],
[ True, True]])
>>> ttorch.le(
... ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': [1.0, 1.5, 2.0],
... }),
... ttorch.tensor({
... 'a': [[1, 1], [4, 4]],
... 'b': [1.3, 1.2, 2.0],
... }),
... )
<Tensor 0x7ff363bc6198>
├── a --> tensor([[ True, False],
│ [ True, True]])
└── b --> tensor([ True, False, True])
"""
return torch.le(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def gt(input, other, *args, **kwargs):
"""
In ``treetensor``, you can get greater-than situation of the two tree tensors with :func:`gt`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.gt(
... torch.tensor([[1, 2], [3, 4]]),
... torch.tensor([[1, 1], [4, 4]]),
... )
tensor([[False, True],
[False, False]])
>>> ttorch.gt(
... ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': [1.0, 1.5, 2.0],
... }),
... ttorch.tensor({
... 'a': [[1, 1], [4, 4]],
... 'b': [1.3, 1.2, 2.0],
... }),
... )
<Tensor 0x7ff363bc6f28>
├── a --> tensor([[False, True],
│ [False, False]])
└── b --> tensor([False, True, False])
"""
return torch.gt(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def ge(input, other, *args, **kwargs):
"""
In ``treetensor``, you can get greater-than-or-equal situation of the two tree tensors with :func:`ge`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.ge(
... torch.tensor([[1, 2], [3, 4]]),
... torch.tensor([[1, 1], [4, 4]]),
... )
tensor([[ True, True],
[False, True]])
>>> ttorch.ge(
... ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': [1.0, 1.5, 2.0],
... }),
... ttorch.tensor({
... 'a': [[1, 1], [4, 4]],
... 'b': [1.3, 1.2, 2.0],
... }),
... )
<Tensor 0x7ff363bc6f28>
├── a --> tensor([[ True, True],
│ [False, True]])
└── b --> tensor([False, True, True])
"""
return torch.ge(input, other, *args, **kwargs)
import torch
from .base import doc_from_base, func_treelize
__all__ = [
'tensor', 'clone',
'zeros', 'zeros_like',
'randn', 'randn_like',
'randint', 'randint_like',
'ones', 'ones_like',
'full', 'full_like',
'empty', 'empty_like',
]
@doc_from_base()
@func_treelize()
def tensor(*args, **kwargs):
"""
In ``treetensor``, you can create a tree tensor with simple data structure.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.tensor(True) # the same as torch.tensor(True)
tensor(True)
>>> ttorch.tensor([1, 2, 3]) # the same as torch.tensor([1, 2, 3])
tensor([1, 2, 3])
>>> ttorch.tensor({'a': 1, 'b': [1, 2, 3], 'c': [[True, False], [False, True]]})
<Tensor 0x7ff363bbcc50>
├── a --> tensor(1)
├── b --> tensor([1, 2, 3])
└── c --> tensor([[ True, False],
[False, True]])
"""
return torch.tensor(*args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def clone(input, *args, **kwargs):
"""
In ``treetensor``, you can create a clone of the original tree with :func:`treetensor.torch.clone`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.clone(torch.tensor([[1, 2], [3, 4]]))
tensor([[1, 2],
[3, 4]])
>>> ttorch.clone(ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': {'x': [[5], [6], [7]]},
... }))
<Tensor 0x7f2a820ba5e0>
├── a --> tensor([[1, 2],
│ [3, 4]])
└── b --> <Tensor 0x7f2a820aaf70>
└── x --> tensor([[5],
[6],
[7]])
"""
return torch.clone(input, *args, **kwargs)
@doc_from_base()
@func_treelize()
def zeros(*args, **kwargs):
"""
In ``treetensor``, you can use ``zeros`` to create a tree of tensors with all zeros.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.zeros(2, 3) # the same as torch.zeros(2, 3)
tensor([[0., 0., 0.],
[0., 0., 0.]])
>>> ttorch.zeros({'a': (2, 3), 'b': {'x': (4, )}})
<Tensor 0x7f5f6ccf1ef0>
├── a --> tensor([[0., 0., 0.],
│ [0., 0., 0.]])
└── b --> <Tensor 0x7f5fe0107208>
└── x --> tensor([0., 0., 0., 0.])
"""
return torch.zeros(*args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def zeros_like(input, *args, **kwargs):
"""
In ``treetensor``, you can use ``zeros_like`` to create a tree of tensors with all zeros like another tree.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.zeros_like(torch.randn(2, 3)) # the same as torch.zeros_like(torch.randn(2, 3))
tensor([[0., 0., 0.],
[0., 0., 0.]])
>>> ttorch.zeros_like({
... 'a': torch.randn(2, 3),
... 'b': {'x': torch.randn(4, )},
... })
<Tensor 0x7ff363bb6128>
├── a --> tensor([[0., 0., 0.],
│ [0., 0., 0.]])
└── b --> <Tensor 0x7ff363bb6080>
└── x --> tensor([0., 0., 0., 0.])
"""
return torch.zeros_like(input, *args, **kwargs)
@doc_from_base()
@func_treelize()
def randn(*args, **kwargs):
"""
In ``treetensor``, you can use ``randn`` to create a tree of tensors with numbers
obey standard normal distribution.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.randn(2, 3) # the same as torch.randn(2, 3)
tensor([[-0.8534, -0.5754, -0.2507],
[ 0.0826, -1.4110, 0.9748]])
>>> ttorch.randn({'a': (2, 3), 'b': {'x': (4, )}})
<Tensor 0x7ff363bb6518>
├── a --> tensor([[ 0.5398, 0.7529, -2.0339],
│ [-0.5722, -1.1900, 0.7945]])
└── b --> <Tensor 0x7ff363bb6438>
└── x --> tensor([-0.7181, 0.1670, -1.3587, -1.5129])
"""
return torch.randn(*args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def randn_like(input, *args, **kwargs):
"""
In ``treetensor``, you can use ``randn_like`` to create a tree of tensors with numbers
obey standard normal distribution like another tree.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.randn_like(torch.ones(2, 3)) # the same as torch.randn_like(torch.ones(2, 3))
tensor([[ 1.8436, 0.2601, 0.9687],
[ 1.6430, -0.1765, -1.1732]])
>>> ttorch.randn_like({
... 'a': torch.ones(2, 3),
... 'b': {'x': torch.ones(4, )},
... })
<Tensor 0x7ff3d6f3cb38>
├── a --> tensor([[-0.1532, 1.3965, -1.2956],
│ [-0.0750, 0.6475, 1.1421]])
└── b --> <Tensor 0x7ff3d6f420b8>
└── x --> tensor([ 0.1730, 1.6085, 0.6487, -1.1022])
"""
return torch.randn_like(input, *args, **kwargs)
@doc_from_base()
@func_treelize()
def randint(*args, **kwargs):
"""
In ``treetensor``, you can use ``randint`` to create a tree of tensors with numbers
in an integer range.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.randint(10, (2, 3)) # the same as torch.randint(10, (2, 3))
tensor([[3, 4, 5],
[4, 5, 5]])
>>> ttorch.randint(10, {'a': (2, 3), 'b': {'x': (4, )}})
<Tensor 0x7ff363bb6438>
├── a --> tensor([[5, 3, 7],
│ [8, 1, 8]])
└── b --> <Tensor 0x7ff363bb6240>
└── x --> tensor([8, 8, 2, 4])
"""
return torch.randint(*args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def randint_like(input, *args, **kwargs):
"""
In ``treetensor``, you can use ``randint_like`` to create a tree of tensors with numbers
in an integer range.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.randint_like(torch.ones(2, 3), 10) # the same as torch.randint_like(torch.ones(2, 3), 10)
tensor([[0., 5., 0.],
[2., 0., 9.]])
>>> ttorch.randint_like({
... 'a': torch.ones(2, 3),
... 'b': {'x': torch.ones(4, )},
... }, 10)
<Tensor 0x7ff363bb6748>
├── a --> tensor([[3., 6., 1.],
│ [8., 9., 5.]])
└── b --> <Tensor 0x7ff363bb6898>
└── x --> tensor([4., 4., 7., 1.])
"""
return torch.randint_like(input, *args, **kwargs)
@doc_from_base()
@func_treelize()
def ones(*args, **kwargs):
"""
In ``treetensor``, you can use ``ones`` to create a tree of tensors with all ones.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.ones(2, 3) # the same as torch.ones(2, 3)
tensor([[1., 1., 1.],
[1., 1., 1.]])
>>> ttorch.ones({'a': (2, 3), 'b': {'x': (4, )}})
<Tensor 0x7ff363bb6eb8>
├── a --> tensor([[1., 1., 1.],
│ [1., 1., 1.]])
└── b --> <Tensor 0x7ff363bb6dd8>
└── x --> tensor([1., 1., 1., 1.])
"""
return torch.ones(*args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def ones_like(input, *args, **kwargs):
"""
In ``treetensor``, you can use ``ones_like`` to create a tree of tensors with all ones like another tree.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.ones_like(torch.randn(2, 3)) # the same as torch.ones_like(torch.randn(2, 3))
tensor([[1., 1., 1.],
[1., 1., 1.]])
>>> ttorch.ones_like({
... 'a': torch.randn(2, 3),
... 'b': {'x': torch.randn(4, )},
... })
<Tensor 0x7ff363bbc320>
├── a --> tensor([[1., 1., 1.],
│ [1., 1., 1.]])
└── b --> <Tensor 0x7ff363bbc240>
└── x --> tensor([1., 1., 1., 1.])
"""
return torch.ones_like(input, *args, **kwargs)
@doc_from_base()
@func_treelize()
def full(*args, **kwargs):
"""
In ``treetensor``, you can use ``ones`` to create a tree of tensors with the same value.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.full((2, 3), 2.3) # the same as torch.full((2, 3), 2.3)
tensor([[2.3000, 2.3000, 2.3000],
[2.3000, 2.3000, 2.3000]])
>>> ttorch.full({'a': (2, 3), 'b': {'x': (4, )}}, 2.3)
<Tensor 0x7ff363bbc7f0>
├── a --> tensor([[2.3000, 2.3000, 2.3000],
│ [2.3000, 2.3000, 2.3000]])
└── b --> <Tensor 0x7ff363bbc8d0>
└── x --> tensor([2.3000, 2.3000, 2.3000, 2.3000])
"""
return torch.full(*args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def full_like(input, *args, **kwargs):
"""
In ``treetensor``, you can use ``ones_like`` to create a tree of tensors with
all the same value of like another tree.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.full_like(torch.randn(2, 3), 2.3) # the same as torch.full_like(torch.randn(2, 3), 2.3)
tensor([[2.3000, 2.3000, 2.3000],
[2.3000, 2.3000, 2.3000]])
>>> ttorch.full_like({
... 'a': torch.randn(2, 3),
... 'b': {'x': torch.randn(4, )},
... }, 2.3)
<Tensor 0x7ff363bb6cf8>
├── a --> tensor([[2.3000, 2.3000, 2.3000],
│ [2.3000, 2.3000, 2.3000]])
└── b --> <Tensor 0x7ff363bb69e8>
└── x --> tensor([2.3000, 2.3000, 2.3000, 2.3000])
"""
return torch.full_like(input, *args, **kwargs)
@doc_from_base()
@func_treelize()
def empty(*args, **kwargs):
"""
In ``treetensor``, you can use ``ones`` to create a tree of tensors with
the uninitialized values.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.empty(2, 3) # the same as torch.empty(2, 3)
tensor([[-1.3267e-36, 3.0802e-41, 2.3000e+00],
[ 2.3000e+00, 2.3000e+00, 2.3000e+00]])
>>> ttorch.empty({'a': (2, 3), 'b': {'x': (4, )}})
<Tensor 0x7ff363bb6080>
├── a --> tensor([[-3.6515e+14, 4.5900e-41, -1.3253e-36],
│ [ 3.0802e-41, 2.3000e+00, 2.3000e+00]])
└── b --> <Tensor 0x7ff363bb66d8>
└── x --> tensor([-3.6515e+14, 4.5900e-41, -3.8091e-38, 3.0802e-41])
"""
return torch.empty(*args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def empty_like(input, *args, **kwargs):
"""
In ``treetensor``, you can use ``ones_like`` to create a tree of tensors with
all the uninitialized values of like another tree.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.empty_like(torch.randn(2, 3)) # the same as torch.empty_like(torch.randn(2, 3), 2.3)
tensor([[-3.6515e+14, 4.5900e-41, -1.3266e-36],
[ 3.0802e-41, 4.4842e-44, 0.0000e+00]])
>>> ttorch.empty_like({
... 'a': torch.randn(2, 3),
... 'b': {'x': torch.randn(4, )},
... })
<Tensor 0x7ff363bbc780>
├── a --> tensor([[-3.6515e+14, 4.5900e-41, -3.6515e+14],
│ [ 4.5900e-41, 1.1592e-41, 0.0000e+00]])
└── b --> <Tensor 0x7ff3d6f3cb38>
└── x --> tensor([-1.3267e-36, 3.0802e-41, -3.8049e-38, 3.0802e-41])
"""
return torch.empty_like(input, *args, **kwargs)
import torch
from .base import doc_from_base, func_treelize
__all__ = [
'dot', 'matmul', 'mm',
]
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def dot(input, other, *args, **kwargs):
"""
In ``treetensor``, you can get the dot product of 2 tree tensors with :func:`treetensor.torch.dot`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.dot(torch.tensor([1, 2]), torch.tensor([2, 3]))
tensor(8)
>>> ttorch.dot(
... ttorch.tensor({
... 'a': [1, 2, 3],
... 'b': {'x': [3, 4]},
... }),
... ttorch.tensor({
... 'a': [5, 6, 7],
... 'b': {'x': [1, 2]},
... })
... )
<Tensor 0x7feac55bde50>
├── a --> tensor(38)
└── b --> <Tensor 0x7feac55c9250>
└── x --> tensor(11)
"""
return torch.dot(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def matmul(input, other, *args, **kwargs):
"""
In ``treetensor``, you can create a matrix product with :func:`treetensor.torch.matmul`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.matmul(
... torch.tensor([[1, 2], [3, 4]]),
... torch.tensor([[5, 6], [7, 2]]),
... )
tensor([[19, 10],
[43, 26]])
>>> ttorch.matmul(
... ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': {'x': [3, 4, 5, 6]},
... }),
... ttorch.tensor({
... 'a': [[5, 6], [7, 2]],
... 'b': {'x': [4, 3, 2, 1]},
... }),
... )
<Tensor 0x7f2e74883f40>
├── a --> tensor([[19, 10],
│ [43, 26]])
└── b --> <Tensor 0x7f2e74886430>
└── x --> tensor(40)
"""
return torch.matmul(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def mm(input, mat2, *args, **kwargs):
"""
In ``treetensor``, you can create a matrix multiplication with :func:`treetensor.torch.mm`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.mm(
... torch.tensor([[1, 2], [3, 4]]),
... torch.tensor([[5, 6], [7, 2]]),
... )
tensor([[19, 10],
[43, 26]])
>>> ttorch.mm(
... ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': {'x': [[3, 4, 5], [6, 7, 8]]},
... }),
... ttorch.tensor({
... 'a': [[5, 6], [7, 2]],
... 'b': {'x': [[6, 5], [4, 3], [2, 1]]},
... }),
... )
<Tensor 0x7f2e7489f340>
├── a --> tensor([[19, 10],
│ [43, 26]])
└── b --> <Tensor 0x7f2e74896e50>
└── x --> tensor([[44, 32],
[80, 59]])
"""
return torch.mm(input, mat2, *args, **kwargs)
import torch
from treevalue import TreeValue
from treevalue.utils import post_process
from .base import doc_from_base, func_treelize, auto_tensor
__all__ = [
'cat', 'split', 'stack', 'reshape', 'where', 'squeeze', 'unsqueeze',
]
@doc_from_base()
@func_treelize(subside=dict(return_type=TreeValue))
def cat(tensors, *args, **kwargs):
"""
Concatenates the given sequence of ``seq`` tensors in the given dimension.
All tensors must either have the same shape (except in the concatenating dimension) or be empty.
Examples:
>>> import torch
>>> import treetensor.torch as ttorch
>>> t1 = torch.randint(10, 30, (2, 3))
>>> t1
tensor([[21, 29, 17],
[16, 11, 16]])
>>> t2 = torch.randint(30, 50, (2, 3))
tensor([[46, 46, 46],
[30, 47, 36]])
>>> t2
>>> t3 = torch.randint(50, 70, (2, 3))
tensor([[51, 65, 65],
[54, 67, 57]])
>>> t3
>>> ttorch.cat((t1, t2, t3))
tensor([[21, 29, 17],
[16, 11, 16],
[46, 46, 46],
[30, 47, 36],
[51, 65, 65],
[54, 67, 57]])
>>> tt1 = ttorch.Tensor({
... 'a': t1,
... 'b': {'x': t2, 'y': t3},
... })
>>> tt1
<Tensor 0x7fed579acf60>
├── a --> tensor([[21, 29, 17],
│ [16, 11, 16]])
└── b --> <Tensor 0x7fed579acf28>
├── x --> tensor([[46, 46, 46],
│ [30, 47, 36]])
└── y --> tensor([[51, 65, 65],
[54, 67, 57]])
>>> tt2 = ttorch.Tensor({
... 'a': t2,
... 'b': {'x': t3, 'y': t1},
... })
>>> tt2
<Tensor 0x7fed579d62e8>
├── a --> tensor([[46, 46, 46],
│ [30, 47, 36]])
└── b --> <Tensor 0x7fed579d62b0>
├── x --> tensor([[51, 65, 65],
│ [54, 67, 57]])
└── y --> tensor([[21, 29, 17],
[16, 11, 16]])
>>> tt3 = ttorch.Tensor({
... 'a': t3,
... 'b': {'x': t1, 'y': t2},
... })
>>> tt3
<Tensor 0x7fed579d66a0>
├── a --> tensor([[51, 65, 65],
│ [54, 67, 57]])
└── b --> <Tensor 0x7fed579d65f8>
├── x --> tensor([[21, 29, 17],
│ [16, 11, 16]])
└── y --> tensor([[46, 46, 46],
[30, 47, 36]]
>>> ttorch.cat((tt1, tt2, tt3))
<Tensor 0x7fed579d6ac8>
├── a --> tensor([[21, 29, 17],
│ [16, 11, 16],
│ [46, 46, 46],
│ [30, 47, 36],
│ [51, 65, 65],
│ [54, 67, 57]])
└── b --> <Tensor 0x7fed579d6a90>
├── x --> tensor([[46, 46, 46],
│ [30, 47, 36],
│ [51, 65, 65],
│ [54, 67, 57],
│ [21, 29, 17],
│ [16, 11, 16]])
└── y --> tensor([[51, 65, 65],
[54, 67, 57],
[21, 29, 17],
[16, 11, 16],
[46, 46, 46],
[30, 47, 36]])
>>> ttorch.cat((tt1, tt2, tt3), dim=1)
<Tensor 0x7fed579644a8>
├── a --> tensor([[21, 29, 17, 46, 46, 46, 51, 65, 65],
│ [16, 11, 16, 30, 47, 36, 54, 67, 57]])
└── b --> <Tensor 0x7fed57964438>
├── x --> tensor([[46, 46, 46, 51, 65, 65, 21, 29, 17],
│ [30, 47, 36, 54, 67, 57, 16, 11, 16]])
└── y --> tensor([[51, 65, 65, 21, 29, 17, 46, 46, 46],
[54, 67, 57, 16, 11, 16, 30, 47, 36]])
"""
return torch.cat(tensors, *args, **kwargs)
# noinspection PyShadowingNames
@doc_from_base()
@post_process(lambda r: tuple(map(auto_tensor, r)))
@func_treelize(return_type=TreeValue, rise=dict(template=[None]))
@post_process(lambda r: list(r))
def split(tensor, split_size_or_sections, *args, **kwargs):
"""
Splits the tensor into chunks. Each chunk is a view of the original tensor.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t1 = torch.randint(100, (6, 2))
>>> t1
tensor([[59, 82],
[86, 42],
[71, 84],
[61, 58],
[82, 37],
[14, 31]])
>>> ttorch.split(t1, (1, 2, 3))
(tensor([[59, 82]]), tensor([[86, 42],
[71, 84]]), tensor([[61, 58],
[82, 37],
[14, 31]]))
>>> tt1 = ttorch.randint(100, {
... 'a': (6, 2),
... 'b': {'x': (6, 2, 3)},
... })
>>> tt1
<Tensor 0x7f4c8d786400>
├── a --> tensor([[ 1, 65],
│ [68, 31],
│ [76, 73],
│ [74, 76],
│ [90, 0],
│ [95, 89]])
└── b --> <Tensor 0x7f4c8d786320>
└── x --> tensor([[[11, 20, 74],
[17, 85, 44]],
[[67, 37, 89],
[76, 28, 0]],
[[56, 12, 7],
[17, 63, 32]],
[[81, 75, 19],
[89, 21, 55]],
[[71, 53, 0],
[66, 82, 57]],
[[73, 81, 11],
[58, 54, 78]]])
>>> ttorch.split(tt1, (1, 2, 3))
(<Tensor 0x7f4c8d7861d0>
├── a --> tensor([[ 1, 65]])
└── b --> <Tensor 0x7f4c8d786128>
└── x --> tensor([[[11, 20, 74],
[17, 85, 44]]])
, <Tensor 0x7f4c8d7860f0>
├── a --> tensor([[68, 31],
│ [76, 73]])
└── b --> <Tensor 0x7f4c8d7860b8>
└── x --> tensor([[[67, 37, 89],
[76, 28, 0]],
[[56, 12, 7],
[17, 63, 32]]])
, <Tensor 0x7f4c8d7866d8>
├── a --> tensor([[74, 76],
│ [90, 0],
│ [95, 89]])
└── b --> <Tensor 0x7f4c8d786668>
└── x --> tensor([[[81, 75, 19],
[89, 21, 55]],
[[71, 53, 0],
[66, 82, 57]],
[[73, 81, 11],
[58, 54, 78]]])
)
"""
return torch.split(tensor, split_size_or_sections, *args, **kwargs)
@doc_from_base()
@func_treelize(subside=dict(return_type=TreeValue))
def stack(tensors, *args, **kwargs):
"""
Concatenates a sequence of tensors along a new dimension.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t1 = torch.randint(10, 30, (2, 3))
>>> t1
tensor([[17, 15, 27],
[12, 17, 29]])
>>> t2 = torch.randint(30, 50, (2, 3))
>>> t2
tensor([[45, 41, 47],
[37, 37, 36]])
>>> t3 = torch.randint(50, 70, (2, 3))
>>> t3
tensor([[60, 50, 55],
[69, 54, 58]])
>>> ttorch.stack((t1, t2, t3))
tensor([[[17, 15, 27],
[12, 17, 29]],
[[45, 41, 47],
[37, 37, 36]],
[[60, 50, 55],
[69, 54, 58]]])
>>> tt1 = ttorch.randint(10, 30, {
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... })
>>> tt1
<Tensor 0x7f4c8eba9630>
├── a --> tensor([[25, 22, 29],
│ [19, 21, 27]])
└── b --> <Tensor 0x7f4c8eba9550>
└── x --> tensor([[20, 17, 28, 10],
[28, 16, 27, 27],
[18, 21, 17, 12]])
>>> tt2 = ttorch.randint(30, 50, {
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... })
>>> tt2
<Tensor 0x7f4c8eba97b8>
├── a --> tensor([[40, 44, 41],
│ [39, 44, 40]])
└── b --> <Tensor 0x7f4c8eba9710>
└── x --> tensor([[44, 42, 38, 44],
[30, 44, 42, 31],
[36, 30, 33, 31]])
>>> ttorch.stack((tt1, tt2))
<Tensor 0x7f4c8eb411d0>
├── a --> tensor([[[25, 22, 29],
│ [19, 21, 27]],
│ [[40, 44, 41],
│ [39, 44, 40]]])
└── b --> <Tensor 0x7f4c8eb410b8>
└── x --> tensor([[[20, 17, 28, 10],
[28, 16, 27, 27],
[18, 21, 17, 12]],
[[44, 42, 38, 44],
[30, 44, 42, 31],
[36, 30, 33, 31]]])
>>> ttorch.stack((tt1, tt2), dim=1)
<Tensor 0x7f4c8eba9da0>
├── a --> tensor([[[25, 22, 29],
│ [40, 44, 41]],
│ [[19, 21, 27],
│ [39, 44, 40]]])
└── b --> <Tensor 0x7f4d01fb4898>
└── x --> tensor([[[20, 17, 28, 10],
[44, 42, 38, 44]],
[[28, 16, 27, 27],
[30, 44, 42, 31]],
[[18, 21, 17, 12],
[36, 30, 33, 31]]])
"""
return torch.stack(tensors, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def reshape(input, shape):
"""
Returns a tensor with the same data and number of elements as ``input``,
but with the specified shape. When possible, the returned tensor will be a view of ``input``.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.reshape(torch.tensor([[1, 2], [3, 4]]), (-1, ))
tensor([1, 2, 3, 4])
>>> ttorch.reshape(ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': {'x': [[2], [3], [5], [7], [11], [13]]},
... }), (-1, ))
<Tensor 0x7fc9efa3bda0>
├── a --> tensor([1, 2, 3, 4])
└── b --> <Tensor 0x7fc9efa3bcf8>
└── x --> tensor([ 2, 3, 5, 7, 11, 13])
.. note::
If the given ``shape`` is only one tuple, it should make sure that all the tensors
in this tree can be reshaped to the given ``shape``. Or you can give a tree of tuples
to reshape the tensors to different shapes.
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.reshape(ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': {'x': [[2], [3], [5], [7], [11], [13]]},
... }), {'a': (4, ), 'b': {'x': (3, 2)}})
<Tensor 0x7fc9efa3bd68>
├── a --> tensor([1, 2, 3, 4])
└── b --> <Tensor 0x7fc9efa3bf28>
└── x --> tensor([[ 2, 3],
[ 5, 7],
[11, 13]])
"""
return torch.reshape(input, shape)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def squeeze(input, *args, **kwargs):
"""
Returns a tensor with all the dimensions of ``input`` of size 1 removed.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t1 = torch.randint(100, (2, 1, 2, 1, 2))
>>> t1.shape
torch.Size([2, 1, 2, 1, 2])
>>> ttorch.squeeze(t1).shape
torch.Size([2, 2, 2])
>>> tt1 = ttorch.randint(100, {
... 'a': (2, 1, 2, 1, 2),
... 'b': {'x': (2, 1, 1, 3)},
... })
>>> tt1.shape
<Size 0x7fa4c1b05410>
├── a --> torch.Size([2, 1, 2, 1, 2])
└── b --> <Size 0x7fa4c1b05510>
└── x --> torch.Size([2, 1, 1, 3])
>>> ttorch.squeeze(tt1).shape
<Size 0x7fa4c1b9f3d0>
├── a --> torch.Size([2, 2, 2])
└── b --> <Size 0x7fa4c1afe710>
└── x --> torch.Size([2, 3])
"""
return torch.squeeze(input, *args, *kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@func_treelize()
def unsqueeze(input, dim):
"""
Returns a new tensor with a dimension of size one inserted at the specified position.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t1 = torch.randint(100, (100, ))
>>> t1.shape
torch.Size([100])
>>> ttorch.unsqueeze(t1, 0).shape
torch.Size([1, 100])
>>> tt1 = ttorch.randint(100, {
... 'a': (2, 2, 2),
... 'b': {'x': (2, 3)},
... })
>>> tt1.shape
<Size 0x7f5d1a5741d0>
├── a --> torch.Size([2, 2, 2])
└── b --> <Size 0x7f5d1a5740b8>
└── x --> torch.Size([2, 3])
>>> ttorch.unsqueeze(tt1, 1).shape
<Size 0x7f5d1a5c98d0>
├── a --> torch.Size([2, 1, 2, 2])
└── b --> <Size 0x7f5d1a5c99b0>
└── x --> torch.Size([2, 1, 3])
"""
return torch.unsqueeze(input, dim)
@doc_from_base()
@func_treelize()
def where(condition, x, y):
"""
Return a tree of tensors of elements selected from either ``x`` or ``y``, depending on ``condition``.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.where(
... torch.tensor([[True, False], [False, True]]),
... torch.tensor([[2, 8], [16, 4]]),
... torch.tensor([[3, 11], [5, 7]]),
... )
tensor([[ 2, 11],
[ 5, 4]])
>>> tt1 = ttorch.randint(1, 99, {'a': (2, 3), 'b': {'x': (3, 2, 4)}})
>>> tt1
<Tensor 0x7f6760ad9908>
├── a --> tensor([[27, 90, 80],
│ [12, 59, 5]])
└── b --> <Tensor 0x7f6760ad9860>
└── x --> tensor([[[71, 52, 92, 79],
[48, 4, 13, 96]],
[[72, 89, 44, 62],
[32, 4, 29, 76]],
[[ 6, 3, 93, 89],
[44, 89, 85, 90]]])
>>> ttorch.where(tt1 % 2 == 1, tt1, 0)
<Tensor 0x7f6760ad9d30>
├── a --> tensor([[27, 0, 0],
│ [ 0, 59, 5]])
└── b --> <Tensor 0x7f6760ad9f98>
└── x --> tensor([[[71, 0, 0, 79],
[ 0, 0, 13, 0]],
[[ 0, 89, 0, 0],
[ 0, 0, 29, 0]],
[[ 0, 3, 93, 89],
[ 0, 89, 85, 0]]])
"""
return torch.where(condition, x, y)
import torch
from .base import doc_from_base, func_treelize
from ..tensor import tireduce
from ...common import Object
__all__ = [
'all', 'any',
'min', 'max', 'sum',
]
# noinspection PyShadowingBuiltins
@doc_from_base()
@tireduce(torch.all)
@func_treelize(return_type=Object)
def all(input, *args, **kwargs):
"""
In ``treetensor``, you can get the ``all`` result of a whole tree with this function.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.all(torch.tensor([True, True])) # the same as torch.all
tensor(True)
>>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, True]}}))
tensor(True)
>>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, False]}}))
tensor(False)
.. note::
In this ``all`` function, the return value should be a tensor with single boolean value.
If what you need is a tree of boolean tensors, you should do like this
>>> ttorch.tensor({
... 'a': [True, True],
... 'b': {'x': [True, False]},
... }).map(lambda x: torch.all(x))
<Tensor 0x7ff363bbc588>
├── a --> tensor(True)
└── b --> <Tensor 0x7ff363bb6438>
└── x --> tensor(False)
"""
return torch.all(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@tireduce(torch.any)
@func_treelize(return_type=Object)
def any(input, *args, **kwargs):
"""
In ``treetensor``, you can get the ``any`` result of a whole tree with this function.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.any(torch.tensor([False, False])) # the same as torch.any
tensor(False)
>>> ttorch.any(ttorch.tensor({'a': [True, False], 'b': {'x': [False, False]}}))
tensor(True)
>>> ttorch.any(ttorch.tensor({'a': [False, False], 'b': {'x': [False, False]}}))
tensor(False)
.. note::
In this ``any`` function, the return value should be a tensor with single boolean value.
If what you need is a tree of boolean tensors, you should do like this
>>> ttorch.tensor({
>>> 'a': [True, False],
>>> 'b': {'x': [False, False]},
>>> }).map(lambda x: torch.any(x))
<Tensor 0x7ff363bc6898>
├── a --> tensor(True)
└── b --> <Tensor 0x7ff363bc67f0>
└── x --> tensor(False)
"""
return torch.any(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@tireduce(torch.min)
@func_treelize(return_type=Object)
def min(input, *args, **kwargs):
"""
In ``treetensor``, you can get the ``min`` result of a whole tree with this function.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.min(torch.tensor([1.0, 2.0, 1.5])) # the same as torch.min
tensor(1.)
>>> ttorch.min(ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }))
tensor(0.9000)
.. note::
In this ``min`` function, the return value should be a tensor with single value.
If what you need is a tree of tensors, you should do like this
>>> ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }).map(lambda x: torch.min(x))
<Tensor 0x7ff363bbb2b0>
├── a --> tensor(1.)
└── b --> <Tensor 0x7ff363bbb0b8>
└── x --> tensor(0.9000)
"""
return torch.min(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@tireduce(torch.max)
@func_treelize(return_type=Object)
def max(input, *args, **kwargs):
"""
In ``treetensor``, you can get the ``max`` result of a whole tree with this function.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.max(torch.tensor([1.0, 2.0, 1.5])) # the same as torch.max
tensor(2.)
>>> ttorch.max(ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }))
tensor(2.5000)
.. note::
In this ``max`` function, the return value should be a tensor with single value.
If what you need is a tree of tensors, you should do like this
>>> ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }).map(lambda x: torch.max(x))
<Tensor 0x7ff363bc6b00>
├── a --> tensor(2.)
└── b --> <Tensor 0x7ff363bc6c18>
└── x --> tensor(2.5000)
"""
return torch.max(input, *args, **kwargs)
# noinspection PyShadowingBuiltins
@doc_from_base()
@tireduce(torch.sum)
@func_treelize(return_type=Object)
def sum(input, *args, **kwargs):
"""
In ``treetensor``, you can get the ``sum`` result of a whole tree with this function.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.sum(torch.tensor([1.0, 2.0, 1.5])) # the same as torch.sum
tensor(4.5000)
>>> ttorch.sum(ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }))
tensor(11.)
.. note::
In this ``sum`` function, the return value should be a tensor with single value.
If what you need is a tree of tensors, you should do like this
>>> ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }).map(lambda x: torch.sum(x))
<Tensor 0x7ff363bbbda0>
├── a --> tensor(4.5000)
└── b --> <Tensor 0x7ff363bbbcf8>
└── x --> tensor(6.5000)
"""
return torch.sum(input, *args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册