提交 83aff31b 编写于 作者: HansBug's avatar HansBug 😆

test(hansbug): add compare with facebook nest library

上级 6f2f5b2d
......@@ -42,6 +42,7 @@ jobs:
pip install -r requirements.txt
pip install -r requirements-build.txt
pip install -r requirements-test.txt
./install_test.sh
- name: Test the basic environment
run: |
python -V
......@@ -103,6 +104,7 @@ jobs:
pip install -r requirements.txt
pip install -r requirements-build.txt
pip install -r requirements-test.txt
./install_test.sh
- name: Test the basic environment
run: |
python -V
......
......@@ -1197,6 +1197,7 @@ fabric.properties
.python-version
/docs/build
/public
/.installs
/docs/source/**/*.puml.svg
/docs/source/**/*.puml.png
/docs/source/**/*.gv.svg
......@@ -1208,4 +1209,4 @@ fabric.properties
/docs/source/**/*.sh.err
/docs/source/**/*.sh.exitcode
/docs/source/**/*.dat.*
!/docs/source/_static/**/*
\ No newline at end of file
!/docs/source/_static/**/*
mkdir -p .installs
git clone --depth=1 https://github.com/facebookresearch/torchbeast.git .installs/torchbeast
cd .installs/torchbeast/nest
CXX=c++ pip install . -vv
cd ../../..
try:
import nest
except ImportError:
nest = None
import pytest
from treevalue import FastTreeValue, flatten, mapping, func_treelize, unflatten, union
_TREE_DATA_1 = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}
_TREE_1 = FastTreeValue(_TREE_DATA_1)
_UMARK = pytest.mark.benchmark(group='facebook-nest') if nest is not None else pytest.mark.ignore
@_UMARK
class TestCompareFacebookNest:
def test_nest_flatten(self, benchmark):
benchmark(nest.flatten, _TREE_DATA_1)
def test_tv_flatten(self, benchmark):
benchmark(flatten, _TREE_1)
def test_nest_pack_as(self, benchmark):
benchmark(nest.pack_as, _TREE_DATA_1, nest.flatten(_TREE_DATA_1))
def test_tv_unflatten(self, benchmark):
benchmark(unflatten, flatten(_TREE_1))
def test_nest_map(self, benchmark):
benchmark(nest.map, lambda x: x ** 2, _TREE_DATA_1)
def test_tv_map(self, benchmark):
benchmark(mapping, _TREE_1, lambda x: x ** 2)
def test_nest_map_many2(self, benchmark):
def f(a, b):
return a ** b + a * b
benchmark(nest.map_many2, f, _TREE_DATA_1, _TREE_DATA_1)
def test_nest_map_many(self, benchmark):
def f(a):
return a[0] ** a[1] + a[0] * a[1]
benchmark(nest.map_many, f, _TREE_DATA_1, _TREE_DATA_1)
def test_tv_treelize_call(self, benchmark):
@func_treelize()
def f(a, b):
return a ** b + a * b
benchmark(f, _TREE_1, _TREE_1)
def test_tv_mapping_union(self, benchmark):
def f(a):
return a[0] ** a[1]
def _my_func(fx, *v):
return mapping(union(*v), fx)
benchmark(_my_func, f, _TREE_1, _TREE_1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册