提交 1696c8b8 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): use stream_call for most of the functions

上级 ee354088
......@@ -3,6 +3,7 @@ import builtins
import torch
from .base import doc_from_base, func_treelize
from ..stream import stream_call
from ...common import ireduce
__all__ = [
......@@ -42,7 +43,7 @@ def equal(input, other):
... )
True
"""
return torch.equal(input, other)
return stream_call(torch.equal, input, other)
# noinspection PyShadowingBuiltins
......@@ -70,7 +71,7 @@ def isfinite(input):
└── x --> tensor([[ True, False, True],
[False, True, False]])
"""
return torch.isfinite(input)
return stream_call(torch.isfinite, input)
# noinspection PyShadowingBuiltins
......@@ -98,7 +99,7 @@ def isinf(input):
└── x --> tensor([[False, True, False],
[ True, False, False]])
"""
return torch.isinf(input)
return stream_call(torch.isinf, input)
# noinspection PyShadowingBuiltins
......@@ -126,7 +127,7 @@ def isnan(input):
└── x --> tensor([[False, False, False],
[False, False, True]])
"""
return torch.isnan(input)
return stream_call(torch.isnan, input)
# noinspection PyShadowingBuiltins
......@@ -166,7 +167,7 @@ def isclose(input, other, *args, **kwargs):
└── x --> tensor([[ True, False, True],
[ True, True, False]])
"""
return torch.isclose(input, other, *args, **kwargs)
return stream_call(torch.isclose, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -202,7 +203,7 @@ def eq(input, other, *args, **kwargs):
│ [False, True]])
└── b --> tensor([False, False, True])
"""
return torch.eq(input, other, *args, **kwargs)
return stream_call(torch.eq, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -238,7 +239,7 @@ def ne(input, other, *args, **kwargs):
│ [ True, False]])
└── b --> tensor([ True, True, False])
"""
return torch.ne(input, other, *args, **kwargs)
return stream_call(torch.ne, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -274,7 +275,7 @@ def lt(input, other, *args, **kwargs):
│ [ True, False]])
└── b --> tensor([ True, False, False])
"""
return torch.lt(input, other, *args, **kwargs)
return stream_call(torch.lt, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -310,7 +311,7 @@ def le(input, other, *args, **kwargs):
│ [ True, True]])
└── b --> tensor([ True, False, True])
"""
return torch.le(input, other, *args, **kwargs)
return stream_call(torch.le, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -346,7 +347,7 @@ def gt(input, other, *args, **kwargs):
│ [False, False]])
└── b --> tensor([False, True, False])
"""
return torch.gt(input, other, *args, **kwargs)
return stream_call(torch.gt, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -382,4 +383,4 @@ def ge(input, other, *args, **kwargs):
│ [False, True]])
└── b --> tensor([False, True, True])
"""
return torch.ge(input, other, *args, **kwargs)
return stream_call(torch.ge, input, other, *args, **kwargs)
......@@ -3,6 +3,7 @@ from treevalue import TreeValue
from treevalue.tree.common import TreeStorage
from .base import doc_from_base, func_treelize
from ..stream import stream_call
from ...utils import args_mapping
__all__ = [
......@@ -42,7 +43,7 @@ def tensor(data, *args, **kwargs):
└── c --> tensor([[ True, False],
[False, True]])
"""
return torch.tensor(data, *args, **kwargs)
return stream_call(torch.tensor, data, *args, **kwargs)
@doc_from_base()
......@@ -72,7 +73,7 @@ def as_tensor(data, *args, **kwargs):
└── x --> tensor([[4., 5.],
[6., 7.]])
"""
return torch.as_tensor(data, *args, **kwargs)
return stream_call(torch.as_tensor, data, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -102,7 +103,7 @@ def clone(input, *args, **kwargs):
[6],
[7]])
"""
return torch.clone(input, *args, **kwargs)
return stream_call(torch.clone, input, *args, **kwargs)
@doc_from_base()
......@@ -127,7 +128,7 @@ def zeros(*args, **kwargs):
└── b --> <Tensor 0x7f5fe0107208>
└── x --> tensor([0., 0., 0., 0.])
"""
return torch.zeros(*args, **kwargs)
return stream_call(torch.zeros, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -156,7 +157,7 @@ def zeros_like(input, *args, **kwargs):
└── b --> <Tensor 0x7ff363bb6080>
└── x --> tensor([0., 0., 0., 0.])
"""
return torch.zeros_like(input, *args, **kwargs)
return stream_call(torch.zeros_like, input, *args, **kwargs)
@doc_from_base()
......@@ -182,7 +183,7 @@ def randn(*args, **kwargs):
└── b --> <Tensor 0x7ff363bb6438>
└── x --> tensor([-0.7181, 0.1670, -1.3587, -1.5129])
"""
return torch.randn(*args, **kwargs)
return stream_call(torch.randn, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -212,7 +213,7 @@ def randn_like(input, *args, **kwargs):
└── b --> <Tensor 0x7ff3d6f420b8>
└── x --> tensor([ 0.1730, 1.6085, 0.6487, -1.1022])
"""
return torch.randn_like(input, *args, **kwargs)
return stream_call(torch.randn_like, input, *args, **kwargs)
@doc_from_base()
......@@ -238,7 +239,7 @@ def randint(*args, **kwargs):
└── b --> <Tensor 0x7ff363bb6240>
└── x --> tensor([8, 8, 2, 4])
"""
return torch.randint(*args, **kwargs)
return stream_call(torch.randint, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -268,7 +269,7 @@ def randint_like(input, *args, **kwargs):
└── b --> <Tensor 0x7ff363bb6898>
└── x --> tensor([4., 4., 7., 1.])
"""
return torch.randint_like(input, *args, **kwargs)
return stream_call(torch.randint_like, input, *args, **kwargs)
@doc_from_base()
......@@ -293,7 +294,7 @@ def ones(*args, **kwargs):
└── b --> <Tensor 0x7ff363bb6dd8>
└── x --> tensor([1., 1., 1., 1.])
"""
return torch.ones(*args, **kwargs)
return stream_call(torch.ones, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -322,7 +323,7 @@ def ones_like(input, *args, **kwargs):
└── b --> <Tensor 0x7ff363bbc240>
└── x --> tensor([1., 1., 1., 1.])
"""
return torch.ones_like(input, *args, **kwargs)
return stream_call(torch.ones_like, input, *args, **kwargs)
@doc_from_base()
......@@ -347,7 +348,7 @@ def full(*args, **kwargs):
└── b --> <Tensor 0x7ff363bbc8d0>
└── x --> tensor([2.3000, 2.3000, 2.3000, 2.3000])
"""
return torch.full(*args, **kwargs)
return stream_call(torch.full, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -377,7 +378,7 @@ def full_like(input, *args, **kwargs):
└── b --> <Tensor 0x7ff363bb69e8>
└── x --> tensor([2.3000, 2.3000, 2.3000, 2.3000])
"""
return torch.full_like(input, *args, **kwargs)
return stream_call(torch.full_like, input, *args, **kwargs)
@doc_from_base()
......@@ -403,7 +404,7 @@ def empty(*args, **kwargs):
└── b --> <Tensor 0x7ff363bb66d8>
└── x --> tensor([-3.6515e+14, 4.5900e-41, -3.8091e-38, 3.0802e-41])
"""
return torch.empty(*args, **kwargs)
return stream_call(torch.empty, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -433,4 +434,4 @@ def empty_like(input, *args, **kwargs):
└── b --> <Tensor 0x7ff3d6f3cb38>
└── x --> tensor([-1.3267e-36, 3.0802e-41, -3.8049e-38, 3.0802e-41])
"""
return torch.empty_like(input, *args, **kwargs)
return stream_call(torch.empty_like, input, *args, **kwargs)
import torch
from .base import doc_from_base, func_treelize
from ..stream import stream_call
from ...common import return_self
__all__ = [
......@@ -37,7 +38,7 @@ def abs(input, *args, **kwargs):
└── x --> tensor([[3, 1],
[0, 2]])
"""
return torch.abs(input, *args, **kwargs)
return stream_call(torch.abs, input, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -69,7 +70,7 @@ def abs_(input):
└── x --> tensor([[3, 1],
[0, 2]])
"""
return torch.abs_(input)
return stream_call(torch.abs_, input)
# noinspection PyShadowingBuiltins
......@@ -96,7 +97,7 @@ def clamp(input, *args, **kwargs):
└── x --> tensor([[-0.5000, 0.5000, -0.3697],
[ 0.0489, -0.5000, -0.5000]])
"""
return torch.clamp(input, *args, **kwargs)
return stream_call(torch.clamp, input, *args, **kwargs)
# noinspection PyShadowingBuiltins,PyUnresolvedReferences
......@@ -128,7 +129,7 @@ def clamp_(input, *args, **kwargs):
└── x --> tensor([[-0.5000, 0.5000, -0.3697],
[ 0.0489, -0.5000, -0.5000]])
"""
return torch.clamp_(input, *args, **kwargs)
return stream_call(torch.clamp_, input, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -155,7 +156,7 @@ def sign(input, *args, **kwargs):
└── x --> tensor([[-1, 1],
[ 0, -1]])
"""
return torch.sign(input, *args, **kwargs)
return stream_call(torch.sign, input, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -185,7 +186,7 @@ def round(input, *args, **kwargs):
└── x --> tensor([[ 1., -4., 1.],
[-5., -2., 3.]])
"""
return torch.round(input, *args, **kwargs)
return stream_call(torch.round, input, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -219,7 +220,7 @@ def round_(input):
└── x --> tensor([[ 1., -4., 1.],
[-5., -2., 3.]])
"""
return torch.round_(input)
return stream_call(torch.round_, input)
# noinspection PyShadowingBuiltins
......@@ -249,7 +250,7 @@ def floor(input, *args, **kwargs):
└── x --> tensor([[ 1., -4., 1.],
[-5., -2., 2.]])
"""
return torch.floor(input, *args, **kwargs)
return stream_call(torch.floor, input, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -283,7 +284,7 @@ def floor_(input):
└── x --> tensor([[ 1., -4., 1.],
[-5., -2., 2.]])
"""
return torch.floor_(input)
return stream_call(torch.floor_, input)
# noinspection PyShadowingBuiltins
......@@ -313,7 +314,7 @@ def ceil(input, *args, **kwargs):
└── x --> tensor([[ 1., -3., 2.],
[-4., -2., 3.]])
"""
return torch.ceil(input, *args, **kwargs)
return stream_call(torch.ceil, input, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -347,7 +348,7 @@ def ceil_(input):
└── x --> tensor([[ 1., -3., 2.],
[-4., -2., 3.]])
"""
return torch.ceil_(input)
return stream_call(torch.ceil_, input)
# noinspection PyShadowingBuiltins
......@@ -374,7 +375,7 @@ def sigmoid(input, *args, **kwargs):
└── x --> tensor([[0.6225, 0.7685],
[0.0759, 0.5622]])
"""
return torch.sigmoid(input, *args, **kwargs)
return stream_call(torch.sigmoid, input, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -406,7 +407,7 @@ def sigmoid_(input):
└── x --> tensor([[0.6225, 0.7685],
[0.0759, 0.5622]])
"""
return torch.sigmoid_(input)
return stream_call(torch.sigmoid_, input)
# noinspection PyShadowingBuiltins
......@@ -443,7 +444,7 @@ def add(input, other, *args, **kwargs):
└── x --> tensor([[ 34, -10],
[ 22, 35]])
"""
return torch.add(input, other, *args, **kwargs)
return stream_call(torch.add, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -479,7 +480,7 @@ def sub(input, other, *args, **kwargs):
└── x --> tensor([[-28, 20],
[ -4, -11]])
"""
return torch.sub(input, other, *args, **kwargs)
return stream_call(torch.sub, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -516,7 +517,7 @@ def mul(input, other, *args, **kwargs):
└── x --> tensor([[ 93, -75],
[117, 276]])
"""
return torch.mul(input, other, *args, **kwargs)
return stream_call(torch.mul, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -571,7 +572,7 @@ def div(input, other, *args, **kwargs):
[[-7.8589, 1.3007, -2.0349],
[ 0.1460, 0.5554, -0.1900]]])
"""
return torch.div(input, other, *args, **kwargs)
return stream_call(torch.div, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -627,7 +628,7 @@ def pow(input, exponent, *args, **kwargs):
[[ 1024, 36, 15625],
[823543, 8, 2401]]])
"""
return torch.pow(input, exponent, *args, **kwargs)
return stream_call(torch.pow, input, exponent, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -666,7 +667,7 @@ def neg(input, *args, **kwargs):
[[-4, -6, -5],
[-7, -2, -7]]])
"""
return torch.neg(input, *args, **kwargs)
return stream_call(torch.neg, input, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -710,7 +711,7 @@ def neg_(input):
[[-4, -6, -5],
[-7, -2, -7]]])
"""
return torch.neg_(input)
return stream_call(torch.neg_, input)
# noinspection PyShadowingBuiltins
......@@ -738,7 +739,7 @@ def exp(input, *args, **kwargs):
└── x --> tensor([[1.3534e-01, 3.3201e+00, 1.2840e+00],
[8.8861e+06, 4.2521e+01, 9.6328e-02]])
"""
return torch.exp(input, *args, **kwargs)
return stream_call(torch.exp, input, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -771,7 +772,7 @@ def exp_(input):
└── x --> tensor([[1.3534e-01, 3.3201e+00, 1.2840e+00],
[8.8861e+06, 4.2521e+01, 9.6328e-02]])
"""
return torch.exp_(input)
return stream_call(torch.exp_, input)
# noinspection PyShadowingBuiltins
......@@ -799,7 +800,7 @@ def exp2(input, *args, **kwargs):
└── x --> tensor([[2.5000e-01, 2.2974e+00, 1.1892e+00],
[6.5536e+04, 1.3454e+01, 1.9751e-01]])
"""
return torch.exp2(input, *args, **kwargs)
return stream_call(torch.exp2, input, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -832,7 +833,7 @@ def exp2_(input):
└── x --> tensor([[2.5000e-01, 2.2974e+00, 1.1892e+00],
[6.5536e+04, 1.3454e+01, 1.9751e-01]])
"""
return torch.exp2_(input)
return stream_call(torch.exp2_, input)
# noinspection PyShadowingBuiltins
......@@ -860,7 +861,7 @@ def sqrt(input, *args, **kwargs):
└── x --> tensor([[ nan, 1.0954, 0.5000],
[4.0000, 1.9365, nan]])
"""
return torch.sqrt(input, *args, **kwargs)
return stream_call(torch.sqrt, input, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -893,7 +894,7 @@ def sqrt_(input):
└── x --> tensor([[ nan, 1.0954, 0.5000],
[4.0000, 1.9365, nan]])
"""
return torch.sqrt_(input)
return stream_call(torch.sqrt_, input)
# noinspection PyShadowingBuiltins
......@@ -921,7 +922,7 @@ def log(input, *args, **kwargs):
└── x --> tensor([[ nan, 0.1823, -1.3863],
[ 2.7726, 1.3218, nan]])
"""
return torch.log(input, *args, **kwargs)
return stream_call(torch.log, input, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -954,7 +955,7 @@ def log_(input):
└── x --> tensor([[ nan, 0.1823, -1.3863],
[ 2.7726, 1.3218, nan]])
"""
return torch.log_(input)
return stream_call(torch.log_, input)
# noinspection PyShadowingBuiltins
......@@ -982,7 +983,7 @@ def log2(input, *args, **kwargs):
└── x --> tensor([[ nan, 0.2630, -2.0000],
[ 4.0000, 1.9069, nan]])
"""
return torch.log2(input, *args, **kwargs)
return stream_call(torch.log2, input, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -1015,7 +1016,7 @@ def log2_(input):
└── x --> tensor([[ nan, 0.2630, -2.0000],
[ 4.0000, 1.9069, nan]])
"""
return torch.log2_(input)
return stream_call(torch.log2_, input)
# noinspection PyShadowingBuiltins
......@@ -1043,7 +1044,7 @@ def log10(input, *args, **kwargs):
└── x --> tensor([[ nan, 0.0792, -0.6021],
[ 1.2041, 0.5740, nan]])
"""
return torch.log10(input, *args, **kwargs)
return stream_call(torch.log10, input, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -1076,7 +1077,7 @@ def log10_(input):
└── x --> tensor([[ nan, 0.0792, -0.6021],
[ 1.2041, 0.5740, nan]])
"""
return torch.log10_(input)
return stream_call(torch.log10_, input)
# noinspection PyShadowingBuiltins
......@@ -1117,7 +1118,7 @@ def dist(input, other, *args, **kwargs):
└── b --> <Tensor 0x7f95f68494a8>
└── x --> tensor(4.1904)
"""
return torch.dist(input, other, *args, **kwargs)
return stream_call(torch.dist, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -1157,4 +1158,4 @@ def norm(input, *args, **kwargs):
└── b --> <Tensor 0x7f95f684f978>
└── x --> tensor(3.2982)
"""
return torch.norm(input, *args, **kwargs)
return stream_call(torch.norm, input, *args, **kwargs)
......@@ -37,7 +37,7 @@ def dot(input, other, *args, **kwargs):
└── b --> <Tensor 0x7feac55c9250>
└── x --> tensor(11)
"""
return torch.dot(input, other, *args, **kwargs)
return stream_call(torch.dot, input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -112,4 +112,4 @@ def mm(input, mat2, *args, **kwargs):
└── x --> tensor([[44, 32],
[80, 59]])
"""
return torch.mm(input, mat2, *args, **kwargs)
return stream_call(torch.mm, input, mat2, *args, **kwargs)
......@@ -3,6 +3,7 @@ from hbutils.reflection import post_process
from treevalue import TreeValue
from .base import doc_from_base, func_treelize, auto_tensor
from ..stream import stream_call
__all__ = [
'cat', 'split', 'chunk', 'stack',
......@@ -112,7 +113,7 @@ def cat(tensors, *args, **kwargs):
└── y --> tensor([[51, 65, 65, 21, 29, 17, 46, 46, 46],
[54, 67, 57, 16, 11, 16, 30, 47, 36]])
"""
return torch.cat(tensors, *args, **kwargs)
return stream_call(torch.cat, tensors, *args, **kwargs)
# noinspection PyShadowingNames
......@@ -201,7 +202,7 @@ def split(tensor, split_size_or_sections, *args, **kwargs):
[58, 54, 78]]])
)
"""
return torch.split(tensor, split_size_or_sections, *args, **kwargs)
return stream_call(torch.split, tensor, split_size_or_sections, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -261,7 +262,7 @@ def chunk(input, chunks, *args, **kwargs):
[29, 65, 17, 72],
[53, 50, 75, 0]]])
"""
return torch.chunk(input, chunks, *args, **kwargs)
return stream_call(torch.chunk, input, chunks, *args, **kwargs)
@doc_from_base()
......@@ -352,7 +353,7 @@ def stack(tensors, *args, **kwargs):
[[18, 21, 17, 12],
[36, 30, 33, 31]]])
"""
return torch.stack(tensors, *args, **kwargs)
return stream_call(torch.stack, tensors, *args, **kwargs)
# noinspection PyShadowingBuiltins
......@@ -399,7 +400,7 @@ def reshape(input, shape):
[11, 13]])
"""
return torch.reshape(input, shape)
return stream_call(torch.reshape, input, shape)
# noinspection PyShadowingBuiltins
......@@ -434,7 +435,7 @@ def squeeze(input, *args, **kwargs):
└── b --> <Size 0x7fa4c1afe710>
└── x --> torch.Size([2, 3])
"""
return torch.squeeze(input, *args, *kwargs)
return stream_call(torch.squeeze, input, *args, *kwargs)
# noinspection PyShadowingBuiltins
......@@ -469,7 +470,7 @@ def unsqueeze(input, dim):
└── b --> <Size 0x7f5d1a5c99b0>
└── x --> torch.Size([2, 1, 3])
"""
return torch.unsqueeze(input, dim)
return stream_call(torch.unsqueeze, input, dim)
@doc_from_base()
......@@ -518,7 +519,7 @@ def where(condition, x, y):
[[ 0, 3, 93, 89],
[ 0, 89, 85, 0]]])
"""
return torch.where(condition, x, y)
return stream_call(torch.where, condition, x, y)
# noinspection PyShadowingBuiltins
......@@ -586,4 +587,4 @@ def index_select(input, dim, index, *args, **kwargs):
[-2.1694, -0.4224, 0.3998],
[ 0.9777, -0.0101, -1.1500]])
"""
return torch.index_select(input, dim, index, *args, **kwargs)
return stream_call(torch.index_select, input, dim, index, *args, **kwargs)
......@@ -7,7 +7,7 @@ _stream_pool: Optional[List[torch.cuda.Stream]] = None
_global_streams: Optional[List[torch.cuda.Stream]] = None
__all__ = [
'stream',
'stream', 'stream_call',
]
......
......@@ -5,6 +5,7 @@ from treevalue import method_treelize, TreeValue, typetrans
from .base import Torch, rmreduce, post_reduce, auto_reduce
from .size import Size
from .stream import stream_call
from ..common import Object, ireduce, clsmeta, return_self, auto_tree, get_tree_proxy
from ..numpy import ndarray
from ..utils import current_names, class_autoremove, replaceable_partial
......@@ -152,7 +153,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
This tensor and the returned :class:`treetensor.numpy.ndarray` share the same underlying storage.
Changes to self tensor will be reflected in the ``ndarray`` and vice versa.
"""
return self.numpy()
return stream_call(self.numpy, )
@doc_from_base()
@method_treelize(return_type=Object)
......@@ -175,7 +176,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
'c': True,
})
"""
return self.tolist()
return stream_call(self.tolist, )
@doc_from_base()
@method_treelize()
......@@ -186,7 +187,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
If this tree tensor is already in CPU memory and on the correct device,
then no copy is performed and the original object is returned.
"""
return self.cpu(*args, **kwargs)
return stream_call(self.cpu, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -197,7 +198,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
If this tree tensor is already in CUDA memory and on the correct device,
then no copy is performed and the original object is returned.
"""
return self.cuda(*args, **kwargs)
return stream_call(self.cuda, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -221,7 +222,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
└── x --> tensor([[4., 5.],
[6., 7.]], dtype=torch.float64)
"""
return self.to(*args, **kwargs)
return stream_call(self.to, *args, **kwargs)
@doc_from_base()
@ireduce(sum)
......@@ -230,7 +231,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.numel`
"""
return self.numel()
return stream_call(self.numel, )
@property
@doc_from_base()
......@@ -346,7 +347,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
[ 0.2531, -0.0637, 0.9822, 2.1618],
[ 2.0140, -0.0929, 0.9304, 1.5430]], requires_grad=True)
"""
return self.requires_grad_(requires_grad)
return stream_call(self.requires_grad_, requires_grad)
@doc_from_base()
@method_treelize()
......@@ -354,7 +355,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.detach`.
"""
return self.detach()
return stream_call(self.detach, )
@doc_from_base()
@return_self
......@@ -363,7 +364,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.detach`.
"""
return self.detach_()
return stream_call(self.detach_, )
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(pytorch.all)
......@@ -513,7 +514,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.clone`.
"""
return self.clone(*args, **kwargs)
return stream_call(self.clone, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -521,7 +522,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.dot`.
"""
return self.dot(other, *args, **kwargs)
return stream_call(self.dot, other, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -529,7 +530,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.mm`.
"""
return self.mm(mat2, *args, **kwargs)
return stream_call(self.mm, mat2, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -537,7 +538,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.matmul`.
"""
return self.matmul(tensor2, *args, **kwargs)
return stream_call(self.matmul, tensor2, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -545,7 +546,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.isfinite`.
"""
return self.isfinite()
return stream_call(self.isfinite, )
@doc_from_base()
@method_treelize()
......@@ -553,7 +554,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.isinf`.
"""
return self.isinf()
return stream_call(self.isinf, )
@doc_from_base()
@method_treelize()
......@@ -561,7 +562,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.isnan`.
"""
return self.isnan()
return stream_call(self.isnan, )
@doc_from_base()
@method_treelize()
......@@ -569,7 +570,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.isclose`.
"""
return self.isclose(other, *args, **kwargs)
return stream_call(self.isclose, other, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -577,7 +578,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.abs`.
"""
return self.abs(*args, **kwargs)
return stream_call(self.abs, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -586,7 +587,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.abs_`.
"""
return self.abs_(*args, **kwargs)
return stream_call(self.abs_, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -594,7 +595,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.clamp`.
"""
return self.clamp(*args, **kwargs)
return stream_call(self.clamp, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -603,7 +604,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.clamp_`.
"""
return self.clamp_(*args, **kwargs)
return stream_call(self.clamp_, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -611,7 +612,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.sign`.
"""
return self.sign(*args, **kwargs)
return stream_call(self.sign, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -620,7 +621,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.sign`.
"""
return self.sign_(*args, **kwargs)
return stream_call(self.sign_, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -628,7 +629,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.sigmoid`.
"""
return self.sigmoid(*args, **kwargs)
return stream_call(self.sigmoid, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -637,7 +638,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.sigmoid_`.
"""
return self.sigmoid_(*args, **kwargs)
return stream_call(self.sigmoid_, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -645,7 +646,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.floor`.
"""
return self.floor(*args, **kwargs)
return stream_call(self.floor, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -654,7 +655,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.floor_`.
"""
return self.floor_(*args, **kwargs)
return stream_call(self.floor_, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -662,7 +663,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.ceil`.
"""
return self.ceil(*args, **kwargs)
return stream_call(self.ceil, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -671,7 +672,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.ceil_`.
"""
return self.ceil_(*args, **kwargs)
return stream_call(self.ceil_, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -679,7 +680,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.round`.
"""
return self.round(*args, **kwargs)
return stream_call(self.round, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -688,7 +689,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.round_`.
"""
return self.round_(*args, **kwargs)
return stream_call(self.round_, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -696,7 +697,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.add`.
"""
return self.add(other, *args, **kwargs)
return stream_call(self.add, other, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -705,7 +706,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.add`.
"""
return self.add_(other, *args, **kwargs)
return stream_call(self.add_, other, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -713,7 +714,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.sub`.
"""
return self.sub(other, *args, **kwargs)
return stream_call(self.sub, other, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -722,7 +723,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.sub`.
"""
return self.sub_(other, *args, **kwargs)
return stream_call(self.sub_, other, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -730,7 +731,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.mul`.
"""
return self.mul(other, *args, **kwargs)
return stream_call(self.mul, other, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -739,7 +740,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.mul`.
"""
return self.mul_(other, *args, **kwargs)
return stream_call(self.mul_, other, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -747,7 +748,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.div`.
"""
return self.div(other, *args, **kwargs)
return stream_call(self.div, other, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -756,7 +757,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.div`.
"""
return self.div_(other, *args, **kwargs)
return stream_call(self.div_, other, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -764,7 +765,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.pow`.
"""
return self.pow(exponent, *args, **kwargs)
return stream_call(self.pow, exponent, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -773,7 +774,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.pow`.
"""
return self.pow_(exponent, *args, **kwargs)
return stream_call(self.pow_, exponent, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -781,7 +782,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.neg`.
"""
return self.neg(*args, **kwargs)
return stream_call(self.neg, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -790,7 +791,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.neg`.
"""
return self.neg_(*args, **kwargs)
return stream_call(self.neg_, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -798,7 +799,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.exp`.
"""
return self.exp(*args, **kwargs)
return stream_call(self.exp, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -807,7 +808,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.exp`.
"""
return self.exp_(*args, **kwargs)
return stream_call(self.exp_, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -815,7 +816,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.exp2`.
"""
return self.exp2(*args, **kwargs)
return stream_call(self.exp2, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -824,7 +825,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.exp2`.
"""
return self.exp2_(*args, **kwargs)
return stream_call(self.exp2_, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -832,7 +833,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.sqrt`.
"""
return self.sqrt(*args, **kwargs)
return stream_call(self.sqrt, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -841,7 +842,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.sqrt`.
"""
return self.sqrt_(*args, **kwargs)
return stream_call(self.sqrt_, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -849,7 +850,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.log`.
"""
return self.log(*args, **kwargs)
return stream_call(self.log, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -858,7 +859,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.log`.
"""
return self.log_(*args, **kwargs)
return stream_call(self.log_, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -866,7 +867,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.log2`.
"""
return self.log2(*args, **kwargs)
return stream_call(self.log2, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -875,7 +876,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.log2`.
"""
return self.log2_(*args, **kwargs)
return stream_call(self.log2_, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -883,7 +884,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.log10`.
"""
return self.log10(*args, **kwargs)
return stream_call(self.log10, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -892,7 +893,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.log10`.
"""
return self.log10_(*args, **kwargs)
return stream_call(self.log10_, *args, **kwargs)
@doc_from_base()
@post_process(__auto_tensor)
......@@ -901,7 +902,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.split`.
"""
return self.split(split_size, *args, **kwargs)
return stream_call(self.split, split_size, *args, **kwargs)
@doc_from_base()
@post_process(__auto_tensor)
......@@ -910,7 +911,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.chunk`.
"""
return self.chunk(chunks, *args, **kwargs)
return stream_call(self.chunk, chunks, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -918,7 +919,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.reshape`.
"""
return self.reshape(*args, **kwargs)
return stream_call(self.reshape, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -926,7 +927,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.squeeze`.
"""
return self.squeeze(*args, **kwargs)
return stream_call(self.squeeze, *args, **kwargs)
@doc_from_base()
@return_self
......@@ -935,7 +936,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.squeeze'.
"""
return self.squeeze_(*args, **kwargs)
return stream_call(self.squeeze_, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -943,7 +944,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.unsqueeze`.
"""
return self.unsqueeze(dim)
return stream_call(self.unsqueeze, dim)
@doc_from_base()
@return_self
......@@ -952,7 +953,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
In-place version of :meth:`Tensor.unsqueeze'.
"""
return self.unsqueeze_(dim)
return stream_call(self.unsqueeze_, dim)
@doc_from_base()
@method_treelize()
......@@ -962,7 +963,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
``treetensor.torch.where(condition, self, y)``.
See :func:`treetensor.torch.where`.
"""
return self.where(condition, y, *args, **kwargs)
return stream_call(self.where, condition, y, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -970,7 +971,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.index_select`.
"""
return self.index_select(dim, index)
return stream_call(self.index_select, dim, index)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@rmreduce()
......@@ -1044,7 +1045,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.dist`.
"""
return self.dist(other, *args, **kwargs)
return stream_call(self.dist, other, *args, **kwargs)
@doc_from_base()
@method_treelize()
......@@ -1052,4 +1053,4 @@ class Tensor(Torch, metaclass=_TensorMeta):
"""
See :func:`treetensor.torch.norm`.
"""
return self.norm(*args, **kwargs)
return stream_call(self.norm, *args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册