未验证 提交 e756aaba 编写于 作者: HansBug's avatar HansBug 😆 提交者: GitHub

Merge pull request #6 from opendilab/dev/np2tensor

dev(hansbug): add tensor method for treetensor.numpy.ndarray
import numpy as np
import pytest
import torch
import treetensor.numpy as tnp
import treetensor.torch as ttorch
from treetensor.common import Object
......@@ -233,3 +235,22 @@ class TestNumpyArray:
'd': [0, 0, 0.0],
}
})
def test_tensor(self):
assert (self._DEMO_1.tensor() == ttorch.Tensor({
'a': ttorch.Tensor([[1, 2, 3], [4, 5, 6]]),
'b': ttorch.Tensor([1, 3, 5, 7]),
'x': {
'c': ttorch.Tensor([[11], [23]]),
'd': ttorch.Tensor([3, 9, 11.0])
}
})).all()
assert (self._DEMO_1.tensor(dtype=torch.float64) == ttorch.Tensor({
'a': ttorch.Tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float64),
'b': ttorch.Tensor([1, 3, 5, 7], dtype=torch.float64),
'x': {
'c': ttorch.Tensor([[11], [23]], dtype=torch.float64),
'd': ttorch.Tensor([3, 9, 11.0], dtype=torch.float64),
}
})).all()
from functools import lru_cache
import numpy
import torch
from treevalue import method_treelize
from .base import TreeNumpy
......@@ -12,6 +15,12 @@ __all__ = [
_ArrayProxy, _InstanceArrayProxy = get_tree_proxy(numpy.ndarray)
@lru_cache()
def _get_tensor_class(args0):
from ..torch import Tensor
return Tensor(args0)
class _BaseArrayMeta(clsmeta(numpy.asarray, allow_dict=True)):
pass
......@@ -92,6 +101,13 @@ class ndarray(TreeNumpy, metaclass=_ArrayMeta):
def any(self: numpy.ndarray, *args, **kwargs):
return self.any(*args, **kwargs)
@method_treelize(return_type=_get_tensor_class)
def tensor(self: numpy.ndarray, *args, **kwargs):
tensor_: torch.Tensor = torch.from_numpy(self)
if args or kwargs:
tensor_ = tensor_.to(*args, **kwargs)
return tensor_
@method_treelize()
def __eq__(self, other):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册