提交 455fd137 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): add pshape

上级 7d50bf4e
......@@ -22,6 +22,17 @@ class TestCommonConstraintsShape:
assert c1.prefix == (2, 3, 4)
assert repr(c1) == '<ShapePrefixConstraint (2, 3, 4)>'
assert len(c1) == 3
assert c1[0] == 2
assert c1[1] == 3
assert c1[2] == 4
with pytest.raises(IndexError):
_ = c1[3]
assert c1[-1] == 4
assert c1[-2] == 3
assert c1[-3] == 2
assert c1[1:] == (3, 4)
c1.validate(np.random.rand(2, 3, 4))
c1.validate(np.random.rand(2, 3, 4, 5))
with pytest.raises(ValueError):
......
import numpy as np
import pytest
import torch
import treetensor.torch as ttorch
from treetensor.torch import TensorShapePrefixConstraint, shape_prefix
# noinspection DuplicatedCode
@pytest.mark.unittest
class TestCommonConstraintsShape:
def test_shape_prefix(self):
c1 = shape_prefix(2, 3, 4)
assert isinstance(c1, TensorShapePrefixConstraint)
assert c1.prefix == (2, 3, 4)
assert repr(c1) == '<TensorShapePrefixConstraint (2, 3, 4)>'
assert len(c1) == 3
assert c1[0] == 2
assert c1[1] == 3
assert c1[2] == 4
with pytest.raises(IndexError):
_ = c1[3]
assert c1[-1] == 4
assert c1[-2] == 3
assert c1[-3] == 2
assert c1[1:] == (3, 4)
with pytest.raises(TypeError):
c1.validate(np.random.rand(2, 3, 4))
with pytest.raises(TypeError):
c1.validate(np.random.rand(2, 3, 4, 5))
with pytest.raises(TypeError):
c1.validate(np.random.rand(2, 3))
with pytest.raises(TypeError):
c1.validate(np.random.rand(2, 3, 3))
with pytest.raises(TypeError):
c1.validate(np.random.rand(2, 3, 3, 4))
with pytest.raises(TypeError):
c1.validate([2, 3, 4, 5])
c1.validate(torch.randn(2, 3, 4))
c1.validate(torch.randn(2, 3, 4, 5))
with pytest.raises(ValueError):
c1.validate(torch.randn(2, 3))
with pytest.raises(ValueError):
c1.validate(torch.randn(2, 3, 3))
with pytest.raises(ValueError):
c1.validate(torch.randn(2, 3, 3, 4))
with pytest.raises(TypeError):
c1.validate([2, 3, 4, 5])
assert c1 == shape_prefix(2, 3, 4)
assert not c1 != shape_prefix(2, 3, 4)
assert c1 >= shape_prefix(2, 3, 4)
assert c1 <= shape_prefix(2, 3, 4)
assert not c1 > shape_prefix(2, 3, 4)
assert not c1 < shape_prefix(2, 3, 4)
assert not c1 == shape_prefix(2, 3)
assert c1 != shape_prefix(2, 3)
assert c1 >= shape_prefix(2, 3)
assert not c1 <= shape_prefix(2, 3)
assert c1 > shape_prefix(2, 3)
assert not c1 < shape_prefix(2, 3)
assert not c1 == shape_prefix(2, 3, 4, 5)
assert c1 != shape_prefix(2, 3, 4, 5)
assert not c1 >= shape_prefix(2, 3, 4, 5)
assert c1 <= shape_prefix(2, 3, 4, 5)
assert not c1 > shape_prefix(2, 3, 4, 5)
assert c1 < shape_prefix(2, 3, 4, 5)
assert not c1 == shape_prefix(2, 3, 3)
assert c1 != shape_prefix(2, 3, 3)
assert not c1 >= shape_prefix(2, 3, 3)
assert not c1 <= shape_prefix(2, 3, 3)
assert not c1 > shape_prefix(2, 3, 3)
assert not c1 < shape_prefix(2, 3, 3)
assert not c1 >= np.ndarray
assert not c1 > np.ndarray
assert c1 >= torch.Tensor
assert c1 > torch.Tensor
def test_pshape(self):
tt = ttorch.tensor({
'a': [[0.8479, 1.0074, 0.2725],
[1.1674, 1.0784, 0.0655]],
'b': {'x': [[0.2644, 0.7268, 0.2781, 0.6469],
[2.0015, 0.4448, 0.8814, 1.0063],
[0.1847, 0.5864, 0.4417, 0.2117]]},
})
assert tt.pshape is None
tt2 = tt.with_constraints(shape_prefix(2, 3), clear=False)
assert tt2.pshape == (2, 3)
from collections.abc import Sequence
from typing import Type, TypeVar, Optional
from treevalue.tree import ValueConstraint
......@@ -9,7 +10,7 @@ __all__ = [
]
class ShapePrefixConstraint(ValueConstraint):
class ShapePrefixConstraint(ValueConstraint, Sequence):
__type__: Optional[type] = None
def __init__(self, *prefix):
......@@ -20,6 +21,12 @@ class ShapePrefixConstraint(ValueConstraint):
def prefix(self):
return self.__prefix
def __getitem__(self, index):
return self.__prefix[index]
def __len__(self) -> int:
return len(self.__prefix)
def _validate_value(self, instance):
if self.__type__ and not isinstance(instance, self.__type__):
raise TypeError(f'Invalid type, {self.__type__.__name__!r} expected but {instance!r} found.')
......
......@@ -5,6 +5,8 @@ from typing import Iterable
import torch
from .constraints import *
from .constraints import __all__ as _constraints_all
from .funcs import *
from .funcs import __all__ as _funcs_all
from .funcs.base import get_func_from_torch
......@@ -17,6 +19,7 @@ from .tensor import __all__ as _tensor_all
from ..config.meta import __VERSION__
__all__ = [
*_constraints_all,
*_funcs_all,
*_size_all,
*_tensor_all,
......
from .shape import *
from .shape import __all__ as _shape_all
__all__ = [
*_shape_all
]
import torch
from ...common.constraints import ShapePrefixConstraint
from ...common.constraints import shape_prefix as _origin_shape_prefix
__all__ = [
'TensorShapePrefixConstraint', 'shape_prefix',
]
class TensorShapePrefixConstraint(ShapePrefixConstraint):
__type__ = torch.Tensor
def shape_prefix(*shape):
return _origin_shape_prefix(*shape, type_=TensorShapePrefixConstraint)
from typing import Tuple, Optional
import numpy as np
import torch as pytorch
from hbutils.reflection import post_process
from treevalue import method_treelize, TreeValue, typetrans
from .base import Torch, rmreduce, post_reduce, auto_reduce
from .constraints import TensorShapePrefixConstraint
from .size import Size
from .stream import stream_call
from ..common import Object, ireduce, clsmeta, return_self, auto_tree, get_tree_proxy
......@@ -116,6 +119,14 @@ class Tensor(Torch, metaclass=_TensorMeta):
else:
return tree
@property
def pshape(self) -> Optional[Tuple[int, ...]]:
constraint = self.constraint.access_first(TensorShapePrefixConstraint)
if constraint:
return constraint.prefix
else:
return None
@property
def torch(self):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册