...
 
Commits (5)
    https://gitcode.net/opendilab/treevalue/-/commit/f83db2f404ebe7c2cb83235842dd248bbace2f98 dev(hansbug): fix bug of jax integration 2023-03-15T20:02:20+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/4e5db0162ded35a12ff547ef2260b7363e4ca6ea dev(hansbug): add lower version of torch for test 2023-03-16T14:14:29+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/36a2c767f3e4ec2400c591b76b4a9b8b64458038 dev(hansbug): fix test 2023-03-16T14:20:32+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/5be0032d8dca37f2179b76037ee4ab4e2b4619ab dev(hansbug): add test for 1.13.1, for torch2.0.0 will be used when matrix.to... 2023-03-16T14:26:40+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/131d60c1029b5142cef0f38fcc2163eb0e8d23aa Merge pull request #84 from opendilab/fix/integration 2023-03-16T15:33:35+08:00 Hankson Bradley hansbug@buaa.edu.cn dev(hansbug): fix bug of jax integration
......@@ -26,6 +26,10 @@ jobs:
- '3.9'
- '3.10'
- '3.11'
torch-version:
- '1.7.1'
- '1.13.1'
- 'latest'
steps:
- name: Get system version for Linux
......@@ -94,6 +98,15 @@ jobs:
pip install -r requirements.txt
pip install -r requirements-build.txt
pip install -r requirements-test.txt
- name: Install Torch
shell: bash
if: ${{ matrix.torch-version != 'latest' }}
continue-on-error: true
run: |
pip install torch==${{ matrix.torch-version }}
- name: Install Extra Test Requirements
shell: bash
run: |
pip install -r requirements-test-extra.txt
- name: Test the basic environment
shell: bash
......
......@@ -7,6 +7,7 @@ from treevalue import FastTreeValue, register_for_jax
try:
import jax
from jax.tree_util import register_pytree_node
except (ModuleNotFoundError, ImportError):
jax = None
......@@ -20,10 +21,10 @@ class TestTreeTreeIntegration:
return x * 2 + 1.5
t1 = FastTreeValue({
'a': np.random.randint(0, 10, (2, 3)),
'a': np.random.randint(0, 10, (2, 3)) + 1,
'b': {
'x': np.asarray(233.0),
'y': np.random.randn(2, 3) + 1,
'y': np.random.randn(2, 3) + 100,
}
})
r1 = double(t1)
......@@ -37,10 +38,10 @@ class TestTreeTreeIntegration:
register_for_jax(MyTreeValue)
t2 = MyTreeValue({
'a': np.random.randint(0, 10, (2, 3)),
'a': np.random.randint(0, 10, (2, 3)) + 1,
'b': {
'x': np.asarray(233.0),
'y': np.random.randn(2, 3) + 1,
'y': np.random.randn(2, 3) + 100,
}
})
r2 = double(t2)
......
......@@ -6,6 +6,7 @@ from treevalue import FastTreeValue, register_for_torch
try:
import torch
from torch.utils._pytree import _register_pytree_node
except (ImportError, ModuleNotFoundError):
torch = None
......
......@@ -3,13 +3,16 @@ from functools import wraps
try:
import jax
from jax.tree_util import register_pytree_node
except (ModuleNotFoundError, ImportError):
from .cjax import register_for_jax as _original_register_for_jax
@wraps(_original_register_for_jax)
def register_for_jax(cls):
warnings.warn(f'Jax is not installed, registration of {cls!r} will be ignored.')
warnings.warn(f'Jax doesn\'t have tree_util module due to either not installed '
f'or the installed version is too low, '
f'so the registration of {cls!r} will be ignored.')
else:
from .cjax import register_for_jax
from ..tree import TreeValue
......
......@@ -3,13 +3,16 @@ from functools import wraps
try:
import torch
from torch.utils._pytree import _register_pytree_node
except (ModuleNotFoundError, ImportError):
from .ctorch import register_for_torch as _original_register_for_torch
@wraps(_original_register_for_torch)
def register_for_torch(cls):
warnings.warn(f'Torch is not installed, registration of {cls!r} will be ignored.')
warnings.warn(f'Pytree module is not included in the Torch installation '
f'or the installed version is too low, '
f'so the registration of {cls!r} will be ignored.')
else:
from .ctorch import register_for_torch
from ..tree import TreeValue
......