提交 35bc0e1f 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mge/function): do not deeply copy saved tensor in Function

GitOrigin-RevId: 3c89d1ceaa9f7338b167f80936549a4c364061de
上级 47377c7b
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Iterable, Tuple, Union from typing import Iterable, Tuple, Union
...@@ -142,6 +143,20 @@ class Function(metaclass=ABCMeta): ...@@ -142,6 +143,20 @@ class Function(metaclass=ABCMeta):
""" """
self.saved_tensors = tensors self.saved_tensors = tensors
def __deepcopy__(self, memo):
"""
Defines how the operator is deeply copied
"""
cls = self.__class__
result = cls.__new__(cls)
tmp = self.saved_tensors
self.saved_tensors = None
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
self.saved_tensors = tmp
return result
def __call__(self, *inputs): def __call__(self, *inputs):
assert ( assert (
not self._has_saved_state not self._has_saved_state
......
...@@ -495,7 +495,12 @@ class Tensor: ...@@ -495,7 +495,12 @@ class Tensor:
) )
def __getstate__(self): def __getstate__(self):
assert (self.__val is not None) and (self.__sym is None) r""" __getstate__ will be called for pickle serialization or deep copy
"""
assert (self.__val is not None) and (
self.__sym is None
), "Only SharedND initialized Tensor can be serialized or deep copied"
metadata = {"requires_grad": self.requires_grad} metadata = {"requires_grad": self.requires_grad}
state = { state = {
"data": self.numpy(), "data": self.numpy(),
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...@@ -6,10 +5,13 @@ ...@@ -6,10 +5,13 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
import numpy as np import numpy as np
import megengine.functional as F import megengine.functional as F
from megengine.core import Function, tensor from megengine.core import Function, tensor
from megengine.jit import trace
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
...@@ -76,6 +78,27 @@ def test_ste(): ...@@ -76,6 +78,27 @@ def test_ste():
) )
def test_deepcopy():
class Sigmoid(Function):
def __init__(self, param):
super().__init__()
self.param = param
def forward(self, x):
y = 1 / (1 + F.exp(-x))
self.save_for_backward(y)
return y
def backward(self, grad_y):
(y,) = self.saved_tensors
return grad_y * y * (1 - y)
origin = Sigmoid(0)
new = copy.deepcopy(Sigmoid(0))
assert new.param == origin.param
assert new.saved_tensors == None
def test_save_context(): def test_save_context():
class Sigmoid(Function): class Sigmoid(Function):
def forward(self, x): def forward(self, x):
...@@ -87,14 +110,26 @@ def test_save_context(): ...@@ -87,14 +110,26 @@ def test_save_context():
(y,) = self.saved_tensors (y,) = self.saved_tensors
return grad_y * y * (1 - y) return grad_y * y * (1 - y)
a = tensor(np.array([1926.0817], dtype=np.float32)) def run_saved_context(a, net=None):
s = Sigmoid()(a) return net(a)
s2 = F.sigmoid(a)
assertTensorClose(s.numpy(), s2.numpy()) def run(use_trace, symbolic):
assertTensorClose( a = tensor(np.array([1926.0817], dtype=np.float32))
F.grad(s, a, use_virtual_grad=False).numpy(), net = Sigmoid()
F.grad(s2, a, use_virtual_grad=False).numpy(), func_run = run_saved_context
) if use_trace:
func_run = trace(run_saved_context, symbolic=symbolic)
s = func_run(a, net=net)
s2 = F.sigmoid(a)
assertTensorClose(s.numpy(), s2.numpy())
assertTensorClose(
F.grad(s, a, use_virtual_grad=False).numpy(),
F.grad(s2, a, use_virtual_grad=False).numpy(),
)
run(False, False)
run(True, False)
run(True, True)
def test_none_in_out_grad(): def test_none_in_out_grad():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册