提交 a926878c 编写于 作者: M Megvii Engine Team

feat(imperative): remove symbolvar of imperative

GitOrigin-RevId: 16da6d1491526b707ea6851fb68e330c02cc788a
上级 14813d13
...@@ -7,9 +7,7 @@ from typing import Union ...@@ -7,9 +7,7 @@ from typing import Union
import numpy as np import numpy as np
from .. import _config from .. import _config
from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import ( from .._imperative_rt.core2 import (
SymbolVar,
Tensor, Tensor,
apply, apply,
astype_cpp, astype_cpp,
...@@ -17,9 +15,11 @@ from .._imperative_rt.core2 import ( ...@@ -17,9 +15,11 @@ from .._imperative_rt.core2 import (
broadcast_cpp, broadcast_cpp,
getitem_cpp, getitem_cpp,
matmul_cpp, matmul_cpp,
reshape_cpp,
setitem_cpp,
squeeze_cpp,
transpose_cpp,
) )
from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar
from .._imperative_rt.core2 import reshape_cpp, setitem_cpp, squeeze_cpp, transpose_cpp
from ..ops import builtin from ..ops import builtin
from . import amp from . import amp
from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph
...@@ -189,9 +189,7 @@ def _todo(*_): ...@@ -189,9 +189,7 @@ def _todo(*_):
def _expand_args(args): def _expand_args(args):
if len(args) == 1: if len(args) == 1:
if isinstance( if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),):
args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray),
):
args = args[0] args = args[0]
return args return args
......
...@@ -8,7 +8,6 @@ import numpy as np ...@@ -8,7 +8,6 @@ import numpy as np
from .._imperative_rt import make_const from .._imperative_rt import make_const
from .._imperative_rt.core2 import ( from .._imperative_rt.core2 import (
Const, Const,
SymbolVar,
Tensor, Tensor,
_get_convert_inputs, _get_convert_inputs,
_set_convert_inputs, _set_convert_inputs,
...@@ -77,7 +76,7 @@ def result_type(*args): ...@@ -77,7 +76,7 @@ def result_type(*args):
def isscalar(x): def isscalar(x):
if isinstance(x, (Tensor, SymbolVar)): if isinstance(x, Tensor):
return x._isscalar() return x._isscalar()
return np.isscalar(x) return np.isscalar(x)
...@@ -283,7 +282,7 @@ def interpret_subgraph(func, dtype, device): ...@@ -283,7 +282,7 @@ def interpret_subgraph(func, dtype, device):
return results return results
def apply_const(value, dtype=dtype, device=device): def apply_const(value, dtype=dtype, device=device):
return Const(value, dtype, device, None) return Const(value, dtype, device)
outputs, outputs_has_grad = func(args, apply_expr, apply_const) outputs, outputs_has_grad = func(args, apply_expr, apply_const)
outputs = [ outputs = [
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
import numpy as np import numpy as np
from ..core._imperative_rt.core2 import SymbolVar, apply from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import Elemwise from ..core.ops.builtin import Elemwise
from ..core.tensor.array_method import _elwise from ..core.tensor.array_method import _elwise
......
...@@ -538,7 +538,7 @@ def topk( ...@@ -538,7 +538,7 @@ def topk(
op = builtin.TopK(mode=mode) op = builtin.TopK(mode=mode)
if not isinstance(k, Tensor): if not isinstance(k, Tensor):
k = Const(k, "int32", inp.device, None) k = Const(k, "int32", inp.device)
if len(inp.shape) == 1: if len(inp.shape) == 1:
if kth_only: if kth_only:
......
...@@ -1222,7 +1222,7 @@ def batch_norm( ...@@ -1222,7 +1222,7 @@ def batch_norm(
raise ValueError("Invalid param_dim {}".format(param_dim)) raise ValueError("Invalid param_dim {}".format(param_dim))
if x is None: if x is None:
x = Const(value, inp.dtype, inp.device, None) x = Const(value, inp.dtype, inp.device)
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), x, shape) (result,) = apply(builtin.Broadcast(), x, shape)
return result return result
...@@ -1446,7 +1446,7 @@ def sync_batch_norm( ...@@ -1446,7 +1446,7 @@ def sync_batch_norm(
def _make_full_if_none(x, value): def _make_full_if_none(x, value):
if x is None: if x is None:
x = Const(value, inp.dtype, _device, None) x = Const(value, inp.dtype, _device)
(result,) = apply(builtin.Broadcast(), x, reduce_shape) (result,) = apply(builtin.Broadcast(), x, reduce_shape)
return result return result
elif x.ndim == 1: elif x.ndim == 1:
......
...@@ -7,7 +7,6 @@ import numpy as np ...@@ -7,7 +7,6 @@ import numpy as np
from ..core._imperative_rt import CompNode from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import ( from ..core._imperative_rt.core2 import (
Const, Const,
SymbolVar,
apply, apply,
broadcast_cpp, broadcast_cpp,
dtype_promotion, dtype_promotion,
...@@ -151,7 +150,7 @@ def full( ...@@ -151,7 +150,7 @@ def full(
shape = (shape,) shape = (shape,)
if device is None: if device is None:
device = get_default_device() device = get_default_device()
x = Const(value, dtype, device, None) x = Const(value, dtype, device)
if type(shape) in (list, tuple) and len(shape) == 0: if type(shape) in (list, tuple) and len(shape) == 0:
return x return x
return broadcast_to(x, shape) return broadcast_to(x, shape)
...@@ -216,7 +215,7 @@ def zeros( ...@@ -216,7 +215,7 @@ def zeros(
return full(shape, 0.0, dtype=dtype, device=device) return full(shape, 0.0, dtype=dtype, device=device)
def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: def zeros_like(inp: Tensor) -> Tensor:
r"""Returns a tensor filled with zeros with the same shape and data type as input tensor. r"""Returns a tensor filled with zeros with the same shape and data type as input tensor.
Args: Args:
...@@ -235,7 +234,7 @@ def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: ...@@ -235,7 +234,7 @@ def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
return full_like(inp, 0.0) return full_like(inp, 0.0)
def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: def ones_like(inp: Tensor) -> Tensor:
r"""Returns a tensor filled with ones with the same shape and data type as input tensor. r"""Returns a tensor filled with ones with the same shape and data type as input tensor.
Args: Args:
...@@ -253,9 +252,7 @@ def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: ...@@ -253,9 +252,7 @@ def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
return full_like(inp, 1.0) return full_like(inp, 1.0)
def full_like( def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
inp: Union[Tensor, SymbolVar], value: Union[int, float]
) -> Union[Tensor, SymbolVar]:
r"""Returns a tensor filled with given value with the same shape as input tensor. r"""Returns a tensor filled with given value with the same shape as input tensor.
Args: Args:
...@@ -272,7 +269,7 @@ def full_like( ...@@ -272,7 +269,7 @@ def full_like(
Tensor([[2 2 2] Tensor([[2 2 2]
[2 2 2]], dtype=int32, device=xpux:0) [2 2 2]], dtype=int32, device=xpux:0)
""" """
x = Const(value, inp.dtype, inp.device, inp) x = Const(value, inp.dtype, inp.device)
if inp.ndim == 0: if inp.ndim == 0:
return x return x
return broadcast_to(x, inp.shape) return broadcast_to(x, inp.shape)
...@@ -668,9 +665,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor: ...@@ -668,9 +665,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor:
>>> print(v.numpy(), index.numpy()) >>> print(v.numpy(), index.numpy())
[1. 4.] [0 3] [1. 4.] [0 3]
""" """
if not isinstance(x, (Tensor, SymbolVar)): if not isinstance(x, Tensor):
raise TypeError("input must be a tensor") raise TypeError("input must be a tensor")
if not isinstance(mask, (Tensor, SymbolVar)): if not isinstance(mask, Tensor):
raise TypeError("mask must be a tensor") raise TypeError("mask must be a tensor")
if mask.dtype != np.bool_: if mask.dtype != np.bool_:
raise ValueError("mask must be bool") raise ValueError("mask must be bool")
...@@ -843,15 +840,11 @@ def linspace( ...@@ -843,15 +840,11 @@ def linspace(
if not (cur_device is None or device == cur_device): if not (cur_device is None or device == cur_device):
raise ("ambiguous device for linspace opr") raise ("ambiguous device for linspace opr")
is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num]) if not isinstance(start, Tensor):
if any(is_symbolvar) and not all(is_symbolvar):
raise TypeError("start, stop and num should all be VarNode or none of them")
if not isinstance(start, (Tensor, SymbolVar)):
start = Tensor(start, device=device) start = Tensor(start, device=device)
if not isinstance(stop, (Tensor, SymbolVar)): if not isinstance(stop, Tensor):
stop = Tensor(stop, device=device) stop = Tensor(stop, device=device)
if not isinstance(num, (Tensor, SymbolVar)): if not isinstance(num, Tensor):
num = Tensor(num, device=device) num = Tensor(num, device=device)
op = builtin.Linspace(comp_node=device) op = builtin.Linspace(comp_node=device)
...@@ -901,8 +894,11 @@ def arange( ...@@ -901,8 +894,11 @@ def arange(
if stop is None: if stop is None:
start, stop = 0, start start, stop = 0, start
if not isinstance(start, Tensor):
start = Tensor(start, dtype="float32") start = Tensor(start, dtype="float32")
if not isinstance(stop, Tensor):
stop = Tensor(stop, dtype="float32") stop = Tensor(stop, dtype="float32")
if not isinstance(step, Tensor):
step = Tensor(step, dtype="float32") step = Tensor(step, dtype="float32")
num = ceil((stop - start) / step) num = ceil((stop - start) / step)
......
...@@ -7,11 +7,11 @@ small_tensor_cache = {} ...@@ -7,11 +7,11 @@ small_tensor_cache = {}
def _get_scalar_tensor_with_value(value, dtype=None, device=None): def _get_scalar_tensor_with_value(value, dtype=None, device=None):
global small_tensor_cache global small_tensor_cache
if is_tracing(): if is_tracing():
ret = Const(value, dtype, device, None) ret = Const(value, dtype, device)
else: else:
cache_key = (value, dtype, device) cache_key = (value, dtype, device)
if cache_key not in small_tensor_cache: if cache_key not in small_tensor_cache:
ret = Const(value, dtype, device, None) ret = Const(value, dtype, device)
small_tensor_cache[cache_key] = ret small_tensor_cache[cache_key] = ret
else: else:
ret = small_tensor_cache[cache_key] ret = small_tensor_cache[cache_key]
......
...@@ -154,6 +154,8 @@ class Tensor(_Tensor, ArrayMethodMixin): ...@@ -154,6 +154,8 @@ class Tensor(_Tensor, ArrayMethodMixin):
@name.setter @name.setter
def name(self, name): def name(self, name):
self._custom_name = name self._custom_name = name
if name == None:
name = ""
self._name = self._prefix + "." + name if self._prefix else name self._name = self._prefix + "." + name if self._prefix else name
self._set_name(self._name) self._set_name(self._name)
......
...@@ -756,7 +756,7 @@ class Constant(Expr): ...@@ -756,7 +756,7 @@ class Constant(Expr):
def interpret(self, *inputs): def interpret(self, *inputs):
if isinstance(self.value, RawTensor): if isinstance(self.value, RawTensor):
return (Const(self.value.numpy(), None, None, None),) return (Const(self.value.numpy(), None, None),)
return (self.value,) return (self.value,)
def __repr__(self): def __repr__(self):
......
...@@ -395,7 +395,7 @@ class Network: ...@@ -395,7 +395,7 @@ class Network:
for ind, var in enumerate(opr.outputs): for ind, var in enumerate(opr.outputs):
var.owner = repl_dict[opr] var.owner = repl_dict[opr]
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) var.__dict__.update(repl_dict[opr].outputs[ind].__dict__)
var.var = repl_dict[opr].outputs[ind].var var._reset_var(repl_dict[opr].outputs[ind].var)
repl_dict[opr].outputs = opr.outputs repl_dict[opr].outputs = opr.outputs
self._compile() self._compile()
......
...@@ -6,11 +6,11 @@ from typing import Sequence ...@@ -6,11 +6,11 @@ from typing import Sequence
import numpy as np import numpy as np
from ..core import _imperative_rt as rt from ..core import _imperative_rt as rt
from ..core._imperative_rt.core2 import SymbolVar, apply from ..core._imperative_rt.core2 import apply, set_py_varnode_type
from ..core._trace_option import use_symbolic_shape from ..core._trace_option import use_symbolic_shape
from ..core._wrap import Device from ..core._wrap import Device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.tensor.array_method import ArrayMethodMixin from ..tensor import Tensor
from .comp_graph_tools import replace_vars from .comp_graph_tools import replace_vars
from .module_stats import ( from .module_stats import (
preprocess_receptive_field, preprocess_receptive_field,
...@@ -23,26 +23,72 @@ class NetworkNode: ...@@ -23,26 +23,72 @@ class NetworkNode:
pass pass
class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)): class VarNode(NetworkNode, Tensor):
pass _users = None
_owner = None
_name = None
_id = None
def __new__(cls, var, *, owner_opr=None, name=None):
obj = Tensor.__new__(cls, var)
return obj
class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): def __init__(self, var, *, owner_opr=None, name=None):
def __init__(self, var=None, *, owner_opr=None, name=None): self._owner = owner_opr
SymbolVar.__init__(self, var)
self.users = [] # List[OpNode]
self.owner = owner_opr
self.name = name self.name = name
self.id = id(self)
@classmethod @classmethod
def load(cls, sym_var, owner_opr): def load(cls, sym_var, owner_opr):
obj = cls() obj = cls(sym_var)
obj.var = sym_var # mgb varnode obj.var = sym_var # mgb varnode
obj.name = sym_var.name obj.name = sym_var.name
obj.owner = owner_opr obj.owner = owner_opr
return obj return obj
@property
def users(self):
if self._users is None:
self._users = []
return self._users
@property
def owner(self):
return self._owner
@owner.setter
def owner(self, owner):
self._owner = owner
@property
def id(self):
if self._id is None:
self._id = id(self)
return self._id
@property
def var(self):
return super().var()
@var.setter
def var(self, var):
self._reset(var)
def _reset(self, other):
if not isinstance(other, Tensor):
other = VarNode(other)
super()._reset(other)
self.owner = None
def _reset_var(self, var):
origin_owner = self.owner
self.var = var
self.var.name = self.name
self.owner = origin_owner
@property
def graph(self):
return super().graph()
def _get_var_shape(self, axis=None): def _get_var_shape(self, axis=None):
opdef = ( opdef = (
builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis) builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis)
...@@ -77,14 +123,6 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): ...@@ -77,14 +123,6 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
return rst return rst
return self._get_var_shape() if self.var else None return self._get_var_shape() if self.var else None
@property
def dtype(self):
return self.var.dtype if self.var else None
@property
def ndim(self):
return super().ndim
def __bool__(self): def __bool__(self):
return False return False
...@@ -92,27 +130,11 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): ...@@ -92,27 +130,11 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
__int__ = None __int__ = None
__float__ = None __float__ = None
__complex__ = None __complex__ = None
__repr__ = lambda self: "VarNode:" + self.name
def __hash__(self): def __hash__(self):
return id(self) return id(self)
def numpy(self):
return super().numpy()
def _reset(self, other):
if not isinstance(other, VarNode):
assert self.graph, "VarNode _reset must have graph"
node = ImmutableTensor(other, graph=self.graph)
node.compile(self.graph)
other = node.outputs[0]
if self.owner is not None:
idx = self.owner.outputs.index(self)
self.owner.outputs[idx] = VarNode(
self.var, owner_opr=self.owner, name=self.var.name
)
self.var = other.var
self.owner = None
def set_owner_opr(self, owner_opr): def set_owner_opr(self, owner_opr):
self.owner = owner_opr self.owner = owner_opr
...@@ -158,8 +180,7 @@ class OpNode(NetworkNode): ...@@ -158,8 +180,7 @@ class OpNode(NetworkNode):
assert len(outputs) == len(self.outputs) assert len(outputs) == len(self.outputs)
self._opr = outputs[0].owner self._opr = outputs[0].owner
for i in range(len(self.outputs)): for i in range(len(self.outputs)):
self.outputs[i].var = outputs[i] self.outputs[i]._reset_var(outputs[i])
self.outputs[i].var.name = self.outputs[i].name
assert self.outputs[i].owner is self assert self.outputs[i].owner is self
def add_inp_var(self, x): def add_inp_var(self, x):
...@@ -214,8 +235,9 @@ class Host2DeviceCopy(OpNode): ...@@ -214,8 +235,9 @@ class Host2DeviceCopy(OpNode):
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
self._opr = outputs.owner self._opr = outputs.owner
if len(self.outputs) == 0: if len(self.outputs) == 0:
self.outputs.append(VarNode(owner_opr=self, name=self.name)) self.outputs.append(VarNode(outputs, owner_opr=self, name=self.name))
self.outputs[0].var = outputs else:
self.outputs[0]._reset_var(outputs)
assert self.outputs[0].owner is self assert self.outputs[0].owner is self
...@@ -262,8 +284,9 @@ class ConstOpBase(OpNode): ...@@ -262,8 +284,9 @@ class ConstOpBase(OpNode):
data = data.astype(np.int32) data = data.astype(np.int32)
varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name) varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name)
if len(self.outputs) == 0: if len(self.outputs) == 0:
self.outputs.append(VarNode(owner_opr=self, name=self.name)) self.outputs.append(VarNode(varnode, owner_opr=self, name=self.name))
self.outputs[0].var = varnode else:
self.outputs[0]._reset_var(varnode)
self._opr = varnode.owner self._opr = varnode.owner
@classmethod @classmethod
...@@ -313,7 +336,7 @@ class ReadOnlyOpNode(OpNode): ...@@ -313,7 +336,7 @@ class ReadOnlyOpNode(OpNode):
if bool(repl_dict): if bool(repl_dict):
out_vars = replace_vars(self._opr.outputs, repl_dict) out_vars = replace_vars(self._opr.outputs, repl_dict)
for ind, o in enumerate(self.outputs): for ind, o in enumerate(self.outputs):
o.var = out_vars[ind] o._reset_var(out_vars[ind])
class Elemwise(OpNode): class Elemwise(OpNode):
...@@ -785,3 +808,6 @@ class AssertEqual(OpNode): ...@@ -785,3 +808,6 @@ class AssertEqual(OpNode):
class CvtColorForward(OpNode): class CvtColorForward(OpNode):
type = "CvtColor" type = "CvtColor"
opdef = builtin.CvtColor opdef = builtin.CvtColor
set_py_varnode_type(VarNode)
...@@ -114,6 +114,8 @@ void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) { ...@@ -114,6 +114,8 @@ void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) {
} }
} }
py::object Py_Varnode = py::none();
void init_graph_rt(py::module m) { void init_graph_rt(py::module m) {
static const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{ static const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{
std::make_unique<mgb::OprFootprint>()}; std::make_unique<mgb::OprFootprint>()};
...@@ -124,6 +126,7 @@ void init_graph_rt(py::module m) { ...@@ -124,6 +126,7 @@ void init_graph_rt(py::module m) {
def_rendezvous<TensorAttr>(m, "TensorAttrRendezvous"); def_rendezvous<TensorAttr>(m, "TensorAttrRendezvous");
Py_Varnode =
py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode") py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode")
.def_property_readonly( .def_property_readonly(
"owner", [](cg::VarNode* v) { return v->owner_opr(); }) "owner", [](cg::VarNode* v) { return v->owner_opr(); })
...@@ -132,7 +135,8 @@ void init_graph_rt(py::module m) { ...@@ -132,7 +135,8 @@ void init_graph_rt(py::module m) {
.def_property( .def_property(
"name", py::overload_cast<>(&VarNode::name, py::const_), "name", py::overload_cast<>(&VarNode::name, py::const_),
py::overload_cast<std::string>(&VarNode::name)) py::overload_cast<std::string>(&VarNode::name))
.def_property_readonly("dtype", [](cg::VarNode* v) { return v->dtype(); }) .def_property_readonly(
"dtype", [](cg::VarNode* v) { return v->dtype(); })
.def_property_readonly( .def_property_readonly(
"comp_node", [](cg::VarNode* v) { return v->comp_node(); }) "comp_node", [](cg::VarNode* v) { return v->comp_node(); })
.def_property_readonly( .def_property_readonly(
...@@ -147,7 +151,8 @@ void init_graph_rt(py::module m) { ...@@ -147,7 +151,8 @@ void init_graph_rt(py::module m) {
auto&& mgr = v->owner_graph()->static_infer_manager(); auto&& mgr = v->owner_graph()->static_infer_manager();
auto&& type = mgr.get_infer_type(v); auto&& type = mgr.get_infer_type(v);
using InferType = cg::static_infer::InferType; using InferType = cg::static_infer::InferType;
if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { if (!(type.value &
(InferType::CONST | InferType::RT_STATIC))) {
return py::none(); return py::none();
} }
auto* val = mgr.infer_value_fallible(v); auto* val = mgr.infer_value_fallible(v);
...@@ -156,7 +161,8 @@ void init_graph_rt(py::module m) { ...@@ -156,7 +161,8 @@ void init_graph_rt(py::module m) {
} }
return py::cast(*val).attr("numpy")(); return py::cast(*val).attr("numpy")();
}) })
.def_property_readonly("id", [](cg::VarNode* v) { return (v->id()); }) .def_property_readonly(
"id", [](cg::VarNode* v) { return (v->id()); })
.def("__repr__", [](cg::VarNode* v) { return "Var:" + v->name(); }); .def("__repr__", [](cg::VarNode* v) { return "Var:" + v->name(); });
py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>( py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(
......
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#include "megbrain/graph.h" #include "megbrain/graph.h"
#include "megbrain/plugin/opr_footprint.h" #include "megbrain/plugin/opr_footprint.h"
namespace py = pybind11;
extern py::object Py_Varnode;
template <typename T> template <typename T>
class GraphNodePtr { class GraphNodePtr {
std::shared_ptr<mgb::cg::ComputingGraph> m_graph; std::shared_ptr<mgb::cg::ComputingGraph> m_graph;
......
...@@ -48,58 +48,11 @@ namespace mgb::imperative::python { ...@@ -48,58 +48,11 @@ namespace mgb::imperative::python {
namespace { namespace {
WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map; WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
struct SymbolVarContext {
TransformationContext context;
std::shared_ptr<SymbolTransformation> symbol_tsf;
std::shared_ptr<ScalarTransformation> scalar_tsf;
std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf;
std::shared_ptr<DimExpansionTransformation> dim_expansion_tsf;
SymbolVarContext(cg::ComputingGraph* graph) {
symbol_tsf = std::make_shared<SymbolTransformation>(graph);
scalar_tsf = std::make_shared<ScalarTransformation>();
dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>();
dim_expansion_tsf = std::make_shared<DimExpansionTransformation>();
Transformation::swap_context(context);
}
void init() {
symbol_tsf->register_at(Transformation::top());
scalar_tsf->register_at(Transformation::top());
dtype_promote_tsf->register_at(Transformation::top());
dim_expansion_tsf->register_at(Transformation::top());
}
ValueRef symvar2val(py::handle py_symbol_var) {
auto* symbol_var = py_symbol_var.cast<PySymbolVar*>();
ValueRef value = symbol_tsf->value_type().make(symbol_var->m_node);
if (symbol_var->is_scalar) {
value = scalar_tsf->value_type().make(value);
}
return value;
}
py::object val2symvar(py::handle typeobj, ValueRef value) {
bool is_scalar = false;
if (auto* scalar_value = value.as(scalar_tsf->value_type())) {
value = scalar_value->value();
is_scalar = true;
}
auto* node = value.cast(symbol_tsf->value_type()).node();
auto py_symbol_var =
typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic));
py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar;
return py_symbol_var;
}
~SymbolVarContext() { Transformation::swap_context(context); }
};
} // namespace } // namespace
interpreter::Interpreter::Channel* interpreter_for_py = nullptr; interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
PyTypeObject* py_tensor_type = nullptr; PyTypeObject* py_tensor_type = nullptr;
PyTypeObject* py_varnode_type = nullptr;
pybind11::handle py_device_type = nullptr; pybind11::handle py_device_type = nullptr;
PyObject* cpp_use_symbolic_shape; PyObject* cpp_use_symbolic_shape;
...@@ -136,22 +89,6 @@ PyObject* py_apply( ...@@ -136,22 +89,6 @@ PyObject* py_apply(
auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>(); auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>();
SmallVector<ValueRef, 8> tensors(nargs); SmallVector<ValueRef, 8> tensors(nargs);
SmallVector<bool, 8> is_symbol_var(nargs, false);
ComputingGraph* cg = nullptr;
for (size_t i = 0; i < nargs; ++i) {
if ((!TensorWrapper::try_cast(args[i])) &&
py::isinstance<PySymbolVar>(py::handle(args[i]))) {
is_symbol_var[i] = true;
ComputingGraph* cur_cg =
py::handle(args[i]).cast<PySymbolVar*>()->m_node->owner_graph();
if (cg == nullptr) {
cg = cur_cg;
} else {
mgb_assert(cg == cur_cg);
}
}
}
mgb::CompNode target_cn; mgb::CompNode target_cn;
mgb::DType target_dtype; mgb::DType target_dtype;
...@@ -174,35 +111,11 @@ PyObject* py_apply( ...@@ -174,35 +111,11 @@ PyObject* py_apply(
} }
}; };
if (cg != nullptr) { bool is_varnode_apply = false;
// swap to a special context to reuse scalar handle
size_t symbol_var_idx = 8;
SymbolVarContext context(cg);
context.init();
for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
if (is_symbol_var[i]) { if (PyObject_TypeCheck(args[i], py_varnode_type)) {
symbol_var_idx = i; is_varnode_apply = true;
tensors[i] = context.symvar2val(args[i]);
} else if (
DTypePromoteCfg::convert_input_enabled &&
op->same_type<Elemwise>()) {
tensors[i] = convert_pyinput_to_tensor(i);
} else {
PyErr_SetString(
PyExc_TypeError, "py_apply expects tensor as inputs");
return nullptr;
}
}
auto outputs = imperative::apply(*op, tensors);
auto ret = pybind11::tuple(outputs.size());
auto typeobj = py::handle(args[symbol_var_idx]).get_type();
for (size_t i = 0; i < outputs.size(); ++i) {
ret[i] = context.val2symvar(typeobj, outputs[i]);
} }
return ret.release().ptr();
}
for (size_t i = 0; i < nargs; ++i) {
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
tensors[i] = tw->m_tensor->data(); tensors[i] = tw->m_tensor->data();
} else if ( } else if (
...@@ -218,8 +131,9 @@ PyObject* py_apply( ...@@ -218,8 +131,9 @@ PyObject* py_apply(
auto outputs = [&] { return imperative::apply(*op, tensors); }(); auto outputs = [&] { return imperative::apply(*op, tensors); }();
size_t nout = outputs.size(); size_t nout = outputs.size();
auto ret = py::tuple(nout); auto ret = py::tuple(nout);
PyTypeObject* py_type = is_varnode_apply ? py_varnode_type : py_tensor_type;
for (size_t i = 0; i < nout; ++i) { for (size_t i = 0; i < nout; ++i) {
ret[i] = TensorWrapper::make(py_tensor_type, std::move(outputs[i])); ret[i] = TensorWrapper::make(py_type, std::move(outputs[i]));
} }
return ret.release().ptr(); return ret.release().ptr();
} }
...@@ -622,9 +536,17 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -622,9 +536,17 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
CreateTensor::Kind kind = is_const ? CreateTensor::Const CreateTensor::Kind kind = is_const ? CreateTensor::Const
: no_cache ? CreateTensor::Unique : no_cache ? CreateTensor::Unique
: CreateTensor::Common; : CreateTensor::Common;
ValueRef val;
if (py::isinstance(data, Py_Varnode)) {
cg::VarNode* m_node = py::handle(data).cast<cg::VarNode*>();
val = imperative::apply(
CreateNode(m_node), Span<ValueRef>(nullptr, nullptr))[0];
} else {
auto&& hval = pyobj2hval(data, cn, dtype); auto&& hval = pyobj2hval(data, cn, dtype);
auto val = imperative::apply( val = imperative::apply(
CreateTensor(kind, cn, hval.dtype, hval.shape), hval.storage)[0]; CreateTensor(kind, cn, hval.dtype, hval.shape),
hval.storage)[0];
}
m_tensor.emplace(val); m_tensor.emplace(val);
} }
...@@ -734,6 +656,20 @@ PyObject* TensorWrapper::isscalar() { ...@@ -734,6 +656,20 @@ PyObject* TensorWrapper::isscalar() {
} }
} }
PyObject* TensorWrapper::_var() {
TypedValueRef<NodeValue> value =
imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref<NodeValue>();
auto* node = value->node();
return py::cast(node).release().ptr();
}
PyObject* TensorWrapper::_graph() {
TypedValueRef<NodeValue> value =
imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref<NodeValue>();
auto* graph = value->graph();
return py::cast(graph).release().ptr();
}
struct TensorWeakRef { struct TensorWeakRef {
ValueWeakRef data; ValueWeakRef data;
...@@ -807,6 +743,10 @@ void init_tensor(py::module m) { ...@@ -807,6 +743,10 @@ void init_tensor(py::module m) {
.register_at<Segment::Scalar>( .register_at<Segment::Scalar>(
std::make_shared<ScalarTransformation>()) std::make_shared<ScalarTransformation>())
.release()); .release());
MGB_MARK_USED_VAR(transformations
.register_at<Segment::Symbol>(
std::make_shared<SymbolTransformation>())
.release());
MGB_MARK_USED_VAR(transformations MGB_MARK_USED_VAR(transformations
.register_at<Segment::DTypePromote>( .register_at<Segment::DTypePromote>(
std::make_shared<DTypePromoteTransformation>()) std::make_shared<DTypePromoteTransformation>())
...@@ -863,6 +803,8 @@ void init_tensor(py::module m) { ...@@ -863,6 +803,8 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::_detail>("_detail") .def<&TensorWrapper::_detail>("_detail")
.def<&TensorWrapper::_set_name>("_set_name") .def<&TensorWrapper::_set_name>("_set_name")
.def<&TensorWrapper::_watch>("_watch") .def<&TensorWrapper::_watch>("_watch")
.def<&TensorWrapper::_var>("var")
.def<&TensorWrapper::_graph>("graph")
.def_getset< .def_getset<
&TensorWrapper::module_trace_info, &TensorWrapper::module_trace_info,
&TensorWrapper::set_module_trace_info>("_NodeMixin__node") &TensorWrapper::set_module_trace_info>("_NodeMixin__node")
...@@ -875,43 +817,6 @@ void init_tensor(py::module m) { ...@@ -875,43 +817,6 @@ void init_tensor(py::module m) {
.def(py::init<const TensorWrapper&>()) .def(py::init<const TensorWrapper&>())
.def("__call__", &TensorWeakRef::operator()); .def("__call__", &TensorWeakRef::operator());
py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
.def_property_readonly(
"dtype", [](PySymbolVar* v) { return v->m_node->dtype(); })
.def_property(
"var", [](PySymbolVar* v) { return v->m_node; },
[](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; })
.def_property_readonly(
"device", [](PySymbolVar* v) { return v->m_node->comp_node(); })
.def_property_readonly(
"graph", [](PySymbolVar* v) { return v->m_node->owner_graph(); })
.def_property_readonly(
"shape",
[](PySymbolVar* v) -> const TensorShape* {
auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
return mgr.infer_shape_fallible(v->m_node);
})
.def("numpy",
[](PySymbolVar* v) {
auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
auto&& type = mgr.get_infer_type(v->m_node);
using InferType = cg::static_infer::InferType;
if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
throw py::value_error("value invalid!");
}
auto* val = mgr.infer_value_fallible(v->m_node);
if (!val) {
throw py::value_error("value invalid!");
}
auto np_val = py::cast(*val).attr("numpy")();
return np_val;
})
.def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
.def(py::init([](cg::VarNode* node) {
return std::make_shared<PySymbolVar>(node);
}),
py::arg() = nullptr);
static PyMethodDef method_defs[] = { static PyMethodDef method_defs[] = {
MGE_PY_INTERFACE(apply, py_apply), MGE_PY_INTERFACE(apply, py_apply),
MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
...@@ -1027,6 +932,10 @@ void init_tensor(py::module m) { ...@@ -1027,6 +932,10 @@ void init_tensor(py::module m) {
py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr()); py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
}); });
m.def("set_py_varnode_type", [](py::object type_obj) {
py_varnode_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
});
m.def("set_py_device_type", m.def("set_py_device_type",
[](py::object type_obj) { py_device_type = type_obj.inc_ref(); }); [](py::object type_obj) { py_device_type = type_obj.inc_ref(); });
...@@ -1217,31 +1126,6 @@ void init_tensor(py::module m) { ...@@ -1217,31 +1126,6 @@ void init_tensor(py::module m) {
} }
}); });
m.def("reduce_to_scalar", [](py::object op, py::object tensor) -> py::object {
auto reduce_to_scalar = [](const OpDef& op, const ValueRef& input) {
auto make_scalar_shape = [&](CompNode device) {
return imperative::apply(
CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}),
HostStorage::make(device))[0];
};
return imperative::apply(op, input, make_scalar_shape(*input.device()))[0];
};
if (py::isinstance<PySymbolVar>(tensor)) {
auto* graph = tensor.cast<PySymbolVar*>()->m_node->owner_graph();
SymbolVarContext context(graph);
context.init();
auto output = reduce_to_scalar(
*op.cast<std::shared_ptr<OpDef>>(), context.symvar2val(tensor));
auto typeobj = tensor.get_type();
return context.val2symvar(typeobj, output);
} else {
auto* tw = TensorWrapper::try_cast(tensor.ptr());
auto output = reduce_to_scalar(
*op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data());
return TensorWrapper::make(py_tensor_type, output);
}
});
m.def("name_tensor", [](std::string name, py::object tensor) { m.def("name_tensor", [](std::string name, py::object tensor) {
auto* tw = TensorWrapper::try_cast(tensor.ptr()); auto* tw = TensorWrapper::try_cast(tensor.ptr());
auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0]; auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0];
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include "./pyext17.h" #include "./pyext17.h"
#include "megbrain/imperative/dispatch.h" #include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/transformations/symbol.h"
#include "megbrain/imperative/utils/span.h" #include "megbrain/imperative/utils/span.h"
namespace mgb::imperative::python { namespace mgb::imperative::python {
...@@ -27,6 +29,7 @@ namespace mgb::imperative::python { ...@@ -27,6 +29,7 @@ namespace mgb::imperative::python {
extern interpreter::Interpreter::Channel* interpreter_for_py; extern interpreter::Interpreter::Channel* interpreter_for_py;
extern PyTypeObject* py_tensor_type; extern PyTypeObject* py_tensor_type;
extern PyTypeObject* py_varnode_type;
extern pybind11::handle py_device_type; extern pybind11::handle py_device_type;
extern PyObject* cpp_use_symbolic_shape; extern PyObject* cpp_use_symbolic_shape;
extern PyObject* cpp_astensor1d; extern PyObject* cpp_astensor1d;
...@@ -126,16 +129,11 @@ public: ...@@ -126,16 +129,11 @@ public:
void set_module_trace_info(PyObject*); void set_module_trace_info(PyObject*);
void _set_name(PyObject*); void _set_name(PyObject*);
PyObject* _detail(); PyObject* _detail();
PyObject* _var();
PyObject* _graph();
void _watch(); void _watch();
}; };
struct PySymbolVar {
cg::VarNode* m_node = nullptr;
bool is_scalar = false;
PySymbolVar() = default;
PySymbolVar(VarNode* m) : m_node(m) {}
};
PyObject* py_apply( PyObject* py_apply(
PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */); PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */);
......
...@@ -146,15 +146,6 @@ PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) { ...@@ -146,15 +146,6 @@ PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
continue; continue;
} }
if (py::isinstance<PySymbolVar>(py::handle(handle))) {
auto var = py::handle(handle).cast<PySymbolVar*>();
mgb::DType type = var->m_node->dtype();
auto&& descr = npy::dtype_mgb2np_descr(type);
Py_INCREF(descr.get());
tensors.emplace_back(descr.get());
continue;
}
PyArray_Descr* descr = scalar2dtype(handle); PyArray_Descr* descr = scalar2dtype(handle);
if (descr) { if (descr) {
scalars.emplace_back(descr); scalars.emplace_back(descr);
...@@ -204,17 +195,12 @@ CompNode _get_device(PyObject* const* args, size_t nargs) { ...@@ -204,17 +195,12 @@ CompNode _get_device(PyObject* const* args, size_t nargs) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
TensorWrapper* tw = TensorWrapper::try_cast(handle); TensorWrapper* tw = TensorWrapper::try_cast(handle);
bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); if (tw) {
if (tw || is_symvar) {
if (!valid) { if (!valid) {
cn = tw ? tw->m_tensor->comp_node() cn = tw->m_tensor->comp_node();
: py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node();
valid = true; valid = true;
} else { } else {
CompNode cn1 = tw ? tw->m_tensor->comp_node() CompNode cn1 = tw->m_tensor->comp_node();
: py::handle(handle)
.cast<PySymbolVar*>()
->m_node->comp_node();
if (cn1 != cn) { if (cn1 != cn) {
throw py::value_error(ssprintf( throw py::value_error(ssprintf(
"ambiguous device: %s (from %s) vs %s (from %s)", "ambiguous device: %s (from %s) vs %s (from %s)",
...@@ -258,10 +244,6 @@ PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { ...@@ -258,10 +244,6 @@ PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) {
} }
bool is_scalar(PyObject* tensor) { bool is_scalar(PyObject* tensor) {
if (py::isinstance<PySymbolVar>(py::handle(tensor))) {
auto var = py::handle(tensor).cast<PySymbolVar*>();
return var->is_scalar;
}
auto* tw = TensorWrapper::try_cast(tensor); auto* tw = TensorWrapper::try_cast(tensor);
if (tw) { if (tw) {
return tw->m_tensor->is_scalar(); return tw->m_tensor->is_scalar();
...@@ -319,8 +301,7 @@ py::object device2obj(py::handle device, bool mapping = false) { ...@@ -319,8 +301,7 @@ py::object device2obj(py::handle device, bool mapping = false) {
} }
} }
py::object _Const( py::object _Const(py::handle value, py::handle dtype, py::handle device) {
py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) {
py::object val = py::reinterpret_borrow<py::object>(value); py::object val = py::reinterpret_borrow<py::object>(value);
if (PyArray_Check(value.ptr())) { if (PyArray_Check(value.ptr())) {
py::tuple strides = py::tuple strides =
...@@ -338,32 +319,6 @@ py::object _Const( ...@@ -338,32 +319,6 @@ py::object _Const(
val = val.attr("reshape")(orig_shp); val = val.attr("reshape")(orig_shp);
} }
} }
py::object ref;
if (py::isinstance<py::tuple>(ref_hdl)) {
py::tuple tup = py::reinterpret_borrow<py::tuple>(ref_hdl);
if (tup.size()) {
ref = tup[0];
} else {
ref = py::none();
}
} else {
ref = py::reinterpret_borrow<py::object>(ref_hdl);
}
if (py::isinstance<PySymbolVar>(ref)) {
auto ref_var = ref.cast<PySymbolVar*>();
auto* graph = ref_var->m_node->owner_graph();
CompNode cn;
if (device.ptr() == Py_None) {
cn = ref_var->m_node->comp_node();
} else {
cn = device2obj(device).cast<CompNode>();
}
OperatorNodeConfig config(cn);
auto hv = npy::np2tensor(
val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
auto typeobj = ref.get_type();
return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
}
py::object device_obj = device2obj(device, true); py::object device_obj = device2obj(device, true);
py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none()); py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none());
return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr);
...@@ -373,7 +328,7 @@ py::tuple _make_shape_tuple(py::handle shape) { ...@@ -373,7 +328,7 @@ py::tuple _make_shape_tuple(py::handle shape) {
py::list orig; py::list orig;
py::list ret(0); py::list ret(0);
auto solve_one = [&](py::handle val) { auto solve_one = [&](py::handle val) {
if (TensorWrapper::try_cast(val.ptr()) || py::isinstance<PySymbolVar>(val)) { if (TensorWrapper::try_cast(val.ptr())) {
py::object np = getattr(val, "numpy")(); py::object np = getattr(val, "numpy")();
PyArrayObject* arr = (PyArrayObject*)np.ptr(); PyArrayObject* arr = (PyArrayObject*)np.ptr();
PyObject* maybe_list = PyArray_ToList(arr); PyObject* maybe_list = PyArray_ToList(arr);
...@@ -415,25 +370,53 @@ py::tuple _make_shape_tuple(py::handle shape) { ...@@ -415,25 +370,53 @@ py::tuple _make_shape_tuple(py::handle shape) {
return py::reinterpret_steal<py::tuple>(PyList_AsTuple(ret.ptr())); return py::reinterpret_steal<py::tuple>(PyList_AsTuple(ret.ptr()));
} }
bool is_tensor_or_symbolvar(py::handle arg) { bool is_tensor(py::handle arg) {
return bool(TensorWrapper::try_cast(arg.ptr())) || py::isinstance<PySymbolVar>(arg); return bool(TensorWrapper::try_cast(arg.ptr()));
} }
bool is_py_sequence(py::handle arg) { bool is_py_sequence(py::handle arg) {
if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr()) || if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr())) {
py::isinstance<PySymbolVar>(arg)) {
return false; return false;
} }
return PySequence_Check(arg.ptr()); return PySequence_Check(arg.ptr());
} }
mgb::DType _get_dtype(py::handle tensor) { py::object get_res_by_refhdl(
if (auto tw = TensorWrapper::try_cast(tensor.ptr())) { py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) {
return tw->m_tensor->dtype(); py::object res = _Const(value, dtype, device);
py::object ref;
if (py::isinstance<py::tuple>(ref_hdl)) {
py::tuple tup = py::reinterpret_borrow<py::tuple>(ref_hdl);
if (tup.size()) {
ref = tup[0];
} else { } else {
auto var = tensor.cast<PySymbolVar*>(); ref = py::none();
return var->m_node->dtype(); }
} else {
ref = py::reinterpret_borrow<py::object>(ref_hdl);
}
if (PyObject_TypeCheck(ref.ptr(), py_varnode_type)) {
auto temp = dtype.cast<mgb::DType>();
ComputingGraph* graph = getattr(ref, "graph").cast<ComputingGraph*>();
cg::VarNode* node = getattr(ref, "var").cast<cg::VarNode*>();
CompNode cn;
if (device.ptr() == Py_None) {
cn = node->comp_node();
} else {
cn = device2obj(device).cast<CompNode>();
}
OperatorNodeConfig config(cn);
auto hv = npy::np2tensor(
value.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
auto typeobj = ref.get_type();
return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
} }
return res;
}
mgb::DType _get_dtype(py::handle tensor) {
auto tw = TensorWrapper::try_cast(tensor.ptr());
return tw->m_tensor->dtype();
} }
py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) {
...@@ -457,12 +440,12 @@ py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { ...@@ -457,12 +440,12 @@ py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) {
py::object _convert_single_value_cpp( py::object _convert_single_value_cpp(
py::handle value, py::handle dtype, py::handle device) { py::handle value, py::handle dtype, py::handle device) {
if (is_tensor_or_symbolvar(value)) { if (is_tensor(value)) {
if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) { if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) {
return _astype_cpp(value, dtype); return _astype_cpp(value, dtype);
} }
} else { } else {
return _Const(value, dtype, device, py::none()); return _Const(value, dtype, device);
} }
return py::reinterpret_borrow<py::object>(value); return py::reinterpret_borrow<py::object>(value);
} }
...@@ -475,28 +458,8 @@ py::object _convert_inputs_cpp( ...@@ -475,28 +458,8 @@ py::object _convert_inputs_cpp(
for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
py::handle h = py::handle(args[i]); py::handle h = py::handle(args[i]);
lis.append(h); lis.append(h);
if (py::isinstance<PySymbolVar>(h)) {
auto var = h.cast<PySymbolVar*>();
auto g = var->m_node->owner_graph();
if (!graph) {
graph = g;
typeobj = h.get_type();
} else {
mgb_assert(graph == g);
}
}
}
if (graph) {
CompNode cn = device2obj(device).cast<CompNode>();
for (size_t i = 0; i < nargs; ++i) {
OperatorNodeConfig config(cn);
auto hv = npy::np2tensor(
lis[i].ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
if (!py::isinstance<PySymbolVar>(lis[i])) {
lis[i] = typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
}
}
} }
auto convert = [&](py::object value) { auto convert = [&](py::object value) {
if (value.is_none()) { if (value.is_none()) {
return value; return value;
...@@ -517,7 +480,8 @@ py::object _astensor1d_cpp( ...@@ -517,7 +480,8 @@ py::object _astensor1d_cpp(
if (device.ptr() != Py_None) { if (device.ptr() != Py_None) {
device_obj = device2obj(device); device_obj = device2obj(device);
} }
if (py::isinstance<PySymbolVar>(value)) {
if (PyObject_TypeCheck(value.ptr(), py_varnode_type)) {
try { try {
getattr(value, "ndim"); getattr(value, "ndim");
} catch (py::error_already_set& err) { } catch (py::error_already_set& err) {
...@@ -537,14 +501,15 @@ py::object _astensor1d_cpp( ...@@ -537,14 +501,15 @@ py::object _astensor1d_cpp(
return ret; return ret;
} }
} }
size_t ndim = 999; size_t ndim = 999;
if (hasattr(value, "ndim")) { if (hasattr(value, "ndim")) {
ndim = getattr(value, "ndim").cast<size_t>(); ndim = getattr(value, "ndim").cast<size_t>();
if (ndim != 0 && ndim != 1) { if (ndim != 0 && ndim != 1) {
throw py::value_error("ndim != 1 or 0, get : " + std::to_string(ndim)); throw py::value_error("ndim != 1 or 0, get : " + std::to_string(ndim));
} }
if (!is_tensor_or_symbolvar(value)) { if (!is_tensor(value)) {
return _Const(value, dtype, device, ref); return get_res_by_refhdl(value, dtype, device, ref);
} else { } else {
return py::reinterpret_borrow<py::object>(value); return py::reinterpret_borrow<py::object>(value);
} }
...@@ -555,13 +520,13 @@ py::object _astensor1d_cpp( ...@@ -555,13 +520,13 @@ py::object _astensor1d_cpp(
py::list lis = py::reinterpret_steal<py::list>(PySequence_List(value.ptr())); py::list lis = py::reinterpret_steal<py::list>(PySequence_List(value.ptr()));
bool need_concat = false; bool need_concat = false;
for (size_t i = 0; i < lis.size(); ++i) { for (size_t i = 0; i < lis.size(); ++i) {
if (is_tensor_or_symbolvar(lis[i])) { if (is_tensor(lis[i])) {
need_concat = true; need_concat = true;
break; break;
} }
} }
if (!need_concat) { if (!need_concat) {
return _Const(value, dtype, device, ref); return get_res_by_refhdl(value, dtype, device, ref);
} }
if (lis.size() > 1) { if (lis.size() > 1) {
std::vector<PyObject*> c_args(lis.size() + 1); std::vector<PyObject*> c_args(lis.size() + 1);
...@@ -600,10 +565,9 @@ py::object _astensor1d_cpp( ...@@ -600,10 +565,9 @@ py::object _astensor1d_cpp(
} }
py::object _get_index(py::object tensor, py::object src) { py::object _get_index(py::object tensor, py::object src) {
if (!TensorWrapper::try_cast(tensor.ptr()) && if (!TensorWrapper::try_cast(tensor.ptr())) {
!py::isinstance<PySymbolVar>(tensor)) {
auto get_const = [&](mgb::DType dtype) -> py::object { auto get_const = [&](mgb::DType dtype) -> py::object {
return _Const(tensor, py::cast(dtype), src.attr("device"), src); return _Const(tensor, py::cast(dtype), src.attr("device"));
}; };
if (is_bool_list(tensor.ptr()) || is_bool_dtype(tensor.ptr())) { if (is_bool_list(tensor.ptr()) || is_bool_dtype(tensor.ptr())) {
tensor = get_const(dtype::Bool()); tensor = get_const(dtype::Bool());
...@@ -636,9 +600,8 @@ py::tuple _try_cond_take(py::handle tensor, py::handle index) { ...@@ -636,9 +600,8 @@ py::tuple _try_cond_take(py::handle tensor, py::handle index) {
} }
py::object iobj; py::object iobj;
if (PyArray_Check(index.ptr())) { if (PyArray_Check(index.ptr())) {
iobj = iobj = _Const(
_Const(index, py::cast((mgb::DType)dtype::Bool()), index, py::cast((mgb::DType)dtype::Bool()), getattr(tensor, "device"));
getattr(tensor, "device"), tensor);
} else { } else {
iobj = py::reinterpret_borrow<py::object>(index); iobj = py::reinterpret_borrow<py::object>(index);
} }
...@@ -920,8 +883,8 @@ py::object _expand_args(py::handle args) { ...@@ -920,8 +883,8 @@ py::object _expand_args(py::handle args) {
return py::reinterpret_borrow<py::object>(args); return py::reinterpret_borrow<py::object>(args);
} }
py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr()); py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr());
if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || if (args_tup.size() == 1 &&
is_tensor_or_symbolvar(args_tup[0].ptr()))) { (PySequence_Check(args_tup[0].ptr()) || is_tensor(args_tup[0].ptr()))) {
return py::reinterpret_borrow<py::object>(args_tup[0]); return py::reinterpret_borrow<py::object>(args_tup[0]);
} else { } else {
return py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr())); return py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr()));
...@@ -948,7 +911,8 @@ std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) { ...@@ -948,7 +911,8 @@ std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) {
bool enable_fastpath(py::handle inp) { bool enable_fastpath(py::handle inp) {
auto&& tm_tr = TransformationManager::get_instance() auto&& tm_tr = TransformationManager::get_instance()
.segments[TransformationManager::Segment::ModuleTrace]; .segments[TransformationManager::Segment::ModuleTrace];
if (!TensorWrapper::try_cast(inp.ptr()) || bool is_varnode = PyObject_TypeCheck(inp.ptr(), py_varnode_type);
if (is_varnode ||
TransformationManager::get_instance() TransformationManager::get_instance()
.segments[TransformationManager::Segment::Trace] .segments[TransformationManager::Segment::Trace]
.size() > 0 || .size() > 0 ||
...@@ -1181,10 +1145,8 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) { ...@@ -1181,10 +1145,8 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) { py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) {
py::object org_shape = getattr(inp_hdl, "shape"); py::object org_shape = getattr(inp_hdl, "shape");
py::object val = py::reinterpret_borrow<py::object>(val_hdl); py::object val = py::reinterpret_borrow<py::object>(val_hdl);
if (!TensorWrapper::try_cast(val.ptr()) && !py::isinstance<PySymbolVar>(val)) { if (!TensorWrapper::try_cast(val.ptr())) {
val = val = _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"));
_Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"),
inp_hdl);
} }
py::tuple up = _unpack_indexes(inp_hdl, idx_hdl); py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
...@@ -1308,12 +1270,12 @@ py::object _split_cpp( ...@@ -1308,12 +1270,12 @@ py::object _split_cpp(
repr(nsplits_or_sections_hdl).cast<std::string>()); repr(nsplits_or_sections_hdl).cast<std::string>());
} }
py::object pos = div_points[i] - div_points[i - 1]; py::object pos = div_points[i] - div_points[i - 1];
if (is_tensor_or_symbolvar(pos)) { if (is_tensor(pos)) {
partitions.append(pos); partitions.append(pos);
} else { } else {
partitions.append( partitions.append(
_Const(pos, py::cast((mgb::DType)dtype::Int32()), _Const(pos, py::cast((mgb::DType)dtype::Int32()),
getattr(inp_hdl, "device"), inp_hdl)); getattr(inp_hdl, "device")));
} }
} }
op = Split::make(axis, 0); op = Split::make(axis, 0);
...@@ -1438,7 +1400,7 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) { ...@@ -1438,7 +1400,7 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) {
py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
py::object obj = _expand_args(args); py::object obj = _expand_args(args);
py::list lis; py::list lis;
if (!is_tensor_or_symbolvar(obj.ptr()) && PySequence_Check(obj.ptr())) { if (!is_tensor(obj.ptr()) && PySequence_Check(obj.ptr())) {
lis = py::reinterpret_steal<py::list>(PySequence_List(obj.ptr())); lis = py::reinterpret_steal<py::list>(PySequence_List(obj.ptr()));
} else { } else {
py::object np = getattr(obj, "numpy")(); py::object np = getattr(obj, "numpy")();
...@@ -1631,7 +1593,7 @@ PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs) ...@@ -1631,7 +1593,7 @@ PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs)
PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) {
try { try {
return _Const(args[0], args[1], args[2], args[3]).release().ptr(); return _Const(args[0], args[1], args[2]).release().ptr();
} }
PYEXT17_TRANSLATE_EXC_RET(nullptr) PYEXT17_TRANSLATE_EXC_RET(nullptr)
} }
......
...@@ -20,11 +20,12 @@ public: ...@@ -20,11 +20,12 @@ public:
DimExpansion, DimExpansion,
Grad, Grad,
Scalar, Scalar,
Symbol,
Trace, Trace,
Eval, Eval,
}; };
std::array<std::vector<std::shared_ptr<Transformation>>, 7> segments; std::array<std::vector<std::shared_ptr<Transformation>>, 8> segments;
private: private:
template <Segment segment> template <Segment segment>
......
...@@ -11,7 +11,7 @@ from megengine.utils.network_node import VarNode ...@@ -11,7 +11,7 @@ from megengine.utils.network_node import VarNode
def _default_compare_fn(x, y): def _default_compare_fn(x, y):
if isinstance(x, tensor): if isinstance(x, tensor) and not isinstance(x, VarNode):
x = x.numpy() x = x.numpy()
elif not isinstance(x, np.ndarray): elif not isinstance(x, np.ndarray):
x = get_var_value(x) x = get_var_value(x)
......
...@@ -679,6 +679,18 @@ def test_utils_astensor1d(is_varnode): ...@@ -679,6 +679,18 @@ def test_utils_astensor1d(is_varnode):
assert isinstance(xx, type(reference)) assert isinstance(xx, type(reference))
np.testing.assert_equal(xx.numpy(), [1, 2, 3]) np.testing.assert_equal(xx.numpy(), [1, 2, 3])
# varnode
if is_varnode:
a = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
b = np.array([[True, False, True], [False, True, True]])
aa = make_tensor(a, network)
bb = make_tensor(b, network)
x, y = F.cond_take(bb, aa)
for dtype in [None, "float32"]:
xx = astensor1d(x, reference, dtype=dtype)
assert isinstance(xx, type(reference))
np.testing.assert_equal(get_var_value(xx), get_var_value(x))
def test_device(): def test_device():
x = tensor([1, 2, 3], dtype="float32") x = tensor([1, 2, 3], dtype="float32")
......
...@@ -114,8 +114,10 @@ def test_replace_opr(): ...@@ -114,8 +114,10 @@ def test_replace_opr():
vara = graph.var_filter.name("a").as_unique() vara = graph.var_filter.name("a").as_unique()
varb = graph.var_filter.name("b").as_unique() varb = graph.var_filter.name("b").as_unique()
out1 = F.sub(vara, varb) out1 = F.mul(vara, varb)
out1 = F.relu(out1) out1 = F.relu(out1)
out1 += 2
out1 *= 3
out1 = graph.add_dep_oprs(out1) out1 = graph.add_dep_oprs(out1)
orig_opr = graph.opr_filter.has_input(vara).as_unique() orig_opr = graph.opr_filter.has_input(vara).as_unique()
...@@ -135,7 +137,7 @@ def test_replace_opr(): ...@@ -135,7 +137,7 @@ def test_replace_opr():
load_graph = GraphInference(modified_model1) load_graph = GraphInference(modified_model1)
out = load_graph.run(a, b) out = load_graph.run(a, b)
np.testing.assert_equal(out["o"], [0, 0]) np.testing.assert_equal(out["o"], [30, 60])
def test_splice_network(): def test_splice_network():
......
...@@ -82,6 +82,10 @@ std::string DTRCommand::to_string() const { ...@@ -82,6 +82,10 @@ std::string DTRCommand::to_string() const {
return ssprintf("DTRCommandValue{kind=%d}", (int)m_kind); return ssprintf("DTRCommandValue{kind=%d}", (int)m_kind);
} }
std::string CreateNode::to_string() const {
return "CreateNode";
}
std::string GetName::to_string() const { std::string GetName::to_string() const {
return "GetName{}"; return "GetName{}";
} }
...@@ -94,5 +98,9 @@ std::string IsScalar::to_string() const { ...@@ -94,5 +98,9 @@ std::string IsScalar::to_string() const {
return "IsScalar"; return "IsScalar";
} }
std::string GetVarVal::to_string() const {
return "GetVarVal";
}
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
...@@ -157,5 +157,22 @@ public: ...@@ -157,5 +157,22 @@ public:
std::string to_string() const override; std::string to_string() const override;
}; };
class GetVarVal final : public OperatorImpl<GetVarVal, Operator::GetAttrLike> {
public:
std::string to_string() const override;
};
class CreateNode final : public OperatorImpl<CreateNode> {
private:
cg::VarNode* m_node;
public:
CreateNode(cg::VarNode* node) : m_node(node) {}
cg::VarNode* node() const { return m_node; }
std::string to_string() const override;
};
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
...@@ -173,5 +173,24 @@ public: ...@@ -173,5 +173,24 @@ public:
std::string to_string() const override; std::string to_string() const override;
}; };
class NodeStorage {
private:
cg::VarNode* m_node;
public:
NodeStorage() = default;
NodeStorage(VarNode* node) : m_node(node) {}
VarNode* node() const { return m_node; }
ComputingGraph* graph() const { return m_node->owner_graph(); }
std::string to_string() const { return m_node->name(); }
};
class NodeValue final : public PrimitiveValue<NodeValue, NodeStorage> {
public:
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override { return NodeStorage::to_string(); }
};
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
...@@ -39,29 +39,49 @@ private: ...@@ -39,29 +39,49 @@ private:
ObjectType<SymbolValue> m_value_type{"SymbolValue"}; ObjectType<SymbolValue> m_value_type{"SymbolValue"};
public: public:
SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {} SymbolTransformation() {}
ValueRefList apply_transformation( ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override { const Operator& op, Span<ValueRef> inputs) override {
ComputingGraph* cg = nullptr;
if (auto* node_value = op.as<CreateNode>()) {
return {m_value_type.make(node_value->node())};
}
for (auto&& input : inputs) {
if (auto* val = input.as(m_value_type)) {
auto* node = val->node();
ComputingGraph* cur_cg = node->owner_graph();
if (cg == nullptr) {
cg = cur_cg;
} else {
mgb_assert(cg == cur_cg, "input varnode gragh should be the same");
}
}
}
if (!cg) {
return imperative::apply(op, inputs);
}
if (auto* apply_op = op.as<ApplyOp>()) { if (auto* apply_op = op.as<ApplyOp>()) {
SmallVector<VarNode*> input_nodes; SmallVector<VarNode*> input_nodes;
for (auto&& input : inputs) { for (auto&& input : inputs) {
if (!input.is(m_value_type)) {
auto* node = opr::ImmutableTensor::make(
*cg, input.numpy()->as_nd(true), {})
.node();
input_nodes.push_back(node);
} else {
input_nodes.push_back(input.cast(m_value_type).node()); input_nodes.push_back(input.cast(m_value_type).node());
} }
}
auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes);
ValueRefList outputs(output_nodes.size()); ValueRefList outputs(output_nodes.size());
for (size_t i = 0; i < output_nodes.size(); ++i) { for (size_t i = 0; i < output_nodes.size(); ++i) {
outputs[i] = m_value_type.make(output_nodes[i]); outputs[i] = m_value_type.make(output_nodes[i]);
} }
return outputs; return outputs;
} else if (auto* create_tensor = op.as<CreateTensor>()) {
auto&& args = create_tensor->parse(inputs);
mgb_assert(
args.kind == CreateTensor::Const,
"only const value is allowed here");
auto* node = opr::ImmutableTensor::make(*m_graph, *args.host, {}).node();
return {m_value_type.make(node)};
} else if (auto* get_attr = op.as<GetAttr>()) { } else if (auto* get_attr = op.as<GetAttr>()) {
auto* node = inputs.item().cast(m_value_type).node(); auto* node = inputs.item().cast(m_value_type).node();
auto* m_graph = node->owner_graph();
switch (get_attr->attr()) { switch (get_attr->attr()) {
case GetAttr::DType: case GetAttr::DType:
return {DTypeValue::make(node->dtype())}; return {DTypeValue::make(node->dtype())};
...@@ -105,6 +125,10 @@ public: ...@@ -105,6 +125,10 @@ public:
MegBrainError, "Symbol: malformed GetAttr: %s", MegBrainError, "Symbol: malformed GetAttr: %s",
op.to_string().c_str()); op.to_string().c_str());
} }
} else if (auto* get_attr = op.as<GetVarVal>()) {
cg::VarNode* node = inputs.item().cast(m_value_type).node();
NodeStorage inp_var = NodeStorage(node);
return {NodeValue::make(inp_var)};
} else { } else {
return op.fallback(inputs); return op.fallback(inputs);
} }
......
...@@ -33,6 +33,7 @@ class ShapeValue; ...@@ -33,6 +33,7 @@ class ShapeValue;
class DTypeValue; class DTypeValue;
class CompNodeValue; class CompNodeValue;
class StringValue; class StringValue;
class NodeValue;
class Operator; class Operator;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册