From c294b9d18b4127d8ccdf20e29d0da13a3e41af54 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 18 Dec 2020 18:51:14 +0800 Subject: [PATCH] refactor(mge/tensor): remove old implementation remove core.tensor, raw_tensor,TensorWrapper avoid create tensor with zero-stride numpy ndarray GitOrigin-RevId: 4fe5c4c5baaea3a05616710b67b48d9341111afb --- imperative/python/megengine/core/__init__.py | 1 - .../core/autodiff/builtin_op_utils.py | 2 - .../python/megengine/core/autodiff/grad.py | 233 ------------------ .../megengine/core/ops/builtin/__init__.py | 5 - .../python/megengine/core/tensor/__init__.py | 1 - .../python/megengine/core/tensor/core.py | 8 - .../python/megengine/core/tensor/function.py | 13 - .../python/megengine/core/tensor/tensor.py | 117 --------- .../megengine/core/tensor/tensor_wrapper.py | 96 -------- .../megengine/distributed/functional.py | 3 +- .../python/megengine/functional/math.py | 3 +- .../python/megengine/functional/tensor.py | 2 +- imperative/python/megengine/jit/tracing.py | 2 +- .../python/megengine/optimizer/adadelta.py | 1 - .../python/megengine/optimizer/adagrad.py | 1 - imperative/python/megengine/optimizer/adam.py | 1 - imperative/python/megengine/optimizer/sgd.py | 1 - imperative/python/megengine/tensor.py | 6 +- .../megengine/utils/comp_graph_tools.py | 4 +- .../python/test/integration/test_optimizer.py | 6 +- .../test/unit/core/test_dtype_bfloat16.py | 6 +- .../python/test/unit/core/test_dtype_intbx.py | 6 +- .../python/test/unit/core/test_dtype_quant.py | 4 +- .../test/unit/core/test_imperative_rt.py | 19 +- .../python/test/unit/core/test_indexing_op.py | 2 +- imperative/python/test/unit/core/test_jit.py | 2 - .../test/unit/core/test_megbrain_graph.py | 2 +- .../python/test/unit/core/test_raw_tensor.py | 6 +- 28 files changed, 36 insertions(+), 517 deletions(-) delete mode 100644 imperative/python/megengine/core/tensor/tensor.py diff --git a/imperative/python/megengine/core/__init__.py b/imperative/python/megengine/core/__init__.py index 4fd130bc0..8b2d111ad 100644 --- a/imperative/python/megengine/core/__init__.py +++ b/imperative/python/megengine/core/__init__.py @@ -9,5 +9,4 @@ import os import sys -from .tensor import Tensor from .tensor.megbrain_graph import Graph diff --git a/imperative/python/megengine/core/autodiff/builtin_op_utils.py b/imperative/python/megengine/core/autodiff/builtin_op_utils.py index 8db31ce26..e21f48575 100644 --- a/imperative/python/megengine/core/autodiff/builtin_op_utils.py +++ b/imperative/python/megengine/core/autodiff/builtin_op_utils.py @@ -27,8 +27,6 @@ from ..ops.builtin import ( from ..ops.special import Const from ..tensor.core import apply from ..tensor.function import Function -from ..tensor.tensor import Tensor -from ..tensor.tensor_wrapper import TensorWrapper @functools.singledispatch diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index 0a6244cb1..2ccfbb4fe 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -21,7 +21,6 @@ from ..ops.builtin import Elemwise, OpDef, RemoteSend from ..ops.special import Const from ..tensor.core import TensorBase, TensorWrapperBase, apply from ..tensor.function import Function -from ..tensor.tensor import Tensor, get_context from . import builtin_op_utils """ Some notes: @@ -65,238 +64,6 @@ def get_tensor(x): return get_tensor(x) -class Grad: - def __init__(self, name=None): - - if name is None: - global _grad_count - self._name = "grad_" + str(_grad_count) - _grad_count += 1 - else: - self._name = name - assert self._name not in _grad_manager_dict, "grad manager name duplicated" - _grad_manager_dict[self._name] = self - - # list of all x in partial(y) / partial(x) - self.xs = [] - - # constains weak reference of all OpNode during forward - # OpNode contains inputs, outputs and its backward - # ops forms the computational graph - self.ops = [] - - # save remote_send output for backward - self.remote_send_cache = [] - - self._attached_tensors = weakref.WeakSet() - self._enabled = True - - @property - def name(self): - return self._name - - def wrt(self, *args: Tensor, callback=None): - """ Indicates the loss is a function of the input tensors (usually the net trainable parameters), - i.e., d (loss) / d (Tensor) != 0 - - callback is used to perform additional operations after gradient is obtained in backward. - e.g., copy the grad to a particular place - - A VariableNode will be created and saved in the tensor/s _extra_data slot. - """ - - for x in map(get_tensor, args): - v = self._new_variable(x, callback=callback) - assert self not in x._extra_data - x._extra_data[self] = Tracer(v) - self.xs.append(v) - - return self - - def _new_variable(self, owner, opnode=None, callback=None): - self._attached_tensors.add(owner) - return VariableNode(self, owner, opnode=opnode, callback=callback) - - def _new_opnode(self, inputs, outputs): - inputs = tuple(inputs) - for i in inputs: - assert i is None or isinstance(i, VariableNode) - o = OpNode() - o.inputs = inputs - o.outputs = [] - tracers = [] - for i in outputs: - assert isinstance(i, Tensor) - v = self._new_variable(i, o) - o.outputs.append(weakref.ref(v)) - tracers.append(Tracer(v)) - self.ops.append(weakref.ref(o)) - return o, tracers - - def copy(self): - raise NotImplementedError - - def __enter__(self): - return self - - def _exit(self): - """clear all resources""" - self._enabled = False - for o in self.ops: - o = o() - if o: - o.clear() - for i in self._attached_tensors: - i._extra_data.pop(self, None) - self.remote_send_cache = [] - - def __exit__(self, *_): - self._exit() - - def __call__(self, ys, dys): - """ Defines Grad(). - - :param ys: outputs of forward operators, e.g., the loss tensor - :type ys: list of Tensor or TensorWrapperBase - :param dys: delta of outputs, physically equivalent to sensitivity of outputs to the loss, - e.g., one for the loss itself - :type dys: list of Tensor or TensorWrapperBase - """ - assert self._enabled - self._enabled = False - - def check_wrapper(): - if isinstance(dys, TensorWrapperBase): - return type(dys) - if isinstance(dys, TensorBase): - return - assert isinstance(dys, (tuple, list)) - for i in dys: - if isinstance(i, TensorWrapperBase): - return type(i) - # use Tensor as defualt wrapper - return mge.Tensor - - Wrapper = check_wrapper() - - def aslist(x): - if isinstance(x, (Tensor, TensorWrapperBase)): - x = [x] - else: - x = list(x) - x = [i.__wrapped__ if isinstance(i, TensorWrapperBase) else i for i in x] - for i in x: - assert isinstance(i, Tensor) - return x - - ys = aslist(ys) - dys = aslist(dys) - assert len(ys) == len(dys) - - ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()] - - ys = [y for i, y in enumerate(ys) if i in ids] - dys = [dy for i, dy in enumerate(dys) if i in ids] - - # ys is changed to a list of VariableNode which contains more information - # such as OpNode, callback, etc. - ys = [i._extra_data[self].node for i in ys] - - # NOTE: callback is called only if grad is not None - - # the OpNode sequence in backward - op_seq = [] - - # VariableNode -> (i, j), where i is time stamp in backward, j means jth input - last_written_to = {} - - def schedule(): - reached = set(ys) - # i is the time stamp in backward - i = 0 - for o in self.ops[::-1]: - o = o() - if o is None: - continue - - if not o.has_grad_fn(o, reached): - continue - op_seq.append(o) - for j, v in enumerate(o.inputs): - reached.add(v) - last_written_to[v] = i, j - i += 1 - - schedule() - - # VariableNode -> Tensor - cache = {} - - def initialize(): - for y, dy in zip(ys, dys): - cache[y] = dy - if y not in last_written_to and y.callback: - y.callback(y.owner(), dy) - - initialize() - - # NOTE: None is used to mark a node has been consumed - - for seqno, opnode in enumerate(op_seq): - input_nodes = opnode.inputs - output_nodes = [i() for i in opnode.outputs] - backward = opnode.backward - backward_allow_noinput = opnode.backward_allow_noinput - opnode.clear() - - output_grads = [] - for i in output_nodes: - if i is not None: - if i in cache: - assert cache[i] is not None - output_grads.append(cache[i]) - else: - output_grads.append(None) - # read by backward, mark consumed - cache[i] = None - else: - output_grads.append(None) - if ( - any([grad is not None for grad in output_grads]) - or backward_allow_noinput - ): - input_grads = backward(*output_grads) - else: - input_grads = [None] * len(input_nodes) - - assert len(input_nodes) == len(input_grads) - for i, (v, g) in enumerate(zip(input_nodes, input_grads)): - if v is None: - continue - if v in cache: - assert cache[v] - if g is not None: - cache[v] = add(cache[v], g) - elif g is not None: - cache[v] = g - if last_written_to[v] == (seqno, i): - if v.callback: - v.callback( - v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] - ) - if v.opnode is None: - # won't read by backward, mark consumed - cache[v] = None - - for v in cache.values(): - assert v is None - - self._exit() - - def __del__(self): - self._exit() - - class clearable: __cleared = False diff --git a/imperative/python/megengine/core/ops/builtin/__init__.py b/imperative/python/megengine/core/ops/builtin/__init__.py index 3d67846d2..386970715 100644 --- a/imperative/python/megengine/core/ops/builtin/__init__.py +++ b/imperative/python/megengine/core/ops/builtin/__init__.py @@ -10,11 +10,6 @@ import warnings from typing import Union from ..._imperative_rt import OpDef, ops -from ...tensor.core import OpBase, TensorBase, TensorWrapperBase, apply - -# register OpDef as a "virtual subclass" of OpBase, so any of registered -# apply(OpBase, ...) rules could work well on OpDef -OpBase.register(OpDef) __all__ = ["OpDef"] diff --git a/imperative/python/megengine/core/tensor/__init__.py b/imperative/python/megengine/core/tensor/__init__.py index e008c110f..1207b5d98 100644 --- a/imperative/python/megengine/core/tensor/__init__.py +++ b/imperative/python/megengine/core/tensor/__init__.py @@ -6,4 +6,3 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from .tensor_wrapper import TensorWrapper as Tensor diff --git a/imperative/python/megengine/core/tensor/core.py b/imperative/python/megengine/core/tensor/core.py index 07d6edca3..0c1bcee79 100644 --- a/imperative/python/megengine/core/tensor/core.py +++ b/imperative/python/megengine/core/tensor/core.py @@ -13,17 +13,9 @@ import sys import typing from abc import ABC -from .._imperative_rt.core2 import apply as apply2 from .multipledispatch import Dispatcher -def apply_op(op, *args): - Wrapper = type(args[0]) - args = [arg._tensor for arg in args] - results = apply2(op, *args) - return tuple(map(Wrapper, results)) - - class OpBase(ABC): def __call__(self, *args): return apply(self, *args) diff --git a/imperative/python/megengine/core/tensor/function.py b/imperative/python/megengine/core/tensor/function.py index 87f734b3c..d7b6b8cf7 100644 --- a/imperative/python/megengine/core/tensor/function.py +++ b/imperative/python/megengine/core/tensor/function.py @@ -7,9 +7,6 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from ..ops.builtin import OpDef from .core import TensorBase, TensorWrapperBase, apply -from .raw_tensor import RawTensor -from .tensor import Tensor, push_context -from .tensor_wrapper import TensorWrapper class Function: @@ -155,13 +152,3 @@ def _(op: Function, *args: TensorWrapperBase): t._extra_data[k] = i return tuple(map(Wrapper, outputs)) - - -@apply.register() -def _(op: Function, *args: Tensor): - raise NotImplementedError - - -@apply.register() -def _(op: Function, *args: RawTensor): - raise NotImplementedError diff --git a/imperative/python/megengine/core/tensor/tensor.py b/imperative/python/megengine/core/tensor/tensor.py deleted file mode 100644 index c6cad3f49..000000000 --- a/imperative/python/megengine/core/tensor/tensor.py +++ /dev/null @@ -1,117 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import contextlib -import copy - -from .core import Dispatcher, OpBase, TensorBase, apply - - -class Tensor(TensorBase): - def __init__(self, data: TensorBase): - self._data = data - # _extra_data is set up in Grad.wrt - self._extra_data = {} - self._user_data = {} - - def __getattr__(self, name): - if name in self._user_data: - return self._user_data[name] - raise AttributeError(name) - - def reset(self, other): - assert isinstance(other, __class__) - self.__dict__.clear() - self._data = other.data - self._extra_data = other._extra_data.copy() - self._user_data = other._user_data.copy() - - def copy(self): - other = object.__new__(type(self)) - other.reset(self) - return other - - # tensor interface - - @property - def shape(self): - return self._data.shape - - @property - def dtype(self): - return self._data.dtype - - @property - def device(self): - return self._data.device - - def numpy(self): - return self._data.numpy() - - def _drop(self): - self._data._drop() - - def _swap_in(self): - self._data._swap_in() - - def _swap_out(self): - self._data._swap_out() - - -class ApplyContext: - __slots__ = ("inputs", "outputs", "key") - - def __init__(self): - self.inputs = None - self.outputs = None - self.key = None - - -_context = None - - -@contextlib.contextmanager -def push_context(): - global _context - backup = _context - try: - _context = ApplyContext() - yield _context - finally: - _context = backup - - -def get_context(): - return _context - - -@apply.register() -def tensor_apply(op: OpBase, *args: Tensor): - data = tuple(i._data for i in args) - # type(Tensor._data) is RawTensor - # dispached to apply.add@RawTensor.py if passed Tensor args - outputs = apply(op, *data) - ret = tuple(map(Tensor, outputs)) - - with push_context() as ctx: - ctx.inputs = args - ctx.outputs = ret - for k in set().union(*(i._extra_data for i in args)): - ctx.key = k - data = tuple( - i._extra_data.get(k) if isinstance(i, Tensor) else i for i in args - ) - # data are instances of Tracer - # dispatched to apply.add@grad.py - outputs = apply(op, *data) - if outputs is not None: - assert len(outputs) == len(ret) - for t, i in zip(ret, outputs): - t._extra_data[k] = i - - return ret diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index a197f617e..845792c98 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -19,7 +19,6 @@ from ..ops import builtin from ..ops.builtin import Elemwise, GetVarShape from ..ops.special import Const from . import utils -from .core import OpBase, TensorBase, TensorWrapperBase from .indexing import getitem as _getitem from .indexing import setitem as _setitem from .utils import isscalar @@ -439,98 +438,3 @@ class ArrayMethodMixin(abc.ABC): min = _reduce("MIN") max = _reduce("MAX") mean = _reduce("MEAN") - - -class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): - def __init__(self, data): - self.__wrapped__ = data - - def _reset(self, other): - if not isinstance(other, __class__): - raise TypeError(type(other)) - self.__wrapped__ = other.__wrapped__ - return self - - @property - def dtype(self): - return self.__wrapped__.dtype - - @property - def shape(self): - shape = self.__wrapped__.shape - if shape == () or not use_symbolic_shape(): - return shape - return apply(GetVarShape(), self)[0] - - @property - def device(self): - return self.__wrapped__.device - - def numpy(self): - return self.__wrapped__.numpy() - - def _drop(self): - self.__wrapped__._drop() - - def _swap_in(self): - self.__wrapped__._swap_in() - - def _swap_out(self): - self.__wrapped__._swap_out() - - -class TensorWrapper(ArrayMethodMixin, TensorBase): - def __init__(self, data, dtype=None, device=None, isscalar=False): - self._isscalar = isscalar - if isinstance(data, Tensor): - self._tensor = data - else: - if device is None: - device = CompNode._get_default_device() - self._tensor = Tensor(data, dtype, device) - - def _reset(self, other): - if not isinstance(other, __class__): - raise TypeError(type(other)) - self._tensor = other._tensor - return self - - @property - def dtype(self): - return self._tensor.dtype - - @property - def shape(self): - if self._isscalar: - return () - shape = self._tensor.shape - if shape == () or not use_symbolic_shape(): - return shape - return apply(GetVarShape(), self)[0] - - @property - def device(self): - return self._tensor.device - - def numpy(self): - if self._isscalar: - return self._tensor.numpy().squeeze() - return self._tensor.numpy() - - def _drop(self): - self._tensor._drop() - - def _swap_in(self): - self._tensor._swap_in() - - def _swap_out(self): - self._tensor._swap_out() - - def __repr__(self): - piece = "Tensor(" - with np.printoptions(precision=4, suppress=True): - piece += "{}".format(str(self.numpy())) - if self.dtype != np.float32: - piece += ", dtype={}".format(np.dtype(self.dtype).name) - piece += ", device={}".format(self.device) + ")" - return piece diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index 2e0457b99..1eb353966 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -18,9 +18,8 @@ from ..core.autodiff.grad import ( tracer_apply, ) from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend -from ..core.tensor.tensor import Tensor, tensor_apply from ..device import get_default_device -from ..tensor import tensor +from ..tensor import Tensor from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank __all__ = [ diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 8b3296f77..0316a636e 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -16,7 +16,6 @@ from ..core._imperative_rt.core2 import apply from ..core.ops import builtin from ..core.ops.special import Const from ..core.tensor import utils -from ..core.tensor.core import TensorBase, TensorWrapperBase from ..tensor import Tensor from .elemwise import clip, exp, log, log1p from .tensor import reshape, squeeze @@ -703,7 +702,7 @@ def topk( mode = "VALUE_IDX_SORTED" op = builtin.TopK(mode=mode) - if not isinstance(k, (TensorBase, TensorWrapperBase)): + if not isinstance(k, Tensor): (k,) = Const(k, dtype="int32", device=inp.device)(inp) if len(inp.shape) == 1: diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 69e0b33e0..8cd60b8ea 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -14,7 +14,7 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union import numpy as np from ..core._imperative_rt import CompNode -from ..core._imperative_rt.core2 import Tensor, apply +from ..core._imperative_rt.core2 import apply from ..core._wrap import device as as_device from ..core.ops import builtin from ..core.ops.special import Const diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 5ef0fc501..2b5c48353 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -19,6 +19,7 @@ import weakref import numpy as np from ..core._imperative_rt import GraphProfiler +from ..core._imperative_rt.core2 import Tensor from ..core._imperative_rt.ops import ( CollectiveComm, GaussianRNG, @@ -32,7 +33,6 @@ from ..core.ops.special import Const from ..core.tensor import megbrain_graph as G from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor -from ..core.tensor.tensor import Tensor from .sublinear_memory_config import SublinearMemoryConfig diff --git a/imperative/python/megengine/optimizer/adadelta.py b/imperative/python/megengine/optimizer/adadelta.py index cc090cf4a..24f009bcc 100644 --- a/imperative/python/megengine/optimizer/adadelta.py +++ b/imperative/python/megengine/optimizer/adadelta.py @@ -10,7 +10,6 @@ from typing import Iterable, Union import numpy as np -from ..core.tensor.tensor import Tensor from ..tensor import Parameter, tensor from .optimizer import Optimizer diff --git a/imperative/python/megengine/optimizer/adagrad.py b/imperative/python/megengine/optimizer/adagrad.py index 707e9142f..382a048d9 100644 --- a/imperative/python/megengine/optimizer/adagrad.py +++ b/imperative/python/megengine/optimizer/adagrad.py @@ -10,7 +10,6 @@ from typing import Iterable, Union import numpy as np -from ..core.tensor.tensor import Tensor from ..tensor import Parameter, tensor from .optimizer import Optimizer diff --git a/imperative/python/megengine/optimizer/adam.py b/imperative/python/megengine/optimizer/adam.py index 78d963ad6..2d8b9f454 100644 --- a/imperative/python/megengine/optimizer/adam.py +++ b/imperative/python/megengine/optimizer/adam.py @@ -8,7 +8,6 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from typing import Iterable, Tuple, Union -from ..core.tensor.tensor import Tensor from ..tensor import Parameter, tensor from .optimizer import Optimizer diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index db7875806..5a9e9ac30 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -8,7 +8,6 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from typing import Iterable, Union -from ..core.tensor.tensor import Tensor from ..tensor import Parameter, tensor from .optimizer import Optimizer diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 9f63b6880..04f43e1ba 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -16,8 +16,8 @@ from .core._imperative_rt import CompNode from .core._imperative_rt.core2 import Tensor as _Tensor from .core._imperative_rt.core2 import apply from .core._trace_option import use_symbolic_shape +from .core._wrap import device as as_device from .core.ops.builtin import Copy, GetVarShape -from .core.tensor.raw_tensor import as_device from .core.tensor.tensor_wrapper import ArrayMethodMixin from .device import _valid_device, get_default_device from .utils.deprecation import deprecated @@ -43,6 +43,10 @@ class Tensor(_Tensor, ArrayMethodMixin): if isinstance(data, _Tensor): obj = _Tensor.__new__(cls, data) else: + if isinstance(data, np.ndarray): + if 0 in data.strides: + data = data.squeeze().reshape(data.shape) + obj = _Tensor.__new__(cls, data, dtype, cn) return obj diff --git a/imperative/python/megengine/utils/comp_graph_tools.py b/imperative/python/megengine/utils/comp_graph_tools.py index 47cbb8fcf..7cff34003 100644 --- a/imperative/python/megengine/utils/comp_graph_tools.py +++ b/imperative/python/megengine/utils/comp_graph_tools.py @@ -13,7 +13,7 @@ import numpy from ..core import _imperative_rt from ..core._imperative_rt import OperatorNode, VarNode from ..core.tensor import megbrain_graph as G -from ..core.tensor.raw_tensor import as_raw_tensor +from ..tensor import Tensor __all__ = [ "get_dep_vars", @@ -309,7 +309,7 @@ def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.n cg = new_out_list[0].graph func = cg.compile(new_out_list) for node, value in zip(inp_node_list, inp_data_list): - node.set_value(as_raw_tensor(value)._dev_tensor()) + node.set_value(Tensor(value)._dev_tensor()) func.execute() out_data_list = [o.get_value().numpy() for o in out_node_list] return out_data_list diff --git a/imperative/python/test/integration/test_optimizer.py b/imperative/python/test/integration/test_optimizer.py index 38b6350e4..c9e63da27 100644 --- a/imperative/python/test/integration/test_optimizer.py +++ b/imperative/python/test/integration/test_optimizer.py @@ -13,7 +13,7 @@ import megengine.functional as F from megengine import Parameter, optimizer from megengine.jit import trace from megengine.module import Linear, Module -from megengine.tensor import tensor +from megengine.tensor import Tensor class MLP(Module): @@ -54,7 +54,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): for group in opt.param_groups: group["lr"] += 0.01 check_func.lr += 0.01 - data = tensor(np.random.random(data_shape).astype(np.float32)) + data = Tensor(np.random.random(data_shape).astype(np.float32)) opt.clear_grad() with gm: @@ -98,7 +98,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ori_params[param] = np.copy(param.numpy()) train_func( - tensor(np.random.random(data_shape).astype(np.float32)), opt=opt, gm=gm + Tensor(np.random.random(data_shape).astype(np.float32)), opt=opt, gm=gm ) step += 1 check_func(ori_params, net.parameters(), step) diff --git a/imperative/python/test/unit/core/test_dtype_bfloat16.py b/imperative/python/test/unit/core/test_dtype_bfloat16.py index 768ccee2a..510d179ef 100644 --- a/imperative/python/test/unit/core/test_dtype_bfloat16.py +++ b/imperative/python/test/unit/core/test_dtype_bfloat16.py @@ -11,7 +11,7 @@ import pickle import numpy as np from megengine.core.tensor.dtype import bfloat16 -from megengine.core.tensor.raw_tensor import as_raw_tensor +from megengine.tensor import Tensor def test_define(): @@ -42,14 +42,14 @@ def test_cast(): def test_shared_nd(): data = np.array([-3.4, 1.394683, 2.323497, -7.439948, -5.2397], dtype=bfloat16) - snd = as_raw_tensor(data, dtype=bfloat16, device="xpux") + snd = Tensor(data, dtype=bfloat16, device="xpux") assert snd.numpy().dtype == bfloat16 np.testing.assert_allclose( snd.numpy(), [-3.40625, 1.398438, 2.328125, -7.4375, -5.25], atol=1e-6 ) data = np.array([-9.34964, -8.342, 9.4385, 0.18746, 1.48], dtype=bfloat16) - snd = as_raw_tensor(data, dtype=bfloat16, device="xpux") + snd = Tensor(data, dtype=bfloat16, device="xpux") np.testing.assert_allclose( snd.numpy(), [-9.375, -8.3125, 9.4375, 0.1875, 1.476562], atol=1e-6 ) diff --git a/imperative/python/test/unit/core/test_dtype_intbx.py b/imperative/python/test/unit/core/test_dtype_intbx.py index 08cbd3557..1b3449e27 100644 --- a/imperative/python/test/unit/core/test_dtype_intbx.py +++ b/imperative/python/test/unit/core/test_dtype_intbx.py @@ -12,7 +12,7 @@ import numpy as np import pytest from megengine.core.tensor.dtype import intb1, intb2, intb4 -from megengine.core.tensor.raw_tensor import as_raw_tensor +from megengine.tensor import Tensor def bit_define_test(bit, low_bit_type): @@ -78,11 +78,11 @@ def _shared_nd_test(bit, low_bit_type): min_value = 1 - (1 << bit) data = np.arange(min_value, max_value + 2, 2, dtype=low_bit_type) - snd = as_raw_tensor(data, dtype=low_bit_type, device="xpux") + snd = Tensor(data, dtype=low_bit_type, device="xpux") np.testing.assert_allclose(snd.numpy(), range(min_value, max_value + 2, 2)) data = np.arange(min_value, max_value + 2, 4, dtype=low_bit_type) - snd = as_raw_tensor(data, dtype=low_bit_type, device="xpux") + snd = Tensor(data, dtype=low_bit_type, device="xpux") np.testing.assert_allclose(snd.numpy(), range(min_value, max_value + 2, 4)) diff --git a/imperative/python/test/unit/core/test_dtype_quant.py b/imperative/python/test/unit/core/test_dtype_quant.py index 36fdfef72..0ddf01e9d 100644 --- a/imperative/python/test/unit/core/test_dtype_quant.py +++ b/imperative/python/test/unit/core/test_dtype_quant.py @@ -32,8 +32,8 @@ from megengine.core.tensor.dtype import ( quint4, quint8, ) -from megengine.core.tensor.raw_tensor import as_raw_tensor from megengine.distributed.helper import get_device_count_by_fork +from megengine.tensor import Tensor def test_dtype_quint8(): @@ -71,7 +71,7 @@ def _get_compiled_result(inp, dtype, shape, device, calc_func=None): temp_rst = calc_func(inp_node.outputs[0]) oup_node = G.OutputNode(temp_rst) func = graph.compile(oup_node.outputs[0]) - inp_node.set_value(as_raw_tensor(inp, dtype=dtype, device=device)._dev_tensor()) + inp_node.set_value(Tensor(inp, dtype=dtype, device=device)._dev_tensor()) func.execute() return oup_node.get_value().numpy() diff --git a/imperative/python/test/unit/core/test_imperative_rt.py b/imperative/python/test/unit/core/test_imperative_rt.py index 9b8764aa1..b68287b70 100644 --- a/imperative/python/test/unit/core/test_imperative_rt.py +++ b/imperative/python/test/unit/core/test_imperative_rt.py @@ -9,15 +9,15 @@ import numpy as np import pytest -import megengine.core.tensor.raw_tensor -from megengine.core.tensor.core import apply +import megengine +from megengine.core._imperative_rt.core2 import apply +from megengine.tensor import Tensor def elemwise(*args, mode): - from megengine.core._imperative_rt.imperative import apply_op from megengine.core.ops.builtin import Elemwise - return apply_op(Elemwise(mode), args) + return apply(Elemwise(mode), *args) def test_basic_interface(): @@ -44,11 +44,11 @@ def test_simple_arith(): from megengine.core.ops.builtin import Elemwise x = np.random.rand(10).astype("float32") - xx = megengine.core._imperative_rt.put(x) + xx = Tensor(x) (yy,) = elemwise(xx, xx, mode=Elemwise.Mode.MUL) - np.testing.assert_allclose(x * x, megengine.core._imperative_rt.get_value(yy)) - megengine.core._imperative_rt.delete(xx) - megengine.core._imperative_rt.delete(yy) + np.testing.assert_allclose(x * x, yy.numpy()) + del xx + del yy def test_tensor_on_device(): @@ -62,10 +62,9 @@ def test_tensor_on_device(): def test_raw_tensor(): from megengine.core.ops.builtin import Elemwise - from megengine.core.tensor.raw_tensor import as_raw_tensor x = np.random.rand(10).astype("float32") - xx = as_raw_tensor(x) + xx = Tensor(x) (yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx) np.testing.assert_allclose(x * x, yy.numpy()) (yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx) diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index a1c5284c4..78df3f177 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -12,10 +12,10 @@ import numpy as np import pytest import megengine -import megengine.tensor as Tensor from megengine.core._imperative_rt.core2 import apply from megengine.core._trace_option import use_symbolic_shape from megengine.core.ops import builtin +from megengine.tensor import Tensor def cvt_to_shape_desc(val, inpvar, config=None): diff --git a/imperative/python/test/unit/core/test_jit.py b/imperative/python/test/unit/core/test_jit.py index 4bc9c2f17..088f56d6c 100644 --- a/imperative/python/test/unit/core/test_jit.py +++ b/imperative/python/test/unit/core/test_jit.py @@ -8,8 +8,6 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import pytest -from megengine.core import Tensor - # from megengine.core.interpreter.hints import function diff --git a/imperative/python/test/unit/core/test_megbrain_graph.py b/imperative/python/test/unit/core/test_megbrain_graph.py index 2b3f4699d..ee651dd1c 100644 --- a/imperative/python/test/unit/core/test_megbrain_graph.py +++ b/imperative/python/test/unit/core/test_megbrain_graph.py @@ -11,8 +11,8 @@ from concurrent.futures import Future import numpy as np import megengine.functional as F -import megengine.tensor as Tensor from megengine.core.tensor import megbrain_graph as mgb_graph +from megengine.tensor import Tensor def test_io(): diff --git a/imperative/python/test/unit/core/test_raw_tensor.py b/imperative/python/test/unit/core/test_raw_tensor.py index 0f4ae7ec5..5f9aa8875 100644 --- a/imperative/python/test/unit/core/test_raw_tensor.py +++ b/imperative/python/test/unit/core/test_raw_tensor.py @@ -9,12 +9,12 @@ import numpy as np import megengine.functional as F -from megengine.core.tensor.raw_tensor import as_raw_tensor +from megengine.tensor import Tensor def test_as_raw_tensor(): x = np.arange(6, dtype="float32").reshape(2, 3) - xx = as_raw_tensor(x, device="xpux") + xx = Tensor(x, device="xpux") yy = F.add(xx, 1).numpy() assert xx.dtype == np.float32 assert xx.device == "xpux" @@ -23,7 +23,7 @@ def test_as_raw_tensor(): def test_as_raw_tensor_from_int64(): x = np.arange(6, dtype="int64").reshape(2, 3) - xx = as_raw_tensor(x, dtype="float32", device="xpux") + xx = Tensor(x, dtype="float32", device="xpux") yy = F.add(xx, 1).numpy() assert xx.dtype == np.float32 assert xx.device == "xpux" -- GitLab