提交 04cc5e72 编写于 作者: HansBug's avatar HansBug 😆

test(hansbug): add benchmark for speed test

上级 0f1096c7
...@@ -9,23 +9,40 @@ _TREE_1 = FastTreeValue(_TREE_DATA_1) ...@@ -9,23 +9,40 @@ _TREE_1 = FastTreeValue(_TREE_DATA_1)
@pytest.mark.benchmark(group='dm-tree') @pytest.mark.benchmark(group='dm-tree')
class TestCompareWithDMTree: class TestCompareWithDMTree:
def test_dm_flatten_with_path(self, benchmark): N = 5
benchmark(tree.flatten_with_path, _TREE_DATA_1)
def test_dm_flatten_without_path(self, benchmark): def __create_nested_tree_data(self, n):
benchmark(tree.flatten, _TREE_DATA_1) return {
('no_%04d' % (i + 1,)): _TREE_DATA_1 for i in range(n)
}
def test_tv_flatten(self, benchmark): def __create_nested_tree(self, n):
benchmark(flatten, _TREE_1) return FastTreeValue(self.__create_nested_tree_data(n))
def test_dm_map_structure(self, benchmark): @pytest.mark.parametrize('n', [2 ** i for i in range(N)])
benchmark(tree.map_structure, lambda x: x ** 2, _TREE_DATA_1) def test_dm_flatten_with_path(self, benchmark, n):
benchmark(tree.flatten_with_path, self.__create_nested_tree_data(n))
def test_tv_mapping(self, benchmark): @pytest.mark.parametrize('n', [2 ** i for i in range(N)])
benchmark(mapping, _TREE_1, lambda x: x ** 2) def test_dm_flatten_without_path(self, benchmark, n):
benchmark(tree.flatten, self.__create_nested_tree_data(n))
def test_dm_map_structure_with_path(self, benchmark): @pytest.mark.parametrize('n', [2 ** i for i in range(N)])
benchmark(tree.map_structure_with_path, lambda p, x: x * len(p), _TREE_DATA_1) def test_tv_flatten(self, benchmark, n):
benchmark(flatten, self.__create_nested_tree(n))
def test_tv_mapping_with_path(self, benchmark): @pytest.mark.parametrize('n', [2 ** i for i in range(N)])
benchmark(mapping, _TREE_1, lambda x, p: x * len(p)) def test_dm_map_structure(self, benchmark, n):
benchmark(tree.map_structure, lambda x: x ** 2, self.__create_nested_tree_data(n))
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_tv_mapping(self, benchmark, n):
benchmark(mapping, self.__create_nested_tree(n), lambda x: x ** 2)
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_dm_map_structure_with_path(self, benchmark, n):
benchmark(tree.map_structure_with_path, lambda p, x: x * len(p), self.__create_nested_tree_data(n))
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_tv_mapping_with_path(self, benchmark, n):
benchmark(mapping, self.__create_nested_tree(n), lambda x, p: x * len(p))
...@@ -11,20 +11,35 @@ _TREE_1 = FastTreeValue(_TREE_DATA_1) ...@@ -11,20 +11,35 @@ _TREE_1 = FastTreeValue(_TREE_DATA_1)
@pytest.mark.benchmark(group='jax-pytree') @pytest.mark.benchmark(group='jax-pytree')
class TestCompareWithJaxPytree: class TestCompareWithJaxPytree:
def test_jax_tree_map(self, benchmark): N = 5
benchmark(pytree.tree_map, lambda x: x ** 2, _TREE_DATA_1)
def test_tv_mapping(self, benchmark): def __create_nested_tree_data(self, n):
benchmark(mapping, _TREE_1, lambda x: x ** 2) return {
('no_%04d' % (i + 1,)): _TREE_DATA_1 for i in range(n)
}
def __create_nested_tree(self, n):
return FastTreeValue(self.__create_nested_tree_data(n))
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_jax_tree_map(self, benchmark, n):
benchmark(pytree.tree_map, lambda x: x ** 2, self.__create_nested_tree_data(n))
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_tv_mapping(self, benchmark, n):
benchmark(mapping, self.__create_nested_tree(n), lambda x: x ** 2)
def test_tv_mapping_with_path(self, benchmark): @pytest.mark.parametrize('n', [2 ** i for i in range(N)])
benchmark(mapping, _TREE_1, lambda x, p: x ** 2) def test_tv_mapping_with_path(self, benchmark, n):
benchmark(mapping, self.__create_nested_tree(n), lambda x, p: x ** 2)
def test_jax_tree_flatten(self, benchmark): @pytest.mark.parametrize('n', [2 ** i for i in range(N)])
benchmark(pytree.tree_flatten, _TREE_DATA_1) def test_jax_tree_flatten(self, benchmark, n):
benchmark(pytree.tree_flatten, self.__create_nested_tree_data(n))
def test_tv_flatten(self, benchmark): @pytest.mark.parametrize('n', [2 ** i for i in range(N)])
benchmark(flatten, _TREE_1) def test_tv_flatten(self, benchmark, n):
benchmark(flatten, self.__create_nested_tree(n))
def test_jax_tree_unflatten(self, benchmark): def test_jax_tree_unflatten(self, benchmark):
leaves, treedef = pytree.tree_flatten(_TREE_DATA_1) leaves, treedef = pytree.tree_flatten(_TREE_DATA_1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册