未验证 提交 554b8cae 编写于 作者: HansBug's avatar HansBug 😆 提交者: GitHub

Merge pull request #10 from opendilab/dev/stream

dev(hansbug): add stream support for paralleling the calculations in tree
import unittest
import pytest
import torch
import torch.cuda
import treetensor.torch as ttorch
_CUDA_OK = torch.cuda.is_available()
N, M, T = 200, 2, 50
S1, S2, S3 = 32, 128, 512
# noinspection DuplicatedCode
@pytest.mark.unittest
class TestTorchStream:
def test_simple(self):
a = ttorch.randn({f'a{i}': (S1, S2) for i in range(N)})
b = ttorch.randn({f'a{i}': (S2, S3) for i in range(N)})
c = ttorch.matmul(a, b)
for i in range(N):
assert torch.isclose(
c[f'a{i}'], torch.matmul(a[f'a{i}'], b[f'a{i}'])
).all(), f'Not match on item {f"a{i}"!r}.'
@unittest.skipUnless(_CUDA_OK, 'CUDA required')
def test_simple_with_cuda(self):
a = ttorch.randn({f'a{i}': (S1, S2) for i in range(N)}, device='cuda')
b = ttorch.randn({f'a{i}': (S2, S3) for i in range(N)}, device='cuda')
torch.cuda.synchronize()
c = ttorch.matmul(a, b)
torch.cuda.synchronize()
for i in range(N):
assert torch.isclose(
c[f'a{i}'], torch.matmul(a[f'a{i}'], b[f'a{i}'])
).all(), f'Not match on item {f"a{i}"!r}.'
@unittest.skipUnless(not _CUDA_OK, 'No CUDA required')
def test_stream_without_cuda(self):
with pytest.raises(AssertionError):
ttorch.stream(10)
@unittest.skipUnless(_CUDA_OK, 'CUDA required')
def test_stream_with_cuda(self):
a = ttorch.randn({f'a{i}': (S1, S2) for i in range(N)}, device='cuda')
b = ttorch.randn({f'a{i}': (S2, S3) for i in range(N)}, device='cuda')
ttorch.stream(4)
torch.cuda.synchronize()
c = ttorch.matmul(a, b)
torch.cuda.synchronize()
for i in range(N):
assert torch.isclose(
c[f'a{i}'], torch.matmul(a[f'a{i}'], b[f'a{i}'])
).all(), f'Not match on item {f"a{i}"!r}.'
...@@ -10,6 +10,8 @@ from .funcs import __all__ as _funcs_all ...@@ -10,6 +10,8 @@ from .funcs import __all__ as _funcs_all
from .funcs.base import get_func_from_torch from .funcs.base import get_func_from_torch
from .size import * from .size import *
from .size import __all__ as _size_all from .size import __all__ as _size_all
from .stream import *
from .stream import __all__ as _stream_all
from .tensor import * from .tensor import *
from .tensor import __all__ as _tensor_all from .tensor import __all__ as _tensor_all
from ..config.meta import __VERSION__ from ..config.meta import __VERSION__
...@@ -18,6 +20,7 @@ __all__ = [ ...@@ -18,6 +20,7 @@ __all__ = [
*_funcs_all, *_funcs_all,
*_size_all, *_size_all,
*_tensor_all, *_tensor_all,
*_stream_all,
] ]
_basic_types = ( _basic_types = (
......
...@@ -3,6 +3,7 @@ import builtins ...@@ -3,6 +3,7 @@ import builtins
import torch import torch
from .base import doc_from_base, func_treelize from .base import doc_from_base, func_treelize
from ..stream import stream_call
from ...common import ireduce from ...common import ireduce
__all__ = [ __all__ = [
...@@ -42,7 +43,7 @@ def equal(input, other): ...@@ -42,7 +43,7 @@ def equal(input, other):
... ) ... )
True True
""" """
return torch.equal(input, other) return stream_call(torch.equal, input, other)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -70,7 +71,7 @@ def isfinite(input): ...@@ -70,7 +71,7 @@ def isfinite(input):
└── x --> tensor([[ True, False, True], └── x --> tensor([[ True, False, True],
[False, True, False]]) [False, True, False]])
""" """
return torch.isfinite(input) return stream_call(torch.isfinite, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -98,7 +99,7 @@ def isinf(input): ...@@ -98,7 +99,7 @@ def isinf(input):
└── x --> tensor([[False, True, False], └── x --> tensor([[False, True, False],
[ True, False, False]]) [ True, False, False]])
""" """
return torch.isinf(input) return stream_call(torch.isinf, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -126,7 +127,7 @@ def isnan(input): ...@@ -126,7 +127,7 @@ def isnan(input):
└── x --> tensor([[False, False, False], └── x --> tensor([[False, False, False],
[False, False, True]]) [False, False, True]])
""" """
return torch.isnan(input) return stream_call(torch.isnan, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -166,7 +167,7 @@ def isclose(input, other, *args, **kwargs): ...@@ -166,7 +167,7 @@ def isclose(input, other, *args, **kwargs):
└── x --> tensor([[ True, False, True], └── x --> tensor([[ True, False, True],
[ True, True, False]]) [ True, True, False]])
""" """
return torch.isclose(input, other, *args, **kwargs) return stream_call(torch.isclose, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -202,7 +203,7 @@ def eq(input, other, *args, **kwargs): ...@@ -202,7 +203,7 @@ def eq(input, other, *args, **kwargs):
│ [False, True]]) │ [False, True]])
└── b --> tensor([False, False, True]) └── b --> tensor([False, False, True])
""" """
return torch.eq(input, other, *args, **kwargs) return stream_call(torch.eq, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -238,7 +239,7 @@ def ne(input, other, *args, **kwargs): ...@@ -238,7 +239,7 @@ def ne(input, other, *args, **kwargs):
│ [ True, False]]) │ [ True, False]])
└── b --> tensor([ True, True, False]) └── b --> tensor([ True, True, False])
""" """
return torch.ne(input, other, *args, **kwargs) return stream_call(torch.ne, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -274,7 +275,7 @@ def lt(input, other, *args, **kwargs): ...@@ -274,7 +275,7 @@ def lt(input, other, *args, **kwargs):
│ [ True, False]]) │ [ True, False]])
└── b --> tensor([ True, False, False]) └── b --> tensor([ True, False, False])
""" """
return torch.lt(input, other, *args, **kwargs) return stream_call(torch.lt, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -310,7 +311,7 @@ def le(input, other, *args, **kwargs): ...@@ -310,7 +311,7 @@ def le(input, other, *args, **kwargs):
│ [ True, True]]) │ [ True, True]])
└── b --> tensor([ True, False, True]) └── b --> tensor([ True, False, True])
""" """
return torch.le(input, other, *args, **kwargs) return stream_call(torch.le, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -346,7 +347,7 @@ def gt(input, other, *args, **kwargs): ...@@ -346,7 +347,7 @@ def gt(input, other, *args, **kwargs):
│ [False, False]]) │ [False, False]])
└── b --> tensor([False, True, False]) └── b --> tensor([False, True, False])
""" """
return torch.gt(input, other, *args, **kwargs) return stream_call(torch.gt, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -382,4 +383,4 @@ def ge(input, other, *args, **kwargs): ...@@ -382,4 +383,4 @@ def ge(input, other, *args, **kwargs):
│ [False, True]]) │ [False, True]])
└── b --> tensor([False, True, True]) └── b --> tensor([False, True, True])
""" """
return torch.ge(input, other, *args, **kwargs) return stream_call(torch.ge, input, other, *args, **kwargs)
...@@ -3,6 +3,7 @@ from treevalue import TreeValue ...@@ -3,6 +3,7 @@ from treevalue import TreeValue
from treevalue.tree.common import TreeStorage from treevalue.tree.common import TreeStorage
from .base import doc_from_base, func_treelize from .base import doc_from_base, func_treelize
from ..stream import stream_call
from ...utils import args_mapping from ...utils import args_mapping
__all__ = [ __all__ = [
...@@ -42,7 +43,7 @@ def tensor(data, *args, **kwargs): ...@@ -42,7 +43,7 @@ def tensor(data, *args, **kwargs):
└── c --> tensor([[ True, False], └── c --> tensor([[ True, False],
[False, True]]) [False, True]])
""" """
return torch.tensor(data, *args, **kwargs) return stream_call(torch.tensor, data, *args, **kwargs)
@doc_from_base() @doc_from_base()
...@@ -72,7 +73,7 @@ def as_tensor(data, *args, **kwargs): ...@@ -72,7 +73,7 @@ def as_tensor(data, *args, **kwargs):
└── x --> tensor([[4., 5.], └── x --> tensor([[4., 5.],
[6., 7.]]) [6., 7.]])
""" """
return torch.as_tensor(data, *args, **kwargs) return stream_call(torch.as_tensor, data, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -102,7 +103,7 @@ def clone(input, *args, **kwargs): ...@@ -102,7 +103,7 @@ def clone(input, *args, **kwargs):
[6], [6],
[7]]) [7]])
""" """
return torch.clone(input, *args, **kwargs) return stream_call(torch.clone, input, *args, **kwargs)
@doc_from_base() @doc_from_base()
...@@ -127,7 +128,7 @@ def zeros(*args, **kwargs): ...@@ -127,7 +128,7 @@ def zeros(*args, **kwargs):
└── b --> <Tensor 0x7f5fe0107208> └── b --> <Tensor 0x7f5fe0107208>
└── x --> tensor([0., 0., 0., 0.]) └── x --> tensor([0., 0., 0., 0.])
""" """
return torch.zeros(*args, **kwargs) return stream_call(torch.zeros, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -156,7 +157,7 @@ def zeros_like(input, *args, **kwargs): ...@@ -156,7 +157,7 @@ def zeros_like(input, *args, **kwargs):
└── b --> <Tensor 0x7ff363bb6080> └── b --> <Tensor 0x7ff363bb6080>
└── x --> tensor([0., 0., 0., 0.]) └── x --> tensor([0., 0., 0., 0.])
""" """
return torch.zeros_like(input, *args, **kwargs) return stream_call(torch.zeros_like, input, *args, **kwargs)
@doc_from_base() @doc_from_base()
...@@ -182,7 +183,7 @@ def randn(*args, **kwargs): ...@@ -182,7 +183,7 @@ def randn(*args, **kwargs):
└── b --> <Tensor 0x7ff363bb6438> └── b --> <Tensor 0x7ff363bb6438>
└── x --> tensor([-0.7181, 0.1670, -1.3587, -1.5129]) └── x --> tensor([-0.7181, 0.1670, -1.3587, -1.5129])
""" """
return torch.randn(*args, **kwargs) return stream_call(torch.randn, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -212,7 +213,7 @@ def randn_like(input, *args, **kwargs): ...@@ -212,7 +213,7 @@ def randn_like(input, *args, **kwargs):
└── b --> <Tensor 0x7ff3d6f420b8> └── b --> <Tensor 0x7ff3d6f420b8>
└── x --> tensor([ 0.1730, 1.6085, 0.6487, -1.1022]) └── x --> tensor([ 0.1730, 1.6085, 0.6487, -1.1022])
""" """
return torch.randn_like(input, *args, **kwargs) return stream_call(torch.randn_like, input, *args, **kwargs)
@doc_from_base() @doc_from_base()
...@@ -238,7 +239,7 @@ def randint(*args, **kwargs): ...@@ -238,7 +239,7 @@ def randint(*args, **kwargs):
└── b --> <Tensor 0x7ff363bb6240> └── b --> <Tensor 0x7ff363bb6240>
└── x --> tensor([8, 8, 2, 4]) └── x --> tensor([8, 8, 2, 4])
""" """
return torch.randint(*args, **kwargs) return stream_call(torch.randint, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -268,7 +269,7 @@ def randint_like(input, *args, **kwargs): ...@@ -268,7 +269,7 @@ def randint_like(input, *args, **kwargs):
└── b --> <Tensor 0x7ff363bb6898> └── b --> <Tensor 0x7ff363bb6898>
└── x --> tensor([4., 4., 7., 1.]) └── x --> tensor([4., 4., 7., 1.])
""" """
return torch.randint_like(input, *args, **kwargs) return stream_call(torch.randint_like, input, *args, **kwargs)
@doc_from_base() @doc_from_base()
...@@ -293,7 +294,7 @@ def ones(*args, **kwargs): ...@@ -293,7 +294,7 @@ def ones(*args, **kwargs):
└── b --> <Tensor 0x7ff363bb6dd8> └── b --> <Tensor 0x7ff363bb6dd8>
└── x --> tensor([1., 1., 1., 1.]) └── x --> tensor([1., 1., 1., 1.])
""" """
return torch.ones(*args, **kwargs) return stream_call(torch.ones, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -322,7 +323,7 @@ def ones_like(input, *args, **kwargs): ...@@ -322,7 +323,7 @@ def ones_like(input, *args, **kwargs):
└── b --> <Tensor 0x7ff363bbc240> └── b --> <Tensor 0x7ff363bbc240>
└── x --> tensor([1., 1., 1., 1.]) └── x --> tensor([1., 1., 1., 1.])
""" """
return torch.ones_like(input, *args, **kwargs) return stream_call(torch.ones_like, input, *args, **kwargs)
@doc_from_base() @doc_from_base()
...@@ -347,7 +348,7 @@ def full(*args, **kwargs): ...@@ -347,7 +348,7 @@ def full(*args, **kwargs):
└── b --> <Tensor 0x7ff363bbc8d0> └── b --> <Tensor 0x7ff363bbc8d0>
└── x --> tensor([2.3000, 2.3000, 2.3000, 2.3000]) └── x --> tensor([2.3000, 2.3000, 2.3000, 2.3000])
""" """
return torch.full(*args, **kwargs) return stream_call(torch.full, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -377,7 +378,7 @@ def full_like(input, *args, **kwargs): ...@@ -377,7 +378,7 @@ def full_like(input, *args, **kwargs):
└── b --> <Tensor 0x7ff363bb69e8> └── b --> <Tensor 0x7ff363bb69e8>
└── x --> tensor([2.3000, 2.3000, 2.3000, 2.3000]) └── x --> tensor([2.3000, 2.3000, 2.3000, 2.3000])
""" """
return torch.full_like(input, *args, **kwargs) return stream_call(torch.full_like, input, *args, **kwargs)
@doc_from_base() @doc_from_base()
...@@ -403,7 +404,7 @@ def empty(*args, **kwargs): ...@@ -403,7 +404,7 @@ def empty(*args, **kwargs):
└── b --> <Tensor 0x7ff363bb66d8> └── b --> <Tensor 0x7ff363bb66d8>
└── x --> tensor([-3.6515e+14, 4.5900e-41, -3.8091e-38, 3.0802e-41]) └── x --> tensor([-3.6515e+14, 4.5900e-41, -3.8091e-38, 3.0802e-41])
""" """
return torch.empty(*args, **kwargs) return stream_call(torch.empty, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -433,4 +434,4 @@ def empty_like(input, *args, **kwargs): ...@@ -433,4 +434,4 @@ def empty_like(input, *args, **kwargs):
└── b --> <Tensor 0x7ff3d6f3cb38> └── b --> <Tensor 0x7ff3d6f3cb38>
└── x --> tensor([-1.3267e-36, 3.0802e-41, -3.8049e-38, 3.0802e-41]) └── x --> tensor([-1.3267e-36, 3.0802e-41, -3.8049e-38, 3.0802e-41])
""" """
return torch.empty_like(input, *args, **kwargs) return stream_call(torch.empty_like, input, *args, **kwargs)
import torch import torch
from .base import doc_from_base, func_treelize from .base import doc_from_base, func_treelize
from ..stream import stream_call
from ...common import return_self from ...common import return_self
__all__ = [ __all__ = [
...@@ -37,7 +38,7 @@ def abs(input, *args, **kwargs): ...@@ -37,7 +38,7 @@ def abs(input, *args, **kwargs):
└── x --> tensor([[3, 1], └── x --> tensor([[3, 1],
[0, 2]]) [0, 2]])
""" """
return torch.abs(input, *args, **kwargs) return stream_call(torch.abs, input, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -69,7 +70,7 @@ def abs_(input): ...@@ -69,7 +70,7 @@ def abs_(input):
└── x --> tensor([[3, 1], └── x --> tensor([[3, 1],
[0, 2]]) [0, 2]])
""" """
return torch.abs_(input) return stream_call(torch.abs_, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -96,7 +97,7 @@ def clamp(input, *args, **kwargs): ...@@ -96,7 +97,7 @@ def clamp(input, *args, **kwargs):
└── x --> tensor([[-0.5000, 0.5000, -0.3697], └── x --> tensor([[-0.5000, 0.5000, -0.3697],
[ 0.0489, -0.5000, -0.5000]]) [ 0.0489, -0.5000, -0.5000]])
""" """
return torch.clamp(input, *args, **kwargs) return stream_call(torch.clamp, input, *args, **kwargs)
# noinspection PyShadowingBuiltins,PyUnresolvedReferences # noinspection PyShadowingBuiltins,PyUnresolvedReferences
...@@ -128,7 +129,7 @@ def clamp_(input, *args, **kwargs): ...@@ -128,7 +129,7 @@ def clamp_(input, *args, **kwargs):
└── x --> tensor([[-0.5000, 0.5000, -0.3697], └── x --> tensor([[-0.5000, 0.5000, -0.3697],
[ 0.0489, -0.5000, -0.5000]]) [ 0.0489, -0.5000, -0.5000]])
""" """
return torch.clamp_(input, *args, **kwargs) return stream_call(torch.clamp_, input, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -155,7 +156,7 @@ def sign(input, *args, **kwargs): ...@@ -155,7 +156,7 @@ def sign(input, *args, **kwargs):
└── x --> tensor([[-1, 1], └── x --> tensor([[-1, 1],
[ 0, -1]]) [ 0, -1]])
""" """
return torch.sign(input, *args, **kwargs) return stream_call(torch.sign, input, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -185,7 +186,7 @@ def round(input, *args, **kwargs): ...@@ -185,7 +186,7 @@ def round(input, *args, **kwargs):
└── x --> tensor([[ 1., -4., 1.], └── x --> tensor([[ 1., -4., 1.],
[-5., -2., 3.]]) [-5., -2., 3.]])
""" """
return torch.round(input, *args, **kwargs) return stream_call(torch.round, input, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -219,7 +220,7 @@ def round_(input): ...@@ -219,7 +220,7 @@ def round_(input):
└── x --> tensor([[ 1., -4., 1.], └── x --> tensor([[ 1., -4., 1.],
[-5., -2., 3.]]) [-5., -2., 3.]])
""" """
return torch.round_(input) return stream_call(torch.round_, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -249,7 +250,7 @@ def floor(input, *args, **kwargs): ...@@ -249,7 +250,7 @@ def floor(input, *args, **kwargs):
└── x --> tensor([[ 1., -4., 1.], └── x --> tensor([[ 1., -4., 1.],
[-5., -2., 2.]]) [-5., -2., 2.]])
""" """
return torch.floor(input, *args, **kwargs) return stream_call(torch.floor, input, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -283,7 +284,7 @@ def floor_(input): ...@@ -283,7 +284,7 @@ def floor_(input):
└── x --> tensor([[ 1., -4., 1.], └── x --> tensor([[ 1., -4., 1.],
[-5., -2., 2.]]) [-5., -2., 2.]])
""" """
return torch.floor_(input) return stream_call(torch.floor_, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -313,7 +314,7 @@ def ceil(input, *args, **kwargs): ...@@ -313,7 +314,7 @@ def ceil(input, *args, **kwargs):
└── x --> tensor([[ 1., -3., 2.], └── x --> tensor([[ 1., -3., 2.],
[-4., -2., 3.]]) [-4., -2., 3.]])
""" """
return torch.ceil(input, *args, **kwargs) return stream_call(torch.ceil, input, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -347,7 +348,7 @@ def ceil_(input): ...@@ -347,7 +348,7 @@ def ceil_(input):
└── x --> tensor([[ 1., -3., 2.], └── x --> tensor([[ 1., -3., 2.],
[-4., -2., 3.]]) [-4., -2., 3.]])
""" """
return torch.ceil_(input) return stream_call(torch.ceil_, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -374,7 +375,7 @@ def sigmoid(input, *args, **kwargs): ...@@ -374,7 +375,7 @@ def sigmoid(input, *args, **kwargs):
└── x --> tensor([[0.6225, 0.7685], └── x --> tensor([[0.6225, 0.7685],
[0.0759, 0.5622]]) [0.0759, 0.5622]])
""" """
return torch.sigmoid(input, *args, **kwargs) return stream_call(torch.sigmoid, input, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -406,7 +407,7 @@ def sigmoid_(input): ...@@ -406,7 +407,7 @@ def sigmoid_(input):
└── x --> tensor([[0.6225, 0.7685], └── x --> tensor([[0.6225, 0.7685],
[0.0759, 0.5622]]) [0.0759, 0.5622]])
""" """
return torch.sigmoid_(input) return stream_call(torch.sigmoid_, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -443,7 +444,7 @@ def add(input, other, *args, **kwargs): ...@@ -443,7 +444,7 @@ def add(input, other, *args, **kwargs):
└── x --> tensor([[ 34, -10], └── x --> tensor([[ 34, -10],
[ 22, 35]]) [ 22, 35]])
""" """
return torch.add(input, other, *args, **kwargs) return stream_call(torch.add, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -479,7 +480,7 @@ def sub(input, other, *args, **kwargs): ...@@ -479,7 +480,7 @@ def sub(input, other, *args, **kwargs):
└── x --> tensor([[-28, 20], └── x --> tensor([[-28, 20],
[ -4, -11]]) [ -4, -11]])
""" """
return torch.sub(input, other, *args, **kwargs) return stream_call(torch.sub, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -516,7 +517,7 @@ def mul(input, other, *args, **kwargs): ...@@ -516,7 +517,7 @@ def mul(input, other, *args, **kwargs):
└── x --> tensor([[ 93, -75], └── x --> tensor([[ 93, -75],
[117, 276]]) [117, 276]])
""" """
return torch.mul(input, other, *args, **kwargs) return stream_call(torch.mul, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -571,7 +572,7 @@ def div(input, other, *args, **kwargs): ...@@ -571,7 +572,7 @@ def div(input, other, *args, **kwargs):
[[-7.8589, 1.3007, -2.0349], [[-7.8589, 1.3007, -2.0349],
[ 0.1460, 0.5554, -0.1900]]]) [ 0.1460, 0.5554, -0.1900]]])
""" """
return torch.div(input, other, *args, **kwargs) return stream_call(torch.div, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -627,7 +628,7 @@ def pow(input, exponent, *args, **kwargs): ...@@ -627,7 +628,7 @@ def pow(input, exponent, *args, **kwargs):
[[ 1024, 36, 15625], [[ 1024, 36, 15625],
[823543, 8, 2401]]]) [823543, 8, 2401]]])
""" """
return torch.pow(input, exponent, *args, **kwargs) return stream_call(torch.pow, input, exponent, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -666,7 +667,7 @@ def neg(input, *args, **kwargs): ...@@ -666,7 +667,7 @@ def neg(input, *args, **kwargs):
[[-4, -6, -5], [[-4, -6, -5],
[-7, -2, -7]]]) [-7, -2, -7]]])
""" """
return torch.neg(input, *args, **kwargs) return stream_call(torch.neg, input, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -710,7 +711,7 @@ def neg_(input): ...@@ -710,7 +711,7 @@ def neg_(input):
[[-4, -6, -5], [[-4, -6, -5],
[-7, -2, -7]]]) [-7, -2, -7]]])
""" """
return torch.neg_(input) return stream_call(torch.neg_, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -738,7 +739,7 @@ def exp(input, *args, **kwargs): ...@@ -738,7 +739,7 @@ def exp(input, *args, **kwargs):
└── x --> tensor([[1.3534e-01, 3.3201e+00, 1.2840e+00], └── x --> tensor([[1.3534e-01, 3.3201e+00, 1.2840e+00],
[8.8861e+06, 4.2521e+01, 9.6328e-02]]) [8.8861e+06, 4.2521e+01, 9.6328e-02]])
""" """
return torch.exp(input, *args, **kwargs) return stream_call(torch.exp, input, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -771,7 +772,7 @@ def exp_(input): ...@@ -771,7 +772,7 @@ def exp_(input):
└── x --> tensor([[1.3534e-01, 3.3201e+00, 1.2840e+00], └── x --> tensor([[1.3534e-01, 3.3201e+00, 1.2840e+00],
[8.8861e+06, 4.2521e+01, 9.6328e-02]]) [8.8861e+06, 4.2521e+01, 9.6328e-02]])
""" """
return torch.exp_(input) return stream_call(torch.exp_, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -799,7 +800,7 @@ def exp2(input, *args, **kwargs): ...@@ -799,7 +800,7 @@ def exp2(input, *args, **kwargs):
└── x --> tensor([[2.5000e-01, 2.2974e+00, 1.1892e+00], └── x --> tensor([[2.5000e-01, 2.2974e+00, 1.1892e+00],
[6.5536e+04, 1.3454e+01, 1.9751e-01]]) [6.5536e+04, 1.3454e+01, 1.9751e-01]])
""" """
return torch.exp2(input, *args, **kwargs) return stream_call(torch.exp2, input, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -832,7 +833,7 @@ def exp2_(input): ...@@ -832,7 +833,7 @@ def exp2_(input):
└── x --> tensor([[2.5000e-01, 2.2974e+00, 1.1892e+00], └── x --> tensor([[2.5000e-01, 2.2974e+00, 1.1892e+00],
[6.5536e+04, 1.3454e+01, 1.9751e-01]]) [6.5536e+04, 1.3454e+01, 1.9751e-01]])
""" """
return torch.exp2_(input) return stream_call(torch.exp2_, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -860,7 +861,7 @@ def sqrt(input, *args, **kwargs): ...@@ -860,7 +861,7 @@ def sqrt(input, *args, **kwargs):
└── x --> tensor([[ nan, 1.0954, 0.5000], └── x --> tensor([[ nan, 1.0954, 0.5000],
[4.0000, 1.9365, nan]]) [4.0000, 1.9365, nan]])
""" """
return torch.sqrt(input, *args, **kwargs) return stream_call(torch.sqrt, input, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -893,7 +894,7 @@ def sqrt_(input): ...@@ -893,7 +894,7 @@ def sqrt_(input):
└── x --> tensor([[ nan, 1.0954, 0.5000], └── x --> tensor([[ nan, 1.0954, 0.5000],
[4.0000, 1.9365, nan]]) [4.0000, 1.9365, nan]])
""" """
return torch.sqrt_(input) return stream_call(torch.sqrt_, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -921,7 +922,7 @@ def log(input, *args, **kwargs): ...@@ -921,7 +922,7 @@ def log(input, *args, **kwargs):
└── x --> tensor([[ nan, 0.1823, -1.3863], └── x --> tensor([[ nan, 0.1823, -1.3863],
[ 2.7726, 1.3218, nan]]) [ 2.7726, 1.3218, nan]])
""" """
return torch.log(input, *args, **kwargs) return stream_call(torch.log, input, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -954,7 +955,7 @@ def log_(input): ...@@ -954,7 +955,7 @@ def log_(input):
└── x --> tensor([[ nan, 0.1823, -1.3863], └── x --> tensor([[ nan, 0.1823, -1.3863],
[ 2.7726, 1.3218, nan]]) [ 2.7726, 1.3218, nan]])
""" """
return torch.log_(input) return stream_call(torch.log_, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -982,7 +983,7 @@ def log2(input, *args, **kwargs): ...@@ -982,7 +983,7 @@ def log2(input, *args, **kwargs):
└── x --> tensor([[ nan, 0.2630, -2.0000], └── x --> tensor([[ nan, 0.2630, -2.0000],
[ 4.0000, 1.9069, nan]]) [ 4.0000, 1.9069, nan]])
""" """
return torch.log2(input, *args, **kwargs) return stream_call(torch.log2, input, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -1015,7 +1016,7 @@ def log2_(input): ...@@ -1015,7 +1016,7 @@ def log2_(input):
└── x --> tensor([[ nan, 0.2630, -2.0000], └── x --> tensor([[ nan, 0.2630, -2.0000],
[ 4.0000, 1.9069, nan]]) [ 4.0000, 1.9069, nan]])
""" """
return torch.log2_(input) return stream_call(torch.log2_, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -1043,7 +1044,7 @@ def log10(input, *args, **kwargs): ...@@ -1043,7 +1044,7 @@ def log10(input, *args, **kwargs):
└── x --> tensor([[ nan, 0.0792, -0.6021], └── x --> tensor([[ nan, 0.0792, -0.6021],
[ 1.2041, 0.5740, nan]]) [ 1.2041, 0.5740, nan]])
""" """
return torch.log10(input, *args, **kwargs) return stream_call(torch.log10, input, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -1076,7 +1077,7 @@ def log10_(input): ...@@ -1076,7 +1077,7 @@ def log10_(input):
└── x --> tensor([[ nan, 0.0792, -0.6021], └── x --> tensor([[ nan, 0.0792, -0.6021],
[ 1.2041, 0.5740, nan]]) [ 1.2041, 0.5740, nan]])
""" """
return torch.log10_(input) return stream_call(torch.log10_, input)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -1117,7 +1118,7 @@ def dist(input, other, *args, **kwargs): ...@@ -1117,7 +1118,7 @@ def dist(input, other, *args, **kwargs):
└── b --> <Tensor 0x7f95f68494a8> └── b --> <Tensor 0x7f95f68494a8>
└── x --> tensor(4.1904) └── x --> tensor(4.1904)
""" """
return torch.dist(input, other, *args, **kwargs) return stream_call(torch.dist, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -1157,4 +1158,4 @@ def norm(input, *args, **kwargs): ...@@ -1157,4 +1158,4 @@ def norm(input, *args, **kwargs):
└── b --> <Tensor 0x7f95f684f978> └── b --> <Tensor 0x7f95f684f978>
└── x --> tensor(3.2982) └── x --> tensor(3.2982)
""" """
return torch.norm(input, *args, **kwargs) return stream_call(torch.norm, input, *args, **kwargs)
import torch import torch
from .base import doc_from_base, func_treelize from .base import doc_from_base, func_treelize
from ..stream import stream_call
__all__ = [ __all__ = [
'dot', 'matmul', 'mm', 'dot', 'matmul', 'mm',
...@@ -36,7 +37,7 @@ def dot(input, other, *args, **kwargs): ...@@ -36,7 +37,7 @@ def dot(input, other, *args, **kwargs):
└── b --> <Tensor 0x7feac55c9250> └── b --> <Tensor 0x7feac55c9250>
└── x --> tensor(11) └── x --> tensor(11)
""" """
return torch.dot(input, other, *args, **kwargs) return stream_call(torch.dot, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -73,7 +74,7 @@ def matmul(input, other, *args, **kwargs): ...@@ -73,7 +74,7 @@ def matmul(input, other, *args, **kwargs):
└── b --> <Tensor 0x7f2e74886430> └── b --> <Tensor 0x7f2e74886430>
└── x --> tensor(40) └── x --> tensor(40)
""" """
return torch.matmul(input, other, *args, **kwargs) return stream_call(torch.matmul, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -111,4 +112,4 @@ def mm(input, mat2, *args, **kwargs): ...@@ -111,4 +112,4 @@ def mm(input, mat2, *args, **kwargs):
└── x --> tensor([[44, 32], └── x --> tensor([[44, 32],
[80, 59]]) [80, 59]])
""" """
return torch.mm(input, mat2, *args, **kwargs) return stream_call(torch.mm, input, mat2, *args, **kwargs)
...@@ -3,6 +3,7 @@ from hbutils.reflection import post_process ...@@ -3,6 +3,7 @@ from hbutils.reflection import post_process
from treevalue import TreeValue from treevalue import TreeValue
from .base import doc_from_base, func_treelize, auto_tensor from .base import doc_from_base, func_treelize, auto_tensor
from ..stream import stream_call
__all__ = [ __all__ = [
'cat', 'split', 'chunk', 'stack', 'cat', 'split', 'chunk', 'stack',
...@@ -112,7 +113,7 @@ def cat(tensors, *args, **kwargs): ...@@ -112,7 +113,7 @@ def cat(tensors, *args, **kwargs):
└── y --> tensor([[51, 65, 65, 21, 29, 17, 46, 46, 46], └── y --> tensor([[51, 65, 65, 21, 29, 17, 46, 46, 46],
[54, 67, 57, 16, 11, 16, 30, 47, 36]]) [54, 67, 57, 16, 11, 16, 30, 47, 36]])
""" """
return torch.cat(tensors, *args, **kwargs) return stream_call(torch.cat, tensors, *args, **kwargs)
# noinspection PyShadowingNames # noinspection PyShadowingNames
...@@ -201,7 +202,7 @@ def split(tensor, split_size_or_sections, *args, **kwargs): ...@@ -201,7 +202,7 @@ def split(tensor, split_size_or_sections, *args, **kwargs):
[58, 54, 78]]]) [58, 54, 78]]])
) )
""" """
return torch.split(tensor, split_size_or_sections, *args, **kwargs) return stream_call(torch.split, tensor, split_size_or_sections, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -261,7 +262,7 @@ def chunk(input, chunks, *args, **kwargs): ...@@ -261,7 +262,7 @@ def chunk(input, chunks, *args, **kwargs):
[29, 65, 17, 72], [29, 65, 17, 72],
[53, 50, 75, 0]]]) [53, 50, 75, 0]]])
""" """
return torch.chunk(input, chunks, *args, **kwargs) return stream_call(torch.chunk, input, chunks, *args, **kwargs)
@doc_from_base() @doc_from_base()
...@@ -352,7 +353,7 @@ def stack(tensors, *args, **kwargs): ...@@ -352,7 +353,7 @@ def stack(tensors, *args, **kwargs):
[[18, 21, 17, 12], [[18, 21, 17, 12],
[36, 30, 33, 31]]]) [36, 30, 33, 31]]])
""" """
return torch.stack(tensors, *args, **kwargs) return stream_call(torch.stack, tensors, *args, **kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -399,7 +400,7 @@ def reshape(input, shape): ...@@ -399,7 +400,7 @@ def reshape(input, shape):
[11, 13]]) [11, 13]])
""" """
return torch.reshape(input, shape) return stream_call(torch.reshape, input, shape)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -434,7 +435,7 @@ def squeeze(input, *args, **kwargs): ...@@ -434,7 +435,7 @@ def squeeze(input, *args, **kwargs):
└── b --> <Size 0x7fa4c1afe710> └── b --> <Size 0x7fa4c1afe710>
└── x --> torch.Size([2, 3]) └── x --> torch.Size([2, 3])
""" """
return torch.squeeze(input, *args, *kwargs) return stream_call(torch.squeeze, input, *args, *kwargs)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -469,7 +470,7 @@ def unsqueeze(input, dim): ...@@ -469,7 +470,7 @@ def unsqueeze(input, dim):
└── b --> <Size 0x7f5d1a5c99b0> └── b --> <Size 0x7f5d1a5c99b0>
└── x --> torch.Size([2, 1, 3]) └── x --> torch.Size([2, 1, 3])
""" """
return torch.unsqueeze(input, dim) return stream_call(torch.unsqueeze, input, dim)
@doc_from_base() @doc_from_base()
...@@ -518,7 +519,7 @@ def where(condition, x, y): ...@@ -518,7 +519,7 @@ def where(condition, x, y):
[[ 0, 3, 93, 89], [[ 0, 3, 93, 89],
[ 0, 89, 85, 0]]]) [ 0, 89, 85, 0]]])
""" """
return torch.where(condition, x, y) return stream_call(torch.where, condition, x, y)
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
...@@ -586,4 +587,4 @@ def index_select(input, dim, index, *args, **kwargs): ...@@ -586,4 +587,4 @@ def index_select(input, dim, index, *args, **kwargs):
[-2.1694, -0.4224, 0.3998], [-2.1694, -0.4224, 0.3998],
[ 0.9777, -0.0101, -1.1500]]) [ 0.9777, -0.0101, -1.1500]])
""" """
return torch.index_select(input, dim, index, *args, **kwargs) return stream_call(torch.index_select, input, dim, index, *args, **kwargs)
import itertools
from typing import Optional, List
import torch
_stream_pool: Optional[List[torch.cuda.Stream]] = None
_global_streams: Optional[List[torch.cuda.Stream]] = None
__all__ = [
'stream', 'stream_call',
]
def stream(cnt):
assert torch.cuda.is_available(), "CUDA is not supported."
global _stream_pool, _global_streams
if _stream_pool is None:
_stream_pool = [torch.cuda.current_stream()]
if cnt is None: # close stream support by
_global_streams = None
else: # use the given number of streams
while len(_stream_pool) < cnt:
_stream_pool.append(torch.cuda.Stream())
_global_streams = _stream_pool[:cnt]
_stream_count = itertools.count()
def stream_call(func, *args, **kwargs):
if _global_streams is not None:
_stream_index = next(_stream_count) % len(_global_streams)
_stream = _global_streams[_stream_index]
with torch.cuda.stream(_stream):
return func(*args, **kwargs)
else:
return func(*args, **kwargs)
...@@ -5,6 +5,7 @@ from treevalue import method_treelize, TreeValue, typetrans ...@@ -5,6 +5,7 @@ from treevalue import method_treelize, TreeValue, typetrans
from .base import Torch, rmreduce, post_reduce, auto_reduce from .base import Torch, rmreduce, post_reduce, auto_reduce
from .size import Size from .size import Size
from .stream import stream_call
from ..common import Object, ireduce, clsmeta, return_self, auto_tree, get_tree_proxy from ..common import Object, ireduce, clsmeta, return_self, auto_tree, get_tree_proxy
from ..numpy import ndarray from ..numpy import ndarray
from ..utils import current_names, class_autoremove, replaceable_partial from ..utils import current_names, class_autoremove, replaceable_partial
...@@ -152,7 +153,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -152,7 +153,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
This tensor and the returned :class:`treetensor.numpy.ndarray` share the same underlying storage. This tensor and the returned :class:`treetensor.numpy.ndarray` share the same underlying storage.
Changes to self tensor will be reflected in the ``ndarray`` and vice versa. Changes to self tensor will be reflected in the ``ndarray`` and vice versa.
""" """
return self.numpy() return stream_call(self.numpy, )
@doc_from_base() @doc_from_base()
@method_treelize(return_type=Object) @method_treelize(return_type=Object)
...@@ -175,7 +176,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -175,7 +176,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
'c': True, 'c': True,
}) })
""" """
return self.tolist() return stream_call(self.tolist, )
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -186,7 +187,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -186,7 +187,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
If this tree tensor is already in CPU memory and on the correct device, If this tree tensor is already in CPU memory and on the correct device,
then no copy is performed and the original object is returned. then no copy is performed and the original object is returned.
""" """
return self.cpu(*args, **kwargs) return stream_call(self.cpu, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -197,7 +198,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -197,7 +198,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
If this tree tensor is already in CUDA memory and on the correct device, If this tree tensor is already in CUDA memory and on the correct device,
then no copy is performed and the original object is returned. then no copy is performed and the original object is returned.
""" """
return self.cuda(*args, **kwargs) return stream_call(self.cuda, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -221,7 +222,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -221,7 +222,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
└── x --> tensor([[4., 5.], └── x --> tensor([[4., 5.],
[6., 7.]], dtype=torch.float64) [6., 7.]], dtype=torch.float64)
""" """
return self.to(*args, **kwargs) return stream_call(self.to, *args, **kwargs)
@doc_from_base() @doc_from_base()
@ireduce(sum) @ireduce(sum)
...@@ -230,7 +231,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -230,7 +231,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.numel` See :func:`treetensor.torch.numel`
""" """
return self.numel() return stream_call(self.numel, )
@property @property
@doc_from_base() @doc_from_base()
...@@ -346,7 +347,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -346,7 +347,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
[ 0.2531, -0.0637, 0.9822, 2.1618], [ 0.2531, -0.0637, 0.9822, 2.1618],
[ 2.0140, -0.0929, 0.9304, 1.5430]], requires_grad=True) [ 2.0140, -0.0929, 0.9304, 1.5430]], requires_grad=True)
""" """
return self.requires_grad_(requires_grad) return stream_call(self.requires_grad_, requires_grad)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -354,7 +355,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -354,7 +355,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.detach`. See :func:`treetensor.torch.detach`.
""" """
return self.detach() return stream_call(self.detach, )
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -363,7 +364,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -363,7 +364,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.detach`. In-place version of :meth:`Tensor.detach`.
""" """
return self.detach_() return stream_call(self.detach_, )
# noinspection PyShadowingBuiltins,PyUnusedLocal # noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(pytorch.all) @post_reduce(pytorch.all)
...@@ -513,7 +514,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -513,7 +514,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.clone`. See :func:`treetensor.torch.clone`.
""" """
return self.clone(*args, **kwargs) return stream_call(self.clone, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -521,7 +522,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -521,7 +522,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.dot`. See :func:`treetensor.torch.dot`.
""" """
return self.dot(other, *args, **kwargs) return stream_call(self.dot, other, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -529,7 +530,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -529,7 +530,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.mm`. See :func:`treetensor.torch.mm`.
""" """
return self.mm(mat2, *args, **kwargs) return stream_call(self.mm, mat2, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -537,7 +538,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -537,7 +538,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.matmul`. See :func:`treetensor.torch.matmul`.
""" """
return self.matmul(tensor2, *args, **kwargs) return stream_call(self.matmul, tensor2, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -545,7 +546,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -545,7 +546,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.isfinite`. See :func:`treetensor.torch.isfinite`.
""" """
return self.isfinite() return stream_call(self.isfinite, )
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -553,7 +554,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -553,7 +554,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.isinf`. See :func:`treetensor.torch.isinf`.
""" """
return self.isinf() return stream_call(self.isinf, )
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -561,7 +562,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -561,7 +562,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.isnan`. See :func:`treetensor.torch.isnan`.
""" """
return self.isnan() return stream_call(self.isnan, )
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -569,7 +570,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -569,7 +570,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.isclose`. See :func:`treetensor.torch.isclose`.
""" """
return self.isclose(other, *args, **kwargs) return stream_call(self.isclose, other, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -577,7 +578,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -577,7 +578,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.abs`. See :func:`treetensor.torch.abs`.
""" """
return self.abs(*args, **kwargs) return stream_call(self.abs, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -586,7 +587,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -586,7 +587,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.abs_`. See :func:`treetensor.torch.abs_`.
""" """
return self.abs_(*args, **kwargs) return stream_call(self.abs_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -594,7 +595,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -594,7 +595,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.clamp`. See :func:`treetensor.torch.clamp`.
""" """
return self.clamp(*args, **kwargs) return stream_call(self.clamp, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -603,7 +604,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -603,7 +604,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.clamp_`. See :func:`treetensor.torch.clamp_`.
""" """
return self.clamp_(*args, **kwargs) return stream_call(self.clamp_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -611,7 +612,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -611,7 +612,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.sign`. See :func:`treetensor.torch.sign`.
""" """
return self.sign(*args, **kwargs) return stream_call(self.sign, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -620,7 +621,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -620,7 +621,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.sign`. In-place version of :meth:`Tensor.sign`.
""" """
return self.sign_(*args, **kwargs) return stream_call(self.sign_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -628,7 +629,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -628,7 +629,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.sigmoid`. See :func:`treetensor.torch.sigmoid`.
""" """
return self.sigmoid(*args, **kwargs) return stream_call(self.sigmoid, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -637,7 +638,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -637,7 +638,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.sigmoid_`. See :func:`treetensor.torch.sigmoid_`.
""" """
return self.sigmoid_(*args, **kwargs) return stream_call(self.sigmoid_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -645,7 +646,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -645,7 +646,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.floor`. See :func:`treetensor.torch.floor`.
""" """
return self.floor(*args, **kwargs) return stream_call(self.floor, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -654,7 +655,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -654,7 +655,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.floor_`. See :func:`treetensor.torch.floor_`.
""" """
return self.floor_(*args, **kwargs) return stream_call(self.floor_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -662,7 +663,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -662,7 +663,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.ceil`. See :func:`treetensor.torch.ceil`.
""" """
return self.ceil(*args, **kwargs) return stream_call(self.ceil, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -671,7 +672,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -671,7 +672,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.ceil_`. See :func:`treetensor.torch.ceil_`.
""" """
return self.ceil_(*args, **kwargs) return stream_call(self.ceil_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -679,7 +680,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -679,7 +680,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.round`. See :func:`treetensor.torch.round`.
""" """
return self.round(*args, **kwargs) return stream_call(self.round, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -688,7 +689,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -688,7 +689,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.round_`. See :func:`treetensor.torch.round_`.
""" """
return self.round_(*args, **kwargs) return stream_call(self.round_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -696,7 +697,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -696,7 +697,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.add`. See :func:`treetensor.torch.add`.
""" """
return self.add(other, *args, **kwargs) return stream_call(self.add, other, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -705,7 +706,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -705,7 +706,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.add`. In-place version of :meth:`Tensor.add`.
""" """
return self.add_(other, *args, **kwargs) return stream_call(self.add_, other, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -713,7 +714,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -713,7 +714,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.sub`. See :func:`treetensor.torch.sub`.
""" """
return self.sub(other, *args, **kwargs) return stream_call(self.sub, other, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -722,7 +723,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -722,7 +723,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.sub`. In-place version of :meth:`Tensor.sub`.
""" """
return self.sub_(other, *args, **kwargs) return stream_call(self.sub_, other, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -730,7 +731,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -730,7 +731,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.mul`. See :func:`treetensor.torch.mul`.
""" """
return self.mul(other, *args, **kwargs) return stream_call(self.mul, other, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -739,7 +740,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -739,7 +740,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.mul`. In-place version of :meth:`Tensor.mul`.
""" """
return self.mul_(other, *args, **kwargs) return stream_call(self.mul_, other, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -747,7 +748,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -747,7 +748,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.div`. See :func:`treetensor.torch.div`.
""" """
return self.div(other, *args, **kwargs) return stream_call(self.div, other, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -756,7 +757,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -756,7 +757,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.div`. In-place version of :meth:`Tensor.div`.
""" """
return self.div_(other, *args, **kwargs) return stream_call(self.div_, other, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -764,7 +765,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -764,7 +765,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.pow`. See :func:`treetensor.torch.pow`.
""" """
return self.pow(exponent, *args, **kwargs) return stream_call(self.pow, exponent, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -773,7 +774,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -773,7 +774,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.pow`. In-place version of :meth:`Tensor.pow`.
""" """
return self.pow_(exponent, *args, **kwargs) return stream_call(self.pow_, exponent, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -781,7 +782,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -781,7 +782,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.neg`. See :func:`treetensor.torch.neg`.
""" """
return self.neg(*args, **kwargs) return stream_call(self.neg, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -790,7 +791,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -790,7 +791,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.neg`. In-place version of :meth:`Tensor.neg`.
""" """
return self.neg_(*args, **kwargs) return stream_call(self.neg_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -798,7 +799,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -798,7 +799,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.exp`. See :func:`treetensor.torch.exp`.
""" """
return self.exp(*args, **kwargs) return stream_call(self.exp, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -807,7 +808,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -807,7 +808,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.exp`. In-place version of :meth:`Tensor.exp`.
""" """
return self.exp_(*args, **kwargs) return stream_call(self.exp_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -815,7 +816,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -815,7 +816,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.exp2`. See :func:`treetensor.torch.exp2`.
""" """
return self.exp2(*args, **kwargs) return stream_call(self.exp2, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -824,7 +825,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -824,7 +825,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.exp2`. In-place version of :meth:`Tensor.exp2`.
""" """
return self.exp2_(*args, **kwargs) return stream_call(self.exp2_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -832,7 +833,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -832,7 +833,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.sqrt`. See :func:`treetensor.torch.sqrt`.
""" """
return self.sqrt(*args, **kwargs) return stream_call(self.sqrt, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -841,7 +842,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -841,7 +842,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.sqrt`. In-place version of :meth:`Tensor.sqrt`.
""" """
return self.sqrt_(*args, **kwargs) return stream_call(self.sqrt_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -849,7 +850,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -849,7 +850,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.log`. See :func:`treetensor.torch.log`.
""" """
return self.log(*args, **kwargs) return stream_call(self.log, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -858,7 +859,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -858,7 +859,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.log`. In-place version of :meth:`Tensor.log`.
""" """
return self.log_(*args, **kwargs) return stream_call(self.log_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -866,7 +867,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -866,7 +867,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.log2`. See :func:`treetensor.torch.log2`.
""" """
return self.log2(*args, **kwargs) return stream_call(self.log2, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -875,7 +876,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -875,7 +876,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.log2`. In-place version of :meth:`Tensor.log2`.
""" """
return self.log2_(*args, **kwargs) return stream_call(self.log2_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -883,7 +884,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -883,7 +884,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.log10`. See :func:`treetensor.torch.log10`.
""" """
return self.log10(*args, **kwargs) return stream_call(self.log10, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -892,7 +893,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -892,7 +893,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.log10`. In-place version of :meth:`Tensor.log10`.
""" """
return self.log10_(*args, **kwargs) return stream_call(self.log10_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@post_process(__auto_tensor) @post_process(__auto_tensor)
...@@ -901,7 +902,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -901,7 +902,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.split`. See :func:`treetensor.torch.split`.
""" """
return self.split(split_size, *args, **kwargs) return stream_call(self.split, split_size, *args, **kwargs)
@doc_from_base() @doc_from_base()
@post_process(__auto_tensor) @post_process(__auto_tensor)
...@@ -910,7 +911,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -910,7 +911,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.chunk`. See :func:`treetensor.torch.chunk`.
""" """
return self.chunk(chunks, *args, **kwargs) return stream_call(self.chunk, chunks, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -918,7 +919,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -918,7 +919,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.reshape`. See :func:`treetensor.torch.reshape`.
""" """
return self.reshape(*args, **kwargs) return stream_call(self.reshape, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -926,7 +927,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -926,7 +927,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.squeeze`. See :func:`treetensor.torch.squeeze`.
""" """
return self.squeeze(*args, **kwargs) return stream_call(self.squeeze, *args, **kwargs)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -935,7 +936,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -935,7 +936,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.squeeze'. In-place version of :meth:`Tensor.squeeze'.
""" """
return self.squeeze_(*args, **kwargs) return stream_call(self.squeeze_, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -943,7 +944,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -943,7 +944,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.unsqueeze`. See :func:`treetensor.torch.unsqueeze`.
""" """
return self.unsqueeze(dim) return stream_call(self.unsqueeze, dim)
@doc_from_base() @doc_from_base()
@return_self @return_self
...@@ -952,7 +953,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -952,7 +953,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
In-place version of :meth:`Tensor.unsqueeze'. In-place version of :meth:`Tensor.unsqueeze'.
""" """
return self.unsqueeze_(dim) return stream_call(self.unsqueeze_, dim)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -962,7 +963,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -962,7 +963,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
``treetensor.torch.where(condition, self, y)``. ``treetensor.torch.where(condition, self, y)``.
See :func:`treetensor.torch.where`. See :func:`treetensor.torch.where`.
""" """
return self.where(condition, y, *args, **kwargs) return stream_call(self.where, condition, y, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -970,7 +971,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -970,7 +971,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.index_select`. See :func:`treetensor.torch.index_select`.
""" """
return self.index_select(dim, index) return stream_call(self.index_select, dim, index)
# noinspection PyShadowingBuiltins,PyUnusedLocal # noinspection PyShadowingBuiltins,PyUnusedLocal
@rmreduce() @rmreduce()
...@@ -1044,7 +1045,7 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -1044,7 +1045,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.dist`. See :func:`treetensor.torch.dist`.
""" """
return self.dist(other, *args, **kwargs) return stream_call(self.dist, other, *args, **kwargs)
@doc_from_base() @doc_from_base()
@method_treelize() @method_treelize()
...@@ -1052,4 +1053,4 @@ class Tensor(Torch, metaclass=_TensorMeta): ...@@ -1052,4 +1053,4 @@ class Tensor(Torch, metaclass=_TensorMeta):
""" """
See :func:`treetensor.torch.norm`. See :func:`treetensor.torch.norm`.
""" """
return self.norm(*args, **kwargs) return stream_call(self.norm, *args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册