提交 15e8e7df 编写于 作者: M Megvii Engine Team

refactor(mge/imperative): migrate to new core implementation

* swap/drop
* config/set_async_level
* _dev_tensor
* sync

GitOrigin-RevId: 850fb988529b0b15a47e1bbacf272ea2b011c784
上级 34c705fc
...@@ -72,7 +72,7 @@ if sys.platform == "win32": ...@@ -72,7 +72,7 @@ if sys.platform == "win32":
kernel32.SetErrorMode(old_error_mode) kernel32.SetErrorMode(old_error_mode)
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .core._imperative_rt.imperative import sync from .core._imperative_rt.core2 import sync
from .device import * from .device import *
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save from .serialization import load, save
......
...@@ -14,8 +14,8 @@ import numpy as np ...@@ -14,8 +14,8 @@ import numpy as np
from .._imperative_rt.core2 import Tensor, apply from .._imperative_rt.core2 import Tensor, apply
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase
from .dtype import is_equal, is_quantize from .dtype import is_equal, is_quantize
from .megbrain_graph import VarNode
_enable_convert_inputs = True _enable_convert_inputs = True
...@@ -110,7 +110,7 @@ def dtype_promotion(inputs): ...@@ -110,7 +110,7 @@ def dtype_promotion(inputs):
def get_device(inputs): def get_device(inputs):
device = None device = None
for i in inputs: for i in inputs:
if isinstance(i, Tensor): if isinstance(i, (Tensor, VarNode)):
if device is None: if device is None:
device = i.device device = i.device
elif device != i.device: elif device != i.device:
...@@ -142,9 +142,9 @@ def astype(x, dtype): ...@@ -142,9 +142,9 @@ def astype(x, dtype):
def convert_single_value(v, inputs, *, dtype=None, device=None): def convert_single_value(v, inputs, *, dtype=None, device=None):
tensors = [i for i in inputs if isinstance(i, Tensor)] tensors = [i for i in inputs if isinstance(i, (Tensor, VarNode))]
assert len(tensors) > 0 assert len(tensors) > 0
if isinstance(v, (TensorWrapperBase, Tensor)): if isinstance(v, (Tensor, VarNode)):
v = astype(v, v.dtype if is_quantize(v.dtype) else dtype) v = astype(v, v.dtype if is_quantize(v.dtype) else dtype)
else: else:
(v,) = Const(v, dtype=dtype, device=device)(*tensors) (v,) = Const(v, dtype=dtype, device=device)(*tensors)
......
...@@ -905,7 +905,6 @@ def linspace( ...@@ -905,7 +905,6 @@ def linspace(
stop = Tensor(stop, device=device) stop = Tensor(stop, device=device)
num = Tensor(num, device=device) num = Tensor(num, device=device)
device = device if device is None else device.to_c()
op = builtin.Linspace(comp_node=device) op = builtin.Linspace(comp_node=device)
(result,) = apply(op, start, stop, num) (result,) = apply(op, start, stop, num)
if np.dtype(dtype) == np.int32: if np.dtype(dtype) == np.int32:
......
...@@ -119,7 +119,6 @@ class Tensor(_Tensor, ArrayMethodMixin): ...@@ -119,7 +119,6 @@ class Tensor(_Tensor, ArrayMethodMixin):
self.q_dict = state.pop("qdict") self.q_dict = state.pop("qdict")
tensor = Tensor tensor = Tensor
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <pybind11/numpy.h> #include <pybind11/numpy.h>
#include <pybind11/operators.h> #include <pybind11/operators.h>
#include "./helper.h"
namespace py = pybind11; namespace py = pybind11;
namespace mgb::imperative::python { namespace mgb::imperative::python {
...@@ -201,6 +201,24 @@ PyObject* TensorWrapper::detach() { ...@@ -201,6 +201,24 @@ PyObject* TensorWrapper::detach() {
} }
PyObject* TensorWrapper::_dev_tensor(){
auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
return py::cast(dev_tensor).release().ptr();
}
void TensorWrapper::_swap_out() {
interpreter_for_py->swap_out(m_tensor->m_handle.get());
}
void TensorWrapper::_swap_in() {
interpreter_for_py->swap_in(m_tensor->m_handle.get());
}
void TensorWrapper::_drop() {
interpreter_for_py->drop(m_tensor->m_handle.get());
}
PyObject* TensorWrapper::isscalar() { PyObject* TensorWrapper::isscalar() {
if(m_tensor->m_flags & Tensor::Flags::SCALAR) { if(m_tensor->m_flags & Tensor::Flags::SCALAR) {
Py_RETURN_TRUE; Py_RETURN_TRUE;
...@@ -240,6 +258,10 @@ void init_tensor(py::module m) { ...@@ -240,6 +258,10 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::isscalar>("isscalar") .def<&TensorWrapper::isscalar>("isscalar")
.def<&TensorWrapper::setscalar>("setscalar") .def<&TensorWrapper::setscalar>("setscalar")
.def<&TensorWrapper::detach>("detach") .def<&TensorWrapper::detach>("detach")
.def<&TensorWrapper::_dev_tensor>("_dev_tensor")
.def<&TensorWrapper::_swap_out>("_swap_out")
.def<&TensorWrapper::_swap_in>("_swap_in")
.def<&TensorWrapper::_drop>("_drop")
.finalize(); .finalize();
if (!tensor_type) throw py::error_already_set(); if (!tensor_type) throw py::error_already_set();
py::setattr(m, "Tensor", tensor_type); py::setattr(m, "Tensor", tensor_type);
...@@ -253,6 +275,21 @@ void init_tensor(py::module m) { ...@@ -253,6 +275,21 @@ void init_tensor(py::module m) {
if (!apply_func) throw py::error_already_set(); if (!apply_func) throw py::error_already_set();
py::setattr(m, "apply", apply_func); py::setattr(m, "apply", apply_func);
m.def("_set_swap_flag",
[](bool flag) { interpreter_for_py->set_swap_flag(flag); });
m.def("_set_drop_flag",
[](bool flag) { interpreter_for_py->set_drop_flag(flag); });
m.def("config_async_level",
[](int level) { interpreter_for_py->config_async_level(level); });
m.def("get_async_level",
[]() { return interpreter_for_py->get_async_level(); });
m.def("sync",
[]() {
interpreter_for_py->sync();
py_task_q.wait_all_task_finish();
},
py::call_guard<py::gil_scoped_release>());
py::handle grad_key_type = GradKeyWrapper::wrap_t::type() py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach") .def<&GradKeyWrapper::attach>("attach")
.finalize(); .finalize();
......
...@@ -131,6 +131,10 @@ struct TensorWrapper { ...@@ -131,6 +131,10 @@ struct TensorWrapper {
PyObject* detach(); PyObject* detach();
PyObject* isscalar(); PyObject* isscalar();
void setscalar(); void setscalar();
PyObject* _dev_tensor();
void _swap_in();
void _swap_out();
void _drop();
}; };
......
...@@ -15,7 +15,7 @@ import megengine as mge ...@@ -15,7 +15,7 @@ import megengine as mge
import megengine.autodiff as ad import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
from megengine import Tensor from megengine import Tensor
from megengine.core._imperative_rt.imperative import _set_drop_flag, _set_swap_flag from megengine.core._imperative_rt.core2 import _set_drop_flag, _set_swap_flag
from megengine.module import Linear, Module from megengine.module import Linear, Module
from megengine.optimizer import SGD from megengine.optimizer import SGD
......
...@@ -2,7 +2,7 @@ import pytest ...@@ -2,7 +2,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine.core._imperative_rt.imperative import config_async_level, get_async_level from megengine.core._imperative_rt.core2 import config_async_level, get_async_level
def test_basic(): def test_basic():
...@@ -12,7 +12,6 @@ def test_basic(): ...@@ -12,7 +12,6 @@ def test_basic():
config_async_level(3) config_async_level(3)
@pytest.mark.skip
def test_level1_infer_value(): def test_level1_infer_value():
config_async_level(1) config_async_level(1)
a = mge.tensor([[1, 2], [2, 3], [3, 4]], dtype="float32") a = mge.tensor([[1, 2], [2, 3], [3, 4]], dtype="float32")
...@@ -23,7 +22,6 @@ def test_level1_infer_value(): ...@@ -23,7 +22,6 @@ def test_level1_infer_value():
d = F.reshape(a, c) d = F.reshape(a, c)
@pytest.mark.skip
def test_level1_infer_shape_with_unknown(): def test_level1_infer_shape_with_unknown():
config_async_level(2) config_async_level(2)
a = mge.tensor([[1, 2, 2, 3]], dtype="float32") a = mge.tensor([[1, 2, 2, 3]], dtype="float32")
......
...@@ -11,17 +11,13 @@ from concurrent.futures import Future ...@@ -11,17 +11,13 @@ from concurrent.futures import Future
import numpy as np import numpy as np
import megengine.functional as F import megengine.functional as F
import megengine.tensor as Tensor
from megengine.core.tensor import megbrain_graph as mgb_graph from megengine.core.tensor import megbrain_graph as mgb_graph
from megengine.core.tensor.raw_tensor import as_raw_tensor
def make_dev_tensor(value, dtype=None, device=None):
return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
def test_io(): def test_io():
g = mgb_graph.Graph() g = mgb_graph.Graph()
x = make_dev_tensor(np.random.randn(3).astype("float32"), device="xpux") x = Tensor(np.random.randn(3).astype("float32"), device="xpux")._dev_tensor()
vx, _ = mgb_graph.input_callback( vx, _ = mgb_graph.input_callback(
lambda: x, device=x.comp_node, dtype=x.dtype, graph=g lambda: x, device=x.comp_node, dtype=x.dtype, graph=g
) )
...@@ -43,7 +39,7 @@ def test_io2(): ...@@ -43,7 +39,7 @@ def test_io2():
for _ in range(3): for _ in range(3):
f.execute() f.execute()
x = make_dev_tensor(np.random.randn(10).astype(dtype), device=device) x = Tensor(np.random.randn(10).astype(dtype), device=device)._dev_tensor()
px.set_value(x) px.set_value(x)
y = py.get_value() y = py.get_value()
np.testing.assert_equal(x.numpy(), y.numpy()) np.testing.assert_equal(x.numpy(), y.numpy())
...@@ -60,7 +56,7 @@ def test_attr_output(): ...@@ -60,7 +56,7 @@ def test_attr_output():
for shape in [(2,), (3,), (5,)]: for shape in [(2,), (3,), (5,)]:
f.execute() f.execute()
x = make_dev_tensor(np.random.randn(*shape).astype(dtype), device=device) x = Tensor(np.random.randn(*shape).astype(dtype), device=device)._dev_tensor()
px.set_value(x) px.set_value(x)
ay = py.get_value() ay = py.get_value()
assert ay.shape == shape assert ay.shape == shape
...@@ -71,7 +67,7 @@ def test_attr_output(): ...@@ -71,7 +67,7 @@ def test_attr_output():
def test_op(): def test_op():
g = mgb_graph.Graph() g = mgb_graph.Graph()
x = make_dev_tensor(np.random.randn(10).astype("float32"), device="xpux") x = Tensor(np.random.randn(10).astype("float32"), device="xpux")._dev_tensor()
v, _ = mgb_graph.input_callback( v, _ = mgb_graph.input_callback(
lambda: x, device=x.comp_node, dtype=x.dtype, graph=g lambda: x, device=x.comp_node, dtype=x.dtype, graph=g
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册