...
 
Commits (15)
    https://gitcode.net/opendilab/treevalue/-/commit/fc627ee47184188d7159ac96712bb5502537cbe1 dev(hansbug): try adapt torch2 2023-04-16T21:42:20+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/1fef8687d557f23d599a0e85bfc0c198f0ad623d dev(hansbug): disable torch.compile on windows, install libomp for macos 2023-04-16T21:54:00+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/63ded6e8831f060e2fe41406cb653023785cf6fd dev(hansbug): linux only for torch.compile 2023-04-16T22:22:03+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/4ef9bad58252ab39d9b20f9bb25b3cda966fa869 dev(narugo): run unittest of treevalue 2023-08-10T10:50:32+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/c7920bde1b814d9ed9ff9c4df59513ad47380caa dev(narugo): upload for fixing platform_system 2023-08-10T11:15:31+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/88705f32b4717078271d3bdc1a2b19c882027e40 Merge branch 'main' into dev/torch2 2023-08-10T16:20:42+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/e0c78e298698e5881781aa12f1879f4287d4237c dev(narugo): torch.compile not support py3.11 yet 2023-08-10T16:58:41+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/6b01608829a4d0310dab6909a8eb282f6f85bc5d dev(narugo): try fix build bug 2023-08-10T19:09:29+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/cd6c5f10979ab3d72d9776d8579e04e350c613ca dev(narugo): try fix cython 2023-08-10T21:08:51+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/8ee58a75683a04ed08b2b1c495d67dbe27092283 dev(narugo): use cython 0.29 only for windows 2023-08-10T21:43:37+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/7dd49832d9a7a8480ef176beeb810206541691b2 dev(narugo): add buggy test 2023-08-11T11:11:00+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/13b20480076f112de9300aefb6458dfc4e8c685b Merge pull request #85 from opendilab/dev/torch2 2023-08-11T16:23:14+08:00 Hankson Bradley hansbug@buaa.edu.cn dev(hansbug): try adapt torch2 https://gitcode.net/opendilab/treevalue/-/commit/44a89b700cd7654ff4fe085b7a92c5218ff0dc61 dev(narugo): torch compile's test 2023-08-12T13:38:35+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/d415944ffb6a4a6d14486ebebc466ce0db4842c4 Merge pull request #87 from opendilab/test/torch2 2023-08-14T12:13:09+08:00 Hankson Bradley hansbug@buaa.edu.cn dev(narugo): torch compile's test https://gitcode.net/opendilab/treevalue/-/commit/bd171e47eddd5b95747dbb64747e18068f6b3274 dev(hansbug): use version v1.4.12 2023-08-14T12:14:21+08:00 HansBug hansbug@buaa.edu.cn
......@@ -7,6 +7,7 @@ on:
- dev/*
- fix/*
- test/*
workflow_dispatch:
jobs:
unittest:
......
......@@ -2,7 +2,8 @@
requires = [
"setuptools>=42",
"wheel",
"Cython",
"cython>=0.29; platform_system != 'Windows'",
"cython>=0.29,<3; platform_system == 'Windows'",
]
[tool.cibuildwheel]
......
cython>=0.29
cython>=0.29; platform_system != 'Windows'
cython>=0.29,<3; platform_system == 'Windows'
build>=0.7.0
auditwheel>=4
\ No newline at end of file
enum_tools
graphviz~=0.17
dill~=0.3.4
graphviz>=0.17
dill>=0.3.4
click>=7.1.0
hbutils>=0.9.1
\ No newline at end of file
from typing import Tuple, Mapping
from unittest import skipUnless
import pytest
from hbutils.testing import vpip, OS, vpython
from treevalue import FastTreeValue, register_for_torch
......@@ -60,3 +62,102 @@ class TestTreeIntegrationTorch:
with pytest.warns(UserWarning):
register_for_torch(MyTreeValueX)
@skipUnless(vpip('torch') >= '2.0.0' and OS.linux and vpython < '3.11', 'Torch 2 on linux platform required')
def test_torch_compile(self):
@torch.compile
def foo(x, y, t):
z = (x + y * 2000) / (t - 100)
return z
a = torch.randn(3, 4)
b = torch.randn(3, 4)
c = torch.randn(3, 4)
assert torch.isclose(foo(a, b, c), (a + b * 2000) / (c - 100)).all()
@skipUnless(vpip('torch') >= '2.0.0' and OS.linux and vpython < '3.11', 'Torch 2 on linux platform required')
def test_torch_compile_buggy(self):
@torch.compile
def foox(x, y):
z = x + y
return z
x = FastTreeValue({
'a': torch.randn(3, 4) + 200,
'b': torch.randn(5) - 300,
})
y = FastTreeValue({
'a': torch.rand(4) + 500,
'b': torch.randn(4, 5) + 1000,
})
_t_isclose = FastTreeValue.func()(torch.isclose)
assert _t_isclose(foox(x, y), x + y).all() == \
FastTreeValue({'a': torch.tensor(True), 'b': torch.tensor(True)})
@skipUnless(vpip('torch') >= '2.0.0' and OS.linux and vpython < '3.11', 'Torch 2 on linux platform required')
def test_with_module(self):
from torch import nn
class MLP(nn.Module):
def __init__(self, in_features: int, out_features: int, layers: Tuple[int, ...] = (1024,)):
nn.Module.__init__(self)
self.in_features = in_features
self.out_features = out_features
self.layers = layers
ios = [self.in_features, *self.layers, self.out_features]
self.mlp = nn.Sequential(
*(
nn.Linear(in_, out_, bias=True)
for in_, out_ in zip(ios[:-1], ios[1:])
)
)
def forward(self, x):
return self.mlp(x)
class MultiHeadMLP(nn.Module):
def __init__(self, in_features: int, out_features: Mapping[str, int], layers: Tuple[int, ...] = (1024,)):
nn.Module.__init__(self)
self.in_features = in_features
self.out_features = out_features
self.layers = layers
_networks = {
o_name: MLP(in_features, o_feat, layers)
for o_name, o_feat in self.out_features.items()
}
self.mlps = nn.ModuleDict(_networks)
self._t_mlps = FastTreeValue(_networks)
def forward(self, x):
return self._t_mlps(x)
net = MultiHeadMLP(
20,
{'a': 10, 'b': 20, 'c': 14, 'd': 3},
)
net = torch.compile(net)
input1 = torch.randn(3, 20)
output1 = net(input1)
assert output1.shape == FastTreeValue({
'a': torch.Size([3, 10]),
'b': torch.Size([3, 20]),
'c': torch.Size([3, 14]),
'd': torch.Size([3, 3]),
})
input2 = FastTreeValue.func()(torch.randn)(FastTreeValue({
'a': (3, 20),
'b': (4, 20),
'c': (20,),
'd': (2, 5, 20),
}))
output2 = net(input2)
assert output2.shape == FastTreeValue({
'a': torch.Size([3, 10]),
'b': torch.Size([4, 20]),
'c': torch.Size([14]),
'd': torch.Size([2, 5, 3]),
})
......@@ -7,7 +7,7 @@ Overview:
__TITLE__ = "treevalue"
#: Version of this project.
__VERSION__ = "1.4.11"
__VERSION__ = "1.4.12"
#: Short description of the project, will be included in ``setup.py``.
__DESCRIPTION__ = 'A flexible, generalized tree-based data structure.'
......
此差异已折叠。