size.py 688 字节
Newer Older
1
import torch
2
from treevalue import func_treelize as original_func_treelize
3

4 5 6 7
from ..common import TreeObject
from ..utils import replaceable_partial

func_treelize = replaceable_partial(original_func_treelize)
8 9 10


# noinspection PyTypeChecker
11
class TreeSize(TreeObject):
12 13 14 15 16 17 18 19 20 21 22
    @func_treelize(return_type=TreeObject)
    def numel(self: torch.Size) -> TreeObject:
        return self.numel()

    @func_treelize(return_type=TreeObject)
    def index(self: torch.Size, *args, **kwargs) -> TreeObject:
        return self.index(*args, **kwargs)

    @func_treelize(return_type=TreeObject)
    def count(self: torch.Size, *args, **kwargs) -> TreeObject:
        return self.count(*args, **kwargs)