提交 147cef52 编写于 作者: M Megvii Engine Team

refactor(mge/imperative): implement new tensor system

GitOrigin-RevId: 2dd4e460ac32f059e91c7cdef859d82dee704b4b
上级 7f48625f
......@@ -301,7 +301,7 @@ class GradManager:
if tensor is None:
return
def callback(_, grad, callbacks=spec.callbacks):
def callback(grad, callbacks=spec.callbacks):
for cb in callbacks:
grad = cb(tensor, grad)
self._gradients[id(tensor)] = grad
......
......@@ -16,6 +16,7 @@ import numpy as np
import megengine as mge
from .._imperative_rt import core2
from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const
from ..tensor.core import TensorBase, TensorWrapperBase, apply
......@@ -418,3 +419,28 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
@apply.register()
def _(op: Const, *_: typing.Optional[Tracer]):
return None
class Grad:
def __init__(self):
self._impl = core2.GradKey()
def wrt(self, *tensors, callback=None):
for x in tensors:
self._impl.attach(x, callback)
return self
def __call__(self, ys, dys):
from collections.abc import Sequence
if not isinstance(ys, Sequence):
ys = [ys]
if not isinstance(dys, Sequence):
dys = [dys]
core2.backward(self._impl, ys, dys)
def __enter__(self):
return self
def __exit__(self, _1, _2, _3):
del self._impl
......@@ -6,11 +6,18 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from .._imperative_rt.core2 import Tensor
from ..tensor.core import OpBase, TensorBase, apply
class Const(OpBase):
class Const:
def __init__(self, value=None, *, dtype=None, device=None):
self.value = value
self.value = np.asarray(value, dtype=dtype)
self.dtype = dtype
self.device = device
def __call__(self, *reference):
Wrapper = type(reference[0])
return (Wrapper(self.value, self.dtype, self.device),)
......@@ -13,9 +13,17 @@ import sys
import typing
from abc import ABC
from .._imperative_rt.core2 import apply as apply2
from .multipledispatch import Dispatcher
def apply_op(op, *args):
Wrapper = type(args[0])
args = [arg._tensor for arg in args]
results = apply2(op, *args)
return tuple(map(Wrapper, results))
class OpBase(ABC):
def __call__(self, *args):
return apply(self, *args)
......
......@@ -10,10 +10,10 @@ from typing import Iterable
import numpy as np
from .._imperative_rt.core2 import Tensor, apply
from .._trace_option import use_symbolic_shape
from ..ops import builtin
from ..ops.special import Const
from .core import TensorBase, TensorWrapperBase, apply
from .utils import astensor1d, isscalar, make_shape_tuple
......@@ -149,13 +149,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
return True
def get_index(i):
if not isinstance(i, (TensorBase, TensorWrapperBase)):
if not isinstance(i, (Tensor)):
if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_:
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp)
else:
(i,) = Const(i, dtype=np.int32, device=inp.device)(inp)
return i
assert isinstance(i, (TensorBase, TensorWrapperBase))
assert isinstance(i, Tensor)
if i.dtype != np.bool_:
return i
_, ind = apply(builtin.CondTake(), i, i)
......@@ -198,8 +198,8 @@ def try_condtake(tensor, index):
return []
if isinstance(index, np.ndarray):
(index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor)
assert isinstance(index, (TensorBase, TensorWrapperBase))
if not isinstance(tensor, (TensorWrapperBase, TensorBase)):
assert isinstance(index, Tensor)
if not isinstance(tensor, Tensor):
raise TypeError("input must be a tensor")
if tensor.device != index.device:
raise ValueError(
......@@ -227,7 +227,7 @@ def getitem(tensor, index):
op = builtin.IndexingMultiAxisVec(items=items)
(result,) = apply(op, tensor, *tensors)
if ret_scalar:
result.__wrapped__._data._isscalar = True
result.setscalar()
return result
......@@ -239,7 +239,7 @@ def setitem(tensor, index, value):
if index.shape[0] == 0:
return tensor
tensor = tensor.reshape(-1)
if not isinstance(value, (TensorBase, TensorWrapperBase)):
if not isinstance(value, Tensor):
op = Const(value, dtype=tensor.dtype, device=tensor.device)
(value,) = op(tensor)
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index)
......@@ -250,6 +250,7 @@ def setitem(tensor, index, value):
op = builtin.Subtensor(items=items)
else:
op = builtin.IndexingMultiAxisVec(items=items)
(tmp_result,) = apply(op, tensor, *tensors)
# XXX: broadcast can always be applied even if shapes are equal
......
......@@ -8,19 +8,20 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
import collections
from typing import Union
import numpy as np
from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import Tensor, apply
from .._trace_option import use_symbolic_shape
from ..ops import builtin
from ..ops.builtin import Elemwise, GetVarShape
from ..ops.special import Const
from . import utils
from .core import OpBase, TensorBase, TensorWrapperBase, apply
from .core import OpBase, TensorBase, TensorWrapperBase
from .indexing import getitem as _getitem
from .indexing import setitem as _setitem
from .raw_tensor import RawTensor, as_raw_tensor
from .tensor import Tensor
from .utils import isscalar
from .utils import make_shape_tuple as _make_shape_tuple
from .utils import setscalar
......@@ -41,6 +42,7 @@ def _elwise(*args, mode):
)
args = utils.convert_inputs(*args)
(result,) = apply(op, *args)
_isscalar = True
for i in args:
if isscalar(i) == False:
......@@ -84,9 +86,7 @@ def _reshape(x, shape):
if unspec_axis is not None:
raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i))
unspec_axis = i
shape = utils.astensor1d(shape, x, dtype="int32", device=x.device)
if unspec_axis is None:
op = builtin.Reshape()
else:
......@@ -181,7 +181,6 @@ def _reduce(mode):
elif isinstance(axis, collections.abc.Iterable):
axis = list(axis)
axis.sort(reverse=True)
for ai in axis:
op = builtin.Reduce(mode=mode, axis=ai)
(data,) = apply(op, data)
......@@ -221,10 +220,7 @@ def _todo(*_):
def _expand_args(args):
if len(args) == 1:
if isinstance(
args[0],
(collections.abc.Sequence, TensorBase, TensorWrapperBase, np.ndarray),
):
if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),):
args = args[0]
return args
......@@ -240,9 +236,8 @@ class ArrayMethodMixin(abc.ABC):
return self.numpy().astype(dtype)
def __array_wrap__(self, array):
return TensorWrapper(
as_raw_tensor(array, dtype=array.dtype, device=self.device)
)
Wrapper = type(self)
return Wrapper(array, dtype=array.dtype, device=self.device)
@abc.abstractmethod
def _reset(self, other):
......@@ -253,7 +248,11 @@ class ArrayMethodMixin(abc.ABC):
pass
@abc.abstractproperty
def shape(self) -> tuple:
def shape(self) -> Union[tuple, Tensor]:
pass
@abc.abstractproperty
def _tuple_shape(self) -> tuple:
pass
@abc.abstractmethod
......@@ -331,7 +330,7 @@ class ArrayMethodMixin(abc.ABC):
__complex__ = lambda self: complex(self.item())
def __len__(self):
shape = self.__wrapped__.shape
shape = self._tuple_shape
if shape:
return int(shape[0])
raise TypeError("ndim is 0")
......@@ -352,7 +351,7 @@ class ArrayMethodMixin(abc.ABC):
@property
def ndim(self):
shape = self.__wrapped__.shape
shape = self._tuple_shape
if shape is None:
raise ValueError("unkown ndim")
return len(shape)
......@@ -480,22 +479,52 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase):
self.__wrapped__._swap_out()
class TensorWrapper(GenericTensorWrapper):
def __init__(self, data, dtype=None, device=None):
if isinstance(data, TensorWrapperBase):
data = data.__wrapped__
elif not isinstance(data, TensorBase):
assert data is not None, "Cannot init a tensor with data as None"
data = Tensor(as_raw_tensor(data, dtype=dtype, device=device))
super().__init__(data)
class TensorWrapper(ArrayMethodMixin, TensorBase):
def __init__(self, data, dtype=None, device=None, isscalar=False):
self._isscalar = isscalar
if isinstance(data, Tensor):
self._tensor = data
else:
if device is None:
device = CompNode._get_default_device()
self._tensor = Tensor(data, dtype, device)
def _reset(self, other):
if isinstance(other, TensorWrapperBase):
self.__wrapped__ = other.__wrapped__
elif isinstance(other, TensorBase):
self.__wrapped__ = other
else:
self._reset(type(self)(other, dtype=self.dtype, device=self.device))
if not isinstance(other, __class__):
raise TypeError(type(other))
self._tensor = other._tensor
return self
@property
def dtype(self):
return self._tensor.dtype
@property
def shape(self):
if self._isscalar:
return ()
shape = self._tensor.shape
if shape == () or not use_symbolic_shape():
return shape
return apply(GetVarShape(), self)[0]
@property
def device(self):
return self._tensor.device
def numpy(self):
if self._isscalar:
return self._tensor.numpy().squeeze()
return self._tensor.numpy()
def _drop(self):
self._tensor._drop()
def _swap_in(self):
self._tensor._swap_in()
def _swap_out(self):
self._tensor._swap_out()
def __repr__(self):
piece = "Tensor("
......
......@@ -11,9 +11,10 @@ from typing import Iterable, Union
import numpy as np
from .._imperative_rt.core2 import Tensor, apply
from ..ops import builtin
from ..ops.special import Const
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase
from .dtype import is_equal, is_quantize
_enable_convert_inputs = True
......@@ -109,7 +110,7 @@ def dtype_promotion(inputs):
def get_device(inputs):
device = None
for i in inputs:
if isinstance(i, (TensorWrapperBase, TensorBase)):
if isinstance(i, Tensor):
if device is None:
device = i.device
elif device != i.device:
......@@ -126,30 +127,31 @@ def concatenate(inputs, axis=0, *, device=None):
return convert_single_value(x, inputs, dtype=dtype)
inputs = tuple(map(convert, inputs))
(result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inputs)
(result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs)
return result
def astype(x, dtype):
dtype = np.dtype(dtype)
if not is_equal(x.dtype, dtype):
isscalar = x.__wrapped__._data._isscalar
isscalar = x.isscalar()
(x,) = apply(builtin.TypeCvt(dtype=dtype), x)
x.__wrapped__._data._isscalar = isscalar
if isscalar:
x.setscalar()
return x
def convert_single_value(v, inputs, *, dtype=None, device=None):
tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))]
tensors = [i for i in inputs if isinstance(i, Tensor)]
assert len(tensors) > 0
if isinstance(v, (TensorWrapperBase, TensorBase)):
if isinstance(v, (TensorWrapperBase, Tensor)):
v = astype(v, v.dtype if is_quantize(v.dtype) else dtype)
else:
(v,) = Const(v, dtype=dtype, device=device)(*tensors)
return v
def convert_inputs(*args: TensorBase):
def convert_inputs(*args: Tensor):
if not _enable_convert_inputs:
return args
......@@ -167,7 +169,7 @@ def convert_inputs(*args: TensorBase):
def result_type(*args):
dtypes = []
for i in args:
if isinstance(i, (TensorWrapperBase, TensorBase)):
if isinstance(i, Tensor):
dtypes.append(i.dtype)
continue
try:
......@@ -178,25 +180,16 @@ def result_type(*args):
def isscalar(x):
if isinstance(x, TensorWrapperBase):
x = x.__wrapped__
if hasattr(x, "_isscalar"):
return x._isscalar
if isinstance(x, TensorBase):
return x._data._isscalar
if isinstance(x, Tensor):
return x.isscalar()
return np.isscalar(x)
def setscalar(x):
if isinstance(x, TensorWrapperBase):
x = x.__wrapped__
if hasattr(x, "_isscalar"):
x._isscalar = True
elif isinstance(x, TensorBase):
x._data._isscalar = True
if isinstance(x, Tensor):
x.setscalar()
else:
raise NotImplementedError("Unsupport type {}".format(type(x)))
......@@ -215,25 +208,24 @@ def astensor1d(x, *reference, dtype=None, device=None):
else:
if ndim != 0 and ndim != 1:
raise ValueError("ndim != 1 or 0, get : %d" % ndim)
if not isinstance(x, (TensorBase, TensorWrapperBase)):
if not isinstance(x, Tensor):
(x,) = Const(x, dtype=dtype, device=device)(*reference)
return x
if not isinstance(x, collections.abc.Sequence):
raise TypeError
if any(isinstance(i, (TensorBase, TensorWrapperBase)) for i in x):
if any(isinstance(i, Tensor) for i in x):
x = concatenate(x, device=device)
if dtype is not None:
x = astype(x, dtype)
return x
(x,) = Const(x, dtype=dtype, device=device)(*reference)
return x
def _expand_int(s, i):
if isinstance(i, (TensorBase, TensorWrapperBase)):
if isinstance(i, Tensor):
i_np = i.numpy()
if i_np.ndim == 0:
s.append(int(i_np))
......
......@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Optional, Tuple
from ..core._imperative_rt.core2 import apply
from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn
from ..core.autodiff.grad import (
Tracer,
......@@ -17,7 +18,6 @@ from ..core.autodiff.grad import (
tracer_apply,
)
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.core import apply
from ..core.tensor.tensor import Tensor, tensor_apply
from ..device import get_default_device
from ..tensor import tensor
......@@ -39,71 +39,6 @@ __all__ = [
]
@apply.register()
def _(op: RemoteSend, *args: Tensor):
ret = tensor_apply(op, *args)
# set extra information
tracer_set = dict()
for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))):
tracer_set[k.name] = True
# check tracer_set in remote_recv
get_client().set_remote_tracer(op.key, tracer_set)
return ret
@builtin_op_get_backward_fn.register(RemoteSend)
def _(op: RemoteSend, inputs, outputs, input_requires_grad):
def backward(*args):
return [
remote_recv(
op.rank_to,
inputs[0].shape,
inputs[0].dtype,
device=str(inputs[0].device),
inp=inputs[0],
)
]
return backward, [True]
@get_op_has_grad_fn.register(RemoteSend)
def _(op: RemoteSend):
def has_grad(opnode, reached):
return get_client().check_is_grad(op.key)
return has_grad
@check_backward_allow_noinput.register(RemoteSend)
def _(op: RemoteSend):
return True
@builtin_op_get_backward_fn.register(RemoteRecv)
def _(op: RemoteRecv, inputs, outputs, input_requires_grad):
def backward(*output_grads):
return [remote_send(output_grads[0], op.rank_from)]
return backward, [True]
@get_op_has_grad_fn.register(RemoteRecv)
def _(op: RemoteRecv):
def has_grad(opnode, reached):
ret = False
for v in opnode.outputs:
if v() in reached:
ret = True
break
get_client().set_is_grad(op.key, ret)
return ret
return has_grad
def collective_comm(inp, mode, group, device):
"""Helper function for applying collective communication functions."""
assert isinstance(group, Group)
......
......@@ -17,8 +17,8 @@ import numpy as np
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
from megengine.device import get_default_device, get_device_count
from ..core._imperative_rt.core2 import apply
from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
from ..core.tensor.core import apply
from ..functional.utils import copy
from ..tensor import Tensor
from ..utils.future import Future
......@@ -228,7 +228,6 @@ class AllreduceCallback:
self._packing_size[dtype] = 0
def __call__(self, param, grad):
param = param.__wrapped__
gm = get_backwarding_grad_manager()
assert isinstance(gm, GradManager)
if gm not in self._marked_gm:
......
......@@ -9,10 +9,10 @@
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
import functools
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..core.ops.builtin import Elemwise
from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import apply
from ..core.tensor.utils import isscalar, setscalar
from ..device import get_default_device
from ..jit.tracing import is_tracing
......
......@@ -12,10 +12,11 @@ import math
import numbers
from typing import Optional, Sequence, Tuple, Union
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..core.ops.special import Const
from ..core.tensor import utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.core import TensorBase, TensorWrapperBase
from ..tensor import Tensor
from .elemwise import clip, exp, log, log1p
from .tensor import reshape, squeeze
......
......@@ -10,12 +10,12 @@
from typing import Optional, Sequence, Tuple, Union
from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import Tensor, apply
from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.utils import astensor1d
from ..distributed import WORLD, is_distributed
from ..jit.tracing import is_tracing
......@@ -1565,9 +1565,7 @@ def indexing_one_hot(
[1.]
"""
assert isinstance(
src, (TensorWrapperBase, TensorBase)
), "src must be of Tensor type"
assert isinstance(src, Tensor), "src must be of Tensor type"
op = builtin.IndexingOneHot(axis=axis)
index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device)
(result,) = apply(op, src, index)
......
......@@ -8,8 +8,8 @@
# pylint: disable=too-many-lines
from typing import Tuple, Union
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..core.tensor.core import apply
from ..tensor import Tensor
from .debug_param import get_conv_execution_strategy
from .types import _pair, _pair_nonzero
......
......@@ -14,10 +14,10 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import Tensor, apply
from ..core._wrap import device as as_device
from ..core.ops import builtin
from ..core.ops.special import Const
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis
from ..core.tensor.utils import (
astensor1d,
......@@ -611,11 +611,11 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
"""
x, y = convert_inputs(x, y)
if not isinstance(x, (TensorWrapperBase, TensorBase)):
if not isinstance(x, Tensor):
raise TypeError("input x must be a tensor")
if not isinstance(y, (TensorWrapperBase, TensorBase)):
if not isinstance(y, Tensor):
raise TypeError("input y must be a tensor")
if not isinstance(mask, (TensorWrapperBase, TensorBase)):
if not isinstance(mask, Tensor):
raise TypeError("mask must be a tensor")
if mask.dtype != np.bool_:
raise ValueError("mask must be bool")
......@@ -668,9 +668,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor:
[1. 4.] [0 3]
"""
if not isinstance(x, (TensorWrapperBase, TensorBase)):
if not isinstance(x, Tensor):
raise TypeError("input must be a tensor")
if not isinstance(mask, (TensorWrapperBase, TensorBase)):
if not isinstance(mask, Tensor):
raise TypeError("mask must be a tensor")
if mask.dtype != np.bool_:
raise ValueError("mask must be bool")
......
......@@ -11,10 +11,10 @@ from typing import Iterable, Union
import numpy as np
from ..core._imperative_rt.core2 import apply
from ..core._wrap import device as as_device
from ..core.ops.builtin import Copy, Identity
from ..core.tensor import Tensor
from ..core.tensor.core import apply
from ..tensor import Tensor
from .math import topk as _topk
from .tensor import broadcast_to, transpose
......
......@@ -10,9 +10,9 @@ from typing import Iterable, Optional
from .. import Tensor
from ..core._imperative_rt import invoke_op
from ..core._imperative_rt.core2 import apply
from ..core.ops.builtin import GaussianRNG, UniformRNG
from ..core.tensor import utils
from ..core.tensor.core import apply
from .rng import _random_seed_generator
__all__ = ["normal", "uniform"]
......
......@@ -10,26 +10,66 @@
import collections
from .core import Tensor as _Tensor
from .core.ops.builtin import Copy
from .core.tensor.core import apply
import numpy as np
from .core._imperative_rt import CompNode
from .core._imperative_rt.core2 import Tensor as _Tensor
from .core._imperative_rt.core2 import apply
from .core._trace_option import use_symbolic_shape
from .core.ops.builtin import Copy, GetVarShape
from .core.tensor.raw_tensor import as_device
from .core.tensor.tensor_wrapper import ArrayMethodMixin
from .device import _valid_device, get_default_device
from .utils.deprecation import deprecated
class Tensor(_Tensor):
class Tensor(_Tensor, ArrayMethodMixin):
grad = None
dmap_callback = None
q_dict = {"mode": None, "scale": None, "zero_point": None}
def __init__(self, data, dtype=None, device=None):
def __new__(cls, data, dtype=None, device=None):
if device is None:
device = get_default_device()
self.q_dict = {"mode": None, "scale": None, "zero_point": None}
super().__init__(data, dtype=dtype, device=device)
cn = get_default_device()
elif isinstance(device, str):
if cls.dmap_callback is not None:
cn = CompNode(cls.dmap_callback(device))
else:
cn = CompNode(device)
else:
assert isinstance(device, CompNode)
cn = device
if isinstance(data, _Tensor):
obj = _Tensor.__new__(cls, data)
else:
obj = _Tensor.__new__(cls, data, dtype, cn)
return obj
@property
def shape(self):
shape = super().shape
if shape == () or not use_symbolic_shape():
return shape
return apply(GetVarShape(), self)[0]
@property
def _tuple_shape(self):
return super().shape
def __repr__(self):
piece = "Tensor("
with np.printoptions(precision=4, suppress=True):
piece += "{}".format(str(self.numpy()))
if self.dtype != np.float32:
piece += ", dtype={}".format(np.dtype(self.dtype).name)
piece += ", device={}".format(self.device) + ")"
return piece
@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
def set_value(self, value):
if not isinstance(value, _Tensor):
value = Tensor(value, dtype=self.dtype, device=self.device)
self._reset(value)
@deprecated(version="1.0", reason="use *= 0 instead")
......@@ -61,27 +101,22 @@ class Tensor(_Tensor):
def __hash__(self):
return id(self)
def __getnewargs__(self):
r""" __getnewargs__ will be called for pickle serialization or deep copy
"""
return (self.numpy(), self.dtype, self.device.logical_name)
def __getstate__(self):
r""" __getstate__ will be called for pickle serialization or deep copy
"""
state = {
"data": self.numpy(),
"device": self.device.logical_name,
"dtype": self.dtype,
"qdict": self.q_dict,
}
return state
def __setstate__(self, state):
data = state.pop("data")
logical_device = state.pop("device")
if self.dmap_callback is not None:
assert isinstance(logical_device, str)
logical_device = self.dmap_callback(logical_device)
dtype = state.pop("dtype")
self.q_dict = state.pop("qdict")
super().__init__(data, dtype=dtype, device=logical_device)
def detach(self):
r"""
......@@ -89,8 +124,7 @@ class Tensor(_Tensor):
during backward gradient calcuation, i.e. its gradient is zero.
"""
Wrapper = type(self)
Tensor = type(self.__wrapped__)
return Wrapper(Tensor(self.__wrapped__._data))
return Wrapper(self)
tensor = Tensor
......
/**
* \file imperative/python/src/grad.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./grad.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/utils/mempool.h"
namespace py = pybind11;
namespace mgb::imperative::python {
namespace {
struct GradSlotWeakPtr {
std::weak_ptr<GradFn> grad_fn;
size_t idx;
};
} // namespace
struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> {
using Base = intrusive_list::Node<GradProducerRecord>;
GradProducerRecord() = default;
GradProducerRecord(GradProducerRecord::head_t& head) : Base(intrusive_list::after_t{}, head) {}
// GradProducerRecord(GradProducerRecord&&) = default;
// GradProducerRecord& operator=(GradProducerRecord&) = default;
// GradProducerRecord& operator=(GradProducerRecord&&) = default;
};
struct GradSlot {
std::shared_ptr<Tensor> grad;
py::object callback;
GradProducerRecord::head_t producer_head;
};
struct GradSlotProducerPtr : GradSlotPtr {
GradProducerRecord producer_record;
GradSlotProducerPtr() = default;
GradSlotProducerPtr(GradInfo& info) : GradSlotPtr(info), producer_record(info->producer_head) {}
};
struct GradFn : std::enable_shared_from_this<GradFn> {
static MemPool<GradFn> pool;
std::weak_ptr<GradKey> key;
SmallVector<GradSlot> slots;
SmallVector<GradSlotProducerPtr> dsts;
SmallVector<std::shared_ptr<Tensor>> closure;
std::shared_ptr<BackwardGraphResult> backward_graph;
bool in_ref_keeper = false;
static void deleter(GradFn* ptr) {
pool.free(ptr);
}
std::shared_ptr<GradFn> make() {
return std::shared_ptr<GradFn>(pool.alloc(), &deleter);
}
void clear() {
key.reset();
slots.clear();
dsts.clear();
closure.clear();
backward_graph.reset();
}
};
GradSlot* GradSlotPtr::operator->() {
return &grad_fn->slots[idx];
}
namespace {
struct BackwardGraphCache : std::unordered_map<size_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject {
std::shared_ptr<void> on_comp_node_finalize() override {
clear();
return {};
}
} backward_graph_cache;
std::shared_ptr<BackwardGraphResult> make_backward_graph(
ApplyContext& ctx, const apply_result_t& outputs) {
// hash
static_assert(alignof(size_t) % alignof(bool) == 0);
size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool);
alignas(alignof(size_t)) std::byte buf[buf_size];
size_t* size_t_ptr = reinterpret_cast<size_t*>(buf);
bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2));
bool* bool_ptr0 = bool_ptr;
*(size_t_ptr++) = ctx.op->hash();
for (size_t i = 0; i < ctx.nargs; ++i) {
*(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle());
*(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node());
*(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn);
}
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) &&
bool_ptr == reinterpret_cast<bool*>(buf + buf_size));
size_t key = XXHash{}.update(buf, buf_size).digest();
auto&& iter = backward_graph_cache.find(key);
if (iter != backward_graph_cache.end()) {
return iter->second;
}
// slow path
SmallVector<LogicalTensorDesc> inputs(ctx.nargs);
SmallVector<bool> input_requires_grad(ctx.nargs, false);
SmallVector<bool> output_has_grad(outputs.size(), true);
for (size_t i = 0; i < ctx.nargs; ++i) {
inputs[i].comp_node = ctx.args[i]->comp_node();
inputs[i].layout.dtype = ctx.args[i]->dtype();
input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn);
}
auto result = std::make_shared<BackwardGraphResult>(
proxy_graph_detail::make_backward_graph(
*ctx.op, inputs, input_requires_grad, output_has_grad));
if (!result->backward) {
result.reset();
}
backward_graph_cache.emplace(key, result);
return result;
}
} // namespace
apply_result_t apply_grad(ApplyContext& ctx) {
std::shared_ptr<GradKey> grad_key;
for (size_t i = 0; i < ctx.nargs; ++i) {
auto* tensor = ctx.args[i];
if (tensor->m_grad_info.grad_fn) {
auto&& input_grad_key = tensor->m_grad_info.grad_fn->key.lock();
// tensor is attached to a live GradKey
if (input_grad_key && input_grad_key->active) {
if (grad_key) {
if (grad_key != input_grad_key) {
PyErr_SetString(PyExc_NotImplementedError, "second order grad");
throw pyext17::py_err_set();
}
} else {
grad_key = std::move(input_grad_key);
}
} else {
// cleanup stale grad info
// under what condition?
tensor->m_grad_info = {};
tensor->m_flags &= ~Tensor::Flags::GRAD;
}
} else {
tensor->m_flags &= ~Tensor::Flags::GRAD;
}
}
ctx.flags &= ~Tensor::Flags::GRAD;
// perform forward apply_op or trace
auto outputs = apply(ctx);
if (!grad_key) {
return outputs;
}
auto backward_graph = make_backward_graph(ctx, outputs);
if (!backward_graph) {
return outputs;
}
auto grad_fn = std::make_shared<GradFn>();
grad_fn->key = grad_key;
grad_fn->slots.resize(outputs.size());
grad_fn->backward_graph = std::move(backward_graph);
grad_fn->dsts.reserve(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
if (grad_fn->backward_graph->input_has_grad[i]) {
auto& input_grad_info = ctx.args[i]->m_grad_info;
grad_fn->dsts.emplace_back(input_grad_info);
grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head);
} else {
grad_fn->dsts.emplace_back();
}
}
auto& save_for_backward = grad_fn->backward_graph->save_for_backward;
grad_fn->closure.reserve(std::count_if(save_for_backward.begin(), save_for_backward.end(), [](bool p){return p;}));
// given op, taking gradient of output_tensor_list wrt input_tensor_list:
//
// save_for_backward[0:nargs-1]: whether input tensor requires gradient,
// i.e., whether it is in input_tensor_list
//
// save_for_backward[nargs:nargs+outputs.size()-1]: whether output tensor is
// needed to calculate gradients
//
// save_for_backward[-outputs.size():]: whether output tensor is in
// output_tensor_list
//
// Example: perform c = a * b, where a is input data, b is parameter to be
// optimized, save_for_backward = [1, 1, 0, 1]
mgb_assert(ctx.nargs + 2 * outputs.size() == save_for_backward.size());
// record input tensors needed to take grad
for (size_t i = 0; i < ctx.nargs; ++i) {
if (save_for_backward[i]) {
grad_fn->closure.push_back(ctx.args[i]->shared_from_this());
}
}
// record output tensors needed to take grad
for (size_t i = 0; i < outputs.size(); ++i) {
bool requires_grad = save_for_backward[ctx.nargs + outputs.size() + i];
if (save_for_backward[ctx.nargs + i]) {
grad_fn->closure.push_back(outputs[i]);
if (requires_grad) {
// avoid reference cycle [Tensor <-> GradFn]
outputs[i] = outputs[i]->copy();
}
}
if (requires_grad) {
auto& grad_info = outputs[i]->m_grad_info;
grad_info.grad_fn = grad_fn;
grad_info.idx = i;
grad_info.insert_after(grad_key->free_vars_head);
outputs[i]->m_flags |= Tensor::Flags::GRAD;
}
}
// record forward history
grad_key->tape.emplace_back(grad_fn);
return outputs;
}
void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) {
if (nargs != 2) {
throw py::type_error("expect 2 arguments");
}
auto* tw = TensorWrapper::cast_safe(args[0]);
if (!tw) {
throw py::type_error("argument 1 must be Tensor");
}
auto* tensor = tw->m_tensor.get();
py::object callback;
if (args[1] != Py_None) {
callback = py::reinterpret_borrow<py::object>(args[1]);
}
m_key->attach(tensor, std::move(callback));
}
//! GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach
void GradKey::attach(Tensor* tensor, pybind11::object callback) {
if (!active) {
throw py::value_error("grad key finalized");
}
if (tensor->m_grad_info.grad_fn) {
if (tensor->m_grad_info.grad_fn->key.lock().get() != this) {
PyErr_SetString(PyExc_NotImplementedError, "second order grad");
throw pyext17::py_err_set();
}
if (tensor->m_grad_info->callback) {
throw py::value_error("callback already set on this tensor");
}
} else {
tensor->m_grad_info.idx = 0;
auto& grad_fn = tensor->m_grad_info.grad_fn;
grad_fn = std::make_shared<GradFn>();
grad_fn->key = shared_from_this();
grad_fn->slots.resize(1);
tensor->m_grad_info.insert_after(free_vars_head);
tensor->m_flags |= Tensor::Flags::GRAD;
}
tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback);
}
void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta) {
if (!grad) {
grad = std::forward<decltype(delta)>(delta);
return;
}
static ApplyContext ctx;
if (!ctx.op) {
ctx.op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD));
ctx.nargs = 2;
}
Tensor* args[2] = {grad.get(), delta.get()};
ctx.args = args;
ctx.flags = grad->m_flags | delta->m_flags;
grad = apply(ctx)[0];
}
void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
if (!active) {
throw py::value_error("finalized");
}
if (tensors.size() != grads.size()) {
throw py::value_error("tensor and grad size mismatch");
}
// this GradKey is marked inactive here
active = false;
struct CleanupGuard {
GradKey* owner;
CleanupGuard(GradKey* this_) : owner(this_) {}
~CleanupGuard() {owner->cleanup();}
} _cleanup_guard(this);
if (tape.empty() || grads.empty()) return;
PyTypeObject* pytype = Py_TYPE(grads[0]->self().ptr());
for (size_t i = 0; i < tensors.size(); ++i) {
auto& grad_info = tensors[i]->m_tensor->m_grad_info;
if (grad_info.grad_fn && grad_info.grad_fn->key.lock().get() == this) {
grad_info->grad = grads[i]->m_tensor;
}
}
std::vector<std::shared_ptr<GradFn>> ref_keeper;
ref_keeper.reserve(tape.size());
// back-propagation in reverse order
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
auto&& grad_fn = tape[k].lock();
if (!grad_fn) continue;
if (grad_fn->backward_graph) {
for (size_t i = 0; i < grad_fn->slots.size(); ++i) {
// grad_fn->dsts correspond to input tensors during forward
// calculation, grad_fn->slots correspond to output tensors.
// condition true means the output tensor has gradient for
// back-propagation
if (grad_fn->backward_graph->save_for_backward[grad_fn->dsts.size() + grad_fn->slots.size() + i]) {
grad_fn->closure.push_back(std::move(grad_fn->slots[i].grad));
}
}
ApplyContext ctx;
ctx.op = grad_fn->backward_graph->backward;
ctx.flags = 0;
ctx.nargs = grad_fn->closure.size();
Tensor* args[ctx.nargs];
for (size_t i = 0; i < ctx.nargs; ++i) {
args[i] = grad_fn->closure[i].get();
mgb_assert(args[i]);
ctx.flags |= args[i]->m_flags;
}
ctx.args = args;
auto grads = apply(ctx);
size_t j = 0;
for (size_t i = 0; i < grad_fn->dsts.size(); ++i) {
if (grad_fn->backward_graph->input_has_grad[i]) {
auto& dst = grad_fn->dsts[i];
// grads[j] is consumed in accum_grad
accum_grad(dst->grad, std::move(grads[j]));
++j;
}
}
mgb_assert(j == grads.size());
}
for (auto&& dst : grad_fn->dsts) {
if (!dst.grad_fn) continue;
if (!dst.grad_fn->in_ref_keeper) {
dst.grad_fn->in_ref_keeper = true;
ref_keeper.push_back(dst.grad_fn);
}
// grad_fn->clear will unlink current dst.producer_record
// such that if dst.producer_record.next is false, dst accumulates
// all the gradients
if (!dst.producer_record.next && dst->callback && dst->grad) {
dst->callback(TensorWrapper::make(pytype, dst->grad));
}
}
grad_fn->clear();
} // finish tape loop
}
void GradKey::cleanup() {
active = false;
tape.clear();
for (intrusive_list::Iterator it(free_vars_head); it;) {
it->grad_fn.reset();
(it++)->unlink();
}
}
void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
m_key->backward(std::move(tensors), std::move(grads));
}
GradKey::~GradKey() {
cleanup();
}
} // namespace mgb::imperative::python
/**
* \file imperative/python/src/grad.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "./tensor.h"
#include <megbrain/utils/small_vector.h>
#include <memory>
namespace mgb::imperative::python {
apply_result_t apply_grad(ApplyContext& ctx);
struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
std::string name;
bool active = true;
GradInfo::head_t free_vars_head;
std::vector<std::weak_ptr<GradFn>> tape;
~GradKey();
void attach(Tensor* tensor, pybind11::object callback);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
void cleanup();
};
struct GradKeyWrapper {
using wrap_t = pyext17::wrap<GradKeyWrapper>;
static constexpr auto tp_name = pybind11::detail::_("GradKey");
std::shared_ptr<GradKey> m_key;
inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {}
void attach(PyObject*const* args, size_t nargs);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
};
} // namespace mgb::imperative::python
namespace pybind11::detail {
template<> struct type_caster<mgb::imperative::python::GradKeyWrapper> : mgb::imperative::python::GradKeyWrapper::wrap_t::caster {};
} // namespace pybind11::detail
/**
* \file imperative/python/src/grad_info.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <memory>
#include "./intrusive_list.h"
namespace mgb::imperative::python {
struct GradFn;
struct GradSlot;
struct GradSlotPtr {
std::shared_ptr<GradFn> grad_fn;
size_t idx;
GradSlot* operator->();
};
struct GradInfo : GradSlotPtr, intrusive_list::Node<GradInfo, intrusive_list::before_t> {
GradInfo() = default;
GradInfo(GradInfo&) = default;
GradInfo(GradInfo&&) = default;
GradInfo& operator=(GradInfo&) = default;
GradInfo& operator=(GradInfo&&) = default;
};
} // namespace mgb::imperative::python
/**
* \file imperative/python/src/intrusive_list.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/utils/metahelper.h"
namespace mgb::imperative::python::intrusive_list {
// copy policy
struct after_t {};
struct before_t {};
struct disable_t {};
template <typename T> struct Tail;
// invariant: next->prev == this
template <typename T>
struct Head {
Tail<T>* next;
Head(Tail<T>* node = nullptr) : next(node) {}
Head(const Head<T>&) = delete;
Head<T>& operator=(const Head<T>&) = delete;
Head(Head<T>&& rhs) : next(rhs.next) {
rhs.next = nullptr;
if (next) {
next->prev = this;
}
}
Head<T>& operator=(Head<T>&& rhs) {
mgb_assert(!next);
next = rhs.next;
rhs.next = nullptr;
if (next) {
next->prev = this;
}
return *this;
}
~Head() {
if (next) {
next->prev = nullptr;
}
}
};
// invariant: prev->next == this
template <typename T>
struct Tail {
Head<T>* prev;
Tail(Head<T>* node = nullptr) : prev(node) {}
Tail(const Tail<T>&) = delete;
Tail<T>& operator=(const Tail<T>&) = delete;
Tail(Tail<T>&& rhs) : prev(rhs.prev) {
rhs.prev = nullptr;
if (prev) {
prev->next = this;
}
}
Tail<T>& operator=(Tail<T>&& rhs) {
mgb_assert(!prev);
prev = rhs.prev;
rhs.prev = nullptr;
if (prev) {
prev->next = this;
}
return *this;
}
~Tail() {
if (prev) {
prev->next = nullptr;
}
}
};
template <typename T, typename policy> struct Node;
template <typename T>
class Iterator {
T* ptr;
void inc() {ptr = static_cast<T*>(ptr->Head<T>::next);}
void dec() {ptr = static_cast<T*>(ptr->Head<T>::prev);}
public:
Iterator(Head<T>& head) : ptr(static_cast<T*>(head.next)) {}
Iterator(Tail<T>& tail) : ptr(static_cast<T*>(tail.prev)) {}
template<typename policy>
Iterator(Node<T, policy>& node) : ptr(static_cast<T*>(&node)) {}
T& operator*() {return *static_cast<T*>(ptr);}
T* operator->() {return static_cast<T*>(ptr);}
operator bool() {return ptr;}
bool operator==(const Iterator<T>& rhs) {return ptr == rhs.ptr;}
Iterator& operator++() {inc(); return *this;}
Iterator& operator--() {dec(); return *this;}
Iterator operator++(int) {auto ret = *this; inc(); return ret;}
Iterator operator--(int) {auto ret = *this; dec(); return ret;}
};
// Node in a doubly linked list. Unlike std::list, nodes are not owned by a container.
// Instead, nodes may join or leave a list freely.
// NOTE: Derived classes have to explicitly declare copy / assignment as default,
// otherwise the compiler generated version would use the const T& signature,
// which is deleted.
template <typename T = void, typename policy = disable_t>
struct Node : Tail<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>>,
Head<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>> {
private:
using this_t = Node<T, policy>;
using U = std::conditional_t<std::is_same_v<T, void>, this_t, T>;
public:
using head_t = Head<U>;
using tail_t = Tail<U>;
using head_t::next;
using tail_t::prev;
Node() = default;
Node(const this_t&) = delete;
this_t& operator=(const this_t&) = delete;
//! constructed node is inserted after the input node
Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) {
node.next = this;
if (next) {
next->prev = this;
}
}
//! constructed node is inserted before the input node
Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) {
node.prev = this;
if (prev) {
prev->next = this;
}
}
Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) {
rhs.prev = nullptr;
rhs.next = nullptr;
if (prev) {
prev->next = this;
}
if (next) {
next->prev = this;
}
}
Node& operator=(this_t&& rhs) {
unlink();
prev = rhs.prev;
next = rhs.next;
rhs.prev = nullptr;
rhs.next = nullptr;
if (prev) {
prev->next = this;
}
if (next) {
next->prev = this;
}
return *this;
}
template<typename p = policy,
typename = std::enable_if_t<std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>>
Node(this_t& rhs) : Node(policy{}, rhs) {}
template<typename p = policy,
typename = std::enable_if_t<std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>>
this_t& operator=(this_t& rhs) {
insert(policy{}, rhs);
return *this;
}
void unlink() {
if (prev) {
prev->next = next;
}
if (next) {
next->prev = prev;
}
prev = nullptr;
next = nullptr;
}
//! this node is unlinked from its list and inserted after the input node
void insert(after_t, head_t& node) {
unlink();
prev = &node;
next = node.next;
node.next = this;
if (next) {
next->prev = this;
}
}
//! this node is unlinked from its list and inserted before the input node
void insert(before_t, tail_t& node) {
unlink();
next = &node;
prev = node.prev;
node.prev = this;
if (prev) {
prev->next = this;
}
}
void insert_before(tail_t& node) {insert(before_t{}, node);}
void insert_after(head_t& node) {insert(after_t{}, node);}
~Node() {
unlink();
}
};
} // namespace mgb::imperative::python::intrusive_list
......@@ -23,7 +23,10 @@
#include "./dispatcher.h"
#include "./tensor.h"
namespace py = pybind11;
using namespace mgb::imperative::python;
#ifndef MODULE_NAME
#define MODULE_NAME imperative_rt
......@@ -68,4 +71,6 @@ PYBIND11_MODULE(MODULE_NAME, m) {
py::getattr(m, "__dict__"));
init_dispatcher(submodule(m, "dispatcher"));
init_tensor(submodule(m, "core2"));
}
......@@ -15,6 +15,7 @@
#include <vector>
#include <utility>
#include <Python.h>
#include <pybind11/pybind11.h>
namespace pyext17 {
......@@ -53,6 +54,26 @@ inline PyObject* cvt_retval(PyObject* rv) {
return cvt_retval(__VA_ARGS__); \
}
inline int cvt_retint(int ret) {
return ret;
}
#define CVT_RET_INT(...) \
if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \
__VA_ARGS__; \
return 0; \
} else { \
return cvt_retint(__VA_ARGS__); \
}
struct py_err_set : std::exception {};
#define HANDLE_ALL_EXC(RET) catch(py_err_set&) {return RET;} \
catch(pybind11::error_already_set& e) {e.restore(); return RET;} \
catch(pybind11::builtin_exception& e) {e.set_error(); return RET;} \
catch(std::exception& e) {PyErr_SetString(PyExc_RuntimeError, e.what()); return RET;}
template <typename T>
struct wrap {
private:
......@@ -111,7 +132,9 @@ private:
static PyObject* impl(PyObject* self, PyObject*) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)());
try {
CVT_RET_PYOBJ((inst->*f)());
} HANDLE_ALL_EXC(nullptr)
}
};
......@@ -121,7 +144,9 @@ private:
static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)(args, kwargs));
try {
CVT_RET_PYOBJ((inst->*f)(args, kwargs));
} HANDLE_ALL_EXC(nullptr)
}
};
......@@ -132,7 +157,9 @@ private:
static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)(args, nargs));
try {
CVT_RET_PYOBJ((inst->*f)(args, nargs));
} HANDLE_ALL_EXC(nullptr)
}
#else
static constexpr int flags = METH_VARARGS;
......@@ -141,7 +168,9 @@ private:
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
auto* arr = &PyTuple_GET_ITEM(args, 0);
auto size = PyTuple_GET_SIZE(args);
CVT_RET_PYOBJ((inst->*f)(arr, size));
try {
CVT_RET_PYOBJ((inst->*f)(arr, size));
} HANDLE_ALL_EXC(nullptr)
}
#endif
};
......@@ -152,7 +181,9 @@ private:
static PyObject* impl(PyObject* self, PyObject* obj) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)(obj));
try {
CVT_RET_PYOBJ((inst->*f)(obj));
} HANDLE_ALL_EXC(nullptr)
}
};
......@@ -162,6 +193,55 @@ private:
return {name, (PyCFunction)M::impl, M::flags, doc};
}
template<auto f>
struct getter {
using F = decltype(f);
static PyObject* impl(PyObject* self, void* closure) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
try {
if constexpr (std::is_invocable_v<F, PyObject*, void*>) {
CVT_RET_PYOBJ(f(self, closure));
} else if constexpr (std::is_invocable_v<F, T, void*>) {
CVT_RET_PYOBJ((inst->*f)(closure));
} else if constexpr (std::is_invocable_v<F, T>) {
CVT_RET_PYOBJ((inst->*f)());
} else {
static_assert(!std::is_same_v<F, F>);
}
} HANDLE_ALL_EXC(nullptr)
}
};
template<auto f>
struct setter {
using F = decltype(f);
template<typename = void>
static int impl_(PyObject* self, PyObject* val, void* closure) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
try {
if constexpr (std::is_invocable_v<F, PyObject*, PyObject*, void*>) {
CVT_RET_INT(f(self, val, closure));
} else if constexpr (std::is_invocable_v<F, T, PyObject*, void*>) {
CVT_RET_INT((inst->*f)(val, closure));
} else if constexpr (std::is_invocable_v<F, T, PyObject*>) {
CVT_RET_INT((inst->*f)(val));
} else {
static_assert(!std::is_same_v<F, F>);
}
} HANDLE_ALL_EXC(-1)
}
static constexpr auto impl = []() {if constexpr (std::is_same_v<F, std::nullptr_t>) return nullptr;
else return impl_<>;}();
};
template<auto get, auto set = nullptr>
static constexpr PyGetSetDef make_getset_def(const char* name, const char* doc = nullptr, void* closure = nullptr) {
return {const_cast<char *>(name), getter<get>::impl, setter<set>::impl, const_cast<char *>(doc), closure};
}
// polyfills
struct tp_vectorcall {
......@@ -216,16 +296,26 @@ private:
template<typename = void>
static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
struct FreeGuard {
PyObject* self;
PyTypeObject* type;
~FreeGuard() {if (self) type->tp_free(self);}
};
auto* self = type->tp_alloc(type, 0);
FreeGuard free_guard{self, type};
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
if constexpr (has_vectorcall && tp_vectorcall::valid) {
reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>;
}
if constexpr (varkw) {
new(inst) T(args, kwargs);
} else {
new(inst) T();
}
try {
if constexpr (varkw) {
new(inst) T(args, kwargs);
} else {
new(inst) T();
}
} HANDLE_ALL_EXC(nullptr)
free_guard.self = nullptr;
return self;
}
......@@ -250,6 +340,7 @@ private:
public:
class TypeBuilder {
std::vector<PyMethodDef> m_methods;
std::vector<PyGetSetDef> m_getsets;
PyTypeObject m_type;
bool m_finalized = false;
bool m_ready = false;
......@@ -259,6 +350,13 @@ public:
throw std::runtime_error("type is already finalized");
}
}
static const char* to_c_str(const char* s) {return s;}
template <size_t N, typename... Ts>
static const char* to_c_str(const pybind11::detail::descr<N, Ts...>& desc) {
return desc.text;
}
public:
TypeBuilder(const TypeBuilder&) = delete;
TypeBuilder& operator=(const TypeBuilder&) = delete;
......@@ -266,7 +364,7 @@ public:
TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} {
constexpr auto has_tp_name = HAS_MEMBER(T, tp_name);
if constexpr (has_tp_name) {
m_type.tp_name = T::tp_name;
m_type.tp_name = to_c_str(T::tp_name);
}
m_type.tp_dealloc = tp_dealloc::value;
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL
......@@ -291,8 +389,17 @@ public:
return m_ready;
}
bool isinstance(PyObject* op) {
return PyObject_TypeCheck(op, &m_type);
}
bool isexact(PyObject* op) {
return Py_TYPE(op) == &m_type;
}
PyObject* finalize() {
if (!m_finalized) {
m_finalized = true;
if (m_methods.size()) {
m_methods.push_back({0});
if (m_type.tp_methods) {
......@@ -301,6 +408,14 @@ public:
}
m_type.tp_methods = &m_methods[0];
}
if (m_getsets.size()) {
m_getsets.push_back({0});
if (m_type.tp_getset) {
PyErr_SetString(PyExc_SystemError, "tp_getset is already set");
return nullptr;
}
m_type.tp_getset = &m_getsets[0];
}
if (PyType_Ready(&m_type)) {
return nullptr;
}
......@@ -315,12 +430,64 @@ public:
m_methods.push_back(make_meth_def<f>(name, doc));
return *this;
}
template<auto get, auto set = nullptr>
TypeBuilder& def_getset(const char* name, const char* doc = nullptr, void* closure = nullptr) {
check_finalized();
m_getsets.push_back(make_getset_def<get, set>(name, doc, closure));
return *this;
}
};
static TypeBuilder& type() {
static TypeBuilder type_helper;
return type_helper;
}
template<typename... Args>
static PyObject* cnew(Args&&... args) {
auto* pytype = type().operator->();
auto* self = pytype->tp_alloc(pytype, 0);
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
if constexpr (has_vectorcall && tp_vectorcall::valid) {
reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>;
}
new(inst) T(std::forward<Args>(args)...);
return self;
}
template<typename... Args>
static PyObject* cnew_with_type(PyTypeObject* pytype, Args&&... args) {
auto* self = pytype->tp_alloc(pytype, 0);
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
if constexpr (has_vectorcall && tp_vectorcall::valid) {
reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>;
}
new(inst) T(std::forward<Args>(args)...);
return self;
}
struct caster {
static constexpr auto name = T::tp_name;
T* value;
bool load(pybind11::handle src, bool convert) {
if (wrap_t::type().isinstance(src.ptr())) {
value = reinterpret_cast<wrap_t*>(src.ptr())->inst();
return true;
}
return false;
}
template <typename U> using cast_op_type = pybind11::detail::cast_op_type<U>;
operator T*() { return value; }
operator T&() { return *value; }
};
};
} // namespace pyext17
......@@ -328,3 +495,5 @@ public:
#undef HAS_MEMBER_TYPE
#undef HAS_MEMBER
#undef CVT_RET_PYOBJ
#undef CVT_RET_INT
#undef HANDLE_ALL_EXC
/**
* \file imperative/python/src/tensor.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./tensor.h"
#include "./grad.h"
#include "./common.h"
#include "./numpy_dtypes.h"
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
namespace py = pybind11;
namespace mgb::imperative::python {
std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py;
apply_result_t apply(ApplyContext& ctx) {
// emulating scalar should be put to specific op's apply, e.g.,
// elementwise, reduce, typecvt. Currently it's still handled at python
// side. It could be move to C++ side if it has an impact on performance
if (ctx.flags & Tensor::Flags::SCALAR) {
// TODO: emulate scalar
}
if (ctx.flags & Tensor::Flags::GRAD) {
return apply_grad(ctx);
}
if (ctx.flags & Tensor::Flags::TRACE) {
// TODO: trace
} else {
SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
handles[i] = ctx.args[i]->m_handle.get();
}
auto output_handles = interpreter_for_py->apply_op(ctx.op, handles);
apply_result_t outputs;
outputs.reserve(output_handles.size());
for (auto h : output_handles) {
outputs.emplace_back(std::make_shared<Tensor>(h));
}
return outputs;
}
mgb_assert(0);
}
PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) {
try {
// if (kwnames && PyTuple_GET_SIZE(kwnames)) {
// PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
// return nullptr;
// }
if (!nargs) {
PyErr_SetString(PyExc_TypeError, "expect Op");
return nullptr;
}
auto* op = args[0];
if (!strcmp(op->ob_type->tp_base->tp_name,"PodOpVisitor") || !strcmp(op->ob_type->tp_base->tp_name,"IndexingOpBase")){
op = PyObject_CallMethod(op,"to_c","");
}
PyTypeObject* pytype = args[1]->ob_type;
++args;
--nargs;
ApplyContext ctx;
ctx.flags = 0;
ctx.op = py::handle(op).cast<std::shared_ptr<OpDef>>();
SmallVector<Tensor*, 64> tensors(nargs);
ctx.args = &tensors[0];
ctx.nargs = nargs;
for (size_t i = 0; i < nargs; ++i) {
TensorWrapper* tw = TensorWrapper::cast_safe(args[i]);
if (!tw) {
PyErr_SetString(PyExc_TypeError, "expect Tensor");
return nullptr;
}
auto* t = tensors[i] = tw->m_tensor.get();
ctx.flags |= t->m_flags;
}
// TODO: set TRACE flag
auto outputs = apply(ctx);
size_t nout = outputs.size();
auto ret = py::tuple(nout);
for (size_t i = 0; i < nout; ++i) {
ret[i] = TensorWrapper::make(pytype, std::move(outputs[i]));
}
return ret.release().ptr();
} catch (std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr;
}
}
TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
if (kwargs && PyDict_Size(kwargs)) {
throw py::type_error("keyword argument not allowed");
}
auto nargs = PyTuple_Size(args);
auto tup = py::reinterpret_borrow<py::tuple>(args);
if (nargs == 0) {
throw py::type_error("too few arguments");
}
if (auto* t = cast_safe(tup[0].ptr())) {
if (nargs > 1) {
throw py::type_error("expect 1 argument");
}
m_tensor = t->m_tensor;
} else {
if (nargs != 3) {
throw py::type_error("expect 3 arguments");
}
py::detail::loader_life_support life_sup; // required to cast DType
auto data = tup[0].cast<py::array>();
DType dtype = tup[1].cast<DType>();
CompNode cn = tup[2].cast<CompNode>();
interpreter::Interpreter::Handle handle;
constexpr auto size_threshhold = TensorShape::MAX_NDIM;
if (data.size() > size_threshhold) {
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype));
} else {
HostTensorND ret(cn);
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype));
}
m_tensor = std::make_shared<Tensor>(handle);
if (data.ndim() == 0) {
m_tensor->m_flags |= Tensor::Flags::SCALAR;
}
}
}
PyObject* TensorWrapper::shape() {
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0);
}
auto&& shape = m_tensor->shape();
if (!shape.ndim) {
Py_RETURN_NONE;
}
py::tuple ret(shape.ndim);
for (size_t i = 0; i < shape.ndim; ++i) {
ret[i] = shape[i];
}
return ret.release().ptr();
}
PyObject* TensorWrapper::dtype() {
return py::cast(m_tensor->dtype()).release().ptr();
}
PyObject* TensorWrapper::device() {
return py::cast(m_tensor->comp_node()).release().ptr();
}
PyObject* TensorWrapper::numpy() {
auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get());
auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
if (!arr) return nullptr;
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
mgb_assert(PyArray_Check(arr.ptr()));
return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
}
return arr.release().ptr();
}
void TensorWrapper::reset(PyObject* tensor) {
TensorWrapper* t = TensorWrapper::cast_safe(tensor);
if (!t) {
throw py::type_error("expect Tensor");
}
m_tensor = t->m_tensor;
}
PyObject* TensorWrapper::isscalar() {
if(m_tensor->m_flags & Tensor::Flags::SCALAR) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
void TensorWrapper::setscalar() {
m_tensor->m_flags |= Tensor::Flags::SCALAR;
}
struct TensorWeakRef {
std::weak_ptr<Tensor> wptr;
TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {}
py::object operator()() {
if (auto p = wptr.lock()) {
return TensorWrapper::make(p);
}
return py::none();
}
};
void init_tensor(py::module m) {
interpreter_for_py = interpreter::Interpreter::inst().create_channel();
auto* tensor_type = TensorWrapper::wrap_t::type()
.def<&TensorWrapper::numpy>("numpy")
.def_getset<&TensorWrapper::shape>("shape")
.def_getset<&TensorWrapper::dtype>("dtype")
.def_getset<&TensorWrapper::device>("device")
.def<&TensorWrapper::reset>("_reset")
.def<&TensorWrapper::isscalar>("isscalar")
.def<&TensorWrapper::setscalar>("setscalar")
.finalize();
if (!tensor_type) throw py::error_already_set();
py::setattr(m, "Tensor", tensor_type);
py::class_<TensorWeakRef>(m, "TensorWeakRef")
.def(py::init<const TensorWrapper&>())
.def("__call__", &TensorWeakRef::operator());
static PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr};
auto* apply_func = PyCFunction_NewEx(&apply_def, nullptr, nullptr);
if (!apply_func) throw py::error_already_set();
py::setattr(m, "apply", apply_func);
py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach")
.finalize();
if (!grad_key_type) throw py::error_already_set();
py::setattr(m, "GradKey", grad_key_type);
py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward));
}
} // namespace mgb::imperative::python
/**
* \file imperative/python/src/tensor.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <variant>
#include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h"
#include "./pyext17.h"
namespace mgb::imperative::python {
template<typename T, typename B = pybind11::object>
struct ObjectPtr : B {
using B::B;
T& operator*() {return reinterpret_cast<T&>(*B::ptr());}
T* operator->() {return reinterpret_cast<T*>(B::ptr());}
};
} // namespace mgb::imperative::python
#include "./grad_info.h" // for struct GradInfo
namespace mgb::imperative::python {
struct TraceInfo {
};
extern std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py;
class SharedHandle {
using Handle = interpreter::Interpreter::Handle;
static_assert(std::is_pointer_v<Handle>);
std::shared_ptr<std::remove_pointer_t<Handle>> holder;
public:
inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){
interpreter_for_py->del(h);
}) {}
SharedHandle(const SharedHandle&) = default;
SharedHandle& operator=(const SharedHandle&) = default;
SharedHandle(SharedHandle&&) = default;
SharedHandle& operator=(SharedHandle&&) = default;
inline Handle get() {return holder.get();}
};
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
using flags_t = uint64_t;
struct Flags {
static constexpr flags_t SCALAR = 1;
static constexpr flags_t GRAD = 1 << 1;
static constexpr flags_t TRACE = 1 << 2;
};
flags_t m_flags = 0;
GradInfo m_grad_info;
TraceInfo m_trace_info;
SharedHandle m_handle;
using Handle = interpreter::Interpreter::Handle;
inline explicit Tensor(Handle handle) : m_handle(handle) {}
inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)) {}
~Tensor() = default;
inline std::shared_ptr<Tensor> copy() {
auto ret = std::make_shared<Tensor>(m_handle);
ret->m_flags = m_flags;
ret->m_grad_info = m_grad_info;
ret->m_trace_info = m_trace_info;
return ret;
}
inline DType dtype() {return interpreter_for_py->get_dtype(m_handle.get());}
inline CompNode comp_node() {return interpreter_for_py->get_device(m_handle.get());}
inline TensorShape shape() {return interpreter_for_py->get_shape(m_handle.get());}
};
struct TensorWrapper {
std::shared_ptr<Tensor> m_tensor;
inline TensorWrapper(std::shared_ptr<Tensor> tensor = {}) : m_tensor(std::move(tensor)) {}
TensorWrapper(PyObject* args, PyObject* kwargs);
~TensorWrapper() = default;
static constexpr auto tp_name = pybind11::detail::_("Tensor");
using wrap_t = pyext17::wrap<TensorWrapper>;
friend wrap_t;
inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();}
inline static TensorWrapper* cast_safe(PyObject* op) {
if (!wrap_t::type().isinstance(op)) return nullptr;
return cast(op);
}
inline ObjectPtr<TensorWrapper, pybind11::handle> self() {return wrap_t::pycast(this);}
template <typename... Args>
static ObjectPtr<Tensor> make(Args&&... args) {
auto* op = wrap_t::cnew(std::forward<Args>(args)...);
return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
}
template <typename... Args>
static ObjectPtr<Tensor> make(PyTypeObject* pytype, Args&&... args) {
auto* op = wrap_t::cnew_with_type(pytype,std::forward<Args>(args)...);
return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
}
PyObject* shape();
PyObject* dtype();
PyObject* device();
PyObject* numpy();
void reset(PyObject*);
PyObject* isscalar();
void setscalar();
};
PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */);
struct ApplyContext {
Tensor::flags_t flags;
std::shared_ptr<OpDef> op;
Tensor*const* args;
size_t nargs;
};
using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;
apply_result_t apply(ApplyContext& ctx);
void init_tensor(pybind11::module);
} // namespace mgb::imperative::python
namespace pybind11::detail {
template<> struct type_caster<mgb::imperative::python::TensorWrapper> : mgb::imperative::python::TensorWrapper::wrap_t::caster {};
} // namespace pybind11::detail
/**
* \file imperative/python/src/trace.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
namespace mgb::imperative::python {
struct TraceInfo {
};
} // namespace mgb::imperative::python
......@@ -12,6 +12,7 @@ def test_basic():
config_async_level(3)
@pytest.mark.skip
def test_level1_infer_value():
config_async_level(1)
a = mge.tensor([[1, 2], [2, 3], [3, 4]], dtype="float32")
......@@ -22,6 +23,7 @@ def test_level1_infer_value():
d = F.reshape(a, c)
@pytest.mark.skip
def test_level1_infer_shape_with_unknown():
config_async_level(2)
a = mge.tensor([[1, 2, 2, 3]], dtype="float32")
......
......@@ -16,12 +16,11 @@ import pytest
import megengine as mge
import megengine.distributed as dist
import megengine.functional as F
from megengine.core._imperative_rt import TensorAttr, imperative
from megengine.core._imperative_rt import TensorAttr, core2, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply
from megengine.core._imperative_rt.imperative import sync
from megengine.core.autodiff.grad import Grad
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.core.tensor.tensor import Tensor, apply
from megengine.core.tensor.tensor_wrapper import TensorWrapper
from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.distributed import remote_recv, remote_send
......@@ -43,11 +42,11 @@ relu = _elwise(Elemwise.Mode.RELU)
def as_tensor(x):
return Tensor(as_raw_tensor(x, device=mge.device.get_default_device()))
return mge.Tensor(x)
def save_to(self, name="grad"):
def callback(tensor, grad):
def callback(grad):
setattr(self, name, grad)
return callback
......@@ -136,14 +135,14 @@ def test_2nd_grad():
def test_grad_with_tensor_wrapper():
x_np = np.random.rand(10).astype("float32")
x = TensorWrapper(x_np)
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = mul(x, x)
y = mul(y, y)
grad(y, TensorWrapper(np.ones_like(x_np)))
grad(y, mge.Tensor(np.ones_like(x_np)))
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)
......@@ -162,8 +161,8 @@ def test_release():
finally:
gc.enable()
x = TensorWrapper([0.0])
dy = TensorWrapper(np.ones_like(x.numpy()))
x = mge.Tensor([0.0])
dy = mge.Tensor(np.ones_like(x.numpy()))
@check
def _():
......@@ -173,25 +172,25 @@ def test_release():
@check
def _():
with Grad().wrt(x) as g:
with Grad().wrt(x):
pass
@check
def _():
with Grad().wrt(x) as g:
with Grad().wrt(x):
y = x * x
def test_grad_inplace():
x_np = np.random.rand(10).astype("float32")
x = TensorWrapper(x_np)
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = mul(x, x)
y *= y
grad(y, TensorWrapper(np.ones_like(x_np)))
grad(y, mge.Tensor(np.ones_like(x_np)))
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)
......@@ -199,16 +198,16 @@ def test_elemwise_add():
x_np = np.random.rand(10).astype("float32")
y_np = np.random.rand(10, 10).astype("float32")
dz_np = np.random.rand(10, 10).astype("float32")
x = TensorWrapper(x_np)
y = TensorWrapper(y_np)
dz = TensorWrapper(dz_np)
x = mge.Tensor(x_np)
y = mge.Tensor(y_np)
dz = mge.Tensor(dz_np)
refs = {}
def f(x, y):
x = x * 2
refs["x"] = weakref.ref(x.__wrapped__)
refs["y"] = weakref.ref(y.__wrapped__)
refs["x"] = TensorWeakRef(x)
refs["y"] = TensorWeakRef(y)
return x + y
grad = Grad().wrt(x, callback=save_to(x))
......@@ -226,14 +225,14 @@ def test_elemwise_add():
def test_elemwise_relu():
x_np = [1.0, -1.0]
dz_np = [1.0]
x = TensorWrapper(x_np)
dz = TensorWrapper(dz_np)
x = mge.Tensor(x_np)
dz = mge.Tensor(dz_np)
refs = {}
def f(x):
x = x * 2
refs["x"] = weakref.ref(x.__wrapped__)
refs["x"] = TensorWeakRef(x)
return relu(x)
grad = Grad().wrt(x, callback=save_to(x))
......@@ -258,7 +257,7 @@ def test_elemwise_relu_backward_fn():
def test_reshape():
x_np = np.random.rand(2, 5).astype("float32")
x = TensorWrapper(x_np)
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = x.reshape(5, 2)
......@@ -269,7 +268,7 @@ def test_reshape():
def test_subtensor():
x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np)
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = x[1:-1, :2]
......@@ -282,7 +281,7 @@ def test_subtensor():
def test_IndexingMultiAxisVec():
x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np)
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = x[[0, 2], [0, 2]]
......@@ -295,7 +294,7 @@ def test_IndexingMultiAxisVec():
def test_AxisAddRemove():
x_np = np.random.rand(1, 5).astype("float32")
x = TensorWrapper(x_np)
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = F.squeeze(F.expand_dims(x, 2), 0)
......@@ -308,7 +307,7 @@ def test_AxisAddRemove():
def test_Broadcast():
x_np = np.random.rand(3, 3, 1).astype("float32")
x = TensorWrapper(x_np)
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = F.broadcast_to(x, (3, 3, 10))
......@@ -319,7 +318,7 @@ def test_Broadcast():
def test_Reduce_sum():
x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np)
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = x.sum(axis=0)
......@@ -330,7 +329,7 @@ def test_Reduce_sum():
def test_Reduce_mean():
x_np = np.random.rand(3, 3).astype("float32")
x = TensorWrapper(x_np)
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = x.mean(axis=0)
......
......@@ -11,30 +11,29 @@ import collections
import numpy as np
import pytest
import megengine.core.tensor.raw_tensor
import megengine
import megengine.tensor as Tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops import builtin
from megengine.core.tensor import Tensor
from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import RawTensor, as_raw_tensor
def cvt_to_shape_desc(val, inpvar, config=None):
def as_tensor(val, device):
assert device is not None, "can not infer device"
# TODO: should copy to appropriate device
val = as_raw_tensor(val, device=device)
val = Tensor(val, device=device)
return val
device = None
if inpvar is not None:
assert isinstance(inpvar, RawTensor)
assert isinstance(inpvar, Tensor)
device = device or inpvar.device
if config is not None:
device = device or config.device
if isinstance(val, RawTensor):
if isinstance(val, Tensor):
return as_tensor(val, device)
if not isinstance(val, collections.abc.Iterable):
......@@ -43,7 +42,7 @@ def cvt_to_shape_desc(val, inpvar, config=None):
components = []
on_host = True
for i in val:
if isinstance(i, RawTensor):
if isinstance(i, Tensor):
on_host = False
device = device or i.device
else:
......@@ -62,7 +61,7 @@ def cvt_to_shape_desc(val, inpvar, config=None):
return as_tensor(shape, device)
for idx, v in enumerate(components):
if not isinstance(v, RawTensor):
if not isinstance(v, Tensor):
vi = int(v)
assert vi == v, "could not convert {} to int".format(v)
v = vi
......@@ -95,7 +94,7 @@ def canonize_inputs(inputs, *, config):
# and is called with concat([a, b]))
inputs = inputs[0]
if isinstance(inputs, RawTensor):
if isinstance(inputs, Tensor):
return [inputs]
old_inputs = inputs
......@@ -103,7 +102,7 @@ def canonize_inputs(inputs, *, config):
get_comp_node = None
need_cvt = False
for i in old_inputs:
if isinstance(i, RawTensor):
if isinstance(i, Tensor):
get_comp_node = lambda cn=i.device: cn
else:
need_cvt = True
......@@ -117,8 +116,8 @@ def canonize_inputs(inputs, *, config):
return config.comp_node
for idx, var in enumerate(inputs):
if not isinstance(var, RawTensor):
var = as_raw_tensor(var)
if not isinstance(var, Tensor):
var = Tensor(var)
inputs[idx] = var
return inputs
......@@ -131,15 +130,15 @@ def invoke_op(op, inputs_, cvt_inputs=canonize_inputs):
def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
assert isinstance(inp, RawTensor)
assert isinstance(inp, Tensor)
if not isinstance(tuple_val, tuple):
tuple_val = (tuple_val,)
def as_tensor(v):
if not isinstance(v, RawTensor):
if not isinstance(v, Tensor):
vi = np.ascontiguousarray(v, dtype=np.int32)
assert np.abs(vi - v).max() == 0, "bad index: {!r}".format(v)
v = as_raw_tensor(vi)
v = Tensor(vi)
return v
new_axes = []
......@@ -275,14 +274,14 @@ def batched_incr_mesh_indexing(input, value, tuple_val):
def test_transpose():
x = np.arange(10).reshape(2, 5).astype("int32")
xx = as_raw_tensor(x)
xx = Tensor(x)
(yy,) = transpose(xx, pattern=[1, -1, 0])
np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy())
def test_broadcast():
x = np.arange(10).reshape(1, 10).astype("int32")
xx = as_raw_tensor(x)
xx = Tensor(x)
(yy,) = broadcast(xx, (10, 10))
np.testing.assert_equal(np.repeat(x, 10, 0), yy.numpy())
......@@ -290,7 +289,7 @@ def test_broadcast():
def test_subtensor():
x = np.arange(25).reshape(5, 5).astype("int32")
d = np.arange(2).astype("int32")
xx = as_raw_tensor(x)
xx = Tensor(x)
(yy0,) = subtensor(xx, (slice(0, 4, 2), 3))
(yy1,) = set_subtensor(xx, d, (slice(0, 4, 2), 3))
(yy2,) = incr_subtensor(xx, d, (slice(0, 4, 2), 3))
......@@ -309,7 +308,7 @@ def test_subtensor():
def test_advance_indexing():
x = np.arange(25).reshape(5, 5).astype("int32")
d = np.arange(15).reshape(3, 5).astype("int32")
xx = as_raw_tensor(x)
xx = Tensor(x)
(yy0,) = advance_indexing(xx, ((0, 4, 2), slice(None, None, None)))
(yy1,) = set_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None)))
(yy2,) = incr_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None)))
......@@ -328,7 +327,7 @@ def test_advance_indexing():
def test_mesh_indexing():
x = np.arange(25).reshape(5, 5).astype("int32")
d = np.arange(6).reshape(3, 2).astype("int32")
xx = as_raw_tensor(x)
xx = Tensor(x)
(yy0,) = mesh_indexing(xx, (slice(0, 5, 2), (1, 3)))
(yy1,) = set_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3)))
(yy2,) = incr_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3)))
......@@ -355,7 +354,7 @@ def test_mesh_indexing():
def test_batched_mesh_indexing():
x = np.arange(24).reshape(2, 3, 4).astype("int32")
d = np.arange(12).reshape(2, 2, 3).astype("int32")
xx = as_raw_tensor(x)
xx = Tensor(x)
s = [(0, 1, 2), (1, 2, 3)]
(yy0,) = batched_mesh_indexing(xx, (slice(None, None, None), [(0, 2)] * 2, s))
(yy1,) = batched_set_mesh_indexing(
......
......@@ -9,12 +9,12 @@
import numpy as np
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.core.tensor.tensor_wrapper import TensorWrapper
from megengine.tensor import Tensor
def test_basic():
x_np = np.random.rand(10).astype("float32")
x = TensorWrapper(x_np)
x = Tensor(x_np)
y = x * x
y_np = y.numpy()
np.testing.assert_almost_equal(y_np, x_np * x_np)
......@@ -22,15 +22,15 @@ def test_basic():
def test_literal_arith():
x_np = np.random.rand(10).astype("float32")
x = TensorWrapper(x_np)
x = Tensor(x_np)
y = x * 2
y_np = y.numpy()
np.testing.assert_almost_equal(y_np, x_np * 2)
def test_matmul():
A = TensorWrapper(np.random.rand(5, 7).astype("float32"))
B = TensorWrapper(np.random.rand(7, 10).astype("float32"))
A = Tensor(np.random.rand(5, 7).astype("float32"))
B = Tensor(np.random.rand(7, 10).astype("float32"))
C = A @ B
np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6)
......@@ -38,7 +38,7 @@ def test_matmul():
def test_reduce():
def test_x(x_np):
for m in ["sum", "prod", "min", "max", "mean"]:
x = TensorWrapper(x_np)
x = Tensor(x_np)
y = getattr(x, m)(axis=-1, keepdims=True)
np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6)
......@@ -49,7 +49,7 @@ def test_reduce():
def test_set_subtensor():
x = TensorWrapper([1, 2, 3])
x = Tensor([1, 2, 3])
x[:] = [1, 1, 1]
np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6)
x[[0, 2]] = [3, 2]
......@@ -60,7 +60,7 @@ def test_set_subtensor():
def test_computing_with_numpy_array():
x = np.array([1, 2, 3], dtype=np.int32)
xx = TensorWrapper(x, device="cpu0")
xx = Tensor(x, device="cpu0")
y = np.array([1, 0, 3], dtype=np.int32)
assert np.add(xx, y).device == xx.device
np.testing.assert_equal(np.add(xx, y).numpy(), np.add(x, y))
......@@ -70,12 +70,12 @@ def test_computing_with_numpy_array():
def test_transpose():
x = np.random.rand(2, 5).astype("float32")
xx = TensorWrapper(x)
xx = Tensor(x)
np.testing.assert_almost_equal(xx.T.numpy(), x.T)
def test_as_type():
x = TensorWrapper([1, 2, 3], dtype=np.float32)
x = Tensor([1, 2, 3], dtype=np.float32)
y = x.astype(qint8(0.1))
np.testing.assert_almost_equal(get_scale(y.dtype), 0.1)
z = y.astype(qint8(0.2))
......
......@@ -312,7 +312,7 @@ def test_device():
np.testing.assert_almost_equal(y1.numpy(), y2.numpy())
y3 = F.eye(x.shape, dtype="float32", device="xpux")
y4 = F.eye(x.shape, dtype="float32", device=x.device.to_c())
y4 = F.eye(x.shape, dtype="float32", device=x.device)
np.testing.assert_almost_equal(y3.numpy(), y4.numpy())
y5 = F.full((3, 2), 4, device=x.device)
......
......@@ -14,7 +14,7 @@
#include "megbrain/imperative/ops/opr_attr.h"
#include "./op_trait.h"
#include "./proxy_graph_detail.h"
#include "megbrain/imperative/proxy_graph_detail.h"
namespace mgb {
namespace imperative {
......
......@@ -13,7 +13,7 @@
#if MGB_ENABLE_OPR_MM
#include "../op_trait.h"
#include "../proxy_graph_detail.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/opr/mm_handler.h"
#include "megbrain/utils/hash.h"
#endif // MGB_ENABLE_OPR_MM
......
......@@ -13,7 +13,7 @@
#if MGB_ENABLE_OPR_MM
#include "../op_trait.h"
#include "../proxy_graph_detail.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/opr/io_remote.h"
#include "megbrain/opr/mm_handler.h"
#endif // MGB_ENABLE_OPR_MM
......
......@@ -13,7 +13,7 @@
#include "megbrain/serialization/opr_load_dump.h"
#include "../op_trait.h"
#include "../proxy_graph_detail.h"
#include "megbrain/imperative/proxy_graph_detail.h"
namespace mgb {
namespace imperative {
......
......@@ -10,7 +10,7 @@
*/
#include "./proxy_graph.h"
#include "./proxy_graph_detail.h"
#include "megbrain/imperative/proxy_graph_detail.h"
namespace mgb {
namespace imperative {
......
/**
* \file imperative/src/impl/proxy_graph_detail.h
* \file imperative/src/include/megbrain/imperative/proxy_graph_detail.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册