diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 3bc859aac746ee74701165931d60fba56ebeb1a7..03d5b0c70706c60603189d4630bb024787493555 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -557,7 +557,14 @@ void init_ops(py::module m) { m.def( "delete_rng_handle", [](size_t handle) { + if (mgb::imperative::python::interpreter_for_py->check_available()) { + mgb::imperative::python::interpreter_for_py->sync(); + } mgb::CompNode::sync_all(); + mgb::CompNode::foreach ([](mgb::CompNode cn) { + auto err = cn.check_async_error(); + mgb_assert(!err, "%s", err->what()); + }); py_task_q.wait_all_task_finish(); rng::delete_handle(handle); }, diff --git a/imperative/python/test/conftest.py b/imperative/python/test/conftest.py index 4424d5b3c7297ca15dfb3d5707be7ad9eacf7e6b..465ede6bcf9f808a181df5ef5e2d97a4eba89212 100644 --- a/imperative/python/test/conftest.py +++ b/imperative/python/test/conftest.py @@ -11,13 +11,17 @@ import sys import pytest -import megengine.functional -import megengine.module -from megengine import Parameter -from megengine.core._imperative_rt.core2 import sync +from megengine.core import _config as config +from megengine.core import _trace_option as trace_option +from megengine.core import get_option +from megengine.core._imperative_rt.core2 import ( + _get_amp_dtype_autocast, + _get_amp_high_prec_dtype, + _get_amp_low_prec_dtype, + _get_convert_inputs, +) +from megengine.core.tensor import amp from megengine.device import get_device_count -from megengine.jit import trace as _trace -from megengine.module import Linear, Module sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) @@ -41,3 +45,58 @@ def skip_distributed(request): platform.system() ) ) + + +@pytest.fixture(autouse=True) +def run_around_tests(): + env_vars1 = { + "symbolic_shape": trace_option.use_symbolic_shape(), + "async_level": get_option("async_level"), + "enable_drop": get_option("enable_drop"), + "max_recompute_time": get_option("max_recompute_time"), + "catch_worker_execption": get_option("catch_worker_execption"), + "enable_host_compute": get_option("enable_host_compute"), + # "record_computing_path": get_option("record_computing_path"), + "disable_memory_forwarding": get_option("disable_memory_forwarding"), + "enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"), + "enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"), + "dtr_eviction_threshold": get_option("dtr_eviction_threshold"), + "dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"), + "benchmark_kernel": config.benchmark_kernel, + "deterministic_kernel": config.deterministic_kernel, + "compute_mode": config._compute_mode, + "conv_format": config._conv_format, + "amp_enabled": amp.enabled, + "convert_inputs": _get_convert_inputs(), + "amp_dtype_autocast": _get_amp_dtype_autocast(), + "amp_high_prec_dtype": _get_amp_high_prec_dtype(), + "amp_low_prec_dtype": _get_amp_low_prec_dtype(), + } + yield + env_vars2 = { + "symbolic_shape": trace_option.use_symbolic_shape(), + "async_level": get_option("async_level"), + "enable_drop": get_option("enable_drop"), + "max_recompute_time": get_option("max_recompute_time"), + "catch_worker_execption": get_option("catch_worker_execption"), + "enable_host_compute": get_option("enable_host_compute"), + # "record_computing_path": get_option("record_computing_path"), + "disable_memory_forwarding": get_option("disable_memory_forwarding"), + "enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"), + "enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"), + "dtr_eviction_threshold": get_option("dtr_eviction_threshold"), + "dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"), + "benchmark_kernel": config.benchmark_kernel, + "deterministic_kernel": config.deterministic_kernel, + "compute_mode": config._compute_mode, + "conv_format": config._conv_format, + "amp_enabled": amp.enabled, + "convert_inputs": _get_convert_inputs(), + "amp_dtype_autocast": _get_amp_dtype_autocast(), + "amp_high_prec_dtype": _get_amp_high_prec_dtype(), + "amp_low_prec_dtype": _get_amp_low_prec_dtype(), + } + for key in env_vars1: + assert ( + env_vars1[key] == env_vars2[key] + ), "{} have been changed after test".format(key) diff --git a/imperative/python/test/run.sh b/imperative/python/test/run.sh index bd42e951ac5157b09025e9e9d4a43182ca04c943..5b052290659222daa426c9a1430bdff8d3e867d6 100755 --- a/imperative/python/test/run.sh +++ b/imperative/python/test/run.sh @@ -37,7 +37,7 @@ if [[ "$TEST_PLAT" =~ "local" ]]; then PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'not isolated_distributed' if [[ "$TEST_PLAT" =~ "cuda" ]]; then echo "test GPU pytest now" - PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'isolated_distributed' + PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -s -v $test_dirs -m 'isolated_distributed' --ignore=./integration/test_dtr.py fi else cd $(dirname "${BASH_SOURCE[0]}")/.. diff --git a/imperative/python/test/unit/random/test_rng.py b/imperative/python/test/unit/random/test_rng.py index 3dd89b20d49ff0fae7d2eb30618a4296c403e6e9..504631c1c47f7b3341586b018080c621090fe74f 100644 --- a/imperative/python/test/unit/random/test_rng.py +++ b/imperative/python/test/unit/random/test_rng.py @@ -39,8 +39,6 @@ from megengine.random import uniform get_device_count("xpu") <= 2, reason="xpu counts need > 2", ) def test_gaussian_op(): - # FIXME: remove this sync - mge.core.set_option("async_level", 0) set_global_seed(1024) shape = ( 8, @@ -516,4 +514,3 @@ def test_rng_empty_tensor(is_symbolic): np.testing.assert_equal(out.numpy().shape, (0,)) if is_symbolic is None: break - mge.core.set_option("async_level", 2) diff --git a/imperative/src/impl/ops/reduce.cpp b/imperative/src/impl/ops/reduce.cpp index ffd2400e0f72e9ca5573fa7ccc392ac28655a5c8..bc94769299d7afd1899bad7ca548497090253f54 100644 --- a/imperative/src/impl/ops/reduce.cpp +++ b/imperative/src/impl/ops/reduce.cpp @@ -175,10 +175,9 @@ SmallVector apply_on_physical_tensor( megdnn::Workspace dnn_wk; auto wk_size = dnn_op.op->get_workspace_in_bytes(src, layout); - if (wk_size != 0) { - auto wk = Blob::make(comp_node, wk_size); - dnn_wk.raw_ptr = wk->storage().get(); - dnn_wk.size = wk_size; + if (wk_size) { + TensorLayout w_layout({wk_size}, dtype::Byte()); + dnn_wk = dnn_op.create_workspace(w_layout); } DeviceTensorND out =