From 9440842e27decf8bee55bd260f94a95b7952ee26 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 12 Aug 2020 10:49:15 +0800 Subject: [PATCH] fix(mge/core): fix Tensor deepcopy issue GitOrigin-RevId: 6bea7970b8319fc54458c6b1d3cd2545cca3c0c0 --- python_module/megengine/core/tensor.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python_module/megengine/core/tensor.py b/python_module/megengine/core/tensor.py index a5ef8a0ad..b823544bf 100644 --- a/python_module/megengine/core/tensor.py +++ b/python_module/megengine/core/tensor.py @@ -6,6 +6,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections +import copy import functools import itertools import weakref @@ -674,6 +675,22 @@ class Tensor: snd = mgb.make_shared(device, value=data, dtype=dtype) self._reset(snd, requires_grad=requires_grad) + def __deepcopy__(self, memo): + """ + Since Tensor have __getstate__ and __setstate__ method, + deepcopy only process the that and ignore the attribute of Parameter. + So we need to add __deepcopy__ method to deepcopy correct attribute. + """ + assert (self.__val is not None) and ( + self.__sym is None + ), "Only SharedND initialized Tensor can be serialized or deep copied" + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + setattr(result, k, copy.deepcopy(v, memo)) + return result + def tensor( data: Union[list, np.ndarray] = None, -- GitLab