diff --git a/test/compare/deepmind/test_dm_tree.py b/test/compare/deepmind/test_dm_tree.py index dd76fb2aea125e00274699c8e936dec1d21a4a71..d9686d13c94e012e0175c0e2e08658c7da681d78 100644 --- a/test/compare/deepmind/test_dm_tree.py +++ b/test/compare/deepmind/test_dm_tree.py @@ -9,23 +9,40 @@ _TREE_1 = FastTreeValue(_TREE_DATA_1) @pytest.mark.benchmark(group='dm-tree') class TestCompareWithDMTree: - def test_dm_flatten_with_path(self, benchmark): - benchmark(tree.flatten_with_path, _TREE_DATA_1) + N = 5 - def test_dm_flatten_without_path(self, benchmark): - benchmark(tree.flatten, _TREE_DATA_1) + def __create_nested_tree_data(self, n): + return { + ('no_%04d' % (i + 1,)): _TREE_DATA_1 for i in range(n) + } - def test_tv_flatten(self, benchmark): - benchmark(flatten, _TREE_1) + def __create_nested_tree(self, n): + return FastTreeValue(self.__create_nested_tree_data(n)) - def test_dm_map_structure(self, benchmark): - benchmark(tree.map_structure, lambda x: x ** 2, _TREE_DATA_1) + @pytest.mark.parametrize('n', [2 ** i for i in range(N)]) + def test_dm_flatten_with_path(self, benchmark, n): + benchmark(tree.flatten_with_path, self.__create_nested_tree_data(n)) - def test_tv_mapping(self, benchmark): - benchmark(mapping, _TREE_1, lambda x: x ** 2) + @pytest.mark.parametrize('n', [2 ** i for i in range(N)]) + def test_dm_flatten_without_path(self, benchmark, n): + benchmark(tree.flatten, self.__create_nested_tree_data(n)) - def test_dm_map_structure_with_path(self, benchmark): - benchmark(tree.map_structure_with_path, lambda p, x: x * len(p), _TREE_DATA_1) + @pytest.mark.parametrize('n', [2 ** i for i in range(N)]) + def test_tv_flatten(self, benchmark, n): + benchmark(flatten, self.__create_nested_tree(n)) - def test_tv_mapping_with_path(self, benchmark): - benchmark(mapping, _TREE_1, lambda x, p: x * len(p)) + @pytest.mark.parametrize('n', [2 ** i for i in range(N)]) + def test_dm_map_structure(self, benchmark, n): + benchmark(tree.map_structure, lambda x: x ** 2, self.__create_nested_tree_data(n)) + + @pytest.mark.parametrize('n', [2 ** i for i in range(N)]) + def test_tv_mapping(self, benchmark, n): + benchmark(mapping, self.__create_nested_tree(n), lambda x: x ** 2) + + @pytest.mark.parametrize('n', [2 ** i for i in range(N)]) + def test_dm_map_structure_with_path(self, benchmark, n): + benchmark(tree.map_structure_with_path, lambda p, x: x * len(p), self.__create_nested_tree_data(n)) + + @pytest.mark.parametrize('n', [2 ** i for i in range(N)]) + def test_tv_mapping_with_path(self, benchmark, n): + benchmark(mapping, self.__create_nested_tree(n), lambda x, p: x * len(p)) diff --git a/test/compare/jax/test_jax.py b/test/compare/jax/test_jax.py index f735457bcfba108dd4e814543b5900ed52dd388a..6f1729935714d86147beb1fe770aaf537c8ab57d 100644 --- a/test/compare/jax/test_jax.py +++ b/test/compare/jax/test_jax.py @@ -11,20 +11,35 @@ _TREE_1 = FastTreeValue(_TREE_DATA_1) @pytest.mark.benchmark(group='jax-pytree') class TestCompareWithJaxPytree: - def test_jax_tree_map(self, benchmark): - benchmark(pytree.tree_map, lambda x: x ** 2, _TREE_DATA_1) + N = 5 - def test_tv_mapping(self, benchmark): - benchmark(mapping, _TREE_1, lambda x: x ** 2) + def __create_nested_tree_data(self, n): + return { + ('no_%04d' % (i + 1,)): _TREE_DATA_1 for i in range(n) + } + + def __create_nested_tree(self, n): + return FastTreeValue(self.__create_nested_tree_data(n)) + + @pytest.mark.parametrize('n', [2 ** i for i in range(N)]) + def test_jax_tree_map(self, benchmark, n): + benchmark(pytree.tree_map, lambda x: x ** 2, self.__create_nested_tree_data(n)) + + @pytest.mark.parametrize('n', [2 ** i for i in range(N)]) + def test_tv_mapping(self, benchmark, n): + benchmark(mapping, self.__create_nested_tree(n), lambda x: x ** 2) - def test_tv_mapping_with_path(self, benchmark): - benchmark(mapping, _TREE_1, lambda x, p: x ** 2) + @pytest.mark.parametrize('n', [2 ** i for i in range(N)]) + def test_tv_mapping_with_path(self, benchmark, n): + benchmark(mapping, self.__create_nested_tree(n), lambda x, p: x ** 2) - def test_jax_tree_flatten(self, benchmark): - benchmark(pytree.tree_flatten, _TREE_DATA_1) + @pytest.mark.parametrize('n', [2 ** i for i in range(N)]) + def test_jax_tree_flatten(self, benchmark, n): + benchmark(pytree.tree_flatten, self.__create_nested_tree_data(n)) - def test_tv_flatten(self, benchmark): - benchmark(flatten, _TREE_1) + @pytest.mark.parametrize('n', [2 ** i for i in range(N)]) + def test_tv_flatten(self, benchmark, n): + benchmark(flatten, self.__create_nested_tree(n)) def test_jax_tree_unflatten(self, benchmark): leaves, treedef = pytree.tree_flatten(_TREE_DATA_1)