提交 147cef52 编写于 作者: M Megvii Engine Team

refactor(mge/imperative): implement new tensor system

GitOrigin-RevId: 2dd4e460ac32f059e91c7cdef859d82dee704b4b
上级 7f48625f
...@@ -301,7 +301,7 @@ class GradManager: ...@@ -301,7 +301,7 @@ class GradManager:
if tensor is None: if tensor is None:
return return
def callback(_, grad, callbacks=spec.callbacks): def callback(grad, callbacks=spec.callbacks):
for cb in callbacks: for cb in callbacks:
grad = cb(tensor, grad) grad = cb(tensor, grad)
self._gradients[id(tensor)] = grad self._gradients[id(tensor)] = grad
......
...@@ -16,6 +16,7 @@ import numpy as np ...@@ -16,6 +16,7 @@ import numpy as np
import megengine as mge import megengine as mge
from .._imperative_rt import core2
from ..ops.builtin import Elemwise, OpDef, RemoteSend from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const from ..ops.special import Const
from ..tensor.core import TensorBase, TensorWrapperBase, apply from ..tensor.core import TensorBase, TensorWrapperBase, apply
...@@ -418,3 +419,28 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): ...@@ -418,3 +419,28 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
@apply.register() @apply.register()
def _(op: Const, *_: typing.Optional[Tracer]): def _(op: Const, *_: typing.Optional[Tracer]):
return None return None
class Grad:
def __init__(self):
self._impl = core2.GradKey()
def wrt(self, *tensors, callback=None):
for x in tensors:
self._impl.attach(x, callback)
return self
def __call__(self, ys, dys):
from collections.abc import Sequence
if not isinstance(ys, Sequence):
ys = [ys]
if not isinstance(dys, Sequence):
dys = [dys]
core2.backward(self._impl, ys, dys)
def __enter__(self):
return self
def __exit__(self, _1, _2, _3):
del self._impl
...@@ -6,11 +6,18 @@ ...@@ -6,11 +6,18 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from .._imperative_rt.core2 import Tensor
from ..tensor.core import OpBase, TensorBase, apply from ..tensor.core import OpBase, TensorBase, apply
class Const(OpBase): class Const:
def __init__(self, value=None, *, dtype=None, device=None): def __init__(self, value=None, *, dtype=None, device=None):
self.value = value self.value = np.asarray(value, dtype=dtype)
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
def __call__(self, *reference):
Wrapper = type(reference[0])
return (Wrapper(self.value, self.dtype, self.device),)
...@@ -13,9 +13,17 @@ import sys ...@@ -13,9 +13,17 @@ import sys
import typing import typing
from abc import ABC from abc import ABC
from .._imperative_rt.core2 import apply as apply2
from .multipledispatch import Dispatcher from .multipledispatch import Dispatcher
def apply_op(op, *args):
Wrapper = type(args[0])
args = [arg._tensor for arg in args]
results = apply2(op, *args)
return tuple(map(Wrapper, results))
class OpBase(ABC): class OpBase(ABC):
def __call__(self, *args): def __call__(self, *args):
return apply(self, *args) return apply(self, *args)
......
...@@ -10,10 +10,10 @@ from typing import Iterable ...@@ -10,10 +10,10 @@ from typing import Iterable
import numpy as np import numpy as np
from .._imperative_rt.core2 import Tensor, apply
from .._trace_option import use_symbolic_shape from .._trace_option import use_symbolic_shape
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
from .core import TensorBase, TensorWrapperBase, apply
from .utils import astensor1d, isscalar, make_shape_tuple from .utils import astensor1d, isscalar, make_shape_tuple
...@@ -149,13 +149,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): ...@@ -149,13 +149,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
return True return True
def get_index(i): def get_index(i):
if not isinstance(i, (TensorBase, TensorWrapperBase)): if not isinstance(i, (Tensor)):
if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_:
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp)
else: else:
(i,) = Const(i, dtype=np.int32, device=inp.device)(inp) (i,) = Const(i, dtype=np.int32, device=inp.device)(inp)
return i return i
assert isinstance(i, (TensorBase, TensorWrapperBase)) assert isinstance(i, Tensor)
if i.dtype != np.bool_: if i.dtype != np.bool_:
return i return i
_, ind = apply(builtin.CondTake(), i, i) _, ind = apply(builtin.CondTake(), i, i)
...@@ -198,8 +198,8 @@ def try_condtake(tensor, index): ...@@ -198,8 +198,8 @@ def try_condtake(tensor, index):
return [] return []
if isinstance(index, np.ndarray): if isinstance(index, np.ndarray):
(index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor)
assert isinstance(index, (TensorBase, TensorWrapperBase)) assert isinstance(index, Tensor)
if not isinstance(tensor, (TensorWrapperBase, TensorBase)): if not isinstance(tensor, Tensor):
raise TypeError("input must be a tensor") raise TypeError("input must be a tensor")
if tensor.device != index.device: if tensor.device != index.device:
raise ValueError( raise ValueError(
...@@ -227,7 +227,7 @@ def getitem(tensor, index): ...@@ -227,7 +227,7 @@ def getitem(tensor, index):
op = builtin.IndexingMultiAxisVec(items=items) op = builtin.IndexingMultiAxisVec(items=items)
(result,) = apply(op, tensor, *tensors) (result,) = apply(op, tensor, *tensors)
if ret_scalar: if ret_scalar:
result.__wrapped__._data._isscalar = True result.setscalar()
return result return result
...@@ -239,7 +239,7 @@ def setitem(tensor, index, value): ...@@ -239,7 +239,7 @@ def setitem(tensor, index, value):
if index.shape[0] == 0: if index.shape[0] == 0:
return tensor return tensor
tensor = tensor.reshape(-1) tensor = tensor.reshape(-1)
if not isinstance(value, (TensorBase, TensorWrapperBase)): if not isinstance(value, Tensor):
op = Const(value, dtype=tensor.dtype, device=tensor.device) op = Const(value, dtype=tensor.dtype, device=tensor.device)
(value,) = op(tensor) (value,) = op(tensor)
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index)
...@@ -250,6 +250,7 @@ def setitem(tensor, index, value): ...@@ -250,6 +250,7 @@ def setitem(tensor, index, value):
op = builtin.Subtensor(items=items) op = builtin.Subtensor(items=items)
else: else:
op = builtin.IndexingMultiAxisVec(items=items) op = builtin.IndexingMultiAxisVec(items=items)
(tmp_result,) = apply(op, tensor, *tensors) (tmp_result,) = apply(op, tensor, *tensors)
# XXX: broadcast can always be applied even if shapes are equal # XXX: broadcast can always be applied even if shapes are equal
......
...@@ -8,19 +8,20 @@ ...@@ -8,19 +8,20 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc import abc
import collections import collections
from typing import Union
import numpy as np import numpy as np
from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import Tensor, apply
from .._trace_option import use_symbolic_shape from .._trace_option import use_symbolic_shape
from ..ops import builtin from ..ops import builtin
from ..ops.builtin import Elemwise, GetVarShape from ..ops.builtin import Elemwise, GetVarShape
from ..ops.special import Const from ..ops.special import Const
from . import utils from . import utils
from .core import OpBase, TensorBase, TensorWrapperBase, apply from .core import OpBase, TensorBase, TensorWrapperBase
from .indexing import getitem as _getitem from .indexing import getitem as _getitem
from .indexing import setitem as _setitem from .indexing import setitem as _setitem
from .raw_tensor import RawTensor, as_raw_tensor
from .tensor import Tensor
from .utils import isscalar from .utils import isscalar
from .utils import make_shape_tuple as _make_shape_tuple from .utils import make_shape_tuple as _make_shape_tuple
from .utils import setscalar from .utils import setscalar
...@@ -41,6 +42,7 @@ def _elwise(*args, mode): ...@@ -41,6 +42,7 @@ def _elwise(*args, mode):
) )
args = utils.convert_inputs(*args) args = utils.convert_inputs(*args)
(result,) = apply(op, *args) (result,) = apply(op, *args)
_isscalar = True _isscalar = True
for i in args: for i in args:
if isscalar(i) == False: if isscalar(i) == False:
...@@ -84,9 +86,7 @@ def _reshape(x, shape): ...@@ -84,9 +86,7 @@ def _reshape(x, shape):
if unspec_axis is not None: if unspec_axis is not None:
raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i))
unspec_axis = i unspec_axis = i
shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) shape = utils.astensor1d(shape, x, dtype="int32", device=x.device)
if unspec_axis is None: if unspec_axis is None:
op = builtin.Reshape() op = builtin.Reshape()
else: else:
...@@ -181,7 +181,6 @@ def _reduce(mode): ...@@ -181,7 +181,6 @@ def _reduce(mode):
elif isinstance(axis, collections.abc.Iterable): elif isinstance(axis, collections.abc.Iterable):
axis = list(axis) axis = list(axis)
axis.sort(reverse=True) axis.sort(reverse=True)
for ai in axis: for ai in axis:
op = builtin.Reduce(mode=mode, axis=ai) op = builtin.Reduce(mode=mode, axis=ai)
(data,) = apply(op, data) (data,) = apply(op, data)
...@@ -221,10 +220,7 @@ def _todo(*_): ...@@ -221,10 +220,7 @@ def _todo(*_):
def _expand_args(args): def _expand_args(args):
if len(args) == 1: if len(args) == 1:
if isinstance( if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),):
args[0],
(collections.abc.Sequence, TensorBase, TensorWrapperBase, np.ndarray),
):
args = args[0] args = args[0]
return args return args
...@@ -240,9 +236,8 @@ class ArrayMethodMixin(abc.ABC): ...@@ -240,9 +236,8 @@ class ArrayMethodMixin(abc.ABC):
return self.numpy().astype(dtype) return self.numpy().astype(dtype)
def __array_wrap__(self, array): def __array_wrap__(self, array):
return TensorWrapper( Wrapper = type(self)
as_raw_tensor(array, dtype=array.dtype, device=self.device) return Wrapper(array, dtype=array.dtype, device=self.device)
)
@abc.abstractmethod @abc.abstractmethod
def _reset(self, other): def _reset(self, other):
...@@ -253,7 +248,11 @@ class ArrayMethodMixin(abc.ABC): ...@@ -253,7 +248,11 @@ class ArrayMethodMixin(abc.ABC):
pass pass
@abc.abstractproperty @abc.abstractproperty
def shape(self) -> tuple: def shape(self) -> Union[tuple, Tensor]:
pass
@abc.abstractproperty
def _tuple_shape(self) -> tuple:
pass pass
@abc.abstractmethod @abc.abstractmethod
...@@ -331,7 +330,7 @@ class ArrayMethodMixin(abc.ABC): ...@@ -331,7 +330,7 @@ class ArrayMethodMixin(abc.ABC):
__complex__ = lambda self: complex(self.item()) __complex__ = lambda self: complex(self.item())
def __len__(self): def __len__(self):
shape = self.__wrapped__.shape shape = self._tuple_shape
if shape: if shape:
return int(shape[0]) return int(shape[0])
raise TypeError("ndim is 0") raise TypeError("ndim is 0")
...@@ -352,7 +351,7 @@ class ArrayMethodMixin(abc.ABC): ...@@ -352,7 +351,7 @@ class ArrayMethodMixin(abc.ABC):
@property @property
def ndim(self): def ndim(self):
shape = self.__wrapped__.shape shape = self._tuple_shape
if shape is None: if shape is None:
raise ValueError("unkown ndim") raise ValueError("unkown ndim")
return len(shape) return len(shape)
...@@ -480,22 +479,52 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): ...@@ -480,22 +479,52 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase):
self.__wrapped__._swap_out() self.__wrapped__._swap_out()
class TensorWrapper(GenericTensorWrapper): class TensorWrapper(ArrayMethodMixin, TensorBase):
def __init__(self, data, dtype=None, device=None): def __init__(self, data, dtype=None, device=None, isscalar=False):
if isinstance(data, TensorWrapperBase): self._isscalar = isscalar
data = data.__wrapped__ if isinstance(data, Tensor):
elif not isinstance(data, TensorBase): self._tensor = data
assert data is not None, "Cannot init a tensor with data as None" else:
data = Tensor(as_raw_tensor(data, dtype=dtype, device=device)) if device is None:
super().__init__(data) device = CompNode._get_default_device()
self._tensor = Tensor(data, dtype, device)
def _reset(self, other): def _reset(self, other):
if isinstance(other, TensorWrapperBase): if not isinstance(other, __class__):
self.__wrapped__ = other.__wrapped__ raise TypeError(type(other))
elif isinstance(other, TensorBase): self._tensor = other._tensor
self.__wrapped__ = other return self
else:
self._reset(type(self)(other, dtype=self.dtype, device=self.device)) @property
def dtype(self):
return self._tensor.dtype
@property
def shape(self):
if self._isscalar:
return ()
shape = self._tensor.shape
if shape == () or not use_symbolic_shape():
return shape
return apply(GetVarShape(), self)[0]
@property
def device(self):
return self._tensor.device
def numpy(self):
if self._isscalar:
return self._tensor.numpy().squeeze()
return self._tensor.numpy()
def _drop(self):
self._tensor._drop()
def _swap_in(self):
self._tensor._swap_in()
def _swap_out(self):
self._tensor._swap_out()
def __repr__(self): def __repr__(self):
piece = "Tensor(" piece = "Tensor("
......
...@@ -11,9 +11,10 @@ from typing import Iterable, Union ...@@ -11,9 +11,10 @@ from typing import Iterable, Union
import numpy as np import numpy as np
from .._imperative_rt.core2 import Tensor, apply
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply from ..tensor.core import OpBase, TensorBase, TensorWrapperBase
from .dtype import is_equal, is_quantize from .dtype import is_equal, is_quantize
_enable_convert_inputs = True _enable_convert_inputs = True
...@@ -109,7 +110,7 @@ def dtype_promotion(inputs): ...@@ -109,7 +110,7 @@ def dtype_promotion(inputs):
def get_device(inputs): def get_device(inputs):
device = None device = None
for i in inputs: for i in inputs:
if isinstance(i, (TensorWrapperBase, TensorBase)): if isinstance(i, Tensor):
if device is None: if device is None:
device = i.device device = i.device
elif device != i.device: elif device != i.device:
...@@ -126,30 +127,31 @@ def concatenate(inputs, axis=0, *, device=None): ...@@ -126,30 +127,31 @@ def concatenate(inputs, axis=0, *, device=None):
return convert_single_value(x, inputs, dtype=dtype) return convert_single_value(x, inputs, dtype=dtype)
inputs = tuple(map(convert, inputs)) inputs = tuple(map(convert, inputs))
(result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inputs) (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs)
return result return result
def astype(x, dtype): def astype(x, dtype):
dtype = np.dtype(dtype) dtype = np.dtype(dtype)
if not is_equal(x.dtype, dtype): if not is_equal(x.dtype, dtype):
isscalar = x.__wrapped__._data._isscalar isscalar = x.isscalar()
(x,) = apply(builtin.TypeCvt(dtype=dtype), x) (x,) = apply(builtin.TypeCvt(dtype=dtype), x)
x.__wrapped__._data._isscalar = isscalar if isscalar:
x.setscalar()
return x return x
def convert_single_value(v, inputs, *, dtype=None, device=None): def convert_single_value(v, inputs, *, dtype=None, device=None):
tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))] tensors = [i for i in inputs if isinstance(i, Tensor)]
assert len(tensors) > 0 assert len(tensors) > 0
if isinstance(v, (TensorWrapperBase, TensorBase)): if isinstance(v, (TensorWrapperBase, Tensor)):
v = astype(v, v.dtype if is_quantize(v.dtype) else dtype) v = astype(v, v.dtype if is_quantize(v.dtype) else dtype)
else: else:
(v,) = Const(v, dtype=dtype, device=device)(*tensors) (v,) = Const(v, dtype=dtype, device=device)(*tensors)
return v return v
def convert_inputs(*args: TensorBase): def convert_inputs(*args: Tensor):
if not _enable_convert_inputs: if not _enable_convert_inputs:
return args return args
...@@ -167,7 +169,7 @@ def convert_inputs(*args: TensorBase): ...@@ -167,7 +169,7 @@ def convert_inputs(*args: TensorBase):
def result_type(*args): def result_type(*args):
dtypes = [] dtypes = []
for i in args: for i in args:
if isinstance(i, (TensorWrapperBase, TensorBase)): if isinstance(i, Tensor):
dtypes.append(i.dtype) dtypes.append(i.dtype)
continue continue
try: try:
...@@ -178,25 +180,16 @@ def result_type(*args): ...@@ -178,25 +180,16 @@ def result_type(*args):
def isscalar(x): def isscalar(x):
if isinstance(x, TensorWrapperBase):
x = x.__wrapped__
if hasattr(x, "_isscalar"): if isinstance(x, Tensor):
return x._isscalar return x.isscalar()
if isinstance(x, TensorBase):
return x._data._isscalar
return np.isscalar(x) return np.isscalar(x)
def setscalar(x): def setscalar(x):
if isinstance(x, TensorWrapperBase): if isinstance(x, Tensor):
x = x.__wrapped__ x.setscalar()
if hasattr(x, "_isscalar"):
x._isscalar = True
elif isinstance(x, TensorBase):
x._data._isscalar = True
else: else:
raise NotImplementedError("Unsupport type {}".format(type(x))) raise NotImplementedError("Unsupport type {}".format(type(x)))
...@@ -215,25 +208,24 @@ def astensor1d(x, *reference, dtype=None, device=None): ...@@ -215,25 +208,24 @@ def astensor1d(x, *reference, dtype=None, device=None):
else: else:
if ndim != 0 and ndim != 1: if ndim != 0 and ndim != 1:
raise ValueError("ndim != 1 or 0, get : %d" % ndim) raise ValueError("ndim != 1 or 0, get : %d" % ndim)
if not isinstance(x, (TensorBase, TensorWrapperBase)): if not isinstance(x, Tensor):
(x,) = Const(x, dtype=dtype, device=device)(*reference) (x,) = Const(x, dtype=dtype, device=device)(*reference)
return x return x
if not isinstance(x, collections.abc.Sequence): if not isinstance(x, collections.abc.Sequence):
raise TypeError raise TypeError
if any(isinstance(i, (TensorBase, TensorWrapperBase)) for i in x): if any(isinstance(i, Tensor) for i in x):
x = concatenate(x, device=device) x = concatenate(x, device=device)
if dtype is not None: if dtype is not None:
x = astype(x, dtype) x = astype(x, dtype)
return x return x
(x,) = Const(x, dtype=dtype, device=device)(*reference) (x,) = Const(x, dtype=dtype, device=device)(*reference)
return x return x
def _expand_int(s, i): def _expand_int(s, i):
if isinstance(i, (TensorBase, TensorWrapperBase)): if isinstance(i, Tensor):
i_np = i.numpy() i_np = i.numpy()
if i_np.ndim == 0: if i_np.ndim == 0:
s.append(int(i_np)) s.append(int(i_np))
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Optional, Tuple from typing import Optional, Tuple
from ..core._imperative_rt.core2 import apply
from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn
from ..core.autodiff.grad import ( from ..core.autodiff.grad import (
Tracer, Tracer,
...@@ -17,7 +18,6 @@ from ..core.autodiff.grad import ( ...@@ -17,7 +18,6 @@ from ..core.autodiff.grad import (
tracer_apply, tracer_apply,
) )
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.core import apply
from ..core.tensor.tensor import Tensor, tensor_apply from ..core.tensor.tensor import Tensor, tensor_apply
from ..device import get_default_device from ..device import get_default_device
from ..tensor import tensor from ..tensor import tensor
...@@ -39,71 +39,6 @@ __all__ = [ ...@@ -39,71 +39,6 @@ __all__ = [
] ]
@apply.register()
def _(op: RemoteSend, *args: Tensor):
ret = tensor_apply(op, *args)
# set extra information
tracer_set = dict()
for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))):
tracer_set[k.name] = True
# check tracer_set in remote_recv
get_client().set_remote_tracer(op.key, tracer_set)
return ret
@builtin_op_get_backward_fn.register(RemoteSend)
def _(op: RemoteSend, inputs, outputs, input_requires_grad):
def backward(*args):
return [
remote_recv(
op.rank_to,
inputs[0].shape,
inputs[0].dtype,
device=str(inputs[0].device),
inp=inputs[0],
)
]
return backward, [True]
@get_op_has_grad_fn.register(RemoteSend)
def _(op: RemoteSend):
def has_grad(opnode, reached):
return get_client().check_is_grad(op.key)
return has_grad
@check_backward_allow_noinput.register(RemoteSend)
def _(op: RemoteSend):
return True
@builtin_op_get_backward_fn.register(RemoteRecv)
def _(op: RemoteRecv, inputs, outputs, input_requires_grad):
def backward(*output_grads):
return [remote_send(output_grads[0], op.rank_from)]
return backward, [True]
@get_op_has_grad_fn.register(RemoteRecv)
def _(op: RemoteRecv):
def has_grad(opnode, reached):
ret = False
for v in opnode.outputs:
if v() in reached:
ret = True
break
get_client().set_is_grad(op.key, ret)
return ret
return has_grad
def collective_comm(inp, mode, group, device): def collective_comm(inp, mode, group, device):
"""Helper function for applying collective communication functions.""" """Helper function for applying collective communication functions."""
assert isinstance(group, Group) assert isinstance(group, Group)
......
...@@ -17,8 +17,8 @@ import numpy as np ...@@ -17,8 +17,8 @@ import numpy as np
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
from megengine.device import get_default_device, get_device_count from megengine.device import get_default_device, get_device_count
from ..core._imperative_rt.core2 import apply
from ..core.ops.builtin import ParamPackConcat, ParamPackSplit from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
from ..core.tensor.core import apply
from ..functional.utils import copy from ..functional.utils import copy
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.future import Future from ..utils.future import Future
...@@ -228,7 +228,6 @@ class AllreduceCallback: ...@@ -228,7 +228,6 @@ class AllreduceCallback:
self._packing_size[dtype] = 0 self._packing_size[dtype] = 0
def __call__(self, param, grad): def __call__(self, param, grad):
param = param.__wrapped__
gm = get_backwarding_grad_manager() gm = get_backwarding_grad_manager()
assert isinstance(gm, GradManager) assert isinstance(gm, GradManager)
if gm not in self._marked_gm: if gm not in self._marked_gm:
......
...@@ -9,10 +9,10 @@ ...@@ -9,10 +9,10 @@
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
import functools import functools
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import Elemwise from ..core.ops.builtin import Elemwise
from ..core.tensor import megbrain_graph, utils from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import apply
from ..core.tensor.utils import isscalar, setscalar from ..core.tensor.utils import isscalar, setscalar
from ..device import get_default_device from ..device import get_default_device
from ..jit.tracing import is_tracing from ..jit.tracing import is_tracing
......
...@@ -12,10 +12,11 @@ import math ...@@ -12,10 +12,11 @@ import math
import numbers import numbers
from typing import Optional, Sequence, Tuple, Union from typing import Optional, Sequence, Tuple, Union
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import utils from ..core.tensor import utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.core import TensorBase, TensorWrapperBase
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import clip, exp, log, log1p from .elemwise import clip, exp, log, log1p
from .tensor import reshape, squeeze from .tensor import reshape, squeeze
......
...@@ -10,12 +10,12 @@ ...@@ -10,12 +10,12 @@
from typing import Optional, Sequence, Tuple, Union from typing import Optional, Sequence, Tuple, Union
from ..core._imperative_rt import CompNode from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import Tensor, apply
from ..core._trace_option import use_symbolic_shape from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm from ..core.ops.builtin import BatchNorm
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph, utils from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.utils import astensor1d from ..core.tensor.utils import astensor1d
from ..distributed import WORLD, is_distributed from ..distributed import WORLD, is_distributed
from ..jit.tracing import is_tracing from ..jit.tracing import is_tracing
...@@ -1565,9 +1565,7 @@ def indexing_one_hot( ...@@ -1565,9 +1565,7 @@ def indexing_one_hot(
[1.] [1.]
""" """
assert isinstance( assert isinstance(src, Tensor), "src must be of Tensor type"
src, (TensorWrapperBase, TensorBase)
), "src must be of Tensor type"
op = builtin.IndexingOneHot(axis=axis) op = builtin.IndexingOneHot(axis=axis)
index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device)
(result,) = apply(op, src, index) (result,) = apply(op, src, index)
......
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
from typing import Tuple, Union from typing import Tuple, Union
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin from ..core.ops import builtin
from ..core.tensor.core import apply
from ..tensor import Tensor from ..tensor import Tensor
from .debug_param import get_conv_execution_strategy from .debug_param import get_conv_execution_strategy
from .types import _pair, _pair_nonzero from .types import _pair, _pair_nonzero
......
...@@ -14,10 +14,10 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union ...@@ -14,10 +14,10 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
from ..core._imperative_rt import CompNode from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import Tensor, apply
from ..core._wrap import device as as_device from ..core._wrap import device as as_device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis
from ..core.tensor.utils import ( from ..core.tensor.utils import (
astensor1d, astensor1d,
...@@ -611,11 +611,11 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: ...@@ -611,11 +611,11 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
""" """
x, y = convert_inputs(x, y) x, y = convert_inputs(x, y)
if not isinstance(x, (TensorWrapperBase, TensorBase)): if not isinstance(x, Tensor):
raise TypeError("input x must be a tensor") raise TypeError("input x must be a tensor")
if not isinstance(y, (TensorWrapperBase, TensorBase)): if not isinstance(y, Tensor):
raise TypeError("input y must be a tensor") raise TypeError("input y must be a tensor")
if not isinstance(mask, (TensorWrapperBase, TensorBase)): if not isinstance(mask, Tensor):
raise TypeError("mask must be a tensor") raise TypeError("mask must be a tensor")
if mask.dtype != np.bool_: if mask.dtype != np.bool_:
raise ValueError("mask must be bool") raise ValueError("mask must be bool")
...@@ -668,9 +668,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor: ...@@ -668,9 +668,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor:
[1. 4.] [0 3] [1. 4.] [0 3]
""" """
if not isinstance(x, (TensorWrapperBase, TensorBase)): if not isinstance(x, Tensor):
raise TypeError("input must be a tensor") raise TypeError("input must be a tensor")
if not isinstance(mask, (TensorWrapperBase, TensorBase)): if not isinstance(mask, Tensor):
raise TypeError("mask must be a tensor") raise TypeError("mask must be a tensor")
if mask.dtype != np.bool_: if mask.dtype != np.bool_:
raise ValueError("mask must be bool") raise ValueError("mask must be bool")
......
...@@ -11,10 +11,10 @@ from typing import Iterable, Union ...@@ -11,10 +11,10 @@ from typing import Iterable, Union
import numpy as np import numpy as np
from ..core._imperative_rt.core2 import apply
from ..core._wrap import device as as_device from ..core._wrap import device as as_device
from ..core.ops.builtin import Copy, Identity from ..core.ops.builtin import Copy, Identity
from ..core.tensor import Tensor from ..tensor import Tensor
from ..core.tensor.core import apply
from .math import topk as _topk from .math import topk as _topk
from .tensor import broadcast_to, transpose from .tensor import broadcast_to, transpose
......
...@@ -10,9 +10,9 @@ from typing import Iterable, Optional ...@@ -10,9 +10,9 @@ from typing import Iterable, Optional
from .. import Tensor from .. import Tensor
from ..core._imperative_rt import invoke_op from ..core._imperative_rt import invoke_op
from ..core._imperative_rt.core2 import apply
from ..core.ops.builtin import GaussianRNG, UniformRNG from ..core.ops.builtin import GaussianRNG, UniformRNG
from ..core.tensor import utils from ..core.tensor import utils
from ..core.tensor.core import apply
from .rng import _random_seed_generator from .rng import _random_seed_generator
__all__ = ["normal", "uniform"] __all__ = ["normal", "uniform"]
......
...@@ -10,26 +10,66 @@ ...@@ -10,26 +10,66 @@
import collections import collections
from .core import Tensor as _Tensor import numpy as np
from .core.ops.builtin import Copy
from .core.tensor.core import apply from .core._imperative_rt import CompNode
from .core._imperative_rt.core2 import Tensor as _Tensor
from .core._imperative_rt.core2 import apply
from .core._trace_option import use_symbolic_shape
from .core.ops.builtin import Copy, GetVarShape
from .core.tensor.raw_tensor import as_device from .core.tensor.raw_tensor import as_device
from .core.tensor.tensor_wrapper import ArrayMethodMixin
from .device import _valid_device, get_default_device from .device import _valid_device, get_default_device
from .utils.deprecation import deprecated from .utils.deprecation import deprecated
class Tensor(_Tensor): class Tensor(_Tensor, ArrayMethodMixin):
grad = None grad = None
dmap_callback = None dmap_callback = None
q_dict = {"mode": None, "scale": None, "zero_point": None}
def __init__(self, data, dtype=None, device=None): def __new__(cls, data, dtype=None, device=None):
if device is None: if device is None:
device = get_default_device() cn = get_default_device()
self.q_dict = {"mode": None, "scale": None, "zero_point": None} elif isinstance(device, str):
super().__init__(data, dtype=dtype, device=device) if cls.dmap_callback is not None:
cn = CompNode(cls.dmap_callback(device))
else:
cn = CompNode(device)
else:
assert isinstance(device, CompNode)
cn = device
if isinstance(data, _Tensor):
obj = _Tensor.__new__(cls, data)
else:
obj = _Tensor.__new__(cls, data, dtype, cn)
return obj
@property
def shape(self):
shape = super().shape
if shape == () or not use_symbolic_shape():
return shape
return apply(GetVarShape(), self)[0]
@property
def _tuple_shape(self):
return super().shape
def __repr__(self):
piece = "Tensor("
with np.printoptions(precision=4, suppress=True):
piece += "{}".format(str(self.numpy()))
if self.dtype != np.float32:
piece += ", dtype={}".format(np.dtype(self.dtype).name)
piece += ", device={}".format(self.device) + ")"
return piece
@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
def set_value(self, value): def set_value(self, value):
if not isinstance(value, _Tensor):
value = Tensor(value, dtype=self.dtype, device=self.device)
self._reset(value) self._reset(value)
@deprecated(version="1.0", reason="use *= 0 instead") @deprecated(version="1.0", reason="use *= 0 instead")
...@@ -61,27 +101,22 @@ class Tensor(_Tensor): ...@@ -61,27 +101,22 @@ class Tensor(_Tensor):
def __hash__(self): def __hash__(self):
return id(self) return id(self)
def __getnewargs__(self):
r""" __getnewargs__ will be called for pickle serialization or deep copy
"""
return (self.numpy(), self.dtype, self.device.logical_name)
def __getstate__(self): def __getstate__(self):
r""" __getstate__ will be called for pickle serialization or deep copy r""" __getstate__ will be called for pickle serialization or deep copy
""" """
state = { state = {
"data": self.numpy(),
"device": self.device.logical_name,
"dtype": self.dtype,
"qdict": self.q_dict, "qdict": self.q_dict,
} }
return state return state
def __setstate__(self, state): def __setstate__(self, state):
data = state.pop("data")
logical_device = state.pop("device")
if self.dmap_callback is not None:
assert isinstance(logical_device, str)
logical_device = self.dmap_callback(logical_device)
dtype = state.pop("dtype")
self.q_dict = state.pop("qdict") self.q_dict = state.pop("qdict")
super().__init__(data, dtype=dtype, device=logical_device)
def detach(self): def detach(self):
r""" r"""
...@@ -89,8 +124,7 @@ class Tensor(_Tensor): ...@@ -89,8 +124,7 @@ class Tensor(_Tensor):
during backward gradient calcuation, i.e. its gradient is zero. during backward gradient calcuation, i.e. its gradient is zero.
""" """
Wrapper = type(self) Wrapper = type(self)
Tensor = type(self.__wrapped__) return Wrapper(self)
return Wrapper(Tensor(self.__wrapped__._data))
tensor = Tensor tensor = Tensor
......
/**
* \file imperative/python/src/grad.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./grad.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/utils/mempool.h"
namespace py = pybind11;
namespace mgb::imperative::python {
namespace {
struct GradSlotWeakPtr {
std::weak_ptr<GradFn> grad_fn;
size_t idx;
};
} // namespace
struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> {
using Base = intrusive_list::Node<GradProducerRecord>;
GradProducerRecord() = default;
GradProducerRecord(GradProducerRecord::head_t& head) : Base(intrusive_list::after_t{}, head) {}
// GradProducerRecord(GradProducerRecord&&) = default;
// GradProducerRecord& operator=(GradProducerRecord&) = default;
// GradProducerRecord& operator=(GradProducerRecord&&) = default;
};
struct GradSlot {
std::shared_ptr<Tensor> grad;
py::object callback;
GradProducerRecord::head_t producer_head;
};
struct GradSlotProducerPtr : GradSlotPtr {
GradProducerRecord producer_record;
GradSlotProducerPtr() = default;
GradSlotProducerPtr(GradInfo& info) : GradSlotPtr(info), producer_record(info->producer_head) {}
};
struct GradFn : std::enable_shared_from_this<GradFn> {
static MemPool<GradFn> pool;
std::weak_ptr<GradKey> key;
SmallVector<GradSlot> slots;
SmallVector<GradSlotProducerPtr> dsts;
SmallVector<std::shared_ptr<Tensor>> closure;
std::shared_ptr<BackwardGraphResult> backward_graph;
bool in_ref_keeper = false;
static void deleter(GradFn* ptr) {
pool.free(ptr);
}
std::shared_ptr<GradFn> make() {
return std::shared_ptr<GradFn>(pool.alloc(), &deleter);
}
void clear() {
key.reset();
slots.clear();
dsts.clear();
closure.clear();
backward_graph.reset();
}
};
GradSlot* GradSlotPtr::operator->() {
return &grad_fn->slots[idx];
}
namespace {
struct BackwardGraphCache : std::unordered_map<size_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject {
std::shared_ptr<void> on_comp_node_finalize() override {
clear();
return {};
}
} backward_graph_cache;
std::shared_ptr<BackwardGraphResult> make_backward_graph(
ApplyContext& ctx, const apply_result_t& outputs) {
// hash
static_assert(alignof(size_t) % alignof(bool) == 0);
size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool);
alignas(alignof(size_t)) std::byte buf[buf_size];
size_t* size_t_ptr = reinterpret_cast<size_t*>(buf);
bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2));
bool* bool_ptr0 = bool_ptr;
*(size_t_ptr++) = ctx.op->hash();
for (size_t i = 0; i < ctx.nargs; ++i) {
*(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle());
*(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node());
*(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn);
}
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) &&
bool_ptr == reinterpret_cast<bool*>(buf + buf_size));
size_t key = XXHash{}.update(buf, buf_size).digest();
auto&& iter = backward_graph_cache.find(key);
if (iter != backward_graph_cache.end()) {
return iter->second;
}
// slow path
SmallVector<LogicalTensorDesc> inputs(ctx.nargs);
SmallVector<bool> input_requires_grad(ctx.nargs, false);
SmallVector<bool> output_has_grad(outputs.size(), true);
for (size_t i = 0; i < ctx.nargs; ++i) {
inputs[i].comp_node = ctx.args[i]->comp_node();
inputs[i].layout.dtype = ctx.args[i]->dtype();
input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn);
}
auto result = std::make_shared<BackwardGraphResult>(
proxy_graph_detail::make_backward_graph(
*ctx.op, inputs, input_requires_grad, output_has_grad));
if (!result->backward) {
result.reset();
}
backward_graph_cache.emplace(key, result);
return result;
}
} // namespace
apply_result_t apply_grad(ApplyContext& ctx) {
std::shared_ptr<GradKey> grad_key;
for (size_t i = 0; i < ctx.nargs; ++i) {
auto* tensor = ctx.args[i];
if (tensor->m_grad_info.grad_fn) {
auto&& input_grad_key = tensor->m_grad_info.grad_fn->key.lock();
// tensor is attached to a live GradKey
if (input_grad_key && input_grad_key->active) {
if (grad_key) {
if (grad_key != input_grad_key) {
PyErr_SetString(PyExc_NotImplementedError, "second order grad");
throw pyext17::py_err_set();
}
} else {
grad_key = std::move(input_grad_key);
}
} else {
// cleanup stale grad info
// under what condition?
tensor->m_grad_info = {};
tensor->m_flags &= ~Tensor::Flags::GRAD;
}
} else {
tensor->m_flags &= ~Tensor::Flags::GRAD;
}
}
ctx.flags &= ~Tensor::Flags::GRAD;
// perform forward apply_op or trace
auto outputs = apply(ctx);
if (!grad_key) {
return outputs;
}
auto backward_graph = make_backward_graph(ctx, outputs);
if (!backward_graph) {
return outputs;
}
auto grad_fn = std::make_shared<GradFn>();
grad_fn->key = grad_key;
grad_fn->slots.resize(outputs.size());
grad_fn->backward_graph = std::move(backward_graph);
grad_fn->dsts.reserve(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
if (grad_fn->backward_graph->input_has_grad[i]) {
auto& input_grad_info = ctx.args[i]->m_grad_info;
grad_fn->dsts.emplace_back(input_grad_info);
grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head);
} else {
grad_fn->dsts.emplace_back();
}
}
auto& save_for_backward = grad_fn->backward_graph->save_for_backward;
grad_fn->closure.reserve(std::count_if(save_for_backward.begin(), save_for_backward.end(), [](bool p){return p;}));
// given op, taking gradient of output_tensor_list wrt input_tensor_list:
//
// save_for_backward[0:nargs-1]: whether input tensor requires gradient,
// i.e., whether it is in input_tensor_list
//
// save_for_backward[nargs:nargs+outputs.size()-1]: whether output tensor is
// needed to calculate gradients
//
// save_for_backward[-outputs.size():]: whether output tensor is in
// output_tensor_list
//
// Example: perform c = a * b, where a is input data, b is parameter to be
// optimized, save_for_backward = [1, 1, 0, 1]
mgb_assert(ctx.nargs + 2 * outputs.size() == save_for_backward.size());
// record input tensors needed to take grad
for (size_t i = 0; i < ctx.nargs; ++i) {
if (save_for_backward[i]) {
grad_fn->closure.push_back(ctx.args[i]->shared_from_this());
}
}
// record output tensors needed to take grad
for (size_t i = 0; i < outputs.size(); ++i) {
bool requires_grad = save_for_backward[ctx.nargs + outputs.size() + i];
if (save_for_backward[ctx.nargs + i]) {
grad_fn->closure.push_back(outputs[i]);
if (requires_grad) {
// avoid reference cycle [Tensor <-> GradFn]
outputs[i] = outputs[i]->copy();
}
}
if (requires_grad) {
auto& grad_info = outputs[i]->m_grad_info;
grad_info.grad_fn = grad_fn;
grad_info.idx = i;
grad_info.insert_after(grad_key->free_vars_head);
outputs[i]->m_flags |= Tensor::Flags::GRAD;
}
}
// record forward history
grad_key->tape.emplace_back(grad_fn);
return outputs;
}
void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) {
if (nargs != 2) {
throw py::type_error("expect 2 arguments");
}
auto* tw = TensorWrapper::cast_safe(args[0]);
if (!tw) {
throw py::type_error("argument 1 must be Tensor");
}
auto* tensor = tw->m_tensor.get();
py::object callback;
if (args[1] != Py_None) {
callback = py::reinterpret_borrow<py::object>(args[1]);
}
m_key->attach(tensor, std::move(callback));
}
//! GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach
void GradKey::attach(Tensor* tensor, pybind11::object callback) {
if (!active) {
throw py::value_error("grad key finalized");
}
if (tensor->m_grad_info.grad_fn) {
if (tensor->m_grad_info.grad_fn->key.lock().get() != this) {
PyErr_SetString(PyExc_NotImplementedError, "second order grad");
throw pyext17::py_err_set();
}
if (tensor->m_grad_info->callback) {
throw py::value_error("callback already set on this tensor");
}
} else {
tensor->m_grad_info.idx = 0;
auto& grad_fn = tensor->m_grad_info.grad_fn;
grad_fn = std::make_shared<GradFn>();
grad_fn->key = shared_from_this();
grad_fn->slots.resize(1);
tensor->m_grad_info.insert_after(free_vars_head);
tensor->m_flags |= Tensor::Flags::GRAD;
}
tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback);
}
void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta) {
if (!grad) {
grad = std::forward<decltype(delta)>(delta);
return;
}
static ApplyContext ctx;
if (!ctx.op) {
ctx.op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD));
ctx.nargs = 2;
}
Tensor* args[2] = {grad.get(), delta.get()};
ctx.args = args;
ctx.flags = grad->m_flags | delta->m_flags;
grad = apply(ctx)[0];
}
void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
if (!active) {
throw py::value_error("finalized");
}
if (tensors.size() != grads.size()) {
throw py::value_error("tensor and grad size mismatch");
}
// this GradKey is marked inactive here
active = false;
struct CleanupGuard {
GradKey* owner;
CleanupGuard(GradKey* this_) : owner(this_) {}
~CleanupGuard() {owner->cleanup();}
} _cleanup_guard(this);
if (tape.empty() || grads.empty()) return;
PyTypeObject* pytype = Py_TYPE(grads[0]->self().ptr());
for (size_t i = 0; i < tensors.size(); ++i) {
auto& grad_info = tensors[i]->m_tensor->m_grad_info;
if (grad_info.grad_fn && grad_info.grad_fn->key.lock().get() == this) {
grad_info->grad = grads[i]->m_tensor;
}
}
std::vector<std::shared_ptr<GradFn>> ref_keeper;
ref_keeper.reserve(tape.size());
// back-propagation in reverse order
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
auto&& grad_fn = tape[k].lock();
if (!grad_fn) continue;
if (grad_fn->backward_graph) {
for (size_t i = 0; i < grad_fn->slots.size(); ++i) {
// grad_fn->dsts correspond to input tensors during forward
// calculation, grad_fn->slots correspond to output tensors.
// condition true means the output tensor has gradient for
// back-propagation
if (grad_fn->backward_graph->save_for_backward[grad_fn->dsts.size() + grad_fn->slots.size() + i]) {
grad_fn->closure.push_back(std::move(grad_fn->slots[i].grad));
}
}
ApplyContext ctx;
ctx.op = grad_fn->backward_graph->backward;
ctx.flags = 0;
ctx.nargs = grad_fn->closure.size();
Tensor* args[ctx.nargs];
for (size_t i = 0; i < ctx.nargs; ++i) {
args[i] = grad_fn->closure[i].get();
mgb_assert(args[i]);
ctx.flags |= args[i]->m_flags;
}
ctx.args = args;
auto grads = apply(ctx);
size_t j = 0;
for (size_t i = 0; i < grad_fn->dsts.size(); ++i) {
if (grad_fn->backward_graph->input_has_grad[i]) {
auto& dst = grad_fn->dsts[i];
// grads[j] is consumed in accum_grad
accum_grad(dst->grad, std::move(grads[j]));
++j;
}
}
mgb_assert(j == grads.size());
}
for (auto&& dst : grad_fn->dsts) {
if (!dst.grad_fn) continue;
if (!dst.grad_fn->in_ref_keeper) {
dst.grad_fn->in_ref_keeper = true;
ref_keeper.push_back(dst.grad_fn);
}
// grad_fn->clear will unlink current dst.producer_record
// such that if dst.producer_record.next is false, dst accumulates
// all the gradients
if (!dst.producer_record.next && dst->callback && dst->grad) {
dst->callback(TensorWrapper::make(pytype, dst->grad));
}
}
grad_fn->clear();
} // finish tape loop
}
void GradKey::cleanup() {
active = false;
tape.clear();
for (intrusive_list::Iterator it(free_vars_head); it;) {
it->grad_fn.reset();
(it++)->unlink();
}
}
void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
m_key->backward(std::move(tensors), std::move(grads));
}
GradKey::~GradKey() {
cleanup();
}
} // namespace mgb::imperative::python
/**
* \file imperative/python/src/grad.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "./tensor.h"
#include <megbrain/utils/small_vector.h>
#include <memory>
namespace mgb::imperative::python {
apply_result_t apply_grad(ApplyContext& ctx);
struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
std::string name;
bool active = true;
GradInfo::head_t free_vars_head;
std::vector<std::weak_ptr<GradFn>> tape;
~GradKey();
void attach(Tensor* tensor, pybind11::object callback);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
void cleanup();
};
struct GradKeyWrapper {
using wrap_t = pyext17::wrap<GradKeyWrapper>;
static constexpr auto tp_name = pybind11::detail::_("GradKey");
std::shared_ptr<GradKey> m_key;
inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {}
void attach(PyObject*const* args, size_t nargs);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
};
} // namespace mgb::imperative::python
namespace pybind11::detail {
template<> struct type_caster<mgb::imperative::python::GradKeyWrapper> : mgb::imperative::python::GradKeyWrapper::wrap_t::caster {};
} // namespace pybind11::detail
/**
* \file imperative/python/src/grad_info.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <memory>
#include "./intrusive_list.h"
namespace mgb::imperative::python {
struct GradFn;
struct GradSlot;
struct GradSlotPtr {
std::shared_ptr<GradFn> grad_fn;
size_t idx;
GradSlot* operator->();
};
struct GradInfo : GradSlotPtr, intrusive_list::Node<GradInfo, intrusive_list::before_t> {
GradInfo() = default;
GradInfo(GradInfo&) = default;
GradInfo(GradInfo&&) = default;
GradInfo& operator=(GradInfo&) = default;
GradInfo& operator=(GradInfo&&) = default;
};
} // namespace mgb::imperative::python
/**
* \file imperative/python/src/intrusive_list.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/utils/metahelper.h"
namespace mgb::imperative::python::intrusive_list {
// copy policy
struct after_t {};
struct before_t {};
struct disable_t {};
template <typename T> struct Tail;
// invariant: next->prev == this
template <typename T>
struct Head {
Tail<T>* next;
Head(Tail<T>* node = nullptr) : next(node) {}
Head(const Head<T>&) = delete;
Head<T>& operator=(const Head<T>&) = delete;
Head(Head<T>&& rhs) : next(rhs.next) {
rhs.next = nullptr;
if (next) {
next->prev = this;
}
}
Head<T>& operator=(Head<T>&& rhs) {
mgb_assert(!next);
next = rhs.next;
rhs.next = nullptr;
if (next) {
next->prev = this;
}
return *this;
}
~Head() {
if (next) {
next->prev = nullptr;
}
}
};
// invariant: prev->next == this
template <typename T>
struct Tail {
Head<T>* prev;
Tail(Head<T>* node = nullptr) : prev(node) {}
Tail(const Tail<T>&) = delete;
Tail<T>& operator=(const Tail<T>&) = delete;
Tail(Tail<T>&& rhs) : prev(rhs.prev) {
rhs.prev = nullptr;
if (prev) {
prev->next = this;
}
}
Tail<T>& operator=(Tail<T>&& rhs) {
mgb_assert(!prev);
prev = rhs.prev;
rhs.prev = nullptr;
if (prev) {
prev->next = this;
}
return *this;
}
~Tail() {
if (prev) {
prev->next = nullptr;
}
}
};
template <typename T, typename policy> struct Node;
template <typename T>
class Iterator {
T* ptr;
void inc() {ptr = static_cast<T*>(ptr->Head<T>::next);}
void dec() {ptr = static_cast<T*>(ptr->Head<T>::prev);}
public:
Iterator(Head<T>& head) : ptr(static_cast<T*>(head.next)) {}
Iterator(Tail<T>& tail) : ptr(static_cast<T*>(tail.prev)) {}
template<typename policy>
Iterator(Node<T, policy>& node) : ptr(static_cast<T*>(&node)) {}
T& operator*() {return *static_cast<T*>(ptr);}
T* operator->() {return static_cast<T*>(ptr);}
operator bool() {return ptr;}
bool operator==(const Iterator<T>& rhs) {return ptr == rhs.ptr;}
Iterator& operator++() {inc(); return *this;}
Iterator& operator--() {dec(); return *this;}
Iterator operator++(int) {auto ret = *this; inc(); return ret;}
Iterator operator--(int) {auto ret = *this; dec(); return ret;}
};
// Node in a doubly linked list. Unlike std::list, nodes are not owned by a container.
// Instead, nodes may join or leave a list freely.
// NOTE: Derived classes have to explicitly declare copy / assignment as default,
// otherwise the compiler generated version would use the const T& signature,
// which is deleted.
template <typename T = void, typename policy = disable_t>
struct Node : Tail<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>>,
Head<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>> {
private:
using this_t = Node<T, policy>;
using U = std::conditional_t<std::is_same_v<T, void>, this_t, T>;
public:
using head_t = Head<U>;
using tail_t = Tail<U>;
using head_t::next;
using tail_t::prev;
Node() = default;
Node(const this_t&) = delete;
this_t& operator=(const this_t&) = delete;
//! constructed node is inserted after the input node
Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) {
node.next = this;
if (next) {
next->prev = this;
}
}
//! constructed node is inserted before the input node
Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) {
node.prev = this;
if (prev) {
prev->next = this;
}
}
Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) {
rhs.prev = nullptr;
rhs.next = nullptr;
if (prev) {
prev->next = this;
}
if (next) {
next->prev = this;
}
}
Node& operator=(this_t&& rhs) {
unlink();
prev = rhs.prev;
next = rhs.next;
rhs.prev = nullptr;
rhs.next = nullptr;
if (prev) {
prev->next = this;
}
if (next) {
next->prev = this;
}
return *this;
}
template<typename p = policy,
typename = std::enable_if_t<std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>>
Node(this_t& rhs) : Node(policy{}, rhs) {}
template<typename p = policy,
typename = std::enable_if_t<std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>>
this_t& operator=(this_t& rhs) {
insert(policy{}, rhs);
return *this;
}
void unlink() {
if (prev) {
prev->next = next;
}
if (next) {
next->prev = prev;
}
prev = nullptr;
next = nullptr;
}
//! this node is unlinked from its list and inserted after the input node
void insert(after_t, head_t& node) {
unlink();
prev = &node;
next = node.next;
node.next = this;
if (next) {
next->prev = this;
}
}
//! this node is unlinked from its list and inserted before the input node
void insert(before_t, tail_t& node) {
unlink();
next = &node;
prev = node.prev;
node.prev = this;
if (prev) {
prev->next = this;
}
}
void insert_before(tail_t& node) {insert(before_t{}, node);}
void insert_after(head_t& node) {insert(after_t{}, node);}
~Node() {
unlink();
}
};
} // namespace mgb::imperative::python::intrusive_list
...@@ -23,7 +23,10 @@ ...@@ -23,7 +23,10 @@
#include "./dispatcher.h" #include "./dispatcher.h"
#include "./tensor.h"
namespace py = pybind11; namespace py = pybind11;
using namespace mgb::imperative::python;
#ifndef MODULE_NAME #ifndef MODULE_NAME
#define MODULE_NAME imperative_rt #define MODULE_NAME imperative_rt
...@@ -68,4 +71,6 @@ PYBIND11_MODULE(MODULE_NAME, m) { ...@@ -68,4 +71,6 @@ PYBIND11_MODULE(MODULE_NAME, m) {
py::getattr(m, "__dict__")); py::getattr(m, "__dict__"));
init_dispatcher(submodule(m, "dispatcher")); init_dispatcher(submodule(m, "dispatcher"));
init_tensor(submodule(m, "core2"));
} }
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <Python.h> #include <Python.h>
#include <pybind11/pybind11.h>
namespace pyext17 { namespace pyext17 {
...@@ -53,6 +54,26 @@ inline PyObject* cvt_retval(PyObject* rv) { ...@@ -53,6 +54,26 @@ inline PyObject* cvt_retval(PyObject* rv) {
return cvt_retval(__VA_ARGS__); \ return cvt_retval(__VA_ARGS__); \
} }
inline int cvt_retint(int ret) {
return ret;
}
#define CVT_RET_INT(...) \
if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \
__VA_ARGS__; \
return 0; \
} else { \
return cvt_retint(__VA_ARGS__); \
}
struct py_err_set : std::exception {};
#define HANDLE_ALL_EXC(RET) catch(py_err_set&) {return RET;} \
catch(pybind11::error_already_set& e) {e.restore(); return RET;} \
catch(pybind11::builtin_exception& e) {e.set_error(); return RET;} \
catch(std::exception& e) {PyErr_SetString(PyExc_RuntimeError, e.what()); return RET;}
template <typename T> template <typename T>
struct wrap { struct wrap {
private: private:
...@@ -111,7 +132,9 @@ private: ...@@ -111,7 +132,9 @@ private:
static PyObject* impl(PyObject* self, PyObject*) { static PyObject* impl(PyObject* self, PyObject*) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)()); try {
CVT_RET_PYOBJ((inst->*f)());
} HANDLE_ALL_EXC(nullptr)
} }
}; };
...@@ -121,7 +144,9 @@ private: ...@@ -121,7 +144,9 @@ private:
static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)(args, kwargs)); try {
CVT_RET_PYOBJ((inst->*f)(args, kwargs));
} HANDLE_ALL_EXC(nullptr)
} }
}; };
...@@ -132,7 +157,9 @@ private: ...@@ -132,7 +157,9 @@ private:
static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) { static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)(args, nargs)); try {
CVT_RET_PYOBJ((inst->*f)(args, nargs));
} HANDLE_ALL_EXC(nullptr)
} }
#else #else
static constexpr int flags = METH_VARARGS; static constexpr int flags = METH_VARARGS;
...@@ -141,7 +168,9 @@ private: ...@@ -141,7 +168,9 @@ private:
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
auto* arr = &PyTuple_GET_ITEM(args, 0); auto* arr = &PyTuple_GET_ITEM(args, 0);
auto size = PyTuple_GET_SIZE(args); auto size = PyTuple_GET_SIZE(args);
CVT_RET_PYOBJ((inst->*f)(arr, size)); try {
CVT_RET_PYOBJ((inst->*f)(arr, size));
} HANDLE_ALL_EXC(nullptr)
} }
#endif #endif
}; };
...@@ -152,7 +181,9 @@ private: ...@@ -152,7 +181,9 @@ private:
static PyObject* impl(PyObject* self, PyObject* obj) { static PyObject* impl(PyObject* self, PyObject* obj) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)(obj)); try {
CVT_RET_PYOBJ((inst->*f)(obj));
} HANDLE_ALL_EXC(nullptr)
} }
}; };
...@@ -162,6 +193,55 @@ private: ...@@ -162,6 +193,55 @@ private:
return {name, (PyCFunction)M::impl, M::flags, doc}; return {name, (PyCFunction)M::impl, M::flags, doc};
} }
template<auto f>
struct getter {
using F = decltype(f);
static PyObject* impl(PyObject* self, void* closure) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
try {
if constexpr (std::is_invocable_v<F, PyObject*, void*>) {
CVT_RET_PYOBJ(f(self, closure));
} else if constexpr (std::is_invocable_v<F, T, void*>) {
CVT_RET_PYOBJ((inst->*f)(closure));
} else if constexpr (std::is_invocable_v<F, T>) {
CVT_RET_PYOBJ((inst->*f)());
} else {
static_assert(!std::is_same_v<F, F>);
}
} HANDLE_ALL_EXC(nullptr)
}
};
template<auto f>
struct setter {
using F = decltype(f);
template<typename = void>
static int impl_(PyObject* self, PyObject* val, void* closure) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
try {
if constexpr (std::is_invocable_v<F, PyObject*, PyObject*, void*>) {
CVT_RET_INT(f(self, val, closure));
} else if constexpr (std::is_invocable_v<F, T, PyObject*, void*>) {
CVT_RET_INT((inst->*f)(val, closure));
} else if constexpr (std::is_invocable_v<F, T, PyObject*>) {
CVT_RET_INT((inst->*f)(val));
} else {
static_assert(!std::is_same_v<F, F>);
}
} HANDLE_ALL_EXC(-1)
}
static constexpr auto impl = []() {if constexpr (std::is_same_v<F, std::nullptr_t>) return nullptr;
else return impl_<>;}();
};
template<auto get, auto set = nullptr>
static constexpr PyGetSetDef make_getset_def(const char* name, const char* doc = nullptr, void* closure = nullptr) {
return {const_cast<char *>(name), getter<get>::impl, setter<set>::impl, const_cast<char *>(doc), closure};
}
// polyfills // polyfills
struct tp_vectorcall { struct tp_vectorcall {
...@@ -216,16 +296,26 @@ private: ...@@ -216,16 +296,26 @@ private:
template<typename = void> template<typename = void>
static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) { static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
struct FreeGuard {
PyObject* self;
PyTypeObject* type;
~FreeGuard() {if (self) type->tp_free(self);}
};
auto* self = type->tp_alloc(type, 0); auto* self = type->tp_alloc(type, 0);
FreeGuard free_guard{self, type};
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
if constexpr (has_vectorcall && tp_vectorcall::valid) { if constexpr (has_vectorcall && tp_vectorcall::valid) {
reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>; reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>;
} }
if constexpr (varkw) { try {
new(inst) T(args, kwargs); if constexpr (varkw) {
} else { new(inst) T(args, kwargs);
new(inst) T(); } else {
} new(inst) T();
}
} HANDLE_ALL_EXC(nullptr)
free_guard.self = nullptr;
return self; return self;
} }
...@@ -250,6 +340,7 @@ private: ...@@ -250,6 +340,7 @@ private:
public: public:
class TypeBuilder { class TypeBuilder {
std::vector<PyMethodDef> m_methods; std::vector<PyMethodDef> m_methods;
std::vector<PyGetSetDef> m_getsets;
PyTypeObject m_type; PyTypeObject m_type;
bool m_finalized = false; bool m_finalized = false;
bool m_ready = false; bool m_ready = false;
...@@ -259,6 +350,13 @@ public: ...@@ -259,6 +350,13 @@ public:
throw std::runtime_error("type is already finalized"); throw std::runtime_error("type is already finalized");
} }
} }
static const char* to_c_str(const char* s) {return s;}
template <size_t N, typename... Ts>
static const char* to_c_str(const pybind11::detail::descr<N, Ts...>& desc) {
return desc.text;
}
public: public:
TypeBuilder(const TypeBuilder&) = delete; TypeBuilder(const TypeBuilder&) = delete;
TypeBuilder& operator=(const TypeBuilder&) = delete; TypeBuilder& operator=(const TypeBuilder&) = delete;
...@@ -266,7 +364,7 @@ public: ...@@ -266,7 +364,7 @@ public:
TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} { TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} {
constexpr auto has_tp_name = HAS_MEMBER(T, tp_name); constexpr auto has_tp_name = HAS_MEMBER(T, tp_name);
if constexpr (has_tp_name) { if constexpr (has_tp_name) {
m_type.tp_name = T::tp_name; m_type.tp_name = to_c_str(T::tp_name);
} }
m_type.tp_dealloc = tp_dealloc::value; m_type.tp_dealloc = tp_dealloc::value;
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL #ifdef _Py_TPFLAGS_HAVE_VECTORCALL
...@@ -291,8 +389,17 @@ public: ...@@ -291,8 +389,17 @@ public:
return m_ready; return m_ready;
} }
bool isinstance(PyObject* op) {
return PyObject_TypeCheck(op, &m_type);
}
bool isexact(PyObject* op) {
return Py_TYPE(op) == &m_type;
}
PyObject* finalize() { PyObject* finalize() {
if (!m_finalized) { if (!m_finalized) {
m_finalized = true;
if (m_methods.size()) { if (m_methods.size()) {
m_methods.push_back({0}); m_methods.push_back({0});
if (m_type.tp_methods) { if (m_type.tp_methods) {
...@@ -301,6 +408,14 @@ public: ...@@ -301,6 +408,14 @@ public:
} }
m_type.tp_methods = &m_methods[0]; m_type.tp_methods = &m_methods[0];
} }
if (m_getsets.size()) {
m_getsets.push_back({0});
if (m_type.tp_getset) {
PyErr_SetString(PyExc_SystemError, "tp_getset is already set");
return nullptr;
}
m_type.tp_getset = &m_getsets[0];
}
if (PyType_Ready(&m_type)) { if (PyType_Ready(&m_type)) {
return nullptr; return nullptr;
} }
...@@ -315,12 +430,64 @@ public: ...@@ -315,12 +430,64 @@ public:
m_methods.push_back(make_meth_def<f>(name, doc)); m_methods.push_back(make_meth_def<f>(name, doc));
return *this; return *this;
} }
template<auto get, auto set = nullptr>
TypeBuilder& def_getset(const char* name, const char* doc = nullptr, void* closure = nullptr) {
check_finalized();
m_getsets.push_back(make_getset_def<get, set>(name, doc, closure));
return *this;
}
}; };
static TypeBuilder& type() { static TypeBuilder& type() {
static TypeBuilder type_helper; static TypeBuilder type_helper;
return type_helper; return type_helper;
} }
template<typename... Args>
static PyObject* cnew(Args&&... args) {
auto* pytype = type().operator->();
auto* self = pytype->tp_alloc(pytype, 0);
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
if constexpr (has_vectorcall && tp_vectorcall::valid) {
reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>;
}
new(inst) T(std::forward<Args>(args)...);
return self;
}
template<typename... Args>
static PyObject* cnew_with_type(PyTypeObject* pytype, Args&&... args) {
auto* self = pytype->tp_alloc(pytype, 0);
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
if constexpr (has_vectorcall && tp_vectorcall::valid) {
reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>;
}
new(inst) T(std::forward<Args>(args)...);
return self;
}
struct caster {
static constexpr auto name = T::tp_name;
T* value;
bool load(pybind11::handle src, bool convert) {
if (wrap_t::type().isinstance(src.ptr())) {
value = reinterpret_cast<wrap_t*>(src.ptr())->inst();
return true;
}
return false;
}
template <typename U> using cast_op_type = pybind11::detail::cast_op_type<U>;
operator T*() { return value; }
operator T&() { return *value; }
};
}; };
} // namespace pyext17 } // namespace pyext17
...@@ -328,3 +495,5 @@ public: ...@@ -328,3 +495,5 @@ public:
#undef HAS_MEMBER_TYPE #undef HAS_MEMBER_TYPE
#undef HAS_MEMBER #undef HAS_MEMBER
#undef CVT_RET_PYOBJ #undef CVT_RET_PYOBJ
#undef CVT_RET_INT
#undef HANDLE_ALL_EXC
/**
* \file imperative/python/src/tensor.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./tensor.h"
#include "./grad.h"
#include "./common.h"
#include "./numpy_dtypes.h"
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
namespace py = pybind11;
namespace mgb::imperative::python {
std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py;
apply_result_t apply(ApplyContext& ctx) {
// emulating scalar should be put to specific op's apply, e.g.,
// elementwise, reduce, typecvt. Currently it's still handled at python
// side. It could be move to C++ side if it has an impact on performance
if (ctx.flags & Tensor::Flags::SCALAR) {
// TODO: emulate scalar
}
if (ctx.flags & Tensor::Flags::GRAD) {
return apply_grad(ctx);
}
if (ctx.flags & Tensor::Flags::TRACE) {
// TODO: trace
} else {
SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
handles[i] = ctx.args[i]->m_handle.get();
}
auto output_handles = interpreter_for_py->apply_op(ctx.op, handles);
apply_result_t outputs;
outputs.reserve(output_handles.size());
for (auto h : output_handles) {
outputs.emplace_back(std::make_shared<Tensor>(h));
}
return outputs;
}
mgb_assert(0);
}
PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) {
try {
// if (kwnames && PyTuple_GET_SIZE(kwnames)) {
// PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
// return nullptr;
// }
if (!nargs) {
PyErr_SetString(PyExc_TypeError, "expect Op");
return nullptr;
}
auto* op = args[0];
if (!strcmp(op->ob_type->tp_base->tp_name,"PodOpVisitor") || !strcmp(op->ob_type->tp_base->tp_name,"IndexingOpBase")){
op = PyObject_CallMethod(op,"to_c","");
}
PyTypeObject* pytype = args[1]->ob_type;
++args;
--nargs;
ApplyContext ctx;
ctx.flags = 0;
ctx.op = py::handle(op).cast<std::shared_ptr<OpDef>>();
SmallVector<Tensor*, 64> tensors(nargs);
ctx.args = &tensors[0];
ctx.nargs = nargs;
for (size_t i = 0; i < nargs; ++i) {
TensorWrapper* tw = TensorWrapper::cast_safe(args[i]);
if (!tw) {
PyErr_SetString(PyExc_TypeError, "expect Tensor");
return nullptr;
}
auto* t = tensors[i] = tw->m_tensor.get();
ctx.flags |= t->m_flags;
}
// TODO: set TRACE flag
auto outputs = apply(ctx);
size_t nout = outputs.size();
auto ret = py::tuple(nout);
for (size_t i = 0; i < nout; ++i) {
ret[i] = TensorWrapper::make(pytype, std::move(outputs[i]));
}
return ret.release().ptr();
} catch (std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr;
}
}
TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
if (kwargs && PyDict_Size(kwargs)) {
throw py::type_error("keyword argument not allowed");
}
auto nargs = PyTuple_Size(args);
auto tup = py::reinterpret_borrow<py::tuple>(args);
if (nargs == 0) {
throw py::type_error("too few arguments");
}
if (auto* t = cast_safe(tup[0].ptr())) {
if (nargs > 1) {
throw py::type_error("expect 1 argument");
}
m_tensor = t->m_tensor;
} else {
if (nargs != 3) {
throw py::type_error("expect 3 arguments");
}
py::detail::loader_life_support life_sup; // required to cast DType
auto data = tup[0].cast<py::array>();
DType dtype = tup[1].cast<DType>();
CompNode cn = tup[2].cast<CompNode>();
interpreter::Interpreter::Handle handle;
constexpr auto size_threshhold = TensorShape::MAX_NDIM;
if (data.size() > size_threshhold) {
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype));
} else {
HostTensorND ret(cn);
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype));
}
m_tensor = std::make_shared<Tensor>(handle);
if (data.ndim() == 0) {
m_tensor->m_flags |= Tensor::Flags::SCALAR;
}
}
}
PyObject* TensorWrapper::shape() {
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0);
}
auto&& shape = m_tensor->shape();
if (!shape.ndim) {
Py_RETURN_NONE;
}
py::tuple ret(shape.ndim);
for (size_t i = 0; i < shape.ndim; ++i) {
ret[i] = shape[i];
}
return ret.release().ptr();
}
PyObject* TensorWrapper::dtype() {
return py::cast(m_tensor->dtype()).release().ptr();
}
PyObject* TensorWrapper::device() {
return py::cast(m_tensor->comp_node()).release().ptr();
}
PyObject* TensorWrapper::numpy() {
auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get());
auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
if (!arr) return nullptr;
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
mgb_assert(PyArray_Check(arr.ptr()));
return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
}
return arr.release().ptr();
}
void TensorWrapper::reset(PyObject* tensor) {
TensorWrapper* t = TensorWrapper::cast_safe(tensor);
if (!t) {
throw py::type_error("expect Tensor");
}
m_tensor = t->m_tensor;
}
PyObject* TensorWrapper::isscalar() {
if(m_tensor->m_flags & Tensor::Flags::SCALAR) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
void TensorWrapper::setscalar() {
m_tensor->m_flags |= Tensor::Flags::SCALAR;
}
struct TensorWeakRef {
std::weak_ptr<Tensor> wptr;
TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {}
py::object operator()() {
if (auto p = wptr.lock()) {
return TensorWrapper::make(p);
}
return py::none();
}
};
void init_tensor(py::module m) {
interpreter_for_py = interpreter::Interpreter::inst().create_channel();
auto* tensor_type = TensorWrapper::wrap_t::type()
.def<&TensorWrapper::numpy>("numpy")
.def_getset<&TensorWrapper::shape>("shape")
.def_getset<&TensorWrapper::dtype>("dtype")
.def_getset<&TensorWrapper::device>("device")
.def<&TensorWrapper::reset>("_reset")
.def<&TensorWrapper::isscalar>("isscalar")
.def<&TensorWrapper::setscalar>("setscalar")
.finalize();
if (!tensor_type) throw py::error_already_set();
py::setattr(m, "Tensor", tensor_type);
py::class_<TensorWeakRef>(m, "TensorWeakRef")
.def(py::init<const TensorWrapper&>())
.def("__call__", &TensorWeakRef::operator());
static PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr};
auto* apply_func = PyCFunction_NewEx(&apply_def, nullptr, nullptr);
if (!apply_func) throw py::error_already_set();
py::setattr(m, "apply", apply_func);
py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach")
.finalize();
if (!grad_key_type) throw py::error_already_set();
py::setattr(m, "GradKey", grad_key_type);
py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward));
}
} // namespace mgb::imperative::python
/**
* \file imperative/python/src/tensor.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <variant>
#include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h"
#include "./pyext17.h"
namespace mgb::imperative::python {
template<typename T, typename B = pybind11::object>
struct ObjectPtr : B {
using B::B;
T& operator*() {return reinterpret_cast<T&>(*B::ptr());}
T* operator->() {return reinterpret_cast<T*>(B::ptr());}
};
} // namespace mgb::imperative::python
#include "./grad_info.h" // for struct GradInfo
namespace mgb::imperative::python {
struct TraceInfo {
};
extern std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py;
class SharedHandle {
using Handle = interpreter::Interpreter::Handle;
static_assert(std::is_pointer_v<Handle>);
std::shared_ptr<std::remove_pointer_t<Handle>> holder;
public:
inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){
interpreter_for_py->del(h);
}) {}
SharedHandle(const SharedHandle&) = default;
SharedHandle& operator=(const SharedHandle&) = default;
SharedHandle(SharedHandle&&) = default;
SharedHandle& operator=(SharedHandle&&) = default;
inline Handle get() {return holder.get();}
};
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
using flags_t = uint64_t;
struct Flags {
static constexpr flags_t SCALAR = 1;
static constexpr flags_t GRAD = 1 << 1;
static constexpr flags_t TRACE = 1 << 2;
};
flags_t m_flags = 0;
GradInfo m_grad_info;
TraceInfo m_trace_info;
SharedHandle m_handle;
using Handle = interpreter::Interpreter::Handle;
inline explicit Tensor(Handle handle) : m_handle(handle) {}
inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)) {}
~Tensor() = default;
inline std::shared_ptr<Tensor> copy() {
auto ret = std::make_shared<Tensor>(m_handle);
ret->m_flags = m_flags;
ret->m_grad_info = m_grad_info;
ret->m_trace_info = m_trace_info;
return ret;
}
inline DType dtype() {return interpreter_for_py->get_dtype(m_handle.get());}
inline CompNode comp_node() {return interpreter_for_py->get_device(m_handle.get());}
inline TensorShape shape() {return interpreter_for_py->get_shape(m_handle.get());}
};
struct TensorWrapper {
std::shared_ptr<Tensor> m_tensor;
inline TensorWrapper(std::shared_ptr<Tensor> tensor = {}) : m_tensor(std::move(tensor)) {}
TensorWrapper(PyObject* args, PyObject* kwargs);
~TensorWrapper() = default;
static constexpr auto tp_name = pybind11::detail::_("Tensor");
using wrap_t = pyext17::wrap<TensorWrapper>;
friend wrap_t;
inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();}
inline static TensorWrapper* cast_safe(PyObject* op) {
if (!wrap_t::type().isinstance(op)) return nullptr;
return cast(op);
}
inline ObjectPtr<TensorWrapper, pybind11::handle> self() {return wrap_t::pycast(this);}
template <typename... Args>
static ObjectPtr<Tensor> make(Args&&... args) {
auto* op = wrap_t::cnew(std::forward<Args>(args)...);
return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
}
template <typename... Args>
static ObjectPtr<Tensor> make(PyTypeObject* pytype, Args&&... args) {
auto* op = wrap_t::cnew_with_type(pytype,std::forward<Args>(args)...);
return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
}
PyObject* shape();
PyObject* dtype();
PyObject* device();
PyObject* numpy();
void reset(PyObject*);
PyObject* isscalar();
void setscalar();
};
PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */);
struct ApplyContext {
Tensor::flags_t flags;
std::shared_ptr<OpDef> op;
Tensor*const* args;
size_t nargs;
};
using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;
apply_result_t apply(ApplyContext& ctx);
void init_tensor(pybind11::module);
} // namespace mgb::imperative::python
namespace pybind11::detail {
template<> struct type_caster<mgb::imperative::python::TensorWrapper> : mgb::imperative::python::TensorWrapper::wrap_t::caster {};
} // namespace pybind11::detail
/**
* \file imperative/python/src/trace.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
namespace mgb::imperative::python {
struct TraceInfo {
};
} // namespace mgb::imperative::python
...@@ -12,6 +12,7 @@ def test_basic(): ...@@ -12,6 +12,7 @@ def test_basic():
config_async_level(3) config_async_level(3)
@pytest.mark.skip
def test_level1_infer_value(): def test_level1_infer_value():
config_async_level(1) config_async_level(1)
a = mge.tensor([[1, 2], [2, 3], [3, 4]], dtype="float32") a = mge.tensor([[1, 2], [2, 3], [3, 4]], dtype="float32")
...@@ -22,6 +23,7 @@ def test_level1_infer_value(): ...@@ -22,6 +23,7 @@ def test_level1_infer_value():
d = F.reshape(a, c) d = F.reshape(a, c)
@pytest.mark.skip
def test_level1_infer_shape_with_unknown(): def test_level1_infer_shape_with_unknown():
config_async_level(2) config_async_level(2)
a = mge.tensor([[1, 2, 2, 3]], dtype="float32") a = mge.tensor([[1, 2, 2, 3]], dtype="float32")
......
...@@ -16,12 +16,11 @@ import pytest ...@@ -16,12 +16,11 @@ import pytest
import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
import megengine.functional as F import megengine.functional as F
from megengine.core._imperative_rt import TensorAttr, imperative from megengine.core._imperative_rt import TensorAttr, core2, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply
from megengine.core._imperative_rt.imperative import sync
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.core.tensor.tensor import Tensor, apply
from megengine.core.tensor.tensor_wrapper import TensorWrapper
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.distributed import remote_recv, remote_send from megengine.functional.distributed import remote_recv, remote_send
...@@ -43,11 +42,11 @@ relu = _elwise(Elemwise.Mode.RELU) ...@@ -43,11 +42,11 @@ relu = _elwise(Elemwise.Mode.RELU)
def as_tensor(x): def as_tensor(x):
return Tensor(as_raw_tensor(x, device=mge.device.get_default_device())) return mge.Tensor(x)
def save_to(self, name="grad"): def save_to(self, name="grad"):
def callback(tensor, grad): def callback(grad):
setattr(self, name, grad) setattr(self, name, grad)
return callback return callback
...@@ -136,14 +135,14 @@ def test_2nd_grad(): ...@@ -136,14 +135,14 @@ def test_2nd_grad():
def test_grad_with_tensor_wrapper(): def test_grad_with_tensor_wrapper():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = TensorWrapper(x_np) x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = mul(x, x) y = mul(x, x)
y = mul(y, y) y = mul(y, y)
grad(y, TensorWrapper(np.ones_like(x_np))) grad(y, mge.Tensor(np.ones_like(x_np)))
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)
...@@ -162,8 +161,8 @@ def test_release(): ...@@ -162,8 +161,8 @@ def test_release():
finally: finally:
gc.enable() gc.enable()
x = TensorWrapper([0.0]) x = mge.Tensor([0.0])
dy = TensorWrapper(np.ones_like(x.numpy())) dy = mge.Tensor(np.ones_like(x.numpy()))
@check @check
def _(): def _():
...@@ -173,25 +172,25 @@ def test_release(): ...@@ -173,25 +172,25 @@ def test_release():
@check @check
def _(): def _():
with Grad().wrt(x) as g: with Grad().wrt(x):
pass pass
@check @check
def _(): def _():
with Grad().wrt(x) as g: with Grad().wrt(x):
y = x * x y = x * x
def test_grad_inplace(): def test_grad_inplace():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = TensorWrapper(x_np) x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = mul(x, x) y = mul(x, x)
y *= y y *= y
grad(y, TensorWrapper(np.ones_like(x_np))) grad(y, mge.Tensor(np.ones_like(x_np)))
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)
...@@ -199,16 +198,16 @@ def test_elemwise_add(): ...@@ -199,16 +198,16 @@ def test_elemwise_add():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
y_np = np.random.rand(10, 10).astype("float32") y_np = np.random.rand(10, 10).astype("float32")
dz_np = np.random.rand(10, 10).astype("float32") dz_np = np.random.rand(10, 10).astype("float32")
x = TensorWrapper(x_np) x = mge.Tensor(x_np)
y = TensorWrapper(y_np) y = mge.Tensor(y_np)
dz = TensorWrapper(dz_np) dz = mge.Tensor(dz_np)
refs = {} refs = {}
def f(x, y): def f(x, y):
x = x * 2 x = x * 2
refs["x"] = weakref.ref(x.__wrapped__) refs["x"] = TensorWeakRef(x)
refs["y"] = weakref.ref(y.__wrapped__) refs["y"] = TensorWeakRef(y)
return x + y return x + y
grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
...@@ -226,14 +225,14 @@ def test_elemwise_add(): ...@@ -226,14 +225,14 @@ def test_elemwise_add():
def test_elemwise_relu(): def test_elemwise_relu():
x_np = [1.0, -1.0] x_np = [1.0, -1.0]
dz_np = [1.0] dz_np = [1.0]
x = TensorWrapper(x_np) x = mge.Tensor(x_np)
dz = TensorWrapper(dz_np) dz = mge.Tensor(dz_np)
refs = {} refs = {}
def f(x): def f(x):
x = x * 2 x = x * 2
refs["x"] = weakref.ref(x.__wrapped__) refs["x"] = TensorWeakRef(x)
return relu(x) return relu(x)
grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
...@@ -258,7 +257,7 @@ def test_elemwise_relu_backward_fn(): ...@@ -258,7 +257,7 @@ def test_elemwise_relu_backward_fn():
def test_reshape(): def test_reshape():
x_np = np.random.rand(2, 5).astype("float32") x_np = np.random.rand(2, 5).astype("float32")
x = TensorWrapper(x_np) x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = x.reshape(5, 2) y = x.reshape(5, 2)
...@@ -269,7 +268,7 @@ def test_reshape(): ...@@ -269,7 +268,7 @@ def test_reshape():
def test_subtensor(): def test_subtensor():
x_np = np.random.rand(3, 3).astype("float32") x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np) x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = x[1:-1, :2] y = x[1:-1, :2]
...@@ -282,7 +281,7 @@ def test_subtensor(): ...@@ -282,7 +281,7 @@ def test_subtensor():
def test_IndexingMultiAxisVec(): def test_IndexingMultiAxisVec():
x_np = np.random.rand(3, 3).astype("float32") x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np) x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = x[[0, 2], [0, 2]] y = x[[0, 2], [0, 2]]
...@@ -295,7 +294,7 @@ def test_IndexingMultiAxisVec(): ...@@ -295,7 +294,7 @@ def test_IndexingMultiAxisVec():
def test_AxisAddRemove(): def test_AxisAddRemove():
x_np = np.random.rand(1, 5).astype("float32") x_np = np.random.rand(1, 5).astype("float32")
x = TensorWrapper(x_np) x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = F.squeeze(F.expand_dims(x, 2), 0) y = F.squeeze(F.expand_dims(x, 2), 0)
...@@ -308,7 +307,7 @@ def test_AxisAddRemove(): ...@@ -308,7 +307,7 @@ def test_AxisAddRemove():
def test_Broadcast(): def test_Broadcast():
x_np = np.random.rand(3, 3, 1).astype("float32") x_np = np.random.rand(3, 3, 1).astype("float32")
x = TensorWrapper(x_np) x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = F.broadcast_to(x, (3, 3, 10)) y = F.broadcast_to(x, (3, 3, 10))
...@@ -319,7 +318,7 @@ def test_Broadcast(): ...@@ -319,7 +318,7 @@ def test_Broadcast():
def test_Reduce_sum(): def test_Reduce_sum():
x_np = np.random.rand(3, 3).astype("float32") x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np) x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = x.sum(axis=0) y = x.sum(axis=0)
...@@ -330,7 +329,7 @@ def test_Reduce_sum(): ...@@ -330,7 +329,7 @@ def test_Reduce_sum():
def test_Reduce_mean(): def test_Reduce_mean():
x_np = np.random.rand(3, 3).astype("float32") x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np) x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = x.mean(axis=0) y = x.mean(axis=0)
......
...@@ -11,30 +11,29 @@ import collections ...@@ -11,30 +11,29 @@ import collections
import numpy as np import numpy as np
import pytest import pytest
import megengine.core.tensor.raw_tensor import megengine
import megengine.tensor as Tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops import builtin from megengine.core.ops import builtin
from megengine.core.tensor import Tensor
from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import RawTensor, as_raw_tensor
def cvt_to_shape_desc(val, inpvar, config=None): def cvt_to_shape_desc(val, inpvar, config=None):
def as_tensor(val, device): def as_tensor(val, device):
assert device is not None, "can not infer device" assert device is not None, "can not infer device"
# TODO: should copy to appropriate device # TODO: should copy to appropriate device
val = as_raw_tensor(val, device=device) val = Tensor(val, device=device)
return val return val
device = None device = None
if inpvar is not None: if inpvar is not None:
assert isinstance(inpvar, RawTensor) assert isinstance(inpvar, Tensor)
device = device or inpvar.device device = device or inpvar.device
if config is not None: if config is not None:
device = device or config.device device = device or config.device
if isinstance(val, RawTensor): if isinstance(val, Tensor):
return as_tensor(val, device) return as_tensor(val, device)
if not isinstance(val, collections.abc.Iterable): if not isinstance(val, collections.abc.Iterable):
...@@ -43,7 +42,7 @@ def cvt_to_shape_desc(val, inpvar, config=None): ...@@ -43,7 +42,7 @@ def cvt_to_shape_desc(val, inpvar, config=None):
components = [] components = []
on_host = True on_host = True
for i in val: for i in val:
if isinstance(i, RawTensor): if isinstance(i, Tensor):
on_host = False on_host = False
device = device or i.device device = device or i.device
else: else:
...@@ -62,7 +61,7 @@ def cvt_to_shape_desc(val, inpvar, config=None): ...@@ -62,7 +61,7 @@ def cvt_to_shape_desc(val, inpvar, config=None):
return as_tensor(shape, device) return as_tensor(shape, device)
for idx, v in enumerate(components): for idx, v in enumerate(components):
if not isinstance(v, RawTensor): if not isinstance(v, Tensor):
vi = int(v) vi = int(v)
assert vi == v, "could not convert {} to int".format(v) assert vi == v, "could not convert {} to int".format(v)
v = vi v = vi
...@@ -95,7 +94,7 @@ def canonize_inputs(inputs, *, config): ...@@ -95,7 +94,7 @@ def canonize_inputs(inputs, *, config):
# and is called with concat([a, b])) # and is called with concat([a, b]))
inputs = inputs[0] inputs = inputs[0]
if isinstance(inputs, RawTensor): if isinstance(inputs, Tensor):
return [inputs] return [inputs]
old_inputs = inputs old_inputs = inputs
...@@ -103,7 +102,7 @@ def canonize_inputs(inputs, *, config): ...@@ -103,7 +102,7 @@ def canonize_inputs(inputs, *, config):
get_comp_node = None get_comp_node = None
need_cvt = False need_cvt = False
for i in old_inputs: for i in old_inputs:
if isinstance(i, RawTensor): if isinstance(i, Tensor):
get_comp_node = lambda cn=i.device: cn get_comp_node = lambda cn=i.device: cn
else: else:
need_cvt = True need_cvt = True
...@@ -117,8 +116,8 @@ def canonize_inputs(inputs, *, config): ...@@ -117,8 +116,8 @@ def canonize_inputs(inputs, *, config):
return config.comp_node return config.comp_node
for idx, var in enumerate(inputs): for idx, var in enumerate(inputs):
if not isinstance(var, RawTensor): if not isinstance(var, Tensor):
var = as_raw_tensor(var) var = Tensor(var)
inputs[idx] = var inputs[idx] = var
return inputs return inputs
...@@ -131,15 +130,15 @@ def invoke_op(op, inputs_, cvt_inputs=canonize_inputs): ...@@ -131,15 +130,15 @@ def invoke_op(op, inputs_, cvt_inputs=canonize_inputs):
def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
assert isinstance(inp, RawTensor) assert isinstance(inp, Tensor)
if not isinstance(tuple_val, tuple): if not isinstance(tuple_val, tuple):
tuple_val = (tuple_val,) tuple_val = (tuple_val,)
def as_tensor(v): def as_tensor(v):
if not isinstance(v, RawTensor): if not isinstance(v, Tensor):
vi = np.ascontiguousarray(v, dtype=np.int32) vi = np.ascontiguousarray(v, dtype=np.int32)
assert np.abs(vi - v).max() == 0, "bad index: {!r}".format(v) assert np.abs(vi - v).max() == 0, "bad index: {!r}".format(v)
v = as_raw_tensor(vi) v = Tensor(vi)
return v return v
new_axes = [] new_axes = []
...@@ -275,14 +274,14 @@ def batched_incr_mesh_indexing(input, value, tuple_val): ...@@ -275,14 +274,14 @@ def batched_incr_mesh_indexing(input, value, tuple_val):
def test_transpose(): def test_transpose():
x = np.arange(10).reshape(2, 5).astype("int32") x = np.arange(10).reshape(2, 5).astype("int32")
xx = as_raw_tensor(x) xx = Tensor(x)
(yy,) = transpose(xx, pattern=[1, -1, 0]) (yy,) = transpose(xx, pattern=[1, -1, 0])
np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy()) np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy())
def test_broadcast(): def test_broadcast():
x = np.arange(10).reshape(1, 10).astype("int32") x = np.arange(10).reshape(1, 10).astype("int32")
xx = as_raw_tensor(x) xx = Tensor(x)
(yy,) = broadcast(xx, (10, 10)) (yy,) = broadcast(xx, (10, 10))
np.testing.assert_equal(np.repeat(x, 10, 0), yy.numpy()) np.testing.assert_equal(np.repeat(x, 10, 0), yy.numpy())
...@@ -290,7 +289,7 @@ def test_broadcast(): ...@@ -290,7 +289,7 @@ def test_broadcast():
def test_subtensor(): def test_subtensor():
x = np.arange(25).reshape(5, 5).astype("int32") x = np.arange(25).reshape(5, 5).astype("int32")
d = np.arange(2).astype("int32") d = np.arange(2).astype("int32")
xx = as_raw_tensor(x) xx = Tensor(x)
(yy0,) = subtensor(xx, (slice(0, 4, 2), 3)) (yy0,) = subtensor(xx, (slice(0, 4, 2), 3))
(yy1,) = set_subtensor(xx, d, (slice(0, 4, 2), 3)) (yy1,) = set_subtensor(xx, d, (slice(0, 4, 2), 3))
(yy2,) = incr_subtensor(xx, d, (slice(0, 4, 2), 3)) (yy2,) = incr_subtensor(xx, d, (slice(0, 4, 2), 3))
...@@ -309,7 +308,7 @@ def test_subtensor(): ...@@ -309,7 +308,7 @@ def test_subtensor():
def test_advance_indexing(): def test_advance_indexing():
x = np.arange(25).reshape(5, 5).astype("int32") x = np.arange(25).reshape(5, 5).astype("int32")
d = np.arange(15).reshape(3, 5).astype("int32") d = np.arange(15).reshape(3, 5).astype("int32")
xx = as_raw_tensor(x) xx = Tensor(x)
(yy0,) = advance_indexing(xx, ((0, 4, 2), slice(None, None, None))) (yy0,) = advance_indexing(xx, ((0, 4, 2), slice(None, None, None)))
(yy1,) = set_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None))) (yy1,) = set_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None)))
(yy2,) = incr_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None))) (yy2,) = incr_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None)))
...@@ -328,7 +327,7 @@ def test_advance_indexing(): ...@@ -328,7 +327,7 @@ def test_advance_indexing():
def test_mesh_indexing(): def test_mesh_indexing():
x = np.arange(25).reshape(5, 5).astype("int32") x = np.arange(25).reshape(5, 5).astype("int32")
d = np.arange(6).reshape(3, 2).astype("int32") d = np.arange(6).reshape(3, 2).astype("int32")
xx = as_raw_tensor(x) xx = Tensor(x)
(yy0,) = mesh_indexing(xx, (slice(0, 5, 2), (1, 3))) (yy0,) = mesh_indexing(xx, (slice(0, 5, 2), (1, 3)))
(yy1,) = set_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3))) (yy1,) = set_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3)))
(yy2,) = incr_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3))) (yy2,) = incr_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3)))
...@@ -355,7 +354,7 @@ def test_mesh_indexing(): ...@@ -355,7 +354,7 @@ def test_mesh_indexing():
def test_batched_mesh_indexing(): def test_batched_mesh_indexing():
x = np.arange(24).reshape(2, 3, 4).astype("int32") x = np.arange(24).reshape(2, 3, 4).astype("int32")
d = np.arange(12).reshape(2, 2, 3).astype("int32") d = np.arange(12).reshape(2, 2, 3).astype("int32")
xx = as_raw_tensor(x) xx = Tensor(x)
s = [(0, 1, 2), (1, 2, 3)] s = [(0, 1, 2), (1, 2, 3)]
(yy0,) = batched_mesh_indexing(xx, (slice(None, None, None), [(0, 2)] * 2, s)) (yy0,) = batched_mesh_indexing(xx, (slice(None, None, None), [(0, 2)] * 2, s))
(yy1,) = batched_set_mesh_indexing( (yy1,) = batched_set_mesh_indexing(
......
...@@ -9,12 +9,12 @@ ...@@ -9,12 +9,12 @@
import numpy as np import numpy as np
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.core.tensor.tensor_wrapper import TensorWrapper from megengine.tensor import Tensor
def test_basic(): def test_basic():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = TensorWrapper(x_np) x = Tensor(x_np)
y = x * x y = x * x
y_np = y.numpy() y_np = y.numpy()
np.testing.assert_almost_equal(y_np, x_np * x_np) np.testing.assert_almost_equal(y_np, x_np * x_np)
...@@ -22,15 +22,15 @@ def test_basic(): ...@@ -22,15 +22,15 @@ def test_basic():
def test_literal_arith(): def test_literal_arith():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = TensorWrapper(x_np) x = Tensor(x_np)
y = x * 2 y = x * 2
y_np = y.numpy() y_np = y.numpy()
np.testing.assert_almost_equal(y_np, x_np * 2) np.testing.assert_almost_equal(y_np, x_np * 2)
def test_matmul(): def test_matmul():
A = TensorWrapper(np.random.rand(5, 7).astype("float32")) A = Tensor(np.random.rand(5, 7).astype("float32"))
B = TensorWrapper(np.random.rand(7, 10).astype("float32")) B = Tensor(np.random.rand(7, 10).astype("float32"))
C = A @ B C = A @ B
np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6) np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6)
...@@ -38,7 +38,7 @@ def test_matmul(): ...@@ -38,7 +38,7 @@ def test_matmul():
def test_reduce(): def test_reduce():
def test_x(x_np): def test_x(x_np):
for m in ["sum", "prod", "min", "max", "mean"]: for m in ["sum", "prod", "min", "max", "mean"]:
x = TensorWrapper(x_np) x = Tensor(x_np)
y = getattr(x, m)(axis=-1, keepdims=True) y = getattr(x, m)(axis=-1, keepdims=True)
np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6) np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6)
...@@ -49,7 +49,7 @@ def test_reduce(): ...@@ -49,7 +49,7 @@ def test_reduce():
def test_set_subtensor(): def test_set_subtensor():
x = TensorWrapper([1, 2, 3]) x = Tensor([1, 2, 3])
x[:] = [1, 1, 1] x[:] = [1, 1, 1]
np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6) np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6)
x[[0, 2]] = [3, 2] x[[0, 2]] = [3, 2]
...@@ -60,7 +60,7 @@ def test_set_subtensor(): ...@@ -60,7 +60,7 @@ def test_set_subtensor():
def test_computing_with_numpy_array(): def test_computing_with_numpy_array():
x = np.array([1, 2, 3], dtype=np.int32) x = np.array([1, 2, 3], dtype=np.int32)
xx = TensorWrapper(x, device="cpu0") xx = Tensor(x, device="cpu0")
y = np.array([1, 0, 3], dtype=np.int32) y = np.array([1, 0, 3], dtype=np.int32)
assert np.add(xx, y).device == xx.device assert np.add(xx, y).device == xx.device
np.testing.assert_equal(np.add(xx, y).numpy(), np.add(x, y)) np.testing.assert_equal(np.add(xx, y).numpy(), np.add(x, y))
...@@ -70,12 +70,12 @@ def test_computing_with_numpy_array(): ...@@ -70,12 +70,12 @@ def test_computing_with_numpy_array():
def test_transpose(): def test_transpose():
x = np.random.rand(2, 5).astype("float32") x = np.random.rand(2, 5).astype("float32")
xx = TensorWrapper(x) xx = Tensor(x)
np.testing.assert_almost_equal(xx.T.numpy(), x.T) np.testing.assert_almost_equal(xx.T.numpy(), x.T)
def test_as_type(): def test_as_type():
x = TensorWrapper([1, 2, 3], dtype=np.float32) x = Tensor([1, 2, 3], dtype=np.float32)
y = x.astype(qint8(0.1)) y = x.astype(qint8(0.1))
np.testing.assert_almost_equal(get_scale(y.dtype), 0.1) np.testing.assert_almost_equal(get_scale(y.dtype), 0.1)
z = y.astype(qint8(0.2)) z = y.astype(qint8(0.2))
......
...@@ -312,7 +312,7 @@ def test_device(): ...@@ -312,7 +312,7 @@ def test_device():
np.testing.assert_almost_equal(y1.numpy(), y2.numpy()) np.testing.assert_almost_equal(y1.numpy(), y2.numpy())
y3 = F.eye(x.shape, dtype="float32", device="xpux") y3 = F.eye(x.shape, dtype="float32", device="xpux")
y4 = F.eye(x.shape, dtype="float32", device=x.device.to_c()) y4 = F.eye(x.shape, dtype="float32", device=x.device)
np.testing.assert_almost_equal(y3.numpy(), y4.numpy()) np.testing.assert_almost_equal(y3.numpy(), y4.numpy())
y5 = F.full((3, 2), 4, device=x.device) y5 = F.full((3, 2), 4, device=x.device)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "./op_trait.h" #include "./op_trait.h"
#include "./proxy_graph_detail.h" #include "megbrain/imperative/proxy_graph_detail.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#if MGB_ENABLE_OPR_MM #if MGB_ENABLE_OPR_MM
#include "../op_trait.h" #include "../op_trait.h"
#include "../proxy_graph_detail.h" #include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/opr/mm_handler.h" #include "megbrain/opr/mm_handler.h"
#include "megbrain/utils/hash.h" #include "megbrain/utils/hash.h"
#endif // MGB_ENABLE_OPR_MM #endif // MGB_ENABLE_OPR_MM
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#if MGB_ENABLE_OPR_MM #if MGB_ENABLE_OPR_MM
#include "../op_trait.h" #include "../op_trait.h"
#include "../proxy_graph_detail.h" #include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/opr/io_remote.h" #include "megbrain/opr/io_remote.h"
#include "megbrain/opr/mm_handler.h" #include "megbrain/opr/mm_handler.h"
#endif // MGB_ENABLE_OPR_MM #endif // MGB_ENABLE_OPR_MM
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "megbrain/serialization/opr_load_dump.h" #include "megbrain/serialization/opr_load_dump.h"
#include "../op_trait.h" #include "../op_trait.h"
#include "../proxy_graph_detail.h" #include "megbrain/imperative/proxy_graph_detail.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
*/ */
#include "./proxy_graph.h" #include "./proxy_graph.h"
#include "./proxy_graph_detail.h" #include "megbrain/imperative/proxy_graph_detail.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
......
/** /**
* \file imperative/src/impl/proxy_graph_detail.h * \file imperative/src/include/megbrain/imperative/proxy_graph_detail.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册