未验证 提交 38871e4d 编写于 作者: HansBug's avatar HansBug 😆 提交者: GitHub

Merge pull request #24 from opendilab/test/benchmark

doc(hansbug): add new benchmark and graphs of speed performance 
......@@ -158,11 +158,11 @@ Here is the speed performance of all the operations in `FastTreeValue`, the foll
The following 2 tables are the performance comparison result with [jax pytree](https://github.com/google/jax).
| | mapping | mapping(with path) | flatten | unflatten | flatten_values | flatten_keys |
| --------------------------------------------------- | :------------------: | :-------------------: | :------------------: | :------------------: | :------------------: | :------------------: |
| [treevalue](https://github.com/opendilab/treevalue) | **3.2 µs ± 85.5 ns** | **2.16 µs ± 123 ns** | **515 ns ± 7.53 ns** | **601 ns ± 5.99 ns** | **301 ns ± 12.9 ns** | **451 ns ± 17.3 ns** |
| | **tree_map** | **(Not Implemented)** | **tree_flatten** | **tree_unflatten** | **tree_leaves** | **tree_structure** |
| [jax pytree](https://github.com/google/jax) | 4.67 µs ± 184 ns | --- | 1.29 µs ± 27.2 ns | 742 ns ± 5.82 ns | 1.29 µs ± 22 ns | 1.27 µs ± 16.5 ns |
| | mapping | mapping(with path) | flatten | unflatten | flatten_values | flatten_keys |
| --------------------------------------------------- | :-------------------: | :-------------------: | :------------------: | :------------------: | :------------------: | :------------------: |
| [treevalue](https://github.com/opendilab/treevalue) | **2.21 µs ± 32.2 ns** | **2.16 µs ± 123 ns** | **515 ns ± 7.53 ns** | **601 ns ± 5.99 ns** | **301 ns ± 12.9 ns** | **451 ns ± 17.3 ns** |
| | **tree_map** | **(Not Implemented)** | **tree_flatten** | **tree_unflatten** | **tree_leaves** | **tree_structure** |
| [jax pytree](https://github.com/google/jax) | 4.67 µs ± 184 ns | --- | 1.29 µs ± 27.2 ns | 742 ns ± 5.82 ns | 1.29 µs ± 22 ns | 1.27 µs ± 16.5 ns |
| | flatten + all | flatten + reduce | flatten + reduce(with init) | rise(given structure) | rise(automatic structure) |
| --------------------------------------------------- | :------------------: | :------------------: | :-------------------------: | :-------------------: | :-----------------------: |
......@@ -170,6 +170,12 @@ The following 2 tables are the performance comparison result with [jax pytree](h
| | **tree_all** | **tree_reduce** | **tree_reduce(with init)** | **tree_transpose** | **(Not Implemented)** |
| [jax pytree](https://github.com/google/jax) | 1.47 µs ± 37 ns | 1.88 µs ± 27.2 ns | 1.91 µs ± 47.4 ns | 10 µs ± 117 ns | --- |
This is the comparison between dm-tree, jax-libtree and us, with `flatten` and `mapping` operations (**lower value means less time cost and runs faster**)
![Time cost of flatten operation](docs/source/_static/Time%20cost%20of%20flatten%20operation.svg)
![Time cost of mapping operation](docs/source/_static/Time%20cost%20of%20mapping%20operation.svg)
The following table is the performance comparison result with [tianshou Batch](https://github.com/thu-ml/tianshou).
| | get | set | init | deepcopy | stack | cat | split |
......@@ -177,18 +183,19 @@ The following table is the performance comparison result with [tianshou Batch](h
| [treevalue](https://github.com/opendilab/treevalue) | 51.6 ns ± 0.609 ns | **64.4 ns ± 0.564 ns** | **750 ns ± 14.2 ns** | **88.9 µs ± 887 ns** | **50.2 µs ± 771 ns** | **40.3 µs ± 1.08 µs** | **62 µs ± 1.2 µs** |
| [tianshou Batch](https://github.com/thu-ml/tianshou) | **43.2 ns ± 0.698 ns** | 396 ns ± 8.99 ns | 11.1 µs ± 277 ns | 89 µs ± 1.42 µs | 119 µs ± 1.1 µs | 194 µs ± 1.81 µs | 653 µs ± 17.8 µs |
And this is the comparasion between tianshou Batch and us, with `cat` , `stack` and `split` operations
And this is the comparison between tianshou Batch and us, with `cat` , `stack` and `split` operations (**lower value means less time cost and runs faster**)
![Time cost of cat operation](docs/source/_static/Time%20cost%20of%20cat%20operation.png)
![Time cost of cat operation](docs/source/_static/Time%20cost%20of%20cat%20operation.svg)
![Time cost of stack operation](docs/source/_static/Time%20cost%20of%20stack%20operation.png)
![Time cost of stack operation](docs/source/_static/Time%20cost%20of%20stack%20operation.svg)
![Time cost of split operation](docs/source/_static/Time%20cost%20of%20split%20operation.png)
![Time cost of split operation](docs/source/_static/Time%20cost%20of%20split%20operation.svg)
Test benchmark code can be found here:
* [Comparasion with dm-tree](https://github.com/opendilab/treevalue/blob/main/test/compare/test_dm_tree.py)
* [Comparasion with tianshou Batch](https://github.com/opendilab/treevalue/blob/main/test/compare/test_tianshou_batch.py)
* [Comparison with dm-tree](https://github.com/opendilab/treevalue/blob/main/test/compare/deepmind/test_dm_tree.py)
* [Comparison with jax-libtree](https://github.com/opendilab/treevalue/blob/main/test/compare/jax/test_jax.py)
* [Comparison with tianshou Batch](https://github.com/opendilab/treevalue/blob/main/test/compare/tianshou/test_tianshou_batch.py)
## Contribution
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -9,23 +9,40 @@ _TREE_1 = FastTreeValue(_TREE_DATA_1)
@pytest.mark.benchmark(group='dm-tree')
class TestCompareWithDMTree:
def test_dm_flatten_with_path(self, benchmark):
benchmark(tree.flatten_with_path, _TREE_DATA_1)
N = 5
def test_dm_flatten_without_path(self, benchmark):
benchmark(tree.flatten, _TREE_DATA_1)
def __create_nested_tree_data(self, n):
return {
('no_%04d' % (i + 1,)): _TREE_DATA_1 for i in range(n)
}
def test_tv_flatten(self, benchmark):
benchmark(flatten, _TREE_1)
def __create_nested_tree(self, n):
return FastTreeValue(self.__create_nested_tree_data(n))
def test_dm_map_structure(self, benchmark):
benchmark(tree.map_structure, lambda x: x ** 2, _TREE_DATA_1)
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_dm_flatten_with_path(self, benchmark, n):
benchmark(tree.flatten_with_path, self.__create_nested_tree_data(n))
def test_tv_mapping(self, benchmark):
benchmark(mapping, _TREE_1, lambda x: x ** 2)
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_dm_flatten_without_path(self, benchmark, n):
benchmark(tree.flatten, self.__create_nested_tree_data(n))
def test_dm_map_structure_with_path(self, benchmark):
benchmark(tree.map_structure_with_path, lambda p, x: x * len(p), _TREE_DATA_1)
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_tv_flatten(self, benchmark, n):
benchmark(flatten, self.__create_nested_tree(n))
def test_tv_mapping_with_path(self, benchmark):
benchmark(mapping, _TREE_1, lambda x, p: x * len(p))
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_dm_map_structure(self, benchmark, n):
benchmark(tree.map_structure, lambda x: x ** 2, self.__create_nested_tree_data(n))
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_tv_mapping(self, benchmark, n):
benchmark(mapping, self.__create_nested_tree(n), lambda x: x ** 2)
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_dm_map_structure_with_path(self, benchmark, n):
benchmark(tree.map_structure_with_path, lambda p, x: x * len(p), self.__create_nested_tree_data(n))
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_tv_mapping_with_path(self, benchmark, n):
benchmark(mapping, self.__create_nested_tree(n), lambda x, p: x * len(p))
......@@ -11,20 +11,35 @@ _TREE_1 = FastTreeValue(_TREE_DATA_1)
@pytest.mark.benchmark(group='jax-pytree')
class TestCompareWithJaxPytree:
def test_jax_tree_map(self, benchmark):
benchmark(pytree.tree_map, lambda x: x ** 2, _TREE_DATA_1)
N = 5
def test_tv_mapping(self, benchmark):
benchmark(mapping, _TREE_1, lambda x: x ** 2)
def __create_nested_tree_data(self, n):
return {
('no_%04d' % (i + 1,)): _TREE_DATA_1 for i in range(n)
}
def __create_nested_tree(self, n):
return FastTreeValue(self.__create_nested_tree_data(n))
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_jax_tree_map(self, benchmark, n):
benchmark(pytree.tree_map, lambda x: x ** 2, self.__create_nested_tree_data(n))
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_tv_mapping(self, benchmark, n):
benchmark(mapping, self.__create_nested_tree(n), lambda x: x ** 2)
def test_tv_mapping_with_path(self, benchmark):
benchmark(mapping, _TREE_1, lambda x, p: x ** 2)
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_tv_mapping_with_path(self, benchmark, n):
benchmark(mapping, self.__create_nested_tree(n), lambda x, p: x ** 2)
def test_jax_tree_flatten(self, benchmark):
benchmark(pytree.tree_flatten, _TREE_DATA_1)
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_jax_tree_flatten(self, benchmark, n):
benchmark(pytree.tree_flatten, self.__create_nested_tree_data(n))
def test_tv_flatten(self, benchmark):
benchmark(flatten, _TREE_1)
@pytest.mark.parametrize('n', [2 ** i for i in range(N)])
def test_tv_flatten(self, benchmark, n):
benchmark(flatten, self.__create_nested_tree(n))
def test_jax_tree_unflatten(self, benchmark):
leaves, treedef = pytree.tree_flatten(_TREE_DATA_1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册