diff --git a/imperative/python/megengine/core/ops/special.py b/imperative/python/megengine/core/ops/special.py deleted file mode 100644 index a4a8d693426f3fdb34eef259bf1b44c692145bc7..0000000000000000000000000000000000000000 --- a/imperative/python/megengine/core/ops/special.py +++ /dev/null @@ -1,40 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import numpy as np - -from .._imperative_rt import make_const -from .._imperative_rt.core2 import SymbolVar, Tensor - - -class Const: - def __init__(self, value=None, *, dtype=None, device=None): - self.value = np.asarray(value, dtype=dtype) - self.dtype = dtype - self.device = device - - def __call__(self, *reference): - from ...tensor import Tensor - - device = self.device - - if len(reference) != 0: - reference = reference[0] - assert isinstance( - reference, (SymbolVar, Tensor) - ), "Reference should be Tensor or VarNode" - - if device is None: - device = reference.device - - if isinstance(reference, SymbolVar): - cls = type(reference) - rst = cls(make_const(reference.graph, self.value, device, self.dtype)) - return (rst,) - - return (Tensor(self.value, self.dtype, self.device, True),) diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 43c4bad55899b8bcb04632396326e0d7e051523f..9b14d227a8dd43a49445a9917e720d92d1d48f6c 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -14,6 +14,7 @@ import numpy as np from .._imperative_rt import make_const from .._imperative_rt.core2 import ( + Const, SymbolVar, Tensor, _get_convert_inputs, @@ -28,7 +29,6 @@ from .._imperative_rt.ops import jit_supported from .._wrap import as_device from ..autodiff.grad import Function from ..ops import builtin -from ..ops.special import Const from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype from .dtype import is_dtype_equal, is_quantize @@ -67,7 +67,7 @@ def convert_single_value(v, *, dtype=None, device=None): if not is_quantize(v.dtype): v = astype(v, dtype) else: - (v,) = Const(v, dtype=dtype, device=device)() + v = Const(v, dtype, device, None) return v @@ -155,7 +155,7 @@ def astensor1d(x, *reference, dtype=None, device=None): if ndim != 0 and ndim != 1: raise ValueError("ndim != 1 or 0, get : %d" % ndim) if not isinstance(x, (Tensor, SymbolVar)): - (x,) = Const(x, dtype=dtype, device=device)(*reference) + x = Const(x, dtype, device, reference) return x if not isinstance(x, collections.abc.Sequence): @@ -166,7 +166,7 @@ def astensor1d(x, *reference, dtype=None, device=None): if dtype is not None: x = astype(x, dtype) return x - (x,) = Const(x, dtype=dtype, device=device)(*reference) + x = Const(x, dtype, device, reference) return x @@ -337,7 +337,7 @@ def interpret_subgraph(func, dtype, device): return results def apply_const(value, dtype=dtype, device=device): - return Const(value, dtype=dtype, device=device)()[0] + return Const(value, dtype, device, None) outputs, outputs_has_grad = func(args, apply_expr, apply_const) outputs = [ diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index f1cc41c706b1f2ee5359ceb6e5de96e7f95fd716..8ea91c5eae5da8ae1da562c4c9d87013665aebb0 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -10,10 +10,9 @@ import collections import math from typing import Iterable, Optional, Sequence, Tuple, Union -from ..core._imperative_rt.core2 import apply, dtype_promotion +from ..core._imperative_rt.core2 import Const, apply, dtype_promotion from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core.ops import builtin -from ..core.ops.special import Const from ..core.tensor.array_method import _matmul from ..core.tensor.utils import _normalize_axis from ..tensor import Tensor @@ -729,7 +728,7 @@ def topk( op = builtin.TopK(mode=mode) if not isinstance(k, Tensor): - (k,) = Const(k, dtype="int32", device=inp.device)() + k = Const(k, "int32", inp.device, None) if len(inp.shape) == 1: if kth_only: diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index f0d151054e70b0026c24efea3d3e1d298bb5d2fb..3110aa08a5c4634557d2801b0568bf6b5674f01d 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -11,7 +11,7 @@ from functools import lru_cache from typing import NamedTuple, Optional, Sequence, Tuple, Union from ..core import _config -from ..core._imperative_rt.core2 import apply, dtype_promotion +from ..core._imperative_rt.core2 import Const, apply, dtype_promotion from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed from ..core.ops import builtin @@ -26,7 +26,6 @@ from ..core.ops.builtin import ( Reshape, TypeCvt, ) -from ..core.ops.special import Const from ..core.tensor import amp, megbrain_graph from ..core.tensor.array_method import _elwise_apply from ..core.tensor.utils import ( @@ -1317,7 +1316,7 @@ def batch_norm( raise ValueError("Invalid param_dim {}".format(param_dim)) if x is None: - (x,) = Const(value, dtype=inp.dtype, device=inp.device)() + x = Const(value, inp.dtype, inp.device, None) shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) (result,) = apply(builtin.Broadcast(), x, shape) return result @@ -1541,7 +1540,7 @@ def sync_batch_norm( def _make_full_if_none(x, value): if x is None: - (x,) = Const(value, dtype=inp.dtype, device=_device)() + x = Const(value, inp.dtype, _device, None) (result,) = apply(builtin.Broadcast(), x, reduce_shape) return result elif x.ndim == 1: diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 9f4134d7f19f83bb03515ab651a22303a13e123a..011e55a54b2c7f0c9c4f3fe78ffd54a41da2ea97 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -13,6 +13,7 @@ import numpy as np from ..core._imperative_rt import CompNode from ..core._imperative_rt.core2 import ( + Const, SymbolVar, apply, broadcast_cpp, @@ -24,7 +25,6 @@ from ..core._imperative_rt.core2 import ( from ..core._wrap import as_device from ..core.ops import builtin from ..core.ops.builtin import Copy, Identity -from ..core.ops.special import Const from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn from ..device import get_default_device from ..tensor import Tensor @@ -177,7 +177,7 @@ def full( shape = (shape,) if device is None: device = get_default_device() - (x,) = Const(value, dtype=dtype, device=device)() + x = Const(value, dtype, device, None) if type(shape) in (list, tuple) and len(shape) == 0: return x return broadcast_to(x, shape) @@ -325,7 +325,7 @@ def full_like( [2 2 2]] """ - (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) + x = Const(value, inp.dtype, inp.device, inp) if inp.ndim == 0: return x return broadcast_to(x, inp.shape) diff --git a/imperative/python/megengine/functional/tensor_cache.py b/imperative/python/megengine/functional/tensor_cache.py index 582be4ad458915668e90ef778afe33681b50aed5..50846839573254bf31169f879411255211705d17 100644 --- a/imperative/python/megengine/functional/tensor_cache.py +++ b/imperative/python/megengine/functional/tensor_cache.py @@ -1,4 +1,4 @@ -from ..core.ops.special import Const +from ..core._imperative_rt.core2 import Const from ..jit.tracing import is_tracing small_tensor_cache = {} @@ -7,11 +7,11 @@ small_tensor_cache = {} def _get_scalar_tensor_with_value(value, dtype=None, device=None): global small_tensor_cache if is_tracing(): - (ret,) = Const(value, dtype=dtype, device=device)() + ret = Const(value, dtype, device, None) else: cache_key = (value, dtype, device) if cache_key not in small_tensor_cache: - (ret,) = Const(value, dtype=dtype, device=device)() + ret = Const(value, dtype, device, None) small_tensor_cache[cache_key] = ret else: ret = small_tensor_cache[cache_key] diff --git a/imperative/python/megengine/traced_module/expr.py b/imperative/python/megengine/traced_module/expr.py index 19624482d9f585a3d8fbc1276612cb46b1cc27cb..84ff4317ab67ddbf05a0db4c1730815571eddc8f 100644 --- a/imperative/python/megengine/traced_module/expr.py +++ b/imperative/python/megengine/traced_module/expr.py @@ -16,6 +16,7 @@ from importlib import import_module from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union from ..core._imperative_rt import OpDef +from ..core._imperative_rt.core2 import Const from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import ( apply, @@ -25,7 +26,6 @@ from ..core._imperative_rt.core2 import ( unset_module_tracing, ) from ..core.ops.builtin import FakeQuant -from ..core.ops.special import Const from ..module import Module from ..tensor import Parameter, Tensor from ..version import __version__ @@ -764,7 +764,7 @@ class Constant(Expr): def interpret(self, *inputs): if isinstance(self.value, RawTensor): - return Const(self.value.numpy())() + return (Const(self.value.numpy(), None, None, None),) return (self.value,) def __repr__(self): diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 86f004aeac2dfa2f193b5ea20a06a5107311d757..e231dbbeb516c6e1a974db6cbec2799ceee5715f 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -639,6 +639,7 @@ WRAP_FUNC_PY35(squeeze_cpp); WRAP_FUNC_PY35(transpose_cpp); WRAP_FUNC_PY35(broadcast_cpp); WRAP_FUNC_PY35(reshape_cpp); +WRAP_FUNC_PY35(Const); #undef WRAP_FUNC_PY35 #define MGE_PY_INTERFACE(NAME, FUNC) \ { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } @@ -777,6 +778,7 @@ void init_tensor(py::module m) { MGE_PY_INTERFACE(transpose_cpp, transpose_cpp), MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), + MGE_PY_INTERFACE(Const, Const), {nullptr, nullptr, 0, nullptr}}; for (auto&& def : method_defs) { if (def.ml_meth != nullptr) { diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 2b0e3687b1b3af3c3ba955e7dfb06eb1444ebdd3..17f2f391a7c4911b586f8cfd4a273976b3420acb 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -94,7 +94,7 @@ bool is_bool_dtype(PyObject* args) { } py::object _Const( - py::handle value, py::handle dtype, py::handle device, py::handle ref) { + py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) { py::object val = py::reinterpret_borrow(value); if (PyArray_Check(value.ptr())) { py::tuple strides = @@ -107,21 +107,56 @@ py::object _Const( } if (need_squeeze) { val = py::reinterpret_borrow(value); + py::object orig_shp = val.attr("shape"); val = val.attr("squeeze")(); - val = val.attr("reshape")(val.attr("shape")); + val = val.attr("reshape")(orig_shp); } } + py::object ref; + if (py::isinstance(ref_hdl)) { + py::tuple tup = py::reinterpret_borrow(ref_hdl); + if (tup.size()) { + ref = tup[0]; + } else { + ref = py::none(); + } + } else { + ref = py::reinterpret_borrow(ref_hdl); + } if (py::isinstance(ref)) { auto ref_var = ref.cast(); auto* graph = ref_var->m_node->owner_graph(); - auto cn = device.cast(); + CompNode cn; + if (device.ptr() == Py_None) { + cn = ref_var->m_node->comp_node(); + } else { + cn = device.cast(); + } OperatorNodeConfig config(cn); auto hv = npy::np2tensor( val.ptr(), npy::Meth::borrow(cn), dtype.cast()); auto typeobj = ref.get_type(); return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); } - py::tuple tup = py::make_tuple(val, dtype, device, true, false, py::none()); + py::object device_obj; + if (device.ptr() == Py_None) { + device_obj = py::cast(CompNode::load(get_default_device())); + } else if (py::isinstance(device)) { + py::object dmap = + getattr(py::reinterpret_borrow((PyObject*)py_tensor_type), + "dmap_callback"); + if (dmap.ptr() != Py_None) { + device_obj = dmap(device); + py::print(device_obj); + } else { + device_obj = py::cast(CompNode::load(device.cast())); + } + } else if (py::isinstance(device)) { + device_obj = py::reinterpret_borrow(device); + } else { + device_obj = getattr(device, "_cn"); + } + py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none()); return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); } @@ -1107,4 +1142,14 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) { PYEXT17_TRANSLATE_EXC_RET(nullptr) } +PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _Const(py::handle(args[0]), py::handle(args[1]), py::handle(args[2]), + py::handle(args[3])) + .release() + .ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + } // namespace mgb::imperative::python diff --git a/imperative/python/src/tensor_utils.h b/imperative/python/src/tensor_utils.h index 4c721ff18d205703745edf77dbc457b26033274c..906080043d9aeef9294093d6482068086a4b02c4 100644 --- a/imperative/python/src/tensor_utils.h +++ b/imperative/python/src/tensor_utils.h @@ -20,4 +20,6 @@ PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs); PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); +PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs); + } // namespace mgb::imperative::python \ No newline at end of file