未验证 提交 b61551a6 编写于 作者: HansBug's avatar HansBug 😆 提交者: GitHub

Merge pull request #78 from opendilab/dev/jax

dev(hansbug): add register support for treevalue
......@@ -94,10 +94,6 @@ jobs:
pip install -r requirements.txt
pip install -r requirements-build.txt
pip install -r requirements-test.txt
- name: Install extra PyPI dependencies
continue-on-error: true
shell: bash
run: |
pip install -r requirements-test-extra.txt
- name: Test the basic environment
shell: bash
......
......@@ -8,3 +8,4 @@ treevalue.tree
tree
func
general
integration
treevalue.tree.integration
=======================================
.. py:currentmodule:: treevalue.tree.integration
.. _apidoc_tree_integration_register_for_jax:
register_for_jax
------------------------
.. autofunction:: register_for_jax
torch>=1.1.0
jax[cpu]>=0.3.25; platform_system != 'Windows'
torch>=1.1.0; python_version < '3.11'
from unittest import skipUnless
import numpy as np
import pytest
from treevalue import FastTreeValue
try:
import jax
except (ModuleNotFoundError, ImportError):
jax = None
@pytest.mark.unittest
class TestTreeTreeIntegration:
@skipUnless(jax, 'Jax required.')
def test_jax_double(self):
@jax.jit
def double(x):
return x * 2 + 1.5
t1 = FastTreeValue({
'a': np.random.randint(0, 10, (2, 3)),
'b': {
'x': np.asarray(233.0),
'y': np.random.randn(2, 3)
}
})
assert FastTreeValue.func()(np.isclose)(double(t1), t1 * 2 + 1.5).all() == \
FastTreeValue({'a': True, 'b': {'x': True, 'y': True}})
from .common import raw
from .func import *
from .general import *
from .integration import *
from .tree import *
from .jax import register_for_jax
# distutils:language=c++
# cython:language_level=3
cdef tuple _c_flatten_for_jax(object tv)
cdef object _c_unflatten_for_jax(tuple aux, tuple values)
cpdef void register_for_jax(object cls) except*
# distutils:language=c++
# cython:language_level=3
import cython
from ..tree.flatten cimport _c_flatten, _c_unflatten
from ..tree.tree cimport TreeValue
cdef inline tuple _c_flatten_for_jax(object tv):
cdef list result = []
_c_flatten(tv._detach(), (), result)
cdef list paths = []
cdef list values = []
for path, value in result:
paths.append(path)
values.append(value)
return values, (type(tv), paths)
cdef inline object _c_unflatten_for_jax(tuple aux, tuple values):
cdef object type_
cdef list paths
type_, paths = aux
return type_(_c_unflatten(zip(paths, values)))
@cython.binding(True)
cpdef void register_for_jax(object cls) except*:
"""
Overview:
Register treevalue class for jax.
:param cls: TreeValue class.
Examples::
>>> from treevalue import FastTreeValue, TreeValue, register_for_jax
>>> register_for_jax(TreeValue)
>>> register_for_jax(FastTreeValue)
.. warning::
This method will put a warning message and then do nothing when jax is not installed.
"""
if isinstance(cls, type) and issubclass(cls, TreeValue):
import jax
jax.tree_util.register_pytree_node(cls, _c_flatten_for_jax, _c_unflatten_for_jax)
else:
raise TypeError(f'Registered class should be a subclass of TreeValue, but {cls!r} found.')
import warnings
from functools import wraps
try:
import jax
except (ModuleNotFoundError, ImportError):
from .cjax import register_for_jax as _original_register_for_jax
@wraps(_original_register_for_jax)
def register_for_jax(cls):
warnings.warn(f'Jax is not installed, registration of {cls!r} will be ignored.')
else:
from .cjax import register_for_jax
from ..tree import TreeValue
from ..general import FastTreeValue
register_for_jax(TreeValue)
register_for_jax(FastTreeValue)
......@@ -8,7 +8,7 @@ import cython
from .tree cimport TreeValue
from ..common.storage cimport TreeStorage, _c_undelay_data
cdef void _c_flatten(TreeStorage st, tuple path, list res) except *:
cdef inline void _c_flatten(TreeStorage st, tuple path, list res) except *:
cdef dict data = st.detach()
cdef tuple curpath
......@@ -44,7 +44,7 @@ cpdef list flatten(TreeValue tree):
_c_flatten(tree._detach(), (), result)
return result
cdef void _c_flatten_values(TreeStorage st, list res) except *:
cdef inline void _c_flatten_values(TreeStorage st, list res) except *:
cdef dict data = st.detach()
cdef str k
......@@ -72,7 +72,7 @@ cpdef list flatten_values(TreeValue tree):
_c_flatten_values(tree._detach(), result)
return result
cdef void _c_flatten_keys(TreeStorage st, tuple path, list res) except *:
cdef inline void _c_flatten_keys(TreeStorage st, tuple path, list res) except *:
cdef dict data = st.detach()
cdef tuple curpath
......@@ -102,7 +102,7 @@ cpdef list flatten_keys(TreeValue tree):
_c_flatten_keys(tree._detach(), (), result)
return result
cdef TreeStorage _c_unflatten(object pairs):
cdef inline TreeStorage _c_unflatten(object pairs):
cdef dict raw_data = {}
cdef TreeStorage result = TreeStorage(raw_data)
cdef list stack = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册