提交 de8ffe0c 编写于 作者: M Megvii Engine Team

refactor(imperative): unify interpreter option setting

GitOrigin-RevId: 53510445cc38866a4bae1059aec56e4c91e9ca4d
上级 8b60bdfa
...@@ -79,6 +79,7 @@ from .core._imperative_rt.core2 import close as _close ...@@ -79,6 +79,7 @@ from .core._imperative_rt.core2 import close as _close
from .core._imperative_rt.core2 import full_sync as sync from .core._imperative_rt.core2 import full_sync as sync
from .core._imperative_rt.core2 import sync as _sync from .core._imperative_rt.core2 import sync as _sync
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .config import *
from .device import * from .device import *
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save from .serialization import load, save
......
...@@ -9,12 +9,12 @@ ...@@ -9,12 +9,12 @@
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from ._imperative_rt.core2 import get_option, set_option
__compute_mode = "default" __compute_mode = "default"
__conv_format = "default" __conv_format = "default"
_benchmark_kernel = False _benchmark_kernel = False
_deterministic_kernel = False _deterministic_kernel = False
_async_level = os.getenv("MEGENGINE_INTERP_ASYNC_LEVEL", 2)
__all__ = [ __all__ = [
"benchmark_kernel", "benchmark_kernel",
...@@ -77,13 +77,13 @@ def async_level(mod) -> int: ...@@ -77,13 +77,13 @@ def async_level(mod) -> int:
import megengine as mge import megengine as mge
mge.config.async_level = 2 mge.config.async_level = 2
""" """
return _async_level return get_option("async_level")
@async_level.setter @async_level.setter
def async_level(mod, level: int): def async_level(mod, level: int):
global _async_level assert level >= 0 and level <= 2, "async_level should be 0, 1 or 2"
_async_level = level set_option("async_level", level)
@property @property
...@@ -148,7 +148,7 @@ def _reset_execution_config( ...@@ -148,7 +148,7 @@ def _reset_execution_config(
orig_flags = ( orig_flags = (
_benchmark_kernel, _benchmark_kernel,
_deterministic_kernel, _deterministic_kernel,
_async_level, get_option("async_level"),
__compute_mode, __compute_mode,
__conv_format, __conv_format,
) )
...@@ -157,7 +157,7 @@ def _reset_execution_config( ...@@ -157,7 +157,7 @@ def _reset_execution_config(
if deterministic_kernel is not None: if deterministic_kernel is not None:
_deterministic_kernel = deterministic_kernel _deterministic_kernel = deterministic_kernel
if async_level is not None: if async_level is not None:
_async_level = async_level set_option("async_level", async_level)
if compute_mode is not None: if compute_mode is not None:
__compute_mode = compute_mode __compute_mode = compute_mode
if conv_format is not None: if conv_format is not None:
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
import re import re
from typing import Union from typing import Union
from ..core import set_option as _set_option
from ..core._imperative_rt.core2 import clear_candidates as _clear_candidates from ..core._imperative_rt.core2 import clear_candidates as _clear_candidates
from ..core._imperative_rt.core2 import set_option as _set_option
_eviction_threshold = 0 _eviction_threshold = 0
_evictee_minimum_size = 1024 ** 2 _evictee_minimum_size = 1024 ** 2
......
...@@ -599,7 +599,7 @@ void init_tensor(py::module m) { ...@@ -599,7 +599,7 @@ void init_tensor(py::module m) {
auto val2 = py_async_error.py::object::operator()( auto val2 = py_async_error.py::object::operator()(
"An async error is reported. See above for the actual cause." "An async error is reported. See above for the actual cause."
" Hint: This is where it is reported, not where it happened." " Hint: This is where it is reported, not where it happened."
" You may call `megengine.core.set_option('async_level', 0)` " " You may call `megengine.config.async_level = 0 "
"to get better error reporting."); "to get better error reporting.");
PyException_SetCause( PyException_SetCause(
val2.ptr(), val); // PyException_SetCause steals reference val2.ptr(), val); // PyException_SetCause steals reference
...@@ -698,20 +698,12 @@ void init_tensor(py::module m) { ...@@ -698,20 +698,12 @@ void init_tensor(py::module m) {
py_task_q.wait_all_task_finish(); py_task_q.wait_all_task_finish();
}; };
m.def("clear_candidates", [channel]() { channel->clear_candidates(); });
m.def("set_option", [channel](std::string name, size_t value) { m.def("set_option", [channel](std::string name, size_t value) {
channel->set_option(name, value); channel->set_option(name, value);
}); });
m.def("clear_candidates", [channel]() { channel->clear_candidates(); });
m.def("get_option", m.def("get_option",
[channel](std::string name) { return channel->get_option(name); }); [channel](std::string name) { return channel->get_option(name); });
m.def("_set_drop_flag",
[channel](bool flag) { channel->set_option("enable_drop", flag); });
m.def("config_async_level", [channel](int level) {
mgb_assert(level >= 0 and level <= 2, "async_level should be 0, 1 or 2");
channel->set_option("async_level", level);
});
m.def("get_async_level",
[channel]() { return channel->get_option("async_level"); });
m.def("set_buffer_length", [channel](int length) { m.def("set_buffer_length", [channel](int length) {
mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)"); mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)");
channel->set_option("buffer_length", length); channel->set_option("buffer_length", length);
......
...@@ -14,7 +14,7 @@ import megengine as mge ...@@ -14,7 +14,7 @@ import megengine as mge
import megengine.autodiff as ad import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
from megengine import Tensor from megengine import Tensor
from megengine.core._imperative_rt.core2 import _set_drop_flag, get_option, set_option from megengine.core import get_option, set_option
from megengine.module import Linear, Module from megengine.module import Linear, Module
from megengine.optimizer import SGD from megengine.optimizer import SGD
...@@ -75,7 +75,7 @@ class XORNet(Module): ...@@ -75,7 +75,7 @@ class XORNet(Module):
def test_training_converge_with_drop(): def test_training_converge_with_drop():
_set_drop_flag(True) set_option("enable_drop", 1)
old_buffer_length = get_option("buffer_length") old_buffer_length = get_option("buffer_length")
set_option("buffer_length", 0) set_option("buffer_length", 0)
net = XORNet() net = XORNet()
...@@ -118,5 +118,5 @@ def test_training_converge_with_drop(): ...@@ -118,5 +118,5 @@ def test_training_converge_with_drop():
precision precision
) )
_set_drop_flag(False) set_option("enable_drop", 0)
set_option("buffer_length", old_buffer_length) set_option("buffer_length", old_buffer_length)
...@@ -6,23 +6,19 @@ import pytest ...@@ -6,23 +6,19 @@ import pytest
import megengine as mge import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine.core._imperative_rt.core2 import ( from megengine.core import set_option
AsyncError, from megengine.core._imperative_rt.core2 import AsyncError
_set_drop_flag,
config_async_level,
get_async_level,
)
def test_basic(): def test_basic():
config_async_level(2) mge.config.async_level = 2
assert get_async_level() == 2 assert mge.config.async_level == 2
with pytest.raises(RuntimeError): with pytest.raises(AssertionError):
config_async_level(3) mge.config.async_level = 3
def test_level1_infer_value(): def test_level1_infer_value():
config_async_level(1) mge.config.async_level = 1
a = mge.tensor([[1, 2], [2, 3], [3, 4]], dtype="float32") a = mge.tensor([[1, 2], [2, 3], [3, 4]], dtype="float32")
b = mge.tensor([1, 1], dtype="float32") b = mge.tensor([1, 1], dtype="float32")
identity = mge.tensor(np.array([[1, 0], [0, 1]]), dtype="float32") identity = mge.tensor(np.array([[1, 0], [0, 1]]), dtype="float32")
...@@ -30,11 +26,11 @@ def test_level1_infer_value(): ...@@ -30,11 +26,11 @@ def test_level1_infer_value():
c = F.matmul(b, identity) c = F.matmul(b, identity)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
d = F.reshape(a, c) d = F.reshape(a, c)
config_async_level(2) mge.config.async_level = 2
def test_level1_infer_shape_with_unknown(): def test_level1_infer_shape_with_unknown():
config_async_level(2) mge.config.async_level = 2
a = mge.tensor([[1, 2, 2, 3]], dtype="float32") a = mge.tensor([[1, 2, 2, 3]], dtype="float32")
b = mge.tensor([1, 1], dtype="float32") b = mge.tensor([1, 1], dtype="float32")
multi2 = mge.tensor(np.array([[2, 0], [0, 2]]), dtype="float32") multi2 = mge.tensor(np.array([[2, 0], [0, 2]]), dtype="float32")
...@@ -42,13 +38,13 @@ def test_level1_infer_shape_with_unknown(): ...@@ -42,13 +38,13 @@ def test_level1_infer_shape_with_unknown():
# make DepType::SHAPE unknown # make DepType::SHAPE unknown
d = F.reshape(a, c) d = F.reshape(a, c)
e = mge.tensor([[1, 2]], dtype="float32") e = mge.tensor([[1, 2]], dtype="float32")
config_async_level(1) mge.config.async_level = 1
# test src no shape, throw in level1 # test src no shape, throw in level1
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
f = F.reshape(d, b) f = F.reshape(d, b)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
g = F.matmul(d, e) g = F.matmul(d, e)
config_async_level(2) mge.config.async_level = 2
def test_host_compute_elemwise(): def test_host_compute_elemwise():
...@@ -61,7 +57,7 @@ def test_host_compute_elemwise(): ...@@ -61,7 +57,7 @@ def test_host_compute_elemwise():
def test_drop_basic(): def test_drop_basic():
_set_drop_flag(True) set_option("enable_drop", True)
# test xpu compute # test xpu compute
x = mge.tensor(np.ones((3, 3)), dtype=np.float32) x = mge.tensor(np.ones((3, 3)), dtype=np.float32)
y = mge.tensor(np.ones((3, 3)), dtype=np.float32) y = mge.tensor(np.ones((3, 3)), dtype=np.float32)
...@@ -74,7 +70,7 @@ def test_drop_basic(): ...@@ -74,7 +70,7 @@ def test_drop_basic():
z = x + y z = x + y
z._drop() z._drop()
z.numpy() z.numpy()
_set_drop_flag(False) set_option("enable_drop", False)
def test_finalize(): def test_finalize():
...@@ -107,21 +103,21 @@ def test_async_error_check(): ...@@ -107,21 +103,21 @@ def test_async_error_check():
# NOTE: DO NOT REMOVE THIS TEST # NOTE: DO NOT REMOVE THIS TEST
# This is also a compatibility test for # This is also a compatibility test for
# mge.core.set_option('async_level', 0). # mge.config.async_level = 0.
# If you change the canonical API to set async level, # If you change the canonical API to set async level,
# update the error message of AsyncError as well. # update the error message of AsyncError as well.
def test_async_error(): def test_async_error():
orig_lvl = mge.core.get_option("async_level") orig_lvl = mge.config.async_level
try: try:
mge.core.set_option("async_level", 1) mge.config.async_level = 1
x = F.utils._simulate_error() x = F.utils._simulate_error()
try: try:
x.numpy() x.numpy()
except AsyncError as e: except AsyncError as e:
assert isinstance(e.__cause__, RuntimeError) assert isinstance(e.__cause__, RuntimeError)
mge.core.set_option("async_level", 0) mge.config.async_level = 0
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
F.utils._simulate_error() F.utils._simulate_error()
finally: finally:
mge.core.set_option("async_level", orig_lvl) mge.config.async_level = orig_lvl
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册