From 4ae9dd0074df6374c4b9e77217b37546d34b486c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 26 Jun 2023 20:15:13 +0800 Subject: [PATCH] feat(imperative): add external transform GitOrigin-RevId: e8e3ebe9c86afc9fb97900b5b5af9778cc1354e5 --- .../python/megengine/jit/xla_backend.py | 56 +++++-- imperative/python/megengine/xla/compile.py | 19 ++- imperative/python/src/external_convert.h | 153 ++++++++++++++++++ imperative/python/src/tensor.cpp | 71 +++++++- imperative/python/src/tensor.h | 3 + imperative/python/src/transformation.h | 1 + .../unit/xla/functional/test_xla_convert.py | 46 ++++++ 7 files changed, 332 insertions(+), 17 deletions(-) create mode 100644 imperative/python/src/external_convert.h create mode 100644 imperative/python/test/unit/xla/functional/test_xla_convert.py diff --git a/imperative/python/megengine/jit/xla_backend.py b/imperative/python/megengine/jit/xla_backend.py index 992caedcb..db1f27425 100644 --- a/imperative/python/megengine/jit/xla_backend.py +++ b/imperative/python/megengine/jit/xla_backend.py @@ -1,12 +1,41 @@ from collections import OrderedDict, defaultdict +from .. import tensor from ..core._imperative_rt import CompNode from ..core._imperative_rt.core2 import Tensor as RawTensor +from ..core._imperative_rt.core2 import ( + is_external_convert, + set_external_convert, + set_external_convert_hook, + set_py_external_type, + unset_external_convert, +) from ..core._trace_option import set_use_xla_backend from ..device import get_default_device from ..utils.dlpack import from_dlpack, to_dlpack from .tracing import trace +# try: +# from mge_xlalib.xla_extension import ArrayImpl + +# from ..xla.lib import xla_client as xc +# except ImportError: +# pass + +from mge_xlalib.xla_extension import ArrayImpl + +from ..xla.lib import xla_client as xc + +xla_client_compute_stream = None + + +def apply_external_convert_hook(input, cn): + stream = xla_client_compute_stream + assert isinstance(input, ArrayImpl) + dlpack_capsule = xc._xla.buffer_to_dlpack_managed_tensor(input, take_ownership=True) + output = from_dlpack(dlpack_capsule, stream).to(cn, _borrow=True) + return output + class xla_trace(trace): r"""Wraps a callable, and provides accelerated evaluation compiled by xla. @@ -48,6 +77,12 @@ class xla_trace(trace): def __init__(self, function, *, without_host=True, symbolic_shape=False, **kwargs): assert without_host, "xla trace only support without host mode" assert not symbolic_shape, "xla doesn't support dynamic shape currently" + + set_external_convert_hook(apply_external_convert_hook) + + set_py_external_type(ArrayImpl) + set_external_convert() + super().__init__( function, without_host=without_host, symbolic_shape=symbolic_shape, **kwargs ) @@ -142,8 +177,8 @@ class xla_trace(trace): return xc._xla.buffer_to_dlpack_managed_tensor(x, take_ownership=take_ownership) def execute(self, *args, **kwargs): - from ..traced_module.pytree import tree_flatten from ..tensor import Tensor + from ..traced_module.pytree import tree_flatten from ..utils.module_utils import get_expand_structure inputs, _ = tree_flatten((args, kwargs)) @@ -161,6 +196,8 @@ class xla_trace(trace): arrays = self.prepare_xla_inputs(arrays) outputs = self.xla_exec(*arrays) + global xla_client_compute_stream + xla_client_compute_stream = xla_stream return_vals = [] for i in self.out_list: if i == -1: @@ -170,28 +207,25 @@ class xla_trace(trace): return_vals.append(outputs[self.outkey2idx[i]]) keeped_features = [] for i in self.keeped_activation: - capsule = self.to_dlpack(outputs[self.outkey2idx[i]]) - t = from_dlpack(capsule, xla_stream).to(cn, _borrow=True) - keeped_features.append(t) + keeped_features.append(outputs[self.outkey2idx[i]]) out_tensors = [] for array in return_vals: if array is not None: - capsule = self.to_dlpack(array) - t = from_dlpack(capsule, xla_stream) - out_tensors.append(t.to(cn, _borrow=True)) + t = tensor(array, device=cn) + out_tensors.append(t) else: out_tensors.append(array) if self.overall: for attr, key in self.update_param_dict.items(): param = get_expand_structure(attr[0], attr[1]) xla_array = outputs[self.outkey2idx[key]] - capsule = self.to_dlpack(xla_array) - param._reset(from_dlpack(capsule).to(cn, _borrow=True)) + t = tensor(xla_array, device=cn) + param._reset(t) for state, key in self.update_opt_param_dict.items(): xla_array = outputs[self.outkey2idx[key]] - capsule = self.to_dlpack(xla_array) - state._reset(from_dlpack(capsule).to(cn, _borrow=True)) + t = tensor(xla_array, device=cn) + state._reset(t) rst = ( self.outdef.unflatten(out_tensors) if hasattr(self, "outdef") diff --git a/imperative/python/megengine/xla/compile.py b/imperative/python/megengine/xla/compile.py index 2de389e5b..7efa4931b 100644 --- a/imperative/python/megengine/xla/compile.py +++ b/imperative/python/megengine/xla/compile.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Set, import numpy as np +from .. import tensor from ..distributed import is_distributed from ..utils.dlpack import from_dlpack, to_dlpack from . import ir_utils @@ -68,10 +69,20 @@ class InputsHandler: def __call__(self, input_buffers): rst = [] - for ibuf in input_buffers: - capsule = to_dlpack(ibuf) - xla_array = self.from_dlpack(capsule) - rst.append([xla_array]) + for idx, i in enumerate(input_buffers): + if i._is_external_value(): + rst.append([i._external_obj()]) + else: + if "gpu" in i.device.physical_name: + capsule = to_dlpack(i) + xla_array = self.from_dlpack(capsule) + rst.append([xla_array]) + else: + r = self.handler( + self.local_devices, [self.input_indices[idx],], [i,] + )[0] + rst.append(r) + i._reset(tensor(r[0])) return rst def __str__(self): diff --git a/imperative/python/src/external_convert.h b/imperative/python/src/external_convert.h new file mode 100644 index 000000000..ba9e0c19d --- /dev/null +++ b/imperative/python/src/external_convert.h @@ -0,0 +1,153 @@ +#pragma once + +#include +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/utils/map.h" + +#include "./tensor.h" + +namespace mgb::imperative::python { + +namespace py = pybind11; + +class CreateExternalWrapper final : public OperatorImpl { +private: + py::object m_object; + CompNode m_device; + +public: + CreateExternalWrapper(py::object obj, CompNode device) + : m_object(obj), m_device(device) {} + + py::object object() const { return m_object; } + + CompNode device() const { return m_device; } + + std::string raw_type() const { return "CreateExternalWrapper"; } + + std::string to_string() const { return "CreateExternalWrapper"; }; +}; + +class GetExternalVal final + : public OperatorImpl { +public: + std::string to_string() const { return "GetExternalVal"; }; + std::string raw_type() const { return "GetExternalVal"; } +}; + +class PyobjectStorage { +private: + py::object m_object; + +public: + PyobjectStorage() = default; + PyobjectStorage(py::object object) : m_object(object) {} + py::object object() const { return m_object; } + std::string to_string() const { return "PyobjectStorage"; } +}; + +class PyobjectValue final : public PrimitiveValue { +public: + using PrimitiveValue::PrimitiveValue; + + std::string to_string() const override { return PyobjectStorage::to_string(); } +}; + +class ExternalValue final : public ObjectValue { +private: + py::object m_obj; + mutable CompNodeValue::ref_t m_device; + +public: + ExternalValue(py::object obj, CompNode device) + : m_obj(obj), m_device(CompNodeValue::make(device)) {} + + py::object object() const { return m_obj; } + + CompNodeValue::ref_t device() const { return m_device; } + + std::string to_string() const override { return "ExternalValue"; } + + void clear() override {} +}; + +class ExternalConvertTransformation final : public Transformation { +private: + py::function m_hook_fn; + int m_enabled = 0; + ObjectType m_value_type{"ExternalValue"}; + +public: + ValueRefList apply_external_imperative_hook( + const Operator& op, Span input_values) { + for (int i = 0; i < input_values.size(); i++) { + if (auto* val = input_values[i].as(m_value_type)) { + CompNode cn = *(val->device()); + py::object fn_res = m_hook_fn(val->object(), cn); + auto* tw = TensorWrapper::try_cast(fn_res.ptr()); + mgb_assert(tw, "expect Tensor"); + auto external_input = input_values[i].as_ref(m_value_type); + external_input.reset(tw->m_tensor->data()); + } + } + auto outputs = imperative::apply(op, input_values); + return outputs; + } + + ExternalConvertTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} + ValueRefList apply_transformation( + const Operator& op, Span inputs) override { + if (!m_enabled) { + return imperative::apply(op, inputs); + } + bool has_external_inp = false; + if (auto* obj_value = op.as()) { + return m_value_type.make(obj_value->object(), obj_value->device()); + } + for (auto&& input : inputs) { + if (input.is(m_value_type)) { + has_external_inp = true; + break; + } + } + if (!has_external_inp) { + return imperative::apply(op, inputs); + } else if (op.is()) { + py::object m_object = inputs.item().cast(m_value_type).object(); + PyobjectStorage inp_obj = PyobjectStorage(m_object); + return {PyobjectValue::make(inp_obj)}; + } else if (op.is()) { + return {inputs[0]}; + } else if (auto* get_attr = op.as()) { + auto& input = inputs.item().cast(m_value_type); + ValueRefList outputs; + switch (get_attr->attr()) { + case GetAttr::Device: + outputs = {input.device()}; + break; + default: + outputs = apply_external_imperative_hook(op, inputs); + break; + } + return outputs; + } else { + auto outputs = apply_external_imperative_hook(op, inputs); + return outputs; + } + } + + void enable() { m_enabled = 1; } + + void disable() { m_enabled = 0; } + + bool enabled() const { return m_enabled; } + + ValueRef unwrap(ValueRef value) override { return value; } + + const Type& value_type() const { return m_value_type; } + + std::string name() const override { return "ExternalConvertTransformation"; } +}; + +} // namespace mgb::imperative::python \ No newline at end of file diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 194829cd4..0ec0b3f40 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -28,6 +28,7 @@ #include "./common.h" #include "./dlpack.h" #include "./dlpack_convertor.h" +#include "./external_convert.h" #include "./grad.h" #include "./graph_rt.h" #include "./helper.h" @@ -61,6 +62,7 @@ namespace mgb::imperative::python { interpreter::Interpreter::Channel* interpreter_for_py = nullptr; PyTypeObject* py_tensor_type = nullptr; PyTypeObject* py_varnode_type = nullptr; +PyTypeObject* py_external_type = nullptr; pybind11::handle py_device_type = nullptr; PyObject* cpp_use_symbolic_shape; @@ -589,7 +591,13 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { : no_cache ? CreateTensor::Unique : CreateTensor::Common; ValueRef val; - if (py::isinstance(data, Py_Varnode)) { + bool use_external_inp = py_external_type != nullptr; + if (use_external_inp && + PyObject_TypeCheck(py::handle(data).ptr(), py_external_type)) { + val = imperative::apply( + CreateExternalWrapper(data, cn), + Span(nullptr, nullptr))[0]; + } else if (py::isinstance(data, Py_Varnode)) { cg::VarNode* m_node = py::handle(data).cast(); val = imperative::apply( CreateNode(m_node), Span(nullptr, nullptr))[0]; @@ -750,6 +758,27 @@ PyObject* TensorWrapper::_graph() { return py::cast(graph).release().ptr(); } +PyObject* TensorWrapper::_external_obj() { + TypedValueRef value = + imperative::apply(GetExternalVal(), m_tensor->data())[0] + .as_ref(); + return value->object().release().ptr(); +} + +PyObject* TensorWrapper::_is_external_value() { + auto&& external_tsf = + TransformationManager::get_instance() + .segments[TransformationManager::Segment::ExternalConvert]; + auto* tsf = reinterpret_cast(external_tsf[0].get()); + mgb_assert(tsf->enabled()); + auto valueref = m_tensor->data(); + if (valueref.is(tsf->value_type())) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + void dlpack_capsule_destructor(PyObject* data) { if (!PyCapsule_IsValid(data, "dltensor")) { // early out, see DLPack spec: if a consuming library sets the capsule @@ -931,6 +960,8 @@ void init_tensor(py::module m) { .def<&TensorWrapper::_var>("var") .def<&TensorWrapper::_graph>("graph") .def<&TensorWrapper::value_id>("value_id") + .def<&TensorWrapper::_is_external_value>("_is_external_value") + .def<&TensorWrapper::_external_obj>("_external_obj") .def_getset< &TensorWrapper::module_trace_info, &TensorWrapper::set_module_trace_info>("_NodeMixin__node") @@ -1150,6 +1181,10 @@ void init_tensor(py::module m) { py_varnode_type = reinterpret_cast(type_obj.inc_ref().ptr()); }); + m.def("set_py_external_type", [](py::object type_obj) { + py_external_type = reinterpret_cast(type_obj.inc_ref().ptr()); + }); + m.def("set_py_device_type", [](py::object type_obj) { py_device_type = type_obj.inc_ref(); }); @@ -1705,6 +1740,24 @@ void init_tensor(py::module m) { return module_trace_transformation; }; + static py::function external_convert_hook; + + static auto get_external_convert = [] { + static std::shared_ptr + external_convert_transformation; + if (!external_convert_transformation) { + mgb_assert(external_convert_hook); + external_convert_transformation = + std::make_shared( + external_convert_hook); + MGB_MARK_USED_VAR(transformations + .register_at( + external_convert_transformation) + .release()); + } + return external_convert_transformation; + }; + m.def("set_cpp_use_symbolic_shape", &set_cpp_use_symbolic_shape); m.def("set_module_tracing", [=] { get_module_trace()->enable(); }); @@ -1712,6 +1765,12 @@ void init_tensor(py::module m) { m.def("unset_module_tracing", [=] { get_module_trace()->disable(); }); m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); }); + + m.def("set_external_convert", [=] { get_external_convert()->enable(); }); + + m.def("unset_external_convert", [=] { get_external_convert()->disable(); }); + + m.def("is_external_convert", [=] { return get_external_convert()->enabled(); }); m.def("set_python_backtrace_enabled", &set_python_backtrace_enabled); m.def("set_transformation_backtrace_enabled", &set_transformation_backtrace_enabled); @@ -1723,8 +1782,16 @@ void init_tensor(py::module m) { module_trace_hook.inc_ref(); }); + m.def("set_external_convert_hook", [](py::function function) { + external_convert_hook = function; + external_convert_hook.inc_ref(); + }); + auto atexit = py::module::import("atexit"); - atexit.attr("register")(py::cpp_function([]() { module_trace_hook = {}; })); + atexit.attr("register")(py::cpp_function([]() { + module_trace_hook = {}; + external_convert_hook = {}; + })); m.def("begin_record_values", [] { Value::begin_record_values(); }); m.def("end_record_values", [] { diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 945c63219..5af89a85f 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -31,6 +31,7 @@ namespace mgb::imperative::python { extern interpreter::Interpreter::Channel* interpreter_for_py; extern PyTypeObject* py_tensor_type; extern PyTypeObject* py_varnode_type; +extern PyTypeObject* py_external_type; extern pybind11::handle py_device_type; extern PyObject* cpp_use_symbolic_shape; extern PyObject* cpp_astensor1d; @@ -142,6 +143,8 @@ public: PyObject* _detail(); PyObject* _var(); PyObject* _graph(); + PyObject* _is_external_value(); + PyObject* _external_obj(); void _watch(); }; diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index 8f27b9d45..efd6b891f 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -22,6 +22,7 @@ public: Complex, Format, Grad, + ExternalConvert, Scalar, Symbol, Trace, diff --git a/imperative/python/test/unit/xla/functional/test_xla_convert.py b/imperative/python/test/unit/xla/functional/test_xla_convert.py new file mode 100644 index 000000000..50e4b9a4d --- /dev/null +++ b/imperative/python/test/unit/xla/functional/test_xla_convert.py @@ -0,0 +1,46 @@ +import platform + +import numpy as np +import pytest + +import megengine.functional as F +import megengine.jit as jit +import megengine.tensor as tensor +from megengine import autodiff, is_cuda_available +from megengine.autodiff.grad_manager import GradManager +from meg_xlalib.xla_extension import ArrayImpl + + +def test_external_flag_set(): + + @xla_trace(capture_as_const=True) + def test_fun(): + pass + + + + + + +def test_external_value(): + m = Conv2d(9,9, 3,groups=9) + gm = GradManager() + gm.attach(m.parameters()) + + @xla_trace(capture_as_const=True) + def conv_grad(inp, model): + with gm: + gm.attach(inp) + rst = model(inp) + gm.backward(rst.mean()) + ig = inp.grad + wg = model.weight.grad + inp.grad = None + model.weight.grad = None + return ig, wg + + inp = tensor(np.random.random((9,9, 32, 32)))*100 + + a, b = conv_grad(inp, m) + a1, b1 = conv_grad(inp, m) + np.testing.assert_allclose(a.numpy(), a1.numpy()) \ No newline at end of file -- GitLab