提交 dd9f54cd 编写于 作者: M Megvii Engine Team

refactor(mge): migrate to new core implementation

* remove dispatcher/interpreter python wrapper
* rename tensor_wrapper to array_method

GitOrigin-RevId: b8a402c2be58a3e1f990802d772e8fc80ce23006
上级 b9762d71
...@@ -18,7 +18,7 @@ from ..core._imperative_rt.core2 import apply ...@@ -18,7 +18,7 @@ from ..core._imperative_rt.core2 import apply
from ..core._wrap import device as as_device from ..core._wrap import device as as_device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis from ..core.tensor.array_method import _broadcast, _remove_axis
from ..core.tensor.utils import ( from ..core.tensor.utils import (
astensor1d, astensor1d,
convert_inputs, convert_inputs,
......
...@@ -18,7 +18,7 @@ import weakref ...@@ -18,7 +18,7 @@ import weakref
import numpy as np import numpy as np
from ..core._imperative_rt import GraphProfiler, common, put from ..core._imperative_rt import GraphProfiler, common
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import TensorWeakRef from ..core._imperative_rt.core2 import TensorWeakRef
from ..core._imperative_rt.core2 import __make_empty_tensor as make_empty_tensor from ..core._imperative_rt.core2 import __make_empty_tensor as make_empty_tensor
......
...@@ -18,7 +18,7 @@ from .core._imperative_rt.core2 import apply ...@@ -18,7 +18,7 @@ from .core._imperative_rt.core2 import apply
from .core._trace_option import use_symbolic_shape from .core._trace_option import use_symbolic_shape
from .core._wrap import device as as_device from .core._wrap import device as as_device
from .core.ops.builtin import Copy, GetVarShape from .core.ops.builtin import Copy, GetVarShape
from .core.tensor.tensor_wrapper import ArrayMethodMixin from .core.tensor.array_method import ArrayMethodMixin
from .device import _valid_device, get_default_device from .device import _valid_device, get_default_device
from .utils.deprecation import deprecated from .utils.deprecation import deprecated
...@@ -42,7 +42,6 @@ class Tensor(_Tensor, ArrayMethodMixin): ...@@ -42,7 +42,6 @@ class Tensor(_Tensor, ArrayMethodMixin):
else: else:
cn = device._cn cn = device._cn
# import pdb; pdb.set_trace()
if isinstance(data, _Tensor): if isinstance(data, _Tensor):
obj = _Tensor.__new__(cls, data) obj = _Tensor.__new__(cls, data)
else: else:
......
...@@ -14,7 +14,7 @@ from typing import Iterable, List, Optional ...@@ -14,7 +14,7 @@ from typing import Iterable, List, Optional
from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry
from ..core._imperative_rt import ProfilerImpl as _Profiler from ..core._imperative_rt import ProfilerImpl as _Profiler
from ..core._imperative_rt.imperative import sync from ..core._imperative_rt.core2 import sync
from ..core._imperative_rt.ops import CollectiveComm from ..core._imperative_rt.ops import CollectiveComm
......
from ..core._imperative_rt import TensorSanityCheckImpl from ..core._imperative_rt import TensorSanityCheckImpl
from ..core._imperative_rt.imperative import sync from ..core._imperative_rt.core2 import sync
class TensorSanityCheck: class TensorSanityCheck:
......
/**
* \file imperative/python/src/dispatcher.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./dispatcher.h"
#include "./pyext17.h"
#include "megbrain/exception.h"
#include "megbrain/utils/hash.h"
#include "megbrain/utils/small_vector.h"
#include <unordered_map>
#include <structmember.h>
namespace py = pybind11;
namespace pyx = pyext17;
namespace {
struct Handler {
PyObject* func; // borrowed
bool enabled;
Handler() = default;
Handler(PyObject* func_, bool enable = true) : func(func_), enabled(enable) {}
};
using FastSig = mgb::SmallVector<void*, 8>;
using MRO = std::vector<Handler*>;
struct Frame {
MRO* mro;
size_t mro_offset;
Frame() = default;
Frame(MRO* mro_, size_t mro_offset_ = 0) : mro(mro_), mro_offset(mro_offset_) {}
};
struct FastSigHash {
size_t operator()(const FastSig& sig) const {
auto* ptr = &sig.front();
return mgb::XXHash()
.update(ptr, sig.size() * sizeof(FastSig::value_type))
.digest();
}
};
struct ObjectIdHash : std::hash<void*> {
size_t operator()(const py::handle& h) const {
return std::hash<void*>::operator()(h.ptr());
}
};
namespace {
using Container = std::vector<Frame>;
struct DispatcherStack: Container {
constexpr static size_t MAX_RECURSIVE_DEPTH = 1024u;
DispatcherStack() { reserve(MAX_RECURSIVE_DEPTH); }
template<typename... Args>
auto&& emplace_back_safely(Args&& ...args) {
mgb_throw_if(size() >= MAX_RECURSIVE_DEPTH, mgb::MegBrainError,
"recursion depth %zu is greater than the MAX_RECURSIVE_DEPTH(%zu)",
size(), MAX_RECURSIVE_DEPTH);
return emplace_back(std::forward<Args>(args)...);
}
};
} // anonymous namespace
struct Dispatcher {
std::unordered_map<FastSig, std::unique_ptr<MRO>, FastSigHash> cache;
DispatcherStack stack;
std::unordered_map<py::object, std::unique_ptr<Handler>, ObjectIdHash> registry;
inline py::handle self() {
return pyx::wrap<Dispatcher>::pycast(this);
}
bool prepare_call(PyObject*const* args, Py_ssize_t nargs) {
FastSig sig(nargs);
for (Py_ssize_t i = 0; i < nargs; ++i) {
sig[i] = Py_TYPE(args[i]);
}
auto it = cache.find(sig);
if (it == cache.end()) {
if (auto mro = resolve(sig)) {
it = cache.emplace(std::move(sig), std::move(mro)).first;
} else {
return false;
}
}
stack.emplace_back_safely(it->second.get());
return true;
}
template<typename T>
PyObject* do_call(T&& caller) {
auto& frame = stack.back();
auto& mro = *frame.mro;
auto& i = frame.mro_offset;
if (!mro.size()) {
PyErr_SetString(PyExc_NotImplementedError, "function not registered in dispatcher");
return nullptr;
}
for (; i < mro.size(); ++i) {
if (mro[i]->enabled) {
auto ret = caller(mro[i]->func);
if (ret != Py_NotImplemented) {
stack.pop_back();
return ret;
}
Py_DECREF(ret);
}
}
PyErr_SetString(PyExc_NotImplementedError, "mro exhausted");
stack.pop_back();
return nullptr;
}
std::unique_ptr<MRO> resolve(const FastSig& sig) {
try {
py::tuple args(sig.size());
for (size_t i = 0; i < sig.size(); ++i) {
args[i] = (PyObject*)sig[i];
}
auto mro_iter = self().attr("dispatch_iter")(*args);
auto ret = std::make_unique<MRO>();
for (auto i : mro_iter) {
auto it = registry.find(py::reinterpret_borrow<py::object>(i));
if (it == registry.end()) {
PyErr_SetString(PyExc_RuntimeError, "resolved to unregistered function");
return nullptr;
}
ret->push_back(it->second.get());
}
return ret;
} catch (py::error_already_set& e) {
e.restore();
} catch (std::runtime_error& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
}
return nullptr;
}
public:
static constexpr auto tp_name = "Dispatcher";
PyObject* tp_call(PyObject* args, PyObject* kwargs) {
if (!prepare_call(&PyTuple_GET_ITEM(args, 0), PyTuple_GET_SIZE(args))) return nullptr;
return do_call([=](PyObject* func){return PyObject_Call(func, args, kwargs);});
}
#if PY_MINOR_VERSION >= 6
PyObject* tp_vectorcall(PyObject*const* args, Py_ssize_t nargs) {
if (!prepare_call(args, nargs)) return nullptr;
return do_call([=](PyObject* func){return _PyObject_FastCall(func, const_cast<PyObject**>(args), nargs);});
}
#endif
#if PY_MINOR_VERSION >= 6
PyObject* super(PyObject*const* args, Py_ssize_t nargs) {
if (stack.empty()) {
PyErr_SetString(PyExc_RuntimeError, "super called at top level");
return nullptr;
}
stack.emplace_back_safely(stack.back()).mro_offset++;
return do_call([=](PyObject* func){return _PyObject_FastCall(func, const_cast<PyObject**>(args), nargs);});
}
#else
PyObject* super(PyObject* args, PyObject* kwargs) {
if (stack.empty()) {
PyErr_SetString(PyExc_RuntimeError, "super called at top level");
return nullptr;
}
stack.emplace_back_safely(stack.back()).mro_offset++;
return do_call([=](PyObject* func){return PyObject_Call(func, args, kwargs);});
}
#endif
void enable(PyObject* func) {
auto obj = py::reinterpret_borrow<py::object>(func);
auto it = registry.find(obj);
if (it != registry.end()) {
it->second->enabled = true;
} else {
registry.emplace(std::move(obj), std::make_unique<Handler>(func));
}
}
PyObject* disable(PyObject* func) {
auto obj = py::reinterpret_borrow<py::object>(func);
auto it = registry.find(obj);
if (it == registry.end()) {
PyErr_SetString(PyExc_ValueError, "function not registered");
return nullptr;
} else {
it->second->enabled = false;
}
Py_RETURN_NONE;
}
void clear_cache() {
cache.clear();
}
};
} // namespace
void init_dispatcher(py::module m) {
auto* dispatcher_type = pyx::wrap<Dispatcher>::type()
.def<&Dispatcher::enable>("enable")
.def<&Dispatcher::disable>("disable")
.def<&Dispatcher::clear_cache>("clear_cache")
#if PY_MINOR_VERSION >= 6
.def<&Dispatcher::tp_vectorcall>("call")
#else
.def<&Dispatcher::tp_call>("call")
#endif
.def<&Dispatcher::super>("super")
.finalize();
if (!dispatcher_type) throw py::error_already_set();
m.attr("Dispatcher") = dispatcher_type;
}
/**
* \file imperative/python/src/dispatcher.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <pybind11/pybind11.h>
void init_dispatcher(pybind11::module);
...@@ -51,59 +51,5 @@ make_backward_graph( ...@@ -51,59 +51,5 @@ make_backward_graph(
} // namespace } // namespace
void init_imperative_rt(py::module m) { void init_imperative_rt(py::module m) {
py::class_<Interpreter::Channel>(m, "Interpreter")
.def("put", [](Interpreter::Channel& self, py::array data, DType dtype, CompNode cn) {
if (!cn.valid()) {
cn = CompNode::load(get_default_device());
}
constexpr int size_threshhold = TensorShape::MAX_NDIM;
if (data.size() > size_threshhold) {
return self.put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype));
} else {
HostTensorND ret(cn);
return self.put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype));
}
}, py::arg(), py::arg("dtype") = py::none(), py::arg("device") = py::none())
.def("put", py::overload_cast<const DeviceTensorND&>(&Interpreter::Channel::put))
.def("delete", [](Interpreter::Channel& self, Interpreter::Handle handle) {
return self.del(handle);
})
.def("_swap_in", [](Interpreter::Channel& self, Interpreter::Handle handle) {
self.swap_in(handle);
})
.def("_swap_out", [](Interpreter::Channel& self, Interpreter::Handle handle) {
self.swap_out(handle);
})
.def("_drop", [](Interpreter::Channel& self, Interpreter::Handle handle) {
self.drop(handle);
})
.def("get_value", [](Interpreter::Channel& self, Interpreter::Handle handle) {
PyObject* optr = npy::ndarray_from_tensor(self.get_value(handle), npy::ShareType::TRY_SHARE);
return py::reinterpret_steal<py::object>(optr);
})
.def("get_dtype", &Interpreter::Channel::get_dtype)
.def("get_device", &Interpreter::Channel::get_device)
.def("get_shape", &Interpreter::Channel::get_shape)
.def("_get_dev_tensor", &Interpreter::Channel::get_dev_tensor)
.def("_set_swap_flag", &Interpreter::Channel::set_swap_flag)
.def("_set_drop_flag", &Interpreter::Channel::set_drop_flag)
.def("apply_op", &Interpreter::Channel::apply_op)
.def("config_async_level", &Interpreter::Channel::config_async_level)
.def("get_async_level", &Interpreter::Channel::get_async_level)
.def("sync", &Interpreter::Channel::sync, py::call_guard<py::gil_scoped_release>());
std::unique_ptr<Interpreter::Channel> ch = Interpreter::inst().create_channel();
m.attr("interpreter") = py::detail::make_caster<decltype(ch)>::cast(
std::move(ch), py::return_value_policy::move, {});
for (auto name : {"put", "delete", "get_value", "get_dtype", "get_device", "get_shape", "_get_dev_tensor", "apply_op", "config_async_level", "get_async_level", "_drop", "_swap_in", "_swap_out", "_set_drop_flag", "_set_swap_flag"}) {
m.attr(name) = m.attr("interpreter").attr(name);
}
m.def("sync", [m]() {
m.attr("interpreter").attr("sync")();
py::gil_scoped_release _;
py_task_q.wait_all_task_finish();
});
m.def("make_backward_graph", &make_backward_graph); m.def("make_backward_graph", &make_backward_graph);
} }
...@@ -21,8 +21,6 @@ ...@@ -21,8 +21,6 @@
#include "./graph_rt.h" #include "./graph_rt.h"
#include "./ops.h" #include "./ops.h"
#include "./dispatcher.h"
#include "./tensor.h" #include "./tensor.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -70,7 +68,5 @@ PYBIND11_MODULE(MODULE_NAME, m) { ...@@ -70,7 +68,5 @@ PYBIND11_MODULE(MODULE_NAME, m) {
)", )",
py::getattr(m, "__dict__")); py::getattr(m, "__dict__"));
init_dispatcher(submodule(m, "dispatcher"));
init_tensor(submodule(m, "core2")); init_tensor(submodule(m, "core2"));
} }
...@@ -16,7 +16,7 @@ import pytest ...@@ -16,7 +16,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
import megengine.functional as F import megengine.functional as F
from megengine.core._imperative_rt import CompNode, TensorAttr, core2, imperative from megengine.core._imperative_rt import CompNode, TensorAttr, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
......
...@@ -54,10 +54,10 @@ def test_simple_arith(): ...@@ -54,10 +54,10 @@ def test_simple_arith():
def test_tensor_on_device(): def test_tensor_on_device():
device = megengine.core._imperative_rt.CompNode("cpu0:1") device = megengine.core._imperative_rt.CompNode("cpu0:1")
x = np.random.rand(10).astype("float32") x = np.random.rand(10).astype("float32")
xx = megengine.core._imperative_rt.put(x, device=device) xx = megengine.tensor(x, device=device)
assert str(megengine.core._imperative_rt.get_device(xx)) == "cpu0:1" assert str(xx.device) == "cpu0:1"
np.testing.assert_equal(x, megengine.core._imperative_rt.get_value(xx)) np.testing.assert_equal(x, xx.numpy())
megengine.core._imperative_rt.delete(xx) del xx
def test_raw_tensor(): def test_raw_tensor():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册