提交 4ae9dd00 编写于 作者: M Megvii Engine Team

feat(imperative): add external transform

GitOrigin-RevId: e8e3ebe9c86afc9fb97900b5b5af9778cc1354e5
上级 9914129a
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from .. import tensor
from ..core._imperative_rt import CompNode from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
is_external_convert,
set_external_convert,
set_external_convert_hook,
set_py_external_type,
unset_external_convert,
)
from ..core._trace_option import set_use_xla_backend from ..core._trace_option import set_use_xla_backend
from ..device import get_default_device from ..device import get_default_device
from ..utils.dlpack import from_dlpack, to_dlpack from ..utils.dlpack import from_dlpack, to_dlpack
from .tracing import trace from .tracing import trace
# try:
# from mge_xlalib.xla_extension import ArrayImpl
# from ..xla.lib import xla_client as xc
# except ImportError:
# pass
from mge_xlalib.xla_extension import ArrayImpl
from ..xla.lib import xla_client as xc
xla_client_compute_stream = None
def apply_external_convert_hook(input, cn):
stream = xla_client_compute_stream
assert isinstance(input, ArrayImpl)
dlpack_capsule = xc._xla.buffer_to_dlpack_managed_tensor(input, take_ownership=True)
output = from_dlpack(dlpack_capsule, stream).to(cn, _borrow=True)
return output
class xla_trace(trace): class xla_trace(trace):
r"""Wraps a callable, and provides accelerated evaluation compiled by xla. r"""Wraps a callable, and provides accelerated evaluation compiled by xla.
...@@ -48,6 +77,12 @@ class xla_trace(trace): ...@@ -48,6 +77,12 @@ class xla_trace(trace):
def __init__(self, function, *, without_host=True, symbolic_shape=False, **kwargs): def __init__(self, function, *, without_host=True, symbolic_shape=False, **kwargs):
assert without_host, "xla trace only support without host mode" assert without_host, "xla trace only support without host mode"
assert not symbolic_shape, "xla doesn't support dynamic shape currently" assert not symbolic_shape, "xla doesn't support dynamic shape currently"
set_external_convert_hook(apply_external_convert_hook)
set_py_external_type(ArrayImpl)
set_external_convert()
super().__init__( super().__init__(
function, without_host=without_host, symbolic_shape=symbolic_shape, **kwargs function, without_host=without_host, symbolic_shape=symbolic_shape, **kwargs
) )
...@@ -142,8 +177,8 @@ class xla_trace(trace): ...@@ -142,8 +177,8 @@ class xla_trace(trace):
return xc._xla.buffer_to_dlpack_managed_tensor(x, take_ownership=take_ownership) return xc._xla.buffer_to_dlpack_managed_tensor(x, take_ownership=take_ownership)
def execute(self, *args, **kwargs): def execute(self, *args, **kwargs):
from ..traced_module.pytree import tree_flatten
from ..tensor import Tensor from ..tensor import Tensor
from ..traced_module.pytree import tree_flatten
from ..utils.module_utils import get_expand_structure from ..utils.module_utils import get_expand_structure
inputs, _ = tree_flatten((args, kwargs)) inputs, _ = tree_flatten((args, kwargs))
...@@ -161,6 +196,8 @@ class xla_trace(trace): ...@@ -161,6 +196,8 @@ class xla_trace(trace):
arrays = self.prepare_xla_inputs(arrays) arrays = self.prepare_xla_inputs(arrays)
outputs = self.xla_exec(*arrays) outputs = self.xla_exec(*arrays)
global xla_client_compute_stream
xla_client_compute_stream = xla_stream
return_vals = [] return_vals = []
for i in self.out_list: for i in self.out_list:
if i == -1: if i == -1:
...@@ -170,28 +207,25 @@ class xla_trace(trace): ...@@ -170,28 +207,25 @@ class xla_trace(trace):
return_vals.append(outputs[self.outkey2idx[i]]) return_vals.append(outputs[self.outkey2idx[i]])
keeped_features = [] keeped_features = []
for i in self.keeped_activation: for i in self.keeped_activation:
capsule = self.to_dlpack(outputs[self.outkey2idx[i]]) keeped_features.append(outputs[self.outkey2idx[i]])
t = from_dlpack(capsule, xla_stream).to(cn, _borrow=True)
keeped_features.append(t)
out_tensors = [] out_tensors = []
for array in return_vals: for array in return_vals:
if array is not None: if array is not None:
capsule = self.to_dlpack(array) t = tensor(array, device=cn)
t = from_dlpack(capsule, xla_stream) out_tensors.append(t)
out_tensors.append(t.to(cn, _borrow=True))
else: else:
out_tensors.append(array) out_tensors.append(array)
if self.overall: if self.overall:
for attr, key in self.update_param_dict.items(): for attr, key in self.update_param_dict.items():
param = get_expand_structure(attr[0], attr[1]) param = get_expand_structure(attr[0], attr[1])
xla_array = outputs[self.outkey2idx[key]] xla_array = outputs[self.outkey2idx[key]]
capsule = self.to_dlpack(xla_array) t = tensor(xla_array, device=cn)
param._reset(from_dlpack(capsule).to(cn, _borrow=True)) param._reset(t)
for state, key in self.update_opt_param_dict.items(): for state, key in self.update_opt_param_dict.items():
xla_array = outputs[self.outkey2idx[key]] xla_array = outputs[self.outkey2idx[key]]
capsule = self.to_dlpack(xla_array) t = tensor(xla_array, device=cn)
state._reset(from_dlpack(capsule).to(cn, _borrow=True)) state._reset(t)
rst = ( rst = (
self.outdef.unflatten(out_tensors) self.outdef.unflatten(out_tensors)
if hasattr(self, "outdef") if hasattr(self, "outdef")
......
...@@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Set, ...@@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Set,
import numpy as np import numpy as np
from .. import tensor
from ..distributed import is_distributed from ..distributed import is_distributed
from ..utils.dlpack import from_dlpack, to_dlpack from ..utils.dlpack import from_dlpack, to_dlpack
from . import ir_utils from . import ir_utils
...@@ -68,10 +69,20 @@ class InputsHandler: ...@@ -68,10 +69,20 @@ class InputsHandler:
def __call__(self, input_buffers): def __call__(self, input_buffers):
rst = [] rst = []
for ibuf in input_buffers: for idx, i in enumerate(input_buffers):
capsule = to_dlpack(ibuf) if i._is_external_value():
xla_array = self.from_dlpack(capsule) rst.append([i._external_obj()])
rst.append([xla_array]) else:
if "gpu" in i.device.physical_name:
capsule = to_dlpack(i)
xla_array = self.from_dlpack(capsule)
rst.append([xla_array])
else:
r = self.handler(
self.local_devices, [self.input_indices[idx],], [i,]
)[0]
rst.append(r)
i._reset(tensor(r[0]))
return rst return rst
def __str__(self): def __str__(self):
......
#pragma once
#include <list>
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/map.h"
#include "./tensor.h"
namespace mgb::imperative::python {
namespace py = pybind11;
class CreateExternalWrapper final : public OperatorImpl<CreateExternalWrapper> {
private:
py::object m_object;
CompNode m_device;
public:
CreateExternalWrapper(py::object obj, CompNode device)
: m_object(obj), m_device(device) {}
py::object object() const { return m_object; }
CompNode device() const { return m_device; }
std::string raw_type() const { return "CreateExternalWrapper"; }
std::string to_string() const { return "CreateExternalWrapper"; };
};
class GetExternalVal final
: public OperatorImpl<GetExternalVal, Operator::GetAttrLike> {
public:
std::string to_string() const { return "GetExternalVal"; };
std::string raw_type() const { return "GetExternalVal"; }
};
class PyobjectStorage {
private:
py::object m_object;
public:
PyobjectStorage() = default;
PyobjectStorage(py::object object) : m_object(object) {}
py::object object() const { return m_object; }
std::string to_string() const { return "PyobjectStorage"; }
};
class PyobjectValue final : public PrimitiveValue<PyobjectValue, PyobjectStorage> {
public:
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override { return PyobjectStorage::to_string(); }
};
class ExternalValue final : public ObjectValue<ExternalValue> {
private:
py::object m_obj;
mutable CompNodeValue::ref_t m_device;
public:
ExternalValue(py::object obj, CompNode device)
: m_obj(obj), m_device(CompNodeValue::make(device)) {}
py::object object() const { return m_obj; }
CompNodeValue::ref_t device() const { return m_device; }
std::string to_string() const override { return "ExternalValue"; }
void clear() override {}
};
class ExternalConvertTransformation final : public Transformation {
private:
py::function m_hook_fn;
int m_enabled = 0;
ObjectType<ExternalValue> m_value_type{"ExternalValue"};
public:
ValueRefList apply_external_imperative_hook(
const Operator& op, Span<ValueRef> input_values) {
for (int i = 0; i < input_values.size(); i++) {
if (auto* val = input_values[i].as(m_value_type)) {
CompNode cn = *(val->device());
py::object fn_res = m_hook_fn(val->object(), cn);
auto* tw = TensorWrapper::try_cast(fn_res.ptr());
mgb_assert(tw, "expect Tensor");
auto external_input = input_values[i].as_ref(m_value_type);
external_input.reset(tw->m_tensor->data());
}
}
auto outputs = imperative::apply(op, input_values);
return outputs;
}
ExternalConvertTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {}
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override {
if (!m_enabled) {
return imperative::apply(op, inputs);
}
bool has_external_inp = false;
if (auto* obj_value = op.as<CreateExternalWrapper>()) {
return m_value_type.make(obj_value->object(), obj_value->device());
}
for (auto&& input : inputs) {
if (input.is(m_value_type)) {
has_external_inp = true;
break;
}
}
if (!has_external_inp) {
return imperative::apply(op, inputs);
} else if (op.is<GetExternalVal>()) {
py::object m_object = inputs.item().cast(m_value_type).object();
PyobjectStorage inp_obj = PyobjectStorage(m_object);
return {PyobjectValue::make(inp_obj)};
} else if (op.is<RenameValue>()) {
return {inputs[0]};
} else if (auto* get_attr = op.as<GetAttr>()) {
auto& input = inputs.item().cast(m_value_type);
ValueRefList outputs;
switch (get_attr->attr()) {
case GetAttr::Device:
outputs = {input.device()};
break;
default:
outputs = apply_external_imperative_hook(op, inputs);
break;
}
return outputs;
} else {
auto outputs = apply_external_imperative_hook(op, inputs);
return outputs;
}
}
void enable() { m_enabled = 1; }
void disable() { m_enabled = 0; }
bool enabled() const { return m_enabled; }
ValueRef unwrap(ValueRef value) override { return value; }
const Type<ExternalValue>& value_type() const { return m_value_type; }
std::string name() const override { return "ExternalConvertTransformation"; }
};
} // namespace mgb::imperative::python
\ No newline at end of file
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "./common.h" #include "./common.h"
#include "./dlpack.h" #include "./dlpack.h"
#include "./dlpack_convertor.h" #include "./dlpack_convertor.h"
#include "./external_convert.h"
#include "./grad.h" #include "./grad.h"
#include "./graph_rt.h" #include "./graph_rt.h"
#include "./helper.h" #include "./helper.h"
...@@ -61,6 +62,7 @@ namespace mgb::imperative::python { ...@@ -61,6 +62,7 @@ namespace mgb::imperative::python {
interpreter::Interpreter::Channel* interpreter_for_py = nullptr; interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
PyTypeObject* py_tensor_type = nullptr; PyTypeObject* py_tensor_type = nullptr;
PyTypeObject* py_varnode_type = nullptr; PyTypeObject* py_varnode_type = nullptr;
PyTypeObject* py_external_type = nullptr;
pybind11::handle py_device_type = nullptr; pybind11::handle py_device_type = nullptr;
PyObject* cpp_use_symbolic_shape; PyObject* cpp_use_symbolic_shape;
...@@ -589,7 +591,13 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -589,7 +591,13 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
: no_cache ? CreateTensor::Unique : no_cache ? CreateTensor::Unique
: CreateTensor::Common; : CreateTensor::Common;
ValueRef val; ValueRef val;
if (py::isinstance(data, Py_Varnode)) { bool use_external_inp = py_external_type != nullptr;
if (use_external_inp &&
PyObject_TypeCheck(py::handle(data).ptr(), py_external_type)) {
val = imperative::apply(
CreateExternalWrapper(data, cn),
Span<ValueRef>(nullptr, nullptr))[0];
} else if (py::isinstance(data, Py_Varnode)) {
cg::VarNode* m_node = py::handle(data).cast<cg::VarNode*>(); cg::VarNode* m_node = py::handle(data).cast<cg::VarNode*>();
val = imperative::apply( val = imperative::apply(
CreateNode(m_node), Span<ValueRef>(nullptr, nullptr))[0]; CreateNode(m_node), Span<ValueRef>(nullptr, nullptr))[0];
...@@ -750,6 +758,27 @@ PyObject* TensorWrapper::_graph() { ...@@ -750,6 +758,27 @@ PyObject* TensorWrapper::_graph() {
return py::cast(graph).release().ptr(); return py::cast(graph).release().ptr();
} }
PyObject* TensorWrapper::_external_obj() {
TypedValueRef<PyobjectValue> value =
imperative::apply(GetExternalVal(), m_tensor->data())[0]
.as_ref<PyobjectValue>();
return value->object().release().ptr();
}
PyObject* TensorWrapper::_is_external_value() {
auto&& external_tsf =
TransformationManager::get_instance()
.segments[TransformationManager::Segment::ExternalConvert];
auto* tsf = reinterpret_cast<ExternalConvertTransformation*>(external_tsf[0].get());
mgb_assert(tsf->enabled());
auto valueref = m_tensor->data();
if (valueref.is(tsf->value_type())) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
void dlpack_capsule_destructor(PyObject* data) { void dlpack_capsule_destructor(PyObject* data) {
if (!PyCapsule_IsValid(data, "dltensor")) { if (!PyCapsule_IsValid(data, "dltensor")) {
// early out, see DLPack spec: if a consuming library sets the capsule // early out, see DLPack spec: if a consuming library sets the capsule
...@@ -931,6 +960,8 @@ void init_tensor(py::module m) { ...@@ -931,6 +960,8 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::_var>("var") .def<&TensorWrapper::_var>("var")
.def<&TensorWrapper::_graph>("graph") .def<&TensorWrapper::_graph>("graph")
.def<&TensorWrapper::value_id>("value_id") .def<&TensorWrapper::value_id>("value_id")
.def<&TensorWrapper::_is_external_value>("_is_external_value")
.def<&TensorWrapper::_external_obj>("_external_obj")
.def_getset< .def_getset<
&TensorWrapper::module_trace_info, &TensorWrapper::module_trace_info,
&TensorWrapper::set_module_trace_info>("_NodeMixin__node") &TensorWrapper::set_module_trace_info>("_NodeMixin__node")
...@@ -1150,6 +1181,10 @@ void init_tensor(py::module m) { ...@@ -1150,6 +1181,10 @@ void init_tensor(py::module m) {
py_varnode_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr()); py_varnode_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
}); });
m.def("set_py_external_type", [](py::object type_obj) {
py_external_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
});
m.def("set_py_device_type", m.def("set_py_device_type",
[](py::object type_obj) { py_device_type = type_obj.inc_ref(); }); [](py::object type_obj) { py_device_type = type_obj.inc_ref(); });
...@@ -1705,6 +1740,24 @@ void init_tensor(py::module m) { ...@@ -1705,6 +1740,24 @@ void init_tensor(py::module m) {
return module_trace_transformation; return module_trace_transformation;
}; };
static py::function external_convert_hook;
static auto get_external_convert = [] {
static std::shared_ptr<ExternalConvertTransformation>
external_convert_transformation;
if (!external_convert_transformation) {
mgb_assert(external_convert_hook);
external_convert_transformation =
std::make_shared<ExternalConvertTransformation>(
external_convert_hook);
MGB_MARK_USED_VAR(transformations
.register_at<Segment::ExternalConvert>(
external_convert_transformation)
.release());
}
return external_convert_transformation;
};
m.def("set_cpp_use_symbolic_shape", &set_cpp_use_symbolic_shape); m.def("set_cpp_use_symbolic_shape", &set_cpp_use_symbolic_shape);
m.def("set_module_tracing", [=] { get_module_trace()->enable(); }); m.def("set_module_tracing", [=] { get_module_trace()->enable(); });
...@@ -1712,6 +1765,12 @@ void init_tensor(py::module m) { ...@@ -1712,6 +1765,12 @@ void init_tensor(py::module m) {
m.def("unset_module_tracing", [=] { get_module_trace()->disable(); }); m.def("unset_module_tracing", [=] { get_module_trace()->disable(); });
m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); }); m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); });
m.def("set_external_convert", [=] { get_external_convert()->enable(); });
m.def("unset_external_convert", [=] { get_external_convert()->disable(); });
m.def("is_external_convert", [=] { return get_external_convert()->enabled(); });
m.def("set_python_backtrace_enabled", &set_python_backtrace_enabled); m.def("set_python_backtrace_enabled", &set_python_backtrace_enabled);
m.def("set_transformation_backtrace_enabled", m.def("set_transformation_backtrace_enabled",
&set_transformation_backtrace_enabled); &set_transformation_backtrace_enabled);
...@@ -1723,8 +1782,16 @@ void init_tensor(py::module m) { ...@@ -1723,8 +1782,16 @@ void init_tensor(py::module m) {
module_trace_hook.inc_ref(); module_trace_hook.inc_ref();
}); });
m.def("set_external_convert_hook", [](py::function function) {
external_convert_hook = function;
external_convert_hook.inc_ref();
});
auto atexit = py::module::import("atexit"); auto atexit = py::module::import("atexit");
atexit.attr("register")(py::cpp_function([]() { module_trace_hook = {}; })); atexit.attr("register")(py::cpp_function([]() {
module_trace_hook = {};
external_convert_hook = {};
}));
m.def("begin_record_values", [] { Value::begin_record_values(); }); m.def("begin_record_values", [] { Value::begin_record_values(); });
m.def("end_record_values", [] { m.def("end_record_values", [] {
......
...@@ -31,6 +31,7 @@ namespace mgb::imperative::python { ...@@ -31,6 +31,7 @@ namespace mgb::imperative::python {
extern interpreter::Interpreter::Channel* interpreter_for_py; extern interpreter::Interpreter::Channel* interpreter_for_py;
extern PyTypeObject* py_tensor_type; extern PyTypeObject* py_tensor_type;
extern PyTypeObject* py_varnode_type; extern PyTypeObject* py_varnode_type;
extern PyTypeObject* py_external_type;
extern pybind11::handle py_device_type; extern pybind11::handle py_device_type;
extern PyObject* cpp_use_symbolic_shape; extern PyObject* cpp_use_symbolic_shape;
extern PyObject* cpp_astensor1d; extern PyObject* cpp_astensor1d;
...@@ -142,6 +143,8 @@ public: ...@@ -142,6 +143,8 @@ public:
PyObject* _detail(); PyObject* _detail();
PyObject* _var(); PyObject* _var();
PyObject* _graph(); PyObject* _graph();
PyObject* _is_external_value();
PyObject* _external_obj();
void _watch(); void _watch();
}; };
......
...@@ -22,6 +22,7 @@ public: ...@@ -22,6 +22,7 @@ public:
Complex, Complex,
Format, Format,
Grad, Grad,
ExternalConvert,
Scalar, Scalar,
Symbol, Symbol,
Trace, Trace,
......
import platform
import numpy as np
import pytest
import megengine.functional as F
import megengine.jit as jit
import megengine.tensor as tensor
from megengine import autodiff, is_cuda_available
from megengine.autodiff.grad_manager import GradManager
from meg_xlalib.xla_extension import ArrayImpl
def test_external_flag_set():
@xla_trace(capture_as_const=True)
def test_fun():
pass
def test_external_value():
m = Conv2d(9,9, 3,groups=9)
gm = GradManager()
gm.attach(m.parameters())
@xla_trace(capture_as_const=True)
def conv_grad(inp, model):
with gm:
gm.attach(inp)
rst = model(inp)
gm.backward(rst.mean())
ig = inp.grad
wg = model.weight.grad
inp.grad = None
model.weight.grad = None
return ig, wg
inp = tensor(np.random.random((9,9, 32, 32)))*100
a, b = conv_grad(inp, m)
a1, b1 = conv_grad(inp, m)
np.testing.assert_allclose(a.numpy(), a1.numpy())
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册