提交 741674aa 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): add tensor method for treetensor.numpy.ndarray class

上级 1e922e0e
import numpy as np import numpy as np
import pytest import pytest
import torch
import treetensor.numpy as tnp import treetensor.numpy as tnp
import treetensor.torch as ttorch
from treetensor.common import Object from treetensor.common import Object
...@@ -233,3 +235,22 @@ class TestNumpyArray: ...@@ -233,3 +235,22 @@ class TestNumpyArray:
'd': [0, 0, 0.0], 'd': [0, 0, 0.0],
} }
}) })
def test_tensor(self):
assert self._DEMO_1.tensor() == ttorch.Tensor({
'a': ttorch.Tensor([[1, 2, 3], [4, 5, 6]]),
'b': ttorch.Tensor([1, 3, 5, 7]),
'x': {
'c': ttorch.Tensor([[11], [23]]),
'd': ttorch.Tensor([3, 9, 11.0])
}
})
assert self._DEMO_1.tensor(dtype=torch.float64) == ttorch.Tensor({
'a': ttorch.Tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float64),
'b': ttorch.Tensor([1, 3, 5, 7], dtype=torch.float64),
'x': {
'c': ttorch.Tensor([[11], [23]], dtype=torch.float64),
'd': ttorch.Tensor([3, 9, 11.0], dtype=torch.float64),
}
})
from functools import lru_cache
import numpy import numpy
import torch
from treevalue import method_treelize from treevalue import method_treelize
from .base import TreeNumpy from .base import TreeNumpy
...@@ -12,6 +15,12 @@ __all__ = [ ...@@ -12,6 +15,12 @@ __all__ = [
_ArrayProxy, _InstanceArrayProxy = get_tree_proxy(numpy.ndarray) _ArrayProxy, _InstanceArrayProxy = get_tree_proxy(numpy.ndarray)
@lru_cache()
def _get_tensor_class(args0):
from ..torch import Tensor
return Tensor
class _BaseArrayMeta(clsmeta(numpy.asarray, allow_dict=True)): class _BaseArrayMeta(clsmeta(numpy.asarray, allow_dict=True)):
pass pass
...@@ -92,6 +101,13 @@ class ndarray(TreeNumpy, metaclass=_ArrayMeta): ...@@ -92,6 +101,13 @@ class ndarray(TreeNumpy, metaclass=_ArrayMeta):
def any(self: numpy.ndarray, *args, **kwargs): def any(self: numpy.ndarray, *args, **kwargs):
return self.any(*args, **kwargs) return self.any(*args, **kwargs)
@method_treelize(return_type=_get_tensor_class)
def tensor(self: numpy.ndarray, *args, **kwargs):
tensor_: torch.Tensor = torch.from_numpy(self)
if args or kwargs:
tensor_ = tensor_.to(*args, **kwargs)
return tensor_
@method_treelize() @method_treelize()
def __eq__(self, other): def __eq__(self, other):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册