tensor.py 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


import collections

from .core import Tensor as _Tensor
14 15
from .core.ops.builtin import Copy
from .core.tensor.core import apply
16
from .device import get_default_device
M
Megvii Engine Team 已提交
17
from .utils.deprecation import deprecated
18 19 20


class Tensor(_Tensor):
M
Megvii Engine Team 已提交
21
    grad = None
22 23 24 25 26 27 28 29
    dmap_callback = None

    def __init__(self, data, dtype=None, device=None):
        if device is None:
            device = get_default_device()
        self.q_dict = {"mode": None, "scale": None, "zero_point": None}
        super().__init__(data, dtype=dtype, device=device)

M
Megvii Engine Team 已提交
30
    @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
31 32 33
    def set_value(self, value):
        self._reset(value)

M
Megvii Engine Team 已提交
34
    @deprecated(version="1.0", reason="use *= 0 instead")
35 36 37
    def reset_zero(self):
        self *= 0

38 39 40
    def to(self, cn):
        return apply(Copy(comp_node=cn), self)[0]

M
Megvii Engine Team 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
    @property
    def requires_grad(self):
        raise AttributeError("requires_grad is reserved for future use")

    @requires_grad.setter
    def requires_grad(self, value):
        raise AttributeError("requires_grad is reserved for future use")

    @requires_grad.deleter
    def requires_grad(self):
        raise AttributeError("requires_grad is reserved for future use")

    def __hash__(self):
        return id(self)

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
    def __getstate__(self):
        r""" __getstate__ will be called for pickle serialization or deep copy
        """

        state = {
            "data": self.numpy(),
            "device": str(self.device),
            "dtype": self.dtype,
            "qdict": self.q_dict,
        }
        return state

    def __setstate__(self, state):
        data = state.pop("data")
        device = state.pop("device")
        if self.dmap_callback is not None:
            assert isinstance(device, str)
            device = self.dmap_callback(device)
        dtype = state.pop("dtype")
        self.q_dict = state.pop("qdict")
        super().__init__(data, dtype=dtype, device=device)

    def detach(self):
        r"""
        Returns a new tensor which is treated as constant during backward gradient calcuation,
        i.e. its gradient is zero.

        :param inp: input tensor

        """
        Wrapper = type(self)
        Tensor = type(self.__wrapped__)
        return Wrapper(Tensor(self.__wrapped__._data))


tensor = Tensor


M
Megvii Engine Team 已提交
94 95 96
class Parameter(Tensor):
    r"""A kind of Tensor that is to be considered a module parameter.
    """