提交 84984ec1 编写于 作者: HansBug's avatar HansBug 😆

test, doc(hansbug): split treetensor.common.Object out of the original position

上级 745b5b8e
......@@ -4,5 +4,6 @@ treetensor.common
.. toctree::
:maxdepth: 3
trees.auto
wrappers.auto
object
trees
wrappers
treetensor.common.object
===============================
.. currentmodule:: treetensor.common
Object
-----------------
.. autoclass:: Object
:members: __init__
treetensor.common.trees
==============================
.. py:currentmodule:: treetensor.common.trees
print_tree
------------------
.. autofunction:: print_tree
BaseTreeStruct
-----------------------
.. autoclass:: BaseTreeStruct
:members:
treetensor.common.trees
\ No newline at end of file
treetensor.common.wrappers
===============================
.. py:currentmodule:: treetensor.common
kwreduce
-------------------
.. autofunction:: kwreduce
vreduce
-------------------
.. autofunction:: vreduce
ireduce
-------------------
.. autofunction:: ireduce
treetensor.common.wrappers
\ No newline at end of file
from .test_object import TestCommonObject
from .test_trees import TestCommonTrees
import pytest
from treevalue import typetrans, TreeValue
from treetensor.common import Object
@pytest.mark.unittest
class TestCommonObject:
def test_object(self):
t = Object(1)
assert isinstance(t, int)
assert t == 1
assert Object({'a': 1, 'b': 2}) == typetrans(TreeValue({
'a': 1, 'b': 2
}), Object)
......@@ -2,9 +2,9 @@ import io
import pytest
import torch
from treevalue import typetrans, TreeValue, general_tree_value
from treevalue import general_tree_value
from treetensor.common import Object, print_tree
from treetensor.common import print_tree
def text_compares(expected, actual):
......@@ -24,15 +24,6 @@ Expected: {e}
@pytest.mark.unittest
class TestCommonTrees:
def test_object(self):
t = Object(1)
assert isinstance(t, int)
assert t == 1
assert Object({'a': 1, 'b': 2}) == typetrans(TreeValue({
'a': 1, 'b': 2
}), Object)
def test_print_tree(self):
class _TempTree(general_tree_value()):
def __repr__(self):
......
from .object import *
from .trees import *
from .wrappers import *
from .trees import BaseTreeStruct, clsmeta
__all__ = [
"Object",
]
def _object(obj):
return obj
class Object(BaseTreeStruct, metaclass=clsmeta(_object, allow_dict=True)):
def __init__(self, data):
"""
In :class:`treetensor.common.Object`, object or object tree can be initialized.
Examples::
>>> from treetensor.common import Object
>>> Object(1)
1
>>> Object({'a': 1, 'b': 2, 'x': {'c': 233}})
<Object 0x7fe00b1153a0>
├── a --> 1
├── b --> 2
└── x --> <Object 0x7fe00b115ee0>
└── c --> 233
"""
super(BaseTreeStruct, self).__init__(data)
......@@ -13,7 +13,7 @@ from treevalue.utils import post_process
from ..utils import replaceable_partial, args_mapping
__all__ = [
'BaseTreeStruct', "Object",
'BaseTreeStruct',
'print_tree', 'clsmeta',
]
......@@ -118,11 +118,3 @@ def clsmeta(func, allow_dict: bool = False):
return _result
return _MetaClass
def _object(obj):
return obj
class Object(BaseTreeStruct, metaclass=clsmeta(_object, allow_dict=True)):
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册