Source code for treetensor.torch.funcs.comparison

import builtins

import torch

from .base import doc_from_base, func_treelize
from ...common import ireduce

__all__ = [
    'equal',
    'isfinite', 'isinf', 'isnan', 'isclose',
    'eq', 'ne', 'lt', 'le', 'gt', 'ge',
]


# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @ireduce(builtins.all) @func_treelize() def equal(input, other): """ In ``treetensor``, you can get the equality of the two tree tensors. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.equal( ... torch.tensor([1, 2, 3]), ... torch.tensor([1, 2, 3]), ... ) # the same as torch.equal True >>> ttorch.equal( ... ttorch.tensor({ ... 'a': torch.tensor([1, 2, 3]), ... 'b': torch.tensor([[4, 5], [6, 7]]), ... }), ... ttorch.tensor({ ... 'a': torch.tensor([1, 2, 3]), ... 'b': torch.tensor([[4, 5], [6, 7]]), ... }), ... ) True """ return torch.equal(input, other)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def isfinite(input): """ In ``treetensor``, you can get a tree of new tensors with boolean elements representing if each element is `finite` or not. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) tensor([ True, False, True, False, False]) >>> ttorch.isfinite(ttorch.tensor({ ... 'a': [1, float('inf'), 2, float('-inf'), float('nan')], ... 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]} ... })) <Tensor 0x7fb782a15970> ├── a --> tensor([ True, False, True, False, False]) └── b --> <Tensor 0x7fb782a1e040> └── x --> tensor([[ True, False, True], [False, True, False]]) """ return torch.isfinite(input)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def isinf(input): """ In ``treetensor``, you can test if each element of ``input`` is infinite (positive or negative infinity) or not. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) tensor([False, True, False, True, False]) >>> ttorch.isinf(ttorch.tensor({ ... 'a': [1, float('inf'), 2, float('-inf'), float('nan')], ... 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]} ... })) <Tensor 0x7fb782a29b80> ├── a --> tensor([False, True, False, True, False]) └── b --> <Tensor 0x7fb782a2d1f0> └── x --> tensor([[False, True, False], [ True, False, False]]) """ return torch.isinf(input)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def isnan(input): """ In ``treetensor``, you get a tree of new tensors with boolean elements representing if each element of ``input`` is NaN or not Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.isnan(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) tensor([False, False, False, False, True]) >>> ttorch.isnan(ttorch.tensor({ ... 'a': [1, float('inf'), 2, float('-inf'), float('nan')], ... 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]} ... })) <Tensor 0x7fb782a2d0a0> ├── a --> tensor([False, False, False, False, True]) └── b --> <Tensor 0x7fb782a29d90> └── x --> tensor([[False, False, False], [False, False, True]]) """ return torch.isnan(input)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def isclose(input, other, *args, **kwargs): """ Returns a new tensor with boolean elements representing if each element of ``input`` is “close” to the corresponding element of ``other``. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> import math >>> ttorch.isclose( ... ttorch.tensor((1., 2, 3)), ... ttorch.tensor((1 + 1e-10, 3, 4)) ... ) tensor([ True, False, False]) >>> ttorch.isclose( ... ttorch.tensor({ ... 'a': [1., 2, 3], ... 'b': {'x': [[float('inf'), 4, 1e20], ... [-math.inf, 2.2943, 9483.32]]}, ... }), ... ttorch.tensor({ ... 'a': [1 + 1e-10, 3, 4], ... 'b': {'x': [[math.inf, 6, 1e20+1], ... [-float('inf'), 2.294300000001, 9484.32]]}, ... }), ... ) <Tensor 0x7f5b3219f370> ├── a --> tensor([ True, False, False]) └── b --> <Tensor 0x7f5b3219f550> └── x --> tensor([[ True, False, True], [ True, True, False]]) """ return torch.isclose(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def eq(input, other, *args, **kwargs): """ In ``treetensor``, you can get the equality of the two tree tensors with :func:`eq`. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.eq( ... torch.tensor([[1, 2], [3, 4]]), ... torch.tensor([[1, 1], [4, 4]]), ... ) tensor([[ True, False], [False, True]]) >>> ttorch.eq( ... ttorch.tensor({ ... 'a': [[1, 2], [3, 4]], ... 'b': [1.0, 1.5, 2.0], ... }), ... ttorch.tensor({ ... 'a': [[1, 1], [4, 4]], ... 'b': [1.3, 1.2, 2.0], ... }), ... ) <Tensor 0x7ff363bbce10> ├── a --> tensor([[ True, False], │ [False, True]]) └── b --> tensor([False, False, True]) """ return torch.eq(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def ne(input, other, *args, **kwargs): """ In ``treetensor``, you can get the non-equality of the two tree tensors with :func:`ne`. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.ne( ... torch.tensor([[1, 2], [3, 4]]), ... torch.tensor([[1, 1], [4, 4]]), ... ) tensor([[False, True], [ True, False]]) >>> ttorch.ne( ... ttorch.tensor({ ... 'a': [[1, 2], [3, 4]], ... 'b': [1.0, 1.5, 2.0], ... }), ... ttorch.tensor({ ... 'a': [[1, 1], [4, 4]], ... 'b': [1.3, 1.2, 2.0], ... }), ... ) <Tensor 0x7ff363bb6cf8> ├── a --> tensor([[False, True], │ [ True, False]]) └── b --> tensor([ True, True, False]) """ return torch.ne(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def lt(input, other, *args, **kwargs): """ In ``treetensor``, you can get less-than situation of the two tree tensors with :func:`lt`. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.lt( ... torch.tensor([[1, 2], [3, 4]]), ... torch.tensor([[1, 1], [4, 4]]), ... ) tensor([[False, False], [ True, False]]) >>> ttorch.lt( ... ttorch.tensor({ ... 'a': [[1, 2], [3, 4]], ... 'b': [1.0, 1.5, 2.0], ... }), ... ttorch.tensor({ ... 'a': [[1, 1], [4, 4]], ... 'b': [1.3, 1.2, 2.0], ... }), ... ) <Tensor 0x7ff363bc67f0> ├── a --> tensor([[False, False], │ [ True, False]]) └── b --> tensor([ True, False, False]) """ return torch.lt(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def le(input, other, *args, **kwargs): """ In ``treetensor``, you can get less-than-or-equal situation of the two tree tensors with :func:`le`. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.le( ... torch.tensor([[1, 2], [3, 4]]), ... torch.tensor([[1, 1], [4, 4]]), ... ) tensor([[ True, False], [ True, True]]) >>> ttorch.le( ... ttorch.tensor({ ... 'a': [[1, 2], [3, 4]], ... 'b': [1.0, 1.5, 2.0], ... }), ... ttorch.tensor({ ... 'a': [[1, 1], [4, 4]], ... 'b': [1.3, 1.2, 2.0], ... }), ... ) <Tensor 0x7ff363bc6198> ├── a --> tensor([[ True, False], │ [ True, True]]) └── b --> tensor([ True, False, True]) """ return torch.le(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def gt(input, other, *args, **kwargs): """ In ``treetensor``, you can get greater-than situation of the two tree tensors with :func:`gt`. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.gt( ... torch.tensor([[1, 2], [3, 4]]), ... torch.tensor([[1, 1], [4, 4]]), ... ) tensor([[False, True], [False, False]]) >>> ttorch.gt( ... ttorch.tensor({ ... 'a': [[1, 2], [3, 4]], ... 'b': [1.0, 1.5, 2.0], ... }), ... ttorch.tensor({ ... 'a': [[1, 1], [4, 4]], ... 'b': [1.3, 1.2, 2.0], ... }), ... ) <Tensor 0x7ff363bc6f28> ├── a --> tensor([[False, True], │ [False, False]]) └── b --> tensor([False, True, False]) """ return torch.gt(input, other, *args, **kwargs)
# noinspection PyShadowingBuiltins
[docs]@doc_from_base() @func_treelize() def ge(input, other, *args, **kwargs): """ In ``treetensor``, you can get greater-than-or-equal situation of the two tree tensors with :func:`ge`. Examples:: >>> import torch >>> import treetensor.torch as ttorch >>> ttorch.ge( ... torch.tensor([[1, 2], [3, 4]]), ... torch.tensor([[1, 1], [4, 4]]), ... ) tensor([[ True, True], [False, True]]) >>> ttorch.ge( ... ttorch.tensor({ ... 'a': [[1, 2], [3, 4]], ... 'b': [1.0, 1.5, 2.0], ... }), ... ttorch.tensor({ ... 'a': [[1, 1], [4, 4]], ... 'b': [1.3, 1.2, 2.0], ... }), ... ) <Tensor 0x7ff363bc6f28> ├── a --> tensor([[ True, True], │ [False, True]]) └── b --> tensor([False, True, True]) """ return torch.ge(input, other, *args, **kwargs)