diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 3bc859aac746ee74701165931d60fba56ebeb1a7..03d5b0c70706c60603189d4630bb024787493555 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -557,7 +557,14 @@ void init_ops(py::module m) { m.def( "delete_rng_handle", [](size_t handle) { + if (mgb::imperative::python::interpreter_for_py->check_available()) { + mgb::imperative::python::interpreter_for_py->sync(); + } mgb::CompNode::sync_all(); + mgb::CompNode::foreach ([](mgb::CompNode cn) { + auto err = cn.check_async_error(); + mgb_assert(!err, "%s", err->what()); + }); py_task_q.wait_all_task_finish(); rng::delete_handle(handle); }, diff --git a/imperative/python/test/conftest.py b/imperative/python/test/conftest.py index 4424d5b3c7297ca15dfb3d5707be7ad9eacf7e6b..465ede6bcf9f808a181df5ef5e2d97a4eba89212 100644 --- a/imperative/python/test/conftest.py +++ b/imperative/python/test/conftest.py @@ -11,13 +11,17 @@ import sys import pytest -import megengine.functional -import megengine.module -from megengine import Parameter -from megengine.core._imperative_rt.core2 import sync +from megengine.core import _config as config +from megengine.core import _trace_option as trace_option +from megengine.core import get_option +from megengine.core._imperative_rt.core2 import ( + _get_amp_dtype_autocast, + _get_amp_high_prec_dtype, + _get_amp_low_prec_dtype, + _get_convert_inputs, +) +from megengine.core.tensor import amp from megengine.device import get_device_count -from megengine.jit import trace as _trace -from megengine.module import Linear, Module sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) @@ -41,3 +45,58 @@ def skip_distributed(request): platform.system() ) ) + + +@pytest.fixture(autouse=True) +def run_around_tests(): + env_vars1 = { + "symbolic_shape": trace_option.use_symbolic_shape(), + "async_level": get_option("async_level"), + "enable_drop": get_option("enable_drop"), + "max_recompute_time": get_option("max_recompute_time"), + "catch_worker_execption": get_option("catch_worker_execption"), + "enable_host_compute": get_option("enable_host_compute"), + # "record_computing_path": get_option("record_computing_path"), + "disable_memory_forwarding": get_option("disable_memory_forwarding"), + "enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"), + "enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"), + "dtr_eviction_threshold": get_option("dtr_eviction_threshold"), + "dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"), + "benchmark_kernel": config.benchmark_kernel, + "deterministic_kernel": config.deterministic_kernel, + "compute_mode": config._compute_mode, + "conv_format": config._conv_format, + "amp_enabled": amp.enabled, + "convert_inputs": _get_convert_inputs(), + "amp_dtype_autocast": _get_amp_dtype_autocast(), + "amp_high_prec_dtype": _get_amp_high_prec_dtype(), + "amp_low_prec_dtype": _get_amp_low_prec_dtype(), + } + yield + env_vars2 = { + "symbolic_shape": trace_option.use_symbolic_shape(), + "async_level": get_option("async_level"), + "enable_drop": get_option("enable_drop"), + "max_recompute_time": get_option("max_recompute_time"), + "catch_worker_execption": get_option("catch_worker_execption"), + "enable_host_compute": get_option("enable_host_compute"), + # "record_computing_path": get_option("record_computing_path"), + "disable_memory_forwarding": get_option("disable_memory_forwarding"), + "enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"), + "enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"), + "dtr_eviction_threshold": get_option("dtr_eviction_threshold"), + "dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"), + "benchmark_kernel": config.benchmark_kernel, + "deterministic_kernel": config.deterministic_kernel, + "compute_mode": config._compute_mode, + "conv_format": config._conv_format, + "amp_enabled": amp.enabled, + "convert_inputs": _get_convert_inputs(), + "amp_dtype_autocast": _get_amp_dtype_autocast(), + "amp_high_prec_dtype": _get_amp_high_prec_dtype(), + "amp_low_prec_dtype": _get_amp_low_prec_dtype(), + } + for key in env_vars1: + assert ( + env_vars1[key] == env_vars2[key] + ), "{} have been changed after test".format(key) diff --git a/imperative/python/test/run.sh b/imperative/python/test/run.sh index bd42e951ac5157b09025e9e9d4a43182ca04c943..5b052290659222daa426c9a1430bdff8d3e867d6 100755 --- a/imperative/python/test/run.sh +++ b/imperative/python/test/run.sh @@ -37,7 +37,7 @@ if [[ "$TEST_PLAT" =~ "local" ]]; then PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'not isolated_distributed' if [[ "$TEST_PLAT" =~ "cuda" ]]; then echo "test GPU pytest now" - PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'isolated_distributed' + PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'isolated_distributed' --ignore=./integration/test_dtr.py fi else cd $(dirname "${BASH_SOURCE[0]}")/.. diff --git a/imperative/python/test/unit/random/test_rng.py b/imperative/python/test/unit/random/test_rng.py index 3dd89b20d49ff0fae7d2eb30618a4296c403e6e9..504631c1c47f7b3341586b018080c621090fe74f 100644 --- a/imperative/python/test/unit/random/test_rng.py +++ b/imperative/python/test/unit/random/test_rng.py @@ -39,8 +39,6 @@ from megengine.random import uniform get_device_count("xpu") <= 2, reason="xpu counts need > 2", ) def test_gaussian_op(): - # FIXME: remove this sync - mge.core.set_option("async_level", 0) set_global_seed(1024) shape = ( 8, @@ -516,4 +514,3 @@ def test_rng_empty_tensor(is_symbolic): np.testing.assert_equal(out.numpy().shape, (0,)) if is_symbolic is None: break - mge.core.set_option("async_level", 2) diff --git a/imperative/src/impl/ops/matmul.cpp b/imperative/src/impl/ops/matmul.cpp index 634fee2b6002ee0d05ba2c669968a4bb2738d009..145dabfd1c9bd5ce3911c811771ee0f0603303a7 100644 --- a/imperative/src/impl/ops/matmul.cpp +++ b/imperative/src/impl/ops/matmul.cpp @@ -227,6 +227,11 @@ SmallVector apply_on_physical_tensor( TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, dst_dtype); dst_layout.init_contiguous_stride(); + if (matmul.transposeA) + std::swap(layout_a.shape[0], layout_a.shape[1]); + if (matmul.transposeB) + std::swap(layout_b.shape[0], layout_b.shape[1]); + DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); size_t sz = setup_algo( diff --git a/imperative/src/impl/ops/pooling.cpp b/imperative/src/impl/ops/pooling.cpp index 98958da78cbc66dfc548e91845ac5170293036c4..c7b4c59eac9c5e2f9b5cae2e85d7ab4ff22a5475 100644 --- a/imperative/src/impl/ops/pooling.cpp +++ b/imperative/src/impl/ops/pooling.cpp @@ -80,13 +80,12 @@ SmallVector apply_on_physical_tensor( op_def.policy(), false); megdnn::Workspace dnn_wk; - if (wk_size != 0) { - auto wk = Blob::make(cn, wk_size); - dnn_wk.raw_ptr = wk->storage().get(); - dnn_wk.size = wk_size; + if (wk_size) { + TensorLayout w_layout({wk_size}, dtype::Byte()); + dnn_wk = caller.create_workspace(w_layout); } - dnn_opr->exec(inp_tensornd, out_devtensor.as_megdnn(), {}); + dnn_opr->exec(inp_tensornd, out_devtensor.as_megdnn(), dnn_wk); return {Tensor::make(out_devtensor)}; } diff --git a/imperative/src/impl/ops/reduce.cpp b/imperative/src/impl/ops/reduce.cpp index 35a49ef4de801623df13428049ee21d35a072ae5..a6f334f20246b1f3375f09a9afb2409eafcb088c 100644 --- a/imperative/src/impl/ops/reduce.cpp +++ b/imperative/src/impl/ops/reduce.cpp @@ -174,10 +174,9 @@ SmallVector apply_on_physical_tensor( megdnn::Workspace dnn_wk; auto wk_size = dnn_op.op->get_workspace_in_bytes(src, layout); - if (wk_size != 0) { - auto wk = Blob::make(comp_node, wk_size); - dnn_wk.raw_ptr = wk->storage().get(); - dnn_wk.size = wk_size; + if (wk_size) { + TensorLayout w_layout({wk_size}, dtype::Byte()); + dnn_wk = dnn_op.create_workspace(w_layout); } DeviceTensorND out = diff --git a/src/core/include/megbrain/version.h b/src/core/include/megbrain/version.h index 0a5f4a839377a4fb1118f21d11560bf5921aaedb..8dc7645c952f427600192894ce655f733147034d 100644 --- a/src/core/include/megbrain/version.h +++ b/src/core/include/megbrain/version.h @@ -14,7 +14,7 @@ #include "megbrain_build_config.h" #define MGE_MAJOR 1 -#define MGE_MINOR 8 +#define MGE_MINOR 9999 #define MGE_PATCH 0 // for rc version, could be like "rc1", "rc2", etc