未验证 提交 8a78ab28 编写于 作者: HansBug's avatar HansBug 😆 提交者: GitHub

Merge pull request #20 from opendilab/dev/compare

dev(hansbug): add comparison benchmark to google jax
......@@ -148,14 +148,29 @@ For more quick start explanation and further usage, take a look at:
## Speed Performance
Here is the speed performance of all the operations in `FastTreeValue`
Here is the speed performance of all the operations in `FastTreeValue`, the following table is the performance comparison result with [dm-tree](https://github.com/deepmind/tree).
| | flatten | flatten(with path) | map | map(with path) |
| --------------------------------------------------- | :--------------: | :------------------: | :-------------------: | :-----------------: |
| [treevalue](https://github.com/opendilab/treevalue) | --- | **511 ns ± 6.92 ns** | **3.16 µs ± 42.8 ns** | **1.58 µs ± 30 ns** |
| [dm-tree](https://github.com/deepmind/tree) | 830 ns ± 8.53 ns | 11.9 µs ± 358 ns | 13.3 µs ± 87.2 ns | 62.9 µs ± 2.26 µs |
| | flatten | flatten(with path) | mapping | mapping(with path) |
| --------------------------------------------------- | :--------------: | :-------------------: | :-------------------: | :-------------------------: |
| [treevalue](https://github.com/opendilab/treevalue) | --- | **511 ns ± 6.92 ns** | **3.16 µs ± 42.8 ns** | **1.58 µs ± 30 ns** |
| | **flatten** | **flatten_with_path** | **map_structure** | **map_structure_with_path** |
| [dm-tree](https://github.com/deepmind/tree) | 830 ns ± 8.53 ns | 11.9 µs ± 358 ns | 13.3 µs ± 87.2 ns | 62.9 µs ± 2.26 µs |
The following 2 tables are the performance comparison result with [jax pytree](https://github.com/google/jax).
| | mapping | mapping(with path) | flatten | unflatten | flatten_values | flatten_keys |
| --------------------------------------------------- | :------------------: | :-------------------: | :------------------: | :------------------: | :------------------: | :------------------: |
| [treevalue](https://github.com/opendilab/treevalue) | **3.2 µs ± 85.5 ns** | **2.16 µs ± 123 ns** | **515 ns ± 7.53 ns** | **601 ns ± 5.99 ns** | **301 ns ± 12.9 ns** | **451 ns ± 17.3 ns** |
| | **tree_map** | **(Not Implemented)** | **tree_flatten** | **tree_unflatten** | **tree_leaves** | **tree_structure** |
| [jax pytree](https://github.com/google/jax) | 4.67 µs ± 184 ns | --- | 1.29 µs ± 27.2 ns | 742 ns ± 5.82 ns | 1.29 µs ± 22 ns | 1.27 µs ± 16.5 ns |
| | flatten + all | flatten + reduce | flatten + reduce(with init) | rise(given structure) | rise(automatic structure) |
| --------------------------------------------------- | :------------------: | :------------------: | :-------------------------: | :-------------------: | :-----------------------: |
| [treevalue](https://github.com/opendilab/treevalue) | **425 ns ± 9.33 ns** | **702 ns ± 5.93 ns** | **793 ns ± 13.4 ns** | **9.14 µs ± 129 ns** | **11.5 µs ± 182 ns** |
| | **tree_all** | **tree_reduce** | **tree_reduce(with init)** | **tree_transpose** | **(Not Implemented)** |
| [jax pytree](https://github.com/google/jax) | 1.47 µs ± 37 ns | 1.88 µs ± 27.2 ns | 1.91 µs ± 47.4 ns | 10 µs ± 117 ns | --- |
The following table is the performance comparison result with [tianshou Batch](https://github.com/thu-ml/tianshou).
| | get | set | init | deepcopy | stack | cat | split |
| ---------------------------------------------------- | :--------------------: | :--------------------: | :------------------: | :------------------: | :------------------: | :-------------------: | :----------------: |
......
......@@ -15,3 +15,4 @@ easydict>=1.7,<2
testtools>=2
dm-tree>=0.1.6
tianshou>=0.4.5
jax[cpu]>=0.2.17
\ No newline at end of file
from functools import reduce
import jax.tree_util as pytree
import pytest
from treevalue import FastTreeValue, mapping, flatten, unflatten, flatten_values, flatten_keys, subside, rise, raw
_TREE_DATA_1 = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}
_TREE_1 = FastTreeValue(_TREE_DATA_1)
@pytest.mark.benchmark(group='jax-pytree')
class TestCompareWithJaxPytree:
def test_jax_tree_map(self, benchmark):
benchmark(pytree.tree_map, lambda x: x ** 2, _TREE_DATA_1)
def test_tv_mapping(self, benchmark):
benchmark(mapping, _TREE_1, lambda x: x ** 2)
def test_tv_mapping_with_path(self, benchmark):
benchmark(mapping, _TREE_1, lambda x, p: x ** 2)
def test_jax_tree_flatten(self, benchmark):
benchmark(pytree.tree_flatten, _TREE_DATA_1)
def test_tv_flatten(self, benchmark):
benchmark(flatten, _TREE_1)
def test_jax_tree_unflatten(self, benchmark):
leaves, treedef = pytree.tree_flatten(_TREE_DATA_1)
benchmark(pytree.tree_unflatten, treedef, leaves)
def test_tv_unflatten(self, benchmark):
flatted = flatten(_TREE_1)
benchmark(unflatten, flatted)
def test_jax_tree_all(self, benchmark):
benchmark(pytree.tree_all, _TREE_DATA_1)
def test_tv_flatten_all(self, benchmark):
def _flatten_all(tree):
return all(flatten_values(tree))
benchmark(_flatten_all, _TREE_1)
def test_jax_tree_leaves(self, benchmark):
benchmark(pytree.tree_leaves, _TREE_DATA_1)
def test_tv_flatten_values(self, benchmark):
benchmark(flatten_values, _TREE_1)
def test_jax_tree_reduce(self, benchmark):
benchmark(pytree.tree_reduce, lambda x, y: x + y, _TREE_DATA_1)
def test_jax_tree_reduce_with_init(self, benchmark):
benchmark(pytree.tree_reduce, lambda x, y: x + y, _TREE_DATA_1, 0)
def test_tv_flatten_reduce(self, benchmark):
def _flatten_reduce(tree):
values = flatten_values(tree)
return reduce(lambda x, y: x + y, values)
return benchmark(_flatten_reduce, _TREE_1)
def test_tv_flatten_reduce_with_init(self, benchmark):
def _flatten_reduce_with_init(tree):
values = flatten_values(tree)
return reduce(lambda x, y: x + y, values, 0)
return benchmark(_flatten_reduce_with_init, _TREE_1)
def test_jax_tree_structure(self, benchmark):
benchmark(pytree.tree_structure, _TREE_DATA_1)
def test_tv_flatten_keys(self, benchmark):
benchmark(flatten_keys, _TREE_1)
def test_jax_tree_transpose(self, benchmark):
sto = pytree.tree_structure({'a': 1, 'b': 2, 'c': {'x': 3, 'y': 4}})
sti = pytree.tree_structure({'a': 1, 'b': {'x': 2, 'y': 3}})
value = (
{'a': 1, 'b': {'x': 2, 'y': 3}},
{
'a': {'a': 10, 'b': {'x': 20, 'y': 30}},
'b': [
{'a': 100, 'b': {'x': 200, 'y': 300}},
{'a': 400, 'b': {'x': 500, 'y': 600}},
],
}
)
benchmark(pytree.tree_transpose, sto, sti, value)
def test_tv_subside(self, benchmark):
value = {
'a': FastTreeValue({'a': 1, 'b': {'x': 2, 'y': 3}}),
'b': FastTreeValue({'a': 10, 'b': {'x': 20, 'y': 30}}),
'c': {
'x': FastTreeValue({'a': 100, 'b': {'x': 200, 'y': 300}}),
'y': FastTreeValue({'a': 400, 'b': {'x': 500, 'y': 600}}),
},
}
benchmark(subside, value)
def test_tv_rise(self, benchmark):
value = FastTreeValue({
'a': raw({'a': 1, 'b': {'x': 2, 'y': 3}}),
'b': raw({'a': 10, 'b': {'x': 20, 'y': 30}}),
'c': {
'x': raw({'a': 100, 'b': {'x': 200, 'y': 300}}),
'y': raw({'a': 400, 'b': {'x': 500, 'y': 600}}),
},
})
benchmark(rise, value)
def test_tv_rise_with_template(self, benchmark):
value = FastTreeValue({
'a': raw({'a': 1, 'b': {'x': 2, 'y': 3}}),
'b': raw({'a': 10, 'b': {'x': 20, 'y': 30}}),
'c': {
'x': raw({'a': 100, 'b': {'x': 200, 'y': 300}}),
'y': raw({'a': 400, 'b': {'x': 500, 'y': 600}}),
},
})
benchmark(rise, value, template={'a': None, 'b': {'x': None, 'y': None}})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册