提交 d98be080 编写于 作者: M Megvii Engine Team

perf(mge): move Const into C++

GitOrigin-RevId: 31a443cffdc1b6d1470b5e0fd5ed49ab350cb4ff
上级 1709b394
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor
class Const:
def __init__(self, value=None, *, dtype=None, device=None):
self.value = np.asarray(value, dtype=dtype)
self.dtype = dtype
self.device = device
def __call__(self, *reference):
from ...tensor import Tensor
device = self.device
if len(reference) != 0:
reference = reference[0]
assert isinstance(
reference, (SymbolVar, Tensor)
), "Reference should be Tensor or VarNode"
if device is None:
device = reference.device
if isinstance(reference, SymbolVar):
cls = type(reference)
rst = cls(make_const(reference.graph, self.value, device, self.dtype))
return (rst,)
return (Tensor(self.value, self.dtype, self.device, True),)
......@@ -14,6 +14,7 @@ import numpy as np
from .._imperative_rt import make_const
from .._imperative_rt.core2 import (
Const,
SymbolVar,
Tensor,
_get_convert_inputs,
......@@ -28,7 +29,6 @@ from .._imperative_rt.ops import jit_supported
from .._wrap import as_device
from ..autodiff.grad import Function
from ..ops import builtin
from ..ops.special import Const
from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype
from .dtype import is_dtype_equal, is_quantize
......@@ -67,7 +67,7 @@ def convert_single_value(v, *, dtype=None, device=None):
if not is_quantize(v.dtype):
v = astype(v, dtype)
else:
(v,) = Const(v, dtype=dtype, device=device)()
v = Const(v, dtype, device, None)
return v
......@@ -155,7 +155,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
if ndim != 0 and ndim != 1:
raise ValueError("ndim != 1 or 0, get : %d" % ndim)
if not isinstance(x, (Tensor, SymbolVar)):
(x,) = Const(x, dtype=dtype, device=device)(*reference)
x = Const(x, dtype, device, reference)
return x
if not isinstance(x, collections.abc.Sequence):
......@@ -166,7 +166,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
if dtype is not None:
x = astype(x, dtype)
return x
(x,) = Const(x, dtype=dtype, device=device)(*reference)
x = Const(x, dtype, device, reference)
return x
......@@ -337,7 +337,7 @@ def interpret_subgraph(func, dtype, device):
return results
def apply_const(value, dtype=dtype, device=device):
return Const(value, dtype=dtype, device=device)()[0]
return Const(value, dtype, device, None)
outputs, outputs_has_grad = func(args, apply_expr, apply_const)
outputs = [
......
......@@ -10,10 +10,9 @@ import collections
import math
from typing import Iterable, Optional, Sequence, Tuple, Union
from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.core2 import Const, apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin
from ..core.ops.special import Const
from ..core.tensor.array_method import _matmul
from ..core.tensor.utils import _normalize_axis
from ..tensor import Tensor
......@@ -729,7 +728,7 @@ def topk(
op = builtin.TopK(mode=mode)
if not isinstance(k, Tensor):
(k,) = Const(k, dtype="int32", device=inp.device)()
k = Const(k, "int32", inp.device, None)
if len(inp.shape) == 1:
if kth_only:
......
......@@ -11,7 +11,7 @@ from functools import lru_cache
from typing import NamedTuple, Optional, Sequence, Tuple, Union
from ..core import _config
from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.core2 import Const, apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from ..core.ops import builtin
......@@ -26,7 +26,6 @@ from ..core.ops.builtin import (
Reshape,
TypeCvt,
)
from ..core.ops.special import Const
from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import (
......@@ -1317,7 +1316,7 @@ def batch_norm(
raise ValueError("Invalid param_dim {}".format(param_dim))
if x is None:
(x,) = Const(value, dtype=inp.dtype, device=inp.device)()
x = Const(value, inp.dtype, inp.device, None)
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), x, shape)
return result
......@@ -1541,7 +1540,7 @@ def sync_batch_norm(
def _make_full_if_none(x, value):
if x is None:
(x,) = Const(value, dtype=inp.dtype, device=_device)()
x = Const(value, inp.dtype, _device, None)
(result,) = apply(builtin.Broadcast(), x, reduce_shape)
return result
elif x.ndim == 1:
......
......@@ -13,6 +13,7 @@ import numpy as np
from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import (
Const,
SymbolVar,
apply,
broadcast_cpp,
......@@ -24,7 +25,6 @@ from ..core._imperative_rt.core2 import (
from ..core._wrap import as_device
from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const
from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn
from ..device import get_default_device
from ..tensor import Tensor
......@@ -177,7 +177,7 @@ def full(
shape = (shape,)
if device is None:
device = get_default_device()
(x,) = Const(value, dtype=dtype, device=device)()
x = Const(value, dtype, device, None)
if type(shape) in (list, tuple) and len(shape) == 0:
return x
return broadcast_to(x, shape)
......@@ -325,7 +325,7 @@ def full_like(
[2 2 2]]
"""
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
x = Const(value, inp.dtype, inp.device, inp)
if inp.ndim == 0:
return x
return broadcast_to(x, inp.shape)
......
from ..core.ops.special import Const
from ..core._imperative_rt.core2 import Const
from ..jit.tracing import is_tracing
small_tensor_cache = {}
......@@ -7,11 +7,11 @@ small_tensor_cache = {}
def _get_scalar_tensor_with_value(value, dtype=None, device=None):
global small_tensor_cache
if is_tracing():
(ret,) = Const(value, dtype=dtype, device=device)()
ret = Const(value, dtype, device, None)
else:
cache_key = (value, dtype, device)
if cache_key not in small_tensor_cache:
(ret,) = Const(value, dtype=dtype, device=device)()
ret = Const(value, dtype, device, None)
small_tensor_cache[cache_key] = ret
else:
ret = small_tensor_cache[cache_key]
......
......@@ -16,6 +16,7 @@ from importlib import import_module
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
from ..core._imperative_rt import OpDef
from ..core._imperative_rt.core2 import Const
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
apply,
......@@ -25,7 +26,6 @@ from ..core._imperative_rt.core2 import (
unset_module_tracing,
)
from ..core.ops.builtin import FakeQuant
from ..core.ops.special import Const
from ..module import Module
from ..tensor import Parameter, Tensor
from ..version import __version__
......@@ -764,7 +764,7 @@ class Constant(Expr):
def interpret(self, *inputs):
if isinstance(self.value, RawTensor):
return Const(self.value.numpy())()
return (Const(self.value.numpy(), None, None, None),)
return (self.value,)
def __repr__(self):
......
......@@ -639,6 +639,7 @@ WRAP_FUNC_PY35(squeeze_cpp);
WRAP_FUNC_PY35(transpose_cpp);
WRAP_FUNC_PY35(broadcast_cpp);
WRAP_FUNC_PY35(reshape_cpp);
WRAP_FUNC_PY35(Const);
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
......@@ -777,6 +778,7 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE(transpose_cpp, transpose_cpp),
MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp),
MGE_PY_INTERFACE(reshape_cpp, reshape_cpp),
MGE_PY_INTERFACE(Const, Const),
{nullptr, nullptr, 0, nullptr}};
for (auto&& def : method_defs) {
if (def.ml_meth != nullptr) {
......
......@@ -94,7 +94,7 @@ bool is_bool_dtype(PyObject* args) {
}
py::object _Const(
py::handle value, py::handle dtype, py::handle device, py::handle ref) {
py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) {
py::object val = py::reinterpret_borrow<py::object>(value);
if (PyArray_Check(value.ptr())) {
py::tuple strides =
......@@ -107,21 +107,56 @@ py::object _Const(
}
if (need_squeeze) {
val = py::reinterpret_borrow<py::array>(value);
py::object orig_shp = val.attr("shape");
val = val.attr("squeeze")();
val = val.attr("reshape")(val.attr("shape"));
val = val.attr("reshape")(orig_shp);
}
}
py::object ref;
if (py::isinstance<py::tuple>(ref_hdl)) {
py::tuple tup = py::reinterpret_borrow<py::tuple>(ref_hdl);
if (tup.size()) {
ref = tup[0];
} else {
ref = py::none();
}
} else {
ref = py::reinterpret_borrow<py::object>(ref_hdl);
}
if (py::isinstance<PySymbolVar>(ref)) {
auto ref_var = ref.cast<PySymbolVar*>();
auto* graph = ref_var->m_node->owner_graph();
auto cn = device.cast<CompNode>();
CompNode cn;
if (device.ptr() == Py_None) {
cn = ref_var->m_node->comp_node();
} else {
cn = device.cast<CompNode>();
}
OperatorNodeConfig config(cn);
auto hv = npy::np2tensor(
val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
auto typeobj = ref.get_type();
return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
}
py::tuple tup = py::make_tuple(val, dtype, device, true, false, py::none());
py::object device_obj;
if (device.ptr() == Py_None) {
device_obj = py::cast(CompNode::load(get_default_device()));
} else if (py::isinstance<py::str>(device)) {
py::object dmap =
getattr(py::reinterpret_borrow<py::object>((PyObject*)py_tensor_type),
"dmap_callback");
if (dmap.ptr() != Py_None) {
device_obj = dmap(device);
py::print(device_obj);
} else {
device_obj = py::cast(CompNode::load(device.cast<std::string>()));
}
} else if (py::isinstance<CompNode>(device)) {
device_obj = py::reinterpret_borrow<py::object>(device);
} else {
device_obj = getattr(device, "_cn");
}
py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none());
return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr);
}
......@@ -1107,4 +1142,14 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _Const(py::handle(args[0]), py::handle(args[1]), py::handle(args[2]),
py::handle(args[3]))
.release()
.ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
} // namespace mgb::imperative::python
......@@ -20,4 +20,6 @@ PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs);
} // namespace mgb::imperative::python
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册