提交 d98be080 编写于 作者: M Megvii Engine Team

perf(mge): move Const into C++

GitOrigin-RevId: 31a443cffdc1b6d1470b5e0fd5ed49ab350cb4ff
上级 1709b394
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor
class Const:
def __init__(self, value=None, *, dtype=None, device=None):
self.value = np.asarray(value, dtype=dtype)
self.dtype = dtype
self.device = device
def __call__(self, *reference):
from ...tensor import Tensor
device = self.device
if len(reference) != 0:
reference = reference[0]
assert isinstance(
reference, (SymbolVar, Tensor)
), "Reference should be Tensor or VarNode"
if device is None:
device = reference.device
if isinstance(reference, SymbolVar):
cls = type(reference)
rst = cls(make_const(reference.graph, self.value, device, self.dtype))
return (rst,)
return (Tensor(self.value, self.dtype, self.device, True),)
...@@ -14,6 +14,7 @@ import numpy as np ...@@ -14,6 +14,7 @@ import numpy as np
from .._imperative_rt import make_const from .._imperative_rt import make_const
from .._imperative_rt.core2 import ( from .._imperative_rt.core2 import (
Const,
SymbolVar, SymbolVar,
Tensor, Tensor,
_get_convert_inputs, _get_convert_inputs,
...@@ -28,7 +29,6 @@ from .._imperative_rt.ops import jit_supported ...@@ -28,7 +29,6 @@ from .._imperative_rt.ops import jit_supported
from .._wrap import as_device from .._wrap import as_device
from ..autodiff.grad import Function from ..autodiff.grad import Function
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const
from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype
from .dtype import is_dtype_equal, is_quantize from .dtype import is_dtype_equal, is_quantize
...@@ -67,7 +67,7 @@ def convert_single_value(v, *, dtype=None, device=None): ...@@ -67,7 +67,7 @@ def convert_single_value(v, *, dtype=None, device=None):
if not is_quantize(v.dtype): if not is_quantize(v.dtype):
v = astype(v, dtype) v = astype(v, dtype)
else: else:
(v,) = Const(v, dtype=dtype, device=device)() v = Const(v, dtype, device, None)
return v return v
...@@ -155,7 +155,7 @@ def astensor1d(x, *reference, dtype=None, device=None): ...@@ -155,7 +155,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
if ndim != 0 and ndim != 1: if ndim != 0 and ndim != 1:
raise ValueError("ndim != 1 or 0, get : %d" % ndim) raise ValueError("ndim != 1 or 0, get : %d" % ndim)
if not isinstance(x, (Tensor, SymbolVar)): if not isinstance(x, (Tensor, SymbolVar)):
(x,) = Const(x, dtype=dtype, device=device)(*reference) x = Const(x, dtype, device, reference)
return x return x
if not isinstance(x, collections.abc.Sequence): if not isinstance(x, collections.abc.Sequence):
...@@ -166,7 +166,7 @@ def astensor1d(x, *reference, dtype=None, device=None): ...@@ -166,7 +166,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
if dtype is not None: if dtype is not None:
x = astype(x, dtype) x = astype(x, dtype)
return x return x
(x,) = Const(x, dtype=dtype, device=device)(*reference) x = Const(x, dtype, device, reference)
return x return x
...@@ -337,7 +337,7 @@ def interpret_subgraph(func, dtype, device): ...@@ -337,7 +337,7 @@ def interpret_subgraph(func, dtype, device):
return results return results
def apply_const(value, dtype=dtype, device=device): def apply_const(value, dtype=dtype, device=device):
return Const(value, dtype=dtype, device=device)()[0] return Const(value, dtype, device, None)
outputs, outputs_has_grad = func(args, apply_expr, apply_const) outputs, outputs_has_grad = func(args, apply_expr, apply_const)
outputs = [ outputs = [
......
...@@ -10,10 +10,9 @@ import collections ...@@ -10,10 +10,9 @@ import collections
import math import math
from typing import Iterable, Optional, Sequence, Tuple, Union from typing import Iterable, Optional, Sequence, Tuple, Union
from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.core2 import Const, apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.special import Const
from ..core.tensor.array_method import _matmul from ..core.tensor.array_method import _matmul
from ..core.tensor.utils import _normalize_axis from ..core.tensor.utils import _normalize_axis
from ..tensor import Tensor from ..tensor import Tensor
...@@ -729,7 +728,7 @@ def topk( ...@@ -729,7 +728,7 @@ def topk(
op = builtin.TopK(mode=mode) op = builtin.TopK(mode=mode)
if not isinstance(k, Tensor): if not isinstance(k, Tensor):
(k,) = Const(k, dtype="int32", device=inp.device)() k = Const(k, "int32", inp.device, None)
if len(inp.shape) == 1: if len(inp.shape) == 1:
if kth_only: if kth_only:
......
...@@ -11,7 +11,7 @@ from functools import lru_cache ...@@ -11,7 +11,7 @@ from functools import lru_cache
from typing import NamedTuple, Optional, Sequence, Tuple, Union from typing import NamedTuple, Optional, Sequence, Tuple, Union
from ..core import _config from ..core import _config
from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.core2 import Const, apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from ..core.ops import builtin from ..core.ops import builtin
...@@ -26,7 +26,6 @@ from ..core.ops.builtin import ( ...@@ -26,7 +26,6 @@ from ..core.ops.builtin import (
Reshape, Reshape,
TypeCvt, TypeCvt,
) )
from ..core.ops.special import Const
from ..core.tensor import amp, megbrain_graph from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import ( from ..core.tensor.utils import (
...@@ -1317,7 +1316,7 @@ def batch_norm( ...@@ -1317,7 +1316,7 @@ def batch_norm(
raise ValueError("Invalid param_dim {}".format(param_dim)) raise ValueError("Invalid param_dim {}".format(param_dim))
if x is None: if x is None:
(x,) = Const(value, dtype=inp.dtype, device=inp.device)() x = Const(value, inp.dtype, inp.device, None)
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), x, shape) (result,) = apply(builtin.Broadcast(), x, shape)
return result return result
...@@ -1541,7 +1540,7 @@ def sync_batch_norm( ...@@ -1541,7 +1540,7 @@ def sync_batch_norm(
def _make_full_if_none(x, value): def _make_full_if_none(x, value):
if x is None: if x is None:
(x,) = Const(value, dtype=inp.dtype, device=_device)() x = Const(value, inp.dtype, _device, None)
(result,) = apply(builtin.Broadcast(), x, reduce_shape) (result,) = apply(builtin.Broadcast(), x, reduce_shape)
return result return result
elif x.ndim == 1: elif x.ndim == 1:
......
...@@ -13,6 +13,7 @@ import numpy as np ...@@ -13,6 +13,7 @@ import numpy as np
from ..core._imperative_rt import CompNode from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import ( from ..core._imperative_rt.core2 import (
Const,
SymbolVar, SymbolVar,
apply, apply,
broadcast_cpp, broadcast_cpp,
...@@ -24,7 +25,6 @@ from ..core._imperative_rt.core2 import ( ...@@ -24,7 +25,6 @@ from ..core._imperative_rt.core2 import (
from ..core._wrap import as_device from ..core._wrap import as_device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const
from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn
from ..device import get_default_device from ..device import get_default_device
from ..tensor import Tensor from ..tensor import Tensor
...@@ -177,7 +177,7 @@ def full( ...@@ -177,7 +177,7 @@ def full(
shape = (shape,) shape = (shape,)
if device is None: if device is None:
device = get_default_device() device = get_default_device()
(x,) = Const(value, dtype=dtype, device=device)() x = Const(value, dtype, device, None)
if type(shape) in (list, tuple) and len(shape) == 0: if type(shape) in (list, tuple) and len(shape) == 0:
return x return x
return broadcast_to(x, shape) return broadcast_to(x, shape)
...@@ -325,7 +325,7 @@ def full_like( ...@@ -325,7 +325,7 @@ def full_like(
[2 2 2]] [2 2 2]]
""" """
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) x = Const(value, inp.dtype, inp.device, inp)
if inp.ndim == 0: if inp.ndim == 0:
return x return x
return broadcast_to(x, inp.shape) return broadcast_to(x, inp.shape)
......
from ..core.ops.special import Const from ..core._imperative_rt.core2 import Const
from ..jit.tracing import is_tracing from ..jit.tracing import is_tracing
small_tensor_cache = {} small_tensor_cache = {}
...@@ -7,11 +7,11 @@ small_tensor_cache = {} ...@@ -7,11 +7,11 @@ small_tensor_cache = {}
def _get_scalar_tensor_with_value(value, dtype=None, device=None): def _get_scalar_tensor_with_value(value, dtype=None, device=None):
global small_tensor_cache global small_tensor_cache
if is_tracing(): if is_tracing():
(ret,) = Const(value, dtype=dtype, device=device)() ret = Const(value, dtype, device, None)
else: else:
cache_key = (value, dtype, device) cache_key = (value, dtype, device)
if cache_key not in small_tensor_cache: if cache_key not in small_tensor_cache:
(ret,) = Const(value, dtype=dtype, device=device)() ret = Const(value, dtype, device, None)
small_tensor_cache[cache_key] = ret small_tensor_cache[cache_key] = ret
else: else:
ret = small_tensor_cache[cache_key] ret = small_tensor_cache[cache_key]
......
...@@ -16,6 +16,7 @@ from importlib import import_module ...@@ -16,6 +16,7 @@ from importlib import import_module
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
from ..core._imperative_rt import OpDef from ..core._imperative_rt import OpDef
from ..core._imperative_rt.core2 import Const
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import ( from ..core._imperative_rt.core2 import (
apply, apply,
...@@ -25,7 +26,6 @@ from ..core._imperative_rt.core2 import ( ...@@ -25,7 +26,6 @@ from ..core._imperative_rt.core2 import (
unset_module_tracing, unset_module_tracing,
) )
from ..core.ops.builtin import FakeQuant from ..core.ops.builtin import FakeQuant
from ..core.ops.special import Const
from ..module import Module from ..module import Module
from ..tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
from ..version import __version__ from ..version import __version__
...@@ -764,7 +764,7 @@ class Constant(Expr): ...@@ -764,7 +764,7 @@ class Constant(Expr):
def interpret(self, *inputs): def interpret(self, *inputs):
if isinstance(self.value, RawTensor): if isinstance(self.value, RawTensor):
return Const(self.value.numpy())() return (Const(self.value.numpy(), None, None, None),)
return (self.value,) return (self.value,)
def __repr__(self): def __repr__(self):
......
...@@ -639,6 +639,7 @@ WRAP_FUNC_PY35(squeeze_cpp); ...@@ -639,6 +639,7 @@ WRAP_FUNC_PY35(squeeze_cpp);
WRAP_FUNC_PY35(transpose_cpp); WRAP_FUNC_PY35(transpose_cpp);
WRAP_FUNC_PY35(broadcast_cpp); WRAP_FUNC_PY35(broadcast_cpp);
WRAP_FUNC_PY35(reshape_cpp); WRAP_FUNC_PY35(reshape_cpp);
WRAP_FUNC_PY35(Const);
#undef WRAP_FUNC_PY35 #undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \ #define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
...@@ -777,6 +778,7 @@ void init_tensor(py::module m) { ...@@ -777,6 +778,7 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE(transpose_cpp, transpose_cpp), MGE_PY_INTERFACE(transpose_cpp, transpose_cpp),
MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp),
MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), MGE_PY_INTERFACE(reshape_cpp, reshape_cpp),
MGE_PY_INTERFACE(Const, Const),
{nullptr, nullptr, 0, nullptr}}; {nullptr, nullptr, 0, nullptr}};
for (auto&& def : method_defs) { for (auto&& def : method_defs) {
if (def.ml_meth != nullptr) { if (def.ml_meth != nullptr) {
......
...@@ -94,7 +94,7 @@ bool is_bool_dtype(PyObject* args) { ...@@ -94,7 +94,7 @@ bool is_bool_dtype(PyObject* args) {
} }
py::object _Const( py::object _Const(
py::handle value, py::handle dtype, py::handle device, py::handle ref) { py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) {
py::object val = py::reinterpret_borrow<py::object>(value); py::object val = py::reinterpret_borrow<py::object>(value);
if (PyArray_Check(value.ptr())) { if (PyArray_Check(value.ptr())) {
py::tuple strides = py::tuple strides =
...@@ -107,21 +107,56 @@ py::object _Const( ...@@ -107,21 +107,56 @@ py::object _Const(
} }
if (need_squeeze) { if (need_squeeze) {
val = py::reinterpret_borrow<py::array>(value); val = py::reinterpret_borrow<py::array>(value);
py::object orig_shp = val.attr("shape");
val = val.attr("squeeze")(); val = val.attr("squeeze")();
val = val.attr("reshape")(val.attr("shape")); val = val.attr("reshape")(orig_shp);
} }
} }
py::object ref;
if (py::isinstance<py::tuple>(ref_hdl)) {
py::tuple tup = py::reinterpret_borrow<py::tuple>(ref_hdl);
if (tup.size()) {
ref = tup[0];
} else {
ref = py::none();
}
} else {
ref = py::reinterpret_borrow<py::object>(ref_hdl);
}
if (py::isinstance<PySymbolVar>(ref)) { if (py::isinstance<PySymbolVar>(ref)) {
auto ref_var = ref.cast<PySymbolVar*>(); auto ref_var = ref.cast<PySymbolVar*>();
auto* graph = ref_var->m_node->owner_graph(); auto* graph = ref_var->m_node->owner_graph();
auto cn = device.cast<CompNode>(); CompNode cn;
if (device.ptr() == Py_None) {
cn = ref_var->m_node->comp_node();
} else {
cn = device.cast<CompNode>();
}
OperatorNodeConfig config(cn); OperatorNodeConfig config(cn);
auto hv = npy::np2tensor( auto hv = npy::np2tensor(
val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
auto typeobj = ref.get_type(); auto typeobj = ref.get_type();
return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
} }
py::tuple tup = py::make_tuple(val, dtype, device, true, false, py::none()); py::object device_obj;
if (device.ptr() == Py_None) {
device_obj = py::cast(CompNode::load(get_default_device()));
} else if (py::isinstance<py::str>(device)) {
py::object dmap =
getattr(py::reinterpret_borrow<py::object>((PyObject*)py_tensor_type),
"dmap_callback");
if (dmap.ptr() != Py_None) {
device_obj = dmap(device);
py::print(device_obj);
} else {
device_obj = py::cast(CompNode::load(device.cast<std::string>()));
}
} else if (py::isinstance<CompNode>(device)) {
device_obj = py::reinterpret_borrow<py::object>(device);
} else {
device_obj = getattr(device, "_cn");
}
py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none());
return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr);
} }
...@@ -1107,4 +1142,14 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) { ...@@ -1107,4 +1142,14 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET(nullptr) PYEXT17_TRANSLATE_EXC_RET(nullptr)
} }
PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _Const(py::handle(args[0]), py::handle(args[1]), py::handle(args[2]),
py::handle(args[3]))
.release()
.ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
} // namespace mgb::imperative::python } // namespace mgb::imperative::python
...@@ -20,4 +20,6 @@ PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs); ...@@ -20,4 +20,6 @@ PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs);
} // namespace mgb::imperative::python } // namespace mgb::imperative::python
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册