提交 9440842e 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mge/core): fix Tensor deepcopy issue

GitOrigin-RevId: 6bea7970b8319fc54458c6b1d3cd2545cca3c0c0
上级 d4b86b84
......@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import copy
import functools
import itertools
import weakref
......@@ -674,6 +675,22 @@ class Tensor:
snd = mgb.make_shared(device, value=data, dtype=dtype)
self._reset(snd, requires_grad=requires_grad)
def __deepcopy__(self, memo):
"""
Since Tensor have __getstate__ and __setstate__ method,
deepcopy only process the that and ignore the attribute of Parameter.
So we need to add __deepcopy__ method to deepcopy correct attribute.
"""
assert (self.__val is not None) and (
self.__sym is None
), "Only SharedND initialized Tensor can be serialized or deep copied"
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
return result
def tensor(
data: Union[list, np.ndarray] = None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册