tensor.py 8.0 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
from typing import Union
10

11 12 13 14 15 16
import numpy as np

from .core._imperative_rt import CompNode
from .core._imperative_rt.core2 import Tensor as _Tensor
from .core._imperative_rt.core2 import apply
from .core._trace_option import use_symbolic_shape
17
from .core._wrap import device as as_device
18
from .core.ops.builtin import Copy, GetVarShape
19
from .core.tensor.array_method import ArrayMethodMixin
20
from .device import _valid_device, get_default_device
21
from .logger import get_logger
M
Megvii Engine Team 已提交
22
from .utils.deprecation import deprecated
23
from .utils.naming import AutoNaming
24

25 26
logger = get_logger(__name__)

27

28
class Tensor(_Tensor, ArrayMethodMixin):
29 30
    r"""
    A tensor object represents a multidimensional, homogeneous array of fixed-size items.
31 32 33 34 35 36 37

    :param data: The value of returned Tensor.
    :param dtype: The dtype of returned Tensor. Uses data's dtype if not specified.
    :param device: The desired device of returned Tensor. Uses :func:`get_default_device` if not specified.
    :param is_const: Whether make it a ``ImutableTensor`` in tracing mode.
    :param no_cache: Whether cache it for memory sharing.
    :param name: Used to improve convenience in graph operation on dumped model.
38 39
    """

M
Megvii Engine Team 已提交
40
    grad = None
41
    dmap_callback = None
42
    _qparams = None
43

44
    def __new__(
45 46 47 48 49 50 51
        cls,
        data: Union["Tensor", np.ndarray, list, "scalar"] = None,
        dtype: np.dtype = None,
        device: str = None,
        is_const: bool = False,
        no_cache: bool = False,
        name: str = None,
52
    ):
53 54
        if data is None:
            data = []
55
        if device is None:
56 57 58 59 60 61 62
            cn = get_default_device()
        elif isinstance(device, str):
            if cls.dmap_callback is not None:
                cn = CompNode(cls.dmap_callback(device))
            else:
                cn = CompNode(device)
        else:
63 64 65 66
            if isinstance(device, CompNode):
                cn = device
            else:
                cn = device._cn
67 68 69 70

        if isinstance(data, _Tensor):
            obj = _Tensor.__new__(cls, data)
        else:
71 72 73
            if isinstance(data, np.ndarray):
                if 0 in data.strides:
                    data = data.squeeze().reshape(data.shape)
74
            obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache, name)
75 76
        return obj

77 78 79 80 81 82 83 84 85 86 87
    def __init__(
        self,
        data: Union["Tensor", np.ndarray, list, "scalar"],
        dtype: np.dtype = None,
        device: str = None,
        is_const: bool = False,
        no_cache: bool = False,
        name: str = None,
    ):
        pass

88
    @property
89
    def shape(self) -> Union[tuple, "Tensor"]:
90 91 92 93
        r"""
        Returns a :class:`tuple` or a :class:`~.Tensor` represents tensor dimensions.

        .. note::
94

95 96 97 98 99 100 101 102 103
           The shape of a tensor was usually represented by a :class:`tuple`.
           But if a tensor was treated as symbolic placeholder with tracing, 
           it's shape could also be a :class:`~.Tensor`. See :class:`~.trace` for more details.

        The shape property is usually used to get the current shape of a tensor, 
        but may also be used to reshape the tensor in-place by assigning a tuple of tensor dimensions to it. 
        As with :func:`~.reshape`, one of the new shape dimensions can be -1, 
        in which case its value is inferred from the size of the tensor and the remaining dimensions.
        """
104 105 106 107 108 109 110 111 112
        shape = super().shape
        if shape == () or not use_symbolic_shape():
            return shape
        return apply(GetVarShape(), self)[0]

    @property
    def _tuple_shape(self):
        return super().shape

113 114 115 116 117 118 119
    @property
    def device(self) -> CompNode:
        r"""
        Returns a string represents the device a :class:`~.Tensor` storaged on. 
        """
        return super().device

120 121
    @property
    def dtype(self) -> np.dtype:
122 123 124
        r"""
        Returns a :class:`numpy.dtype` object represents the data type of a :class:`~.Tensor`.
        """
125 126
        return super().dtype

127
    @property
128
    def qparams(self):
129 130 131
        r"""
        Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`.
        """
132 133 134 135 136
        from .quantization.utils import create_qparams  # pylint: disable=all

        if self._qparams is None:
            self._qparams = create_qparams()
        return self._qparams
137

138
    def numpy(self) -> np.ndarray:
139 140 141
        r"""
        Returns self :class:`~.Tensor` as a :class:`numpy.ndarray`.
        """
142 143
        return super().numpy()

144 145 146 147 148 149
    def detach(self):
        r"""
        Returns a new :class:`~.Tensor`, detached from the current graph.
        """
        return super().detach()

150
    def _reset(self, other):
151 152
        if not isinstance(other, _Tensor):
            other = Tensor(other, dtype=self.dtype, device=self.device)
153 154
        super()._reset(other)

155
    def __repr__(self):
156
        piece = "{}(".format(self.__class__.__name__)
157 158 159 160 161 162
        with np.printoptions(precision=4, suppress=True):
            piece += "{}".format(str(self.numpy()))
        if self.dtype != np.float32:
            piece += ", dtype={}".format(np.dtype(self.dtype).name)
        piece += ", device={}".format(self.device) + ")"
        return piece
163

164 165 166 167 168 169 170
    @property
    def name(self):
        return self.c_name

    @name.setter
    def name(self, name):
        self.c_name = name
171
        AutoNaming.record_var_name(self._mixin_handle, name)
172

M
Megvii Engine Team 已提交
173
    @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
174 175 176
    def set_value(self, value):
        self._reset(value)

M
Megvii Engine Team 已提交
177
    @deprecated(version="1.0", reason="use *= 0 instead")
178 179 180
    def reset_zero(self):
        self *= 0

181
    def to(self, device):
182 183 184
        r"""
        Copy self :class:`~.Tensor` to specified device. See :func:`~.copy`
        """
185 186 187 188 189 190
        if isinstance(device, str) and not _valid_device(device):
            raise ValueError(
                "invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format(
                    device
                )
            )
191
        cn = as_device(device).to_c()
192 193
        return apply(Copy(comp_node=cn), self)[0]

M
Megvii Engine Team 已提交
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
    @property
    def requires_grad(self):
        raise AttributeError("requires_grad is reserved for future use")

    @requires_grad.setter
    def requires_grad(self, value):
        raise AttributeError("requires_grad is reserved for future use")

    @requires_grad.deleter
    def requires_grad(self):
        raise AttributeError("requires_grad is reserved for future use")

    def __hash__(self):
        return id(self)

209 210 211 212 213
    def __getnewargs__(self):
        r""" __getnewargs__ will be called for pickle serialization or deep copy
        """
        return (self.numpy(), self.dtype, self.device.logical_name)

214 215 216
    def __getstate__(self):
        r""" __getstate__ will be called for pickle serialization or deep copy
        """
217
        state = {}
218 219
        if self._qparams is not None:
            state["qparams"] = self._qparams
220 221 222
        return state

    def __setstate__(self, state):
223 224 225 226 227 228 229 230
        # for compatibility with old version not using fastcore
        if "data" in state:
            data = state.pop("data")
            device = state.pop("device")
            dtype = state.pop("dtype")
            self._reset(Tensor(data, dtype=dtype, device=device))

        # quantize related state for deepcopy
231 232 233 234 235 236 237 238 239 240
        if "qdict" in state:
            qparams = state.pop("qdict")
            logger.warning(
                "Tensor's 'qdict' state is depreciated. Use 'qparams' instead"
            )
        elif "qparams" in state:
            qparams = state.pop("qparams")
        else:
            qparams = None
        self._qparams = qparams
241 242 243 244 245


tensor = Tensor


M
Megvii Engine Team 已提交
246
class Parameter(Tensor):
247 248
    r"""
    A kind of Tensor that is to be considered a module parameter.
M
Megvii Engine Team 已提交
249
    """