提交 de8ffe0c 编写于 作者: M Megvii Engine Team

refactor(imperative): unify interpreter option setting

GitOrigin-RevId: 53510445cc38866a4bae1059aec56e4c91e9ca4d
上级 8b60bdfa
......@@ -79,6 +79,7 @@ from .core._imperative_rt.core2 import close as _close
from .core._imperative_rt.core2 import full_sync as sync
from .core._imperative_rt.core2 import sync as _sync
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .config import *
from .device import *
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save
......
......@@ -9,12 +9,12 @@
import os
from contextlib import contextmanager
from ._imperative_rt.core2 import get_option, set_option
__compute_mode = "default"
__conv_format = "default"
_benchmark_kernel = False
_deterministic_kernel = False
_async_level = os.getenv("MEGENGINE_INTERP_ASYNC_LEVEL", 2)
__all__ = [
"benchmark_kernel",
......@@ -77,13 +77,13 @@ def async_level(mod) -> int:
import megengine as mge
mge.config.async_level = 2
"""
return _async_level
return get_option("async_level")
@async_level.setter
def async_level(mod, level: int):
global _async_level
_async_level = level
assert level >= 0 and level <= 2, "async_level should be 0, 1 or 2"
set_option("async_level", level)
@property
......@@ -148,7 +148,7 @@ def _reset_execution_config(
orig_flags = (
_benchmark_kernel,
_deterministic_kernel,
_async_level,
get_option("async_level"),
__compute_mode,
__conv_format,
)
......@@ -157,7 +157,7 @@ def _reset_execution_config(
if deterministic_kernel is not None:
_deterministic_kernel = deterministic_kernel
if async_level is not None:
_async_level = async_level
set_option("async_level", async_level)
if compute_mode is not None:
__compute_mode = compute_mode
if conv_format is not None:
......
......@@ -9,8 +9,8 @@
import re
from typing import Union
from ..core import set_option as _set_option
from ..core._imperative_rt.core2 import clear_candidates as _clear_candidates
from ..core._imperative_rt.core2 import set_option as _set_option
_eviction_threshold = 0
_evictee_minimum_size = 1024 ** 2
......
......@@ -599,7 +599,7 @@ void init_tensor(py::module m) {
auto val2 = py_async_error.py::object::operator()(
"An async error is reported. See above for the actual cause."
" Hint: This is where it is reported, not where it happened."
" You may call `megengine.core.set_option('async_level', 0)` "
" You may call `megengine.config.async_level = 0 "
"to get better error reporting.");
PyException_SetCause(
val2.ptr(), val); // PyException_SetCause steals reference
......@@ -698,20 +698,12 @@ void init_tensor(py::module m) {
py_task_q.wait_all_task_finish();
};
m.def("clear_candidates", [channel]() { channel->clear_candidates(); });
m.def("set_option", [channel](std::string name, size_t value) {
channel->set_option(name, value);
});
m.def("clear_candidates", [channel]() { channel->clear_candidates(); });
m.def("get_option",
[channel](std::string name) { return channel->get_option(name); });
m.def("_set_drop_flag",
[channel](bool flag) { channel->set_option("enable_drop", flag); });
m.def("config_async_level", [channel](int level) {
mgb_assert(level >= 0 and level <= 2, "async_level should be 0, 1 or 2");
channel->set_option("async_level", level);
});
m.def("get_async_level",
[channel]() { return channel->get_option("async_level"); });
m.def("set_buffer_length", [channel](int length) {
mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)");
channel->set_option("buffer_length", length);
......
......@@ -14,7 +14,7 @@ import megengine as mge
import megengine.autodiff as ad
import megengine.functional as F
from megengine import Tensor
from megengine.core._imperative_rt.core2 import _set_drop_flag, get_option, set_option
from megengine.core import get_option, set_option
from megengine.module import Linear, Module
from megengine.optimizer import SGD
......@@ -75,7 +75,7 @@ class XORNet(Module):
def test_training_converge_with_drop():
_set_drop_flag(True)
set_option("enable_drop", 1)
old_buffer_length = get_option("buffer_length")
set_option("buffer_length", 0)
net = XORNet()
......@@ -118,5 +118,5 @@ def test_training_converge_with_drop():
precision
)
_set_drop_flag(False)
set_option("enable_drop", 0)
set_option("buffer_length", old_buffer_length)
......@@ -6,23 +6,19 @@ import pytest
import megengine as mge
import megengine.functional as F
from megengine.core._imperative_rt.core2 import (
AsyncError,
_set_drop_flag,
config_async_level,
get_async_level,
)
from megengine.core import set_option
from megengine.core._imperative_rt.core2 import AsyncError
def test_basic():
config_async_level(2)
assert get_async_level() == 2
with pytest.raises(RuntimeError):
config_async_level(3)
mge.config.async_level = 2
assert mge.config.async_level == 2
with pytest.raises(AssertionError):
mge.config.async_level = 3
def test_level1_infer_value():
config_async_level(1)
mge.config.async_level = 1
a = mge.tensor([[1, 2], [2, 3], [3, 4]], dtype="float32")
b = mge.tensor([1, 1], dtype="float32")
identity = mge.tensor(np.array([[1, 0], [0, 1]]), dtype="float32")
......@@ -30,11 +26,11 @@ def test_level1_infer_value():
c = F.matmul(b, identity)
with pytest.raises(RuntimeError):
d = F.reshape(a, c)
config_async_level(2)
mge.config.async_level = 2
def test_level1_infer_shape_with_unknown():
config_async_level(2)
mge.config.async_level = 2
a = mge.tensor([[1, 2, 2, 3]], dtype="float32")
b = mge.tensor([1, 1], dtype="float32")
multi2 = mge.tensor(np.array([[2, 0], [0, 2]]), dtype="float32")
......@@ -42,13 +38,13 @@ def test_level1_infer_shape_with_unknown():
# make DepType::SHAPE unknown
d = F.reshape(a, c)
e = mge.tensor([[1, 2]], dtype="float32")
config_async_level(1)
mge.config.async_level = 1
# test src no shape, throw in level1
with pytest.raises(RuntimeError):
f = F.reshape(d, b)
with pytest.raises(RuntimeError):
g = F.matmul(d, e)
config_async_level(2)
mge.config.async_level = 2
def test_host_compute_elemwise():
......@@ -61,7 +57,7 @@ def test_host_compute_elemwise():
def test_drop_basic():
_set_drop_flag(True)
set_option("enable_drop", True)
# test xpu compute
x = mge.tensor(np.ones((3, 3)), dtype=np.float32)
y = mge.tensor(np.ones((3, 3)), dtype=np.float32)
......@@ -74,7 +70,7 @@ def test_drop_basic():
z = x + y
z._drop()
z.numpy()
_set_drop_flag(False)
set_option("enable_drop", False)
def test_finalize():
......@@ -107,21 +103,21 @@ def test_async_error_check():
# NOTE: DO NOT REMOVE THIS TEST
# This is also a compatibility test for
# mge.core.set_option('async_level', 0).
# mge.config.async_level = 0.
# If you change the canonical API to set async level,
# update the error message of AsyncError as well.
def test_async_error():
orig_lvl = mge.core.get_option("async_level")
orig_lvl = mge.config.async_level
try:
mge.core.set_option("async_level", 1)
mge.config.async_level = 1
x = F.utils._simulate_error()
try:
x.numpy()
except AsyncError as e:
assert isinstance(e.__cause__, RuntimeError)
mge.core.set_option("async_level", 0)
mge.config.async_level = 0
with pytest.raises(RuntimeError):
F.utils._simulate_error()
finally:
mge.core.set_option("async_level", orig_lvl)
mge.config.async_level = orig_lvl
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册