提交 35bc0e1f 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mge/function): do not deeply copy saved tensor in Function

GitOrigin-RevId: 3c89d1ceaa9f7338b167f80936549a4c364061de
上级 47377c7b
......@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
from abc import ABCMeta, abstractmethod
from typing import Iterable, Tuple, Union
......@@ -142,6 +143,20 @@ class Function(metaclass=ABCMeta):
"""
self.saved_tensors = tensors
def __deepcopy__(self, memo):
"""
Defines how the operator is deeply copied
"""
cls = self.__class__
result = cls.__new__(cls)
tmp = self.saved_tensors
self.saved_tensors = None
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
self.saved_tensors = tmp
return result
def __call__(self, *inputs):
assert (
not self._has_saved_state
......
......@@ -495,7 +495,12 @@ class Tensor:
)
def __getstate__(self):
assert (self.__val is not None) and (self.__sym is None)
r""" __getstate__ will be called for pickle serialization or deep copy
"""
assert (self.__val is not None) and (
self.__sym is None
), "Only SharedND initialized Tensor can be serialized or deep copied"
metadata = {"requires_grad": self.requires_grad}
state = {
"data": self.numpy(),
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -6,10 +5,13 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
import numpy as np
import megengine.functional as F
from megengine.core import Function, tensor
from megengine.jit import trace
from megengine.test import assertTensorClose
......@@ -76,6 +78,27 @@ def test_ste():
)
def test_deepcopy():
class Sigmoid(Function):
def __init__(self, param):
super().__init__()
self.param = param
def forward(self, x):
y = 1 / (1 + F.exp(-x))
self.save_for_backward(y)
return y
def backward(self, grad_y):
(y,) = self.saved_tensors
return grad_y * y * (1 - y)
origin = Sigmoid(0)
new = copy.deepcopy(Sigmoid(0))
assert new.param == origin.param
assert new.saved_tensors == None
def test_save_context():
class Sigmoid(Function):
def forward(self, x):
......@@ -87,14 +110,26 @@ def test_save_context():
(y,) = self.saved_tensors
return grad_y * y * (1 - y)
a = tensor(np.array([1926.0817], dtype=np.float32))
s = Sigmoid()(a)
s2 = F.sigmoid(a)
assertTensorClose(s.numpy(), s2.numpy())
assertTensorClose(
F.grad(s, a, use_virtual_grad=False).numpy(),
F.grad(s2, a, use_virtual_grad=False).numpy(),
)
def run_saved_context(a, net=None):
return net(a)
def run(use_trace, symbolic):
a = tensor(np.array([1926.0817], dtype=np.float32))
net = Sigmoid()
func_run = run_saved_context
if use_trace:
func_run = trace(run_saved_context, symbolic=symbolic)
s = func_run(a, net=net)
s2 = F.sigmoid(a)
assertTensorClose(s.numpy(), s2.numpy())
assertTensorClose(
F.grad(s, a, use_virtual_grad=False).numpy(),
F.grad(s2, a, use_virtual_grad=False).numpy(),
)
run(False, False)
run(True, False)
run(True, True)
def test_none_in_out_grad():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册