tensor.py 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


import collections

13 14 15 16 17 18
import numpy as np

from .core._imperative_rt import CompNode
from .core._imperative_rt.core2 import Tensor as _Tensor
from .core._imperative_rt.core2 import apply
from .core._trace_option import use_symbolic_shape
19
from .core._wrap import device as as_device
20
from .core.ops.builtin import Copy, GetVarShape
21
from .core.tensor.array_method import ArrayMethodMixin
22
from .device import _valid_device, get_default_device
M
Megvii Engine Team 已提交
23
from .utils.deprecation import deprecated
24 25


26
class Tensor(_Tensor, ArrayMethodMixin):
M
Megvii Engine Team 已提交
27
    grad = None
28
    dmap_callback = None
29
    q_dict = {"mode": None, "scale": None, "zero_point": None}
30

31
    def __new__(cls, data, dtype=None, device=None, is_const=False, no_cache=False):
32
        if device is None:
33 34 35 36 37 38 39
            cn = get_default_device()
        elif isinstance(device, str):
            if cls.dmap_callback is not None:
                cn = CompNode(cls.dmap_callback(device))
            else:
                cn = CompNode(device)
        else:
40 41 42 43
            if isinstance(device, CompNode):
                cn = device
            else:
                cn = device._cn
44 45 46 47

        if isinstance(data, _Tensor):
            obj = _Tensor.__new__(cls, data)
        else:
48 49 50 51
            if isinstance(data, np.ndarray):
                if 0 in data.strides:
                    data = data.squeeze().reshape(data.shape)

52
            obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache)
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
        return obj

    @property
    def shape(self):
        shape = super().shape
        if shape == () or not use_symbolic_shape():
            return shape
        return apply(GetVarShape(), self)[0]

    @property
    def _tuple_shape(self):
        return super().shape

    def __repr__(self):
        piece = "Tensor("
        with np.printoptions(precision=4, suppress=True):
            piece += "{}".format(str(self.numpy()))
        if self.dtype != np.float32:
            piece += ", dtype={}".format(np.dtype(self.dtype).name)
        piece += ", device={}".format(self.device) + ")"
        return piece
74

M
Megvii Engine Team 已提交
75
    @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
76
    def set_value(self, value):
77 78
        if not isinstance(value, _Tensor):
            value = Tensor(value, dtype=self.dtype, device=self.device)
79 80
        self._reset(value)

M
Megvii Engine Team 已提交
81
    @deprecated(version="1.0", reason="use *= 0 instead")
82 83 84
    def reset_zero(self):
        self *= 0

85
    def to(self, device):
86 87 88 89 90 91
        if isinstance(device, str) and not _valid_device(device):
            raise ValueError(
                "invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format(
                    device
                )
            )
92
        cn = as_device(device).to_c()
93 94
        return apply(Copy(comp_node=cn), self)[0]

M
Megvii Engine Team 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    @property
    def requires_grad(self):
        raise AttributeError("requires_grad is reserved for future use")

    @requires_grad.setter
    def requires_grad(self, value):
        raise AttributeError("requires_grad is reserved for future use")

    @requires_grad.deleter
    def requires_grad(self):
        raise AttributeError("requires_grad is reserved for future use")

    def __hash__(self):
        return id(self)

110 111 112 113 114
    def __getnewargs__(self):
        r""" __getnewargs__ will be called for pickle serialization or deep copy
        """
        return (self.numpy(), self.dtype, self.device.logical_name)

115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
    def __getstate__(self):
        r""" __getstate__ will be called for pickle serialization or deep copy
        """

        state = {
            "qdict": self.q_dict,
        }
        return state

    def __setstate__(self, state):
        self.q_dict = state.pop("qdict")


tensor = Tensor


M
Megvii Engine Team 已提交
131
class Parameter(Tensor):
132 133
    r"""
    A kind of Tensor that is to be considered a module parameter.
M
Megvii Engine Team 已提交
134
    """