提交 c294b9d1 编写于 作者: M Megvii Engine Team

refactor(mge/tensor): remove old implementation

remove core.tensor, raw_tensor,TensorWrapper
avoid create tensor with zero-stride numpy ndarray

GitOrigin-RevId: 4fe5c4c5baaea3a05616710b67b48d9341111afb
上级 15e8e7df
......@@ -9,5 +9,4 @@
import os
import sys
from .tensor import Tensor
from .tensor.megbrain_graph import Graph
......@@ -27,8 +27,6 @@ from ..ops.builtin import (
from ..ops.special import Const
from ..tensor.core import apply
from ..tensor.function import Function
from ..tensor.tensor import Tensor
from ..tensor.tensor_wrapper import TensorWrapper
@functools.singledispatch
......
......@@ -21,7 +21,6 @@ from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const
from ..tensor.core import TensorBase, TensorWrapperBase, apply
from ..tensor.function import Function
from ..tensor.tensor import Tensor, get_context
from . import builtin_op_utils
""" Some notes:
......@@ -65,238 +64,6 @@ def get_tensor(x):
return get_tensor(x)
class Grad:
def __init__(self, name=None):
if name is None:
global _grad_count
self._name = "grad_" + str(_grad_count)
_grad_count += 1
else:
self._name = name
assert self._name not in _grad_manager_dict, "grad manager name duplicated"
_grad_manager_dict[self._name] = self
# list of all x in partial(y) / partial(x)
self.xs = []
# constains weak reference of all OpNode during forward
# OpNode contains inputs, outputs and its backward
# ops forms the computational graph
self.ops = []
# save remote_send output for backward
self.remote_send_cache = []
self._attached_tensors = weakref.WeakSet()
self._enabled = True
@property
def name(self):
return self._name
def wrt(self, *args: Tensor, callback=None):
""" Indicates the loss is a function of the input tensors (usually the net trainable parameters),
i.e., d (loss) / d (Tensor) != 0
callback is used to perform additional operations after gradient is obtained in backward.
e.g., copy the grad to a particular place
A VariableNode will be created and saved in the tensor/s _extra_data slot.
"""
for x in map(get_tensor, args):
v = self._new_variable(x, callback=callback)
assert self not in x._extra_data
x._extra_data[self] = Tracer(v)
self.xs.append(v)
return self
def _new_variable(self, owner, opnode=None, callback=None):
self._attached_tensors.add(owner)
return VariableNode(self, owner, opnode=opnode, callback=callback)
def _new_opnode(self, inputs, outputs):
inputs = tuple(inputs)
for i in inputs:
assert i is None or isinstance(i, VariableNode)
o = OpNode()
o.inputs = inputs
o.outputs = []
tracers = []
for i in outputs:
assert isinstance(i, Tensor)
v = self._new_variable(i, o)
o.outputs.append(weakref.ref(v))
tracers.append(Tracer(v))
self.ops.append(weakref.ref(o))
return o, tracers
def copy(self):
raise NotImplementedError
def __enter__(self):
return self
def _exit(self):
"""clear all resources"""
self._enabled = False
for o in self.ops:
o = o()
if o:
o.clear()
for i in self._attached_tensors:
i._extra_data.pop(self, None)
self.remote_send_cache = []
def __exit__(self, *_):
self._exit()
def __call__(self, ys, dys):
""" Defines Grad().
:param ys: outputs of forward operators, e.g., the loss tensor
:type ys: list of Tensor or TensorWrapperBase
:param dys: delta of outputs, physically equivalent to sensitivity of outputs to the loss,
e.g., one for the loss itself
:type dys: list of Tensor or TensorWrapperBase
"""
assert self._enabled
self._enabled = False
def check_wrapper():
if isinstance(dys, TensorWrapperBase):
return type(dys)
if isinstance(dys, TensorBase):
return
assert isinstance(dys, (tuple, list))
for i in dys:
if isinstance(i, TensorWrapperBase):
return type(i)
# use Tensor as defualt wrapper
return mge.Tensor
Wrapper = check_wrapper()
def aslist(x):
if isinstance(x, (Tensor, TensorWrapperBase)):
x = [x]
else:
x = list(x)
x = [i.__wrapped__ if isinstance(i, TensorWrapperBase) else i for i in x]
for i in x:
assert isinstance(i, Tensor)
return x
ys = aslist(ys)
dys = aslist(dys)
assert len(ys) == len(dys)
ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()]
ys = [y for i, y in enumerate(ys) if i in ids]
dys = [dy for i, dy in enumerate(dys) if i in ids]
# ys is changed to a list of VariableNode which contains more information
# such as OpNode, callback, etc.
ys = [i._extra_data[self].node for i in ys]
# NOTE: callback is called only if grad is not None
# the OpNode sequence in backward
op_seq = []
# VariableNode -> (i, j), where i is time stamp in backward, j means jth input
last_written_to = {}
def schedule():
reached = set(ys)
# i is the time stamp in backward
i = 0
for o in self.ops[::-1]:
o = o()
if o is None:
continue
if not o.has_grad_fn(o, reached):
continue
op_seq.append(o)
for j, v in enumerate(o.inputs):
reached.add(v)
last_written_to[v] = i, j
i += 1
schedule()
# VariableNode -> Tensor
cache = {}
def initialize():
for y, dy in zip(ys, dys):
cache[y] = dy
if y not in last_written_to and y.callback:
y.callback(y.owner(), dy)
initialize()
# NOTE: None is used to mark a node has been consumed
for seqno, opnode in enumerate(op_seq):
input_nodes = opnode.inputs
output_nodes = [i() for i in opnode.outputs]
backward = opnode.backward
backward_allow_noinput = opnode.backward_allow_noinput
opnode.clear()
output_grads = []
for i in output_nodes:
if i is not None:
if i in cache:
assert cache[i] is not None
output_grads.append(cache[i])
else:
output_grads.append(None)
# read by backward, mark consumed
cache[i] = None
else:
output_grads.append(None)
if (
any([grad is not None for grad in output_grads])
or backward_allow_noinput
):
input_grads = backward(*output_grads)
else:
input_grads = [None] * len(input_nodes)
assert len(input_nodes) == len(input_grads)
for i, (v, g) in enumerate(zip(input_nodes, input_grads)):
if v is None:
continue
if v in cache:
assert cache[v]
if g is not None:
cache[v] = add(cache[v], g)
elif g is not None:
cache[v] = g
if last_written_to[v] == (seqno, i):
if v.callback:
v.callback(
v.owner(), Wrapper(cache[v]) if Wrapper else cache[v]
)
if v.opnode is None:
# won't read by backward, mark consumed
cache[v] = None
for v in cache.values():
assert v is None
self._exit()
def __del__(self):
self._exit()
class clearable:
__cleared = False
......
......@@ -10,11 +10,6 @@ import warnings
from typing import Union
from ..._imperative_rt import OpDef, ops
from ...tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
# register OpDef as a "virtual subclass" of OpBase, so any of registered
# apply(OpBase, ...) rules could work well on OpDef
OpBase.register(OpDef)
__all__ = ["OpDef"]
......
......@@ -6,4 +6,3 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .tensor_wrapper import TensorWrapper as Tensor
......@@ -13,17 +13,9 @@ import sys
import typing
from abc import ABC
from .._imperative_rt.core2 import apply as apply2
from .multipledispatch import Dispatcher
def apply_op(op, *args):
Wrapper = type(args[0])
args = [arg._tensor for arg in args]
results = apply2(op, *args)
return tuple(map(Wrapper, results))
class OpBase(ABC):
def __call__(self, *args):
return apply(self, *args)
......
......@@ -7,9 +7,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..ops.builtin import OpDef
from .core import TensorBase, TensorWrapperBase, apply
from .raw_tensor import RawTensor
from .tensor import Tensor, push_context
from .tensor_wrapper import TensorWrapper
class Function:
......@@ -155,13 +152,3 @@ def _(op: Function, *args: TensorWrapperBase):
t._extra_data[k] = i
return tuple(map(Wrapper, outputs))
@apply.register()
def _(op: Function, *args: Tensor):
raise NotImplementedError
@apply.register()
def _(op: Function, *args: RawTensor):
raise NotImplementedError
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import contextlib
import copy
from .core import Dispatcher, OpBase, TensorBase, apply
class Tensor(TensorBase):
def __init__(self, data: TensorBase):
self._data = data
# _extra_data is set up in Grad.wrt
self._extra_data = {}
self._user_data = {}
def __getattr__(self, name):
if name in self._user_data:
return self._user_data[name]
raise AttributeError(name)
def reset(self, other):
assert isinstance(other, __class__)
self.__dict__.clear()
self._data = other.data
self._extra_data = other._extra_data.copy()
self._user_data = other._user_data.copy()
def copy(self):
other = object.__new__(type(self))
other.reset(self)
return other
# tensor interface
@property
def shape(self):
return self._data.shape
@property
def dtype(self):
return self._data.dtype
@property
def device(self):
return self._data.device
def numpy(self):
return self._data.numpy()
def _drop(self):
self._data._drop()
def _swap_in(self):
self._data._swap_in()
def _swap_out(self):
self._data._swap_out()
class ApplyContext:
__slots__ = ("inputs", "outputs", "key")
def __init__(self):
self.inputs = None
self.outputs = None
self.key = None
_context = None
@contextlib.contextmanager
def push_context():
global _context
backup = _context
try:
_context = ApplyContext()
yield _context
finally:
_context = backup
def get_context():
return _context
@apply.register()
def tensor_apply(op: OpBase, *args: Tensor):
data = tuple(i._data for i in args)
# type(Tensor._data) is RawTensor
# dispached to apply.add@RawTensor.py if passed Tensor args
outputs = apply(op, *data)
ret = tuple(map(Tensor, outputs))
with push_context() as ctx:
ctx.inputs = args
ctx.outputs = ret
for k in set().union(*(i._extra_data for i in args)):
ctx.key = k
data = tuple(
i._extra_data.get(k) if isinstance(i, Tensor) else i for i in args
)
# data are instances of Tracer
# dispatched to apply.add@grad.py
outputs = apply(op, *data)
if outputs is not None:
assert len(outputs) == len(ret)
for t, i in zip(ret, outputs):
t._extra_data[k] = i
return ret
......@@ -19,7 +19,6 @@ from ..ops import builtin
from ..ops.builtin import Elemwise, GetVarShape
from ..ops.special import Const
from . import utils
from .core import OpBase, TensorBase, TensorWrapperBase
from .indexing import getitem as _getitem
from .indexing import setitem as _setitem
from .utils import isscalar
......@@ -439,98 +438,3 @@ class ArrayMethodMixin(abc.ABC):
min = _reduce("MIN")
max = _reduce("MAX")
mean = _reduce("MEAN")
class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase):
def __init__(self, data):
self.__wrapped__ = data
def _reset(self, other):
if not isinstance(other, __class__):
raise TypeError(type(other))
self.__wrapped__ = other.__wrapped__
return self
@property
def dtype(self):
return self.__wrapped__.dtype
@property
def shape(self):
shape = self.__wrapped__.shape
if shape == () or not use_symbolic_shape():
return shape
return apply(GetVarShape(), self)[0]
@property
def device(self):
return self.__wrapped__.device
def numpy(self):
return self.__wrapped__.numpy()
def _drop(self):
self.__wrapped__._drop()
def _swap_in(self):
self.__wrapped__._swap_in()
def _swap_out(self):
self.__wrapped__._swap_out()
class TensorWrapper(ArrayMethodMixin, TensorBase):
def __init__(self, data, dtype=None, device=None, isscalar=False):
self._isscalar = isscalar
if isinstance(data, Tensor):
self._tensor = data
else:
if device is None:
device = CompNode._get_default_device()
self._tensor = Tensor(data, dtype, device)
def _reset(self, other):
if not isinstance(other, __class__):
raise TypeError(type(other))
self._tensor = other._tensor
return self
@property
def dtype(self):
return self._tensor.dtype
@property
def shape(self):
if self._isscalar:
return ()
shape = self._tensor.shape
if shape == () or not use_symbolic_shape():
return shape
return apply(GetVarShape(), self)[0]
@property
def device(self):
return self._tensor.device
def numpy(self):
if self._isscalar:
return self._tensor.numpy().squeeze()
return self._tensor.numpy()
def _drop(self):
self._tensor._drop()
def _swap_in(self):
self._tensor._swap_in()
def _swap_out(self):
self._tensor._swap_out()
def __repr__(self):
piece = "Tensor("
with np.printoptions(precision=4, suppress=True):
piece += "{}".format(str(self.numpy()))
if self.dtype != np.float32:
piece += ", dtype={}".format(np.dtype(self.dtype).name)
piece += ", device={}".format(self.device) + ")"
return piece
......@@ -18,9 +18,8 @@ from ..core.autodiff.grad import (
tracer_apply,
)
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.tensor import Tensor, tensor_apply
from ..device import get_default_device
from ..tensor import tensor
from ..tensor import Tensor
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank
__all__ = [
......
......@@ -16,7 +16,6 @@ from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..core.ops.special import Const
from ..core.tensor import utils
from ..core.tensor.core import TensorBase, TensorWrapperBase
from ..tensor import Tensor
from .elemwise import clip, exp, log, log1p
from .tensor import reshape, squeeze
......@@ -703,7 +702,7 @@ def topk(
mode = "VALUE_IDX_SORTED"
op = builtin.TopK(mode=mode)
if not isinstance(k, (TensorBase, TensorWrapperBase)):
if not isinstance(k, Tensor):
(k,) = Const(k, dtype="int32", device=inp.device)(inp)
if len(inp.shape) == 1:
......
......@@ -14,7 +14,7 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import Tensor, apply
from ..core._imperative_rt.core2 import apply
from ..core._wrap import device as as_device
from ..core.ops import builtin
from ..core.ops.special import Const
......
......@@ -19,6 +19,7 @@ import weakref
import numpy as np
from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.core2 import Tensor
from ..core._imperative_rt.ops import (
CollectiveComm,
GaussianRNG,
......@@ -32,7 +33,6 @@ from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor
from ..core.tensor.tensor import Tensor
from .sublinear_memory_config import SublinearMemoryConfig
......
......@@ -10,7 +10,6 @@ from typing import Iterable, Union
import numpy as np
from ..core.tensor.tensor import Tensor
from ..tensor import Parameter, tensor
from .optimizer import Optimizer
......
......@@ -10,7 +10,6 @@ from typing import Iterable, Union
import numpy as np
from ..core.tensor.tensor import Tensor
from ..tensor import Parameter, tensor
from .optimizer import Optimizer
......
......@@ -8,7 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable, Tuple, Union
from ..core.tensor.tensor import Tensor
from ..tensor import Parameter, tensor
from .optimizer import Optimizer
......
......@@ -8,7 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable, Union
from ..core.tensor.tensor import Tensor
from ..tensor import Parameter, tensor
from .optimizer import Optimizer
......
......@@ -16,8 +16,8 @@ from .core._imperative_rt import CompNode
from .core._imperative_rt.core2 import Tensor as _Tensor
from .core._imperative_rt.core2 import apply
from .core._trace_option import use_symbolic_shape
from .core._wrap import device as as_device
from .core.ops.builtin import Copy, GetVarShape
from .core.tensor.raw_tensor import as_device
from .core.tensor.tensor_wrapper import ArrayMethodMixin
from .device import _valid_device, get_default_device
from .utils.deprecation import deprecated
......@@ -43,6 +43,10 @@ class Tensor(_Tensor, ArrayMethodMixin):
if isinstance(data, _Tensor):
obj = _Tensor.__new__(cls, data)
else:
if isinstance(data, np.ndarray):
if 0 in data.strides:
data = data.squeeze().reshape(data.shape)
obj = _Tensor.__new__(cls, data, dtype, cn)
return obj
......
......@@ -13,7 +13,7 @@ import numpy
from ..core import _imperative_rt
from ..core._imperative_rt import OperatorNode, VarNode
from ..core.tensor import megbrain_graph as G
from ..core.tensor.raw_tensor import as_raw_tensor
from ..tensor import Tensor
__all__ = [
"get_dep_vars",
......@@ -309,7 +309,7 @@ def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.n
cg = new_out_list[0].graph
func = cg.compile(new_out_list)
for node, value in zip(inp_node_list, inp_data_list):
node.set_value(as_raw_tensor(value)._dev_tensor())
node.set_value(Tensor(value)._dev_tensor())
func.execute()
out_data_list = [o.get_value().numpy() for o in out_node_list]
return out_data_list
......@@ -13,7 +13,7 @@ import megengine.functional as F
from megengine import Parameter, optimizer
from megengine.jit import trace
from megengine.module import Linear, Module
from megengine.tensor import tensor
from megengine.tensor import Tensor
class MLP(Module):
......@@ -54,7 +54,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
for group in opt.param_groups:
group["lr"] += 0.01
check_func.lr += 0.01
data = tensor(np.random.random(data_shape).astype(np.float32))
data = Tensor(np.random.random(data_shape).astype(np.float32))
opt.clear_grad()
with gm:
......@@ -98,7 +98,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
ori_params[param] = np.copy(param.numpy())
train_func(
tensor(np.random.random(data_shape).astype(np.float32)), opt=opt, gm=gm
Tensor(np.random.random(data_shape).astype(np.float32)), opt=opt, gm=gm
)
step += 1
check_func(ori_params, net.parameters(), step)
......
......@@ -11,7 +11,7 @@ import pickle
import numpy as np
from megengine.core.tensor.dtype import bfloat16
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.tensor import Tensor
def test_define():
......@@ -42,14 +42,14 @@ def test_cast():
def test_shared_nd():
data = np.array([-3.4, 1.394683, 2.323497, -7.439948, -5.2397], dtype=bfloat16)
snd = as_raw_tensor(data, dtype=bfloat16, device="xpux")
snd = Tensor(data, dtype=bfloat16, device="xpux")
assert snd.numpy().dtype == bfloat16
np.testing.assert_allclose(
snd.numpy(), [-3.40625, 1.398438, 2.328125, -7.4375, -5.25], atol=1e-6
)
data = np.array([-9.34964, -8.342, 9.4385, 0.18746, 1.48], dtype=bfloat16)
snd = as_raw_tensor(data, dtype=bfloat16, device="xpux")
snd = Tensor(data, dtype=bfloat16, device="xpux")
np.testing.assert_allclose(
snd.numpy(), [-9.375, -8.3125, 9.4375, 0.1875, 1.476562], atol=1e-6
)
......
......@@ -12,7 +12,7 @@ import numpy as np
import pytest
from megengine.core.tensor.dtype import intb1, intb2, intb4
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.tensor import Tensor
def bit_define_test(bit, low_bit_type):
......@@ -78,11 +78,11 @@ def _shared_nd_test(bit, low_bit_type):
min_value = 1 - (1 << bit)
data = np.arange(min_value, max_value + 2, 2, dtype=low_bit_type)
snd = as_raw_tensor(data, dtype=low_bit_type, device="xpux")
snd = Tensor(data, dtype=low_bit_type, device="xpux")
np.testing.assert_allclose(snd.numpy(), range(min_value, max_value + 2, 2))
data = np.arange(min_value, max_value + 2, 4, dtype=low_bit_type)
snd = as_raw_tensor(data, dtype=low_bit_type, device="xpux")
snd = Tensor(data, dtype=low_bit_type, device="xpux")
np.testing.assert_allclose(snd.numpy(), range(min_value, max_value + 2, 4))
......
......@@ -32,8 +32,8 @@ from megengine.core.tensor.dtype import (
quint4,
quint8,
)
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.distributed.helper import get_device_count_by_fork
from megengine.tensor import Tensor
def test_dtype_quint8():
......@@ -71,7 +71,7 @@ def _get_compiled_result(inp, dtype, shape, device, calc_func=None):
temp_rst = calc_func(inp_node.outputs[0])
oup_node = G.OutputNode(temp_rst)
func = graph.compile(oup_node.outputs[0])
inp_node.set_value(as_raw_tensor(inp, dtype=dtype, device=device)._dev_tensor())
inp_node.set_value(Tensor(inp, dtype=dtype, device=device)._dev_tensor())
func.execute()
return oup_node.get_value().numpy()
......
......@@ -9,15 +9,15 @@
import numpy as np
import pytest
import megengine.core.tensor.raw_tensor
from megengine.core.tensor.core import apply
import megengine
from megengine.core._imperative_rt.core2 import apply
from megengine.tensor import Tensor
def elemwise(*args, mode):
from megengine.core._imperative_rt.imperative import apply_op
from megengine.core.ops.builtin import Elemwise
return apply_op(Elemwise(mode), args)
return apply(Elemwise(mode), *args)
def test_basic_interface():
......@@ -44,11 +44,11 @@ def test_simple_arith():
from megengine.core.ops.builtin import Elemwise
x = np.random.rand(10).astype("float32")
xx = megengine.core._imperative_rt.put(x)
xx = Tensor(x)
(yy,) = elemwise(xx, xx, mode=Elemwise.Mode.MUL)
np.testing.assert_allclose(x * x, megengine.core._imperative_rt.get_value(yy))
megengine.core._imperative_rt.delete(xx)
megengine.core._imperative_rt.delete(yy)
np.testing.assert_allclose(x * x, yy.numpy())
del xx
del yy
def test_tensor_on_device():
......@@ -62,10 +62,9 @@ def test_tensor_on_device():
def test_raw_tensor():
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.raw_tensor import as_raw_tensor
x = np.random.rand(10).astype("float32")
xx = as_raw_tensor(x)
xx = Tensor(x)
(yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx)
np.testing.assert_allclose(x * x, yy.numpy())
(yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx)
......
......@@ -12,10 +12,10 @@ import numpy as np
import pytest
import megengine
import megengine.tensor as Tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops import builtin
from megengine.tensor import Tensor
def cvt_to_shape_desc(val, inpvar, config=None):
......
......@@ -8,8 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pytest
from megengine.core import Tensor
# from megengine.core.interpreter.hints import function
......
......@@ -11,8 +11,8 @@ from concurrent.futures import Future
import numpy as np
import megengine.functional as F
import megengine.tensor as Tensor
from megengine.core.tensor import megbrain_graph as mgb_graph
from megengine.tensor import Tensor
def test_io():
......
......@@ -9,12 +9,12 @@
import numpy as np
import megengine.functional as F
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.tensor import Tensor
def test_as_raw_tensor():
x = np.arange(6, dtype="float32").reshape(2, 3)
xx = as_raw_tensor(x, device="xpux")
xx = Tensor(x, device="xpux")
yy = F.add(xx, 1).numpy()
assert xx.dtype == np.float32
assert xx.device == "xpux"
......@@ -23,7 +23,7 @@ def test_as_raw_tensor():
def test_as_raw_tensor_from_int64():
x = np.arange(6, dtype="int64").reshape(2, 3)
xx = as_raw_tensor(x, dtype="float32", device="xpux")
xx = Tensor(x, dtype="float32", device="xpux")
yy = F.add(xx, 1).numpy()
assert xx.dtype == np.float32
assert xx.device == "xpux"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册