diff --git a/python_module/megengine/core/function.py b/python_module/megengine/core/function.py index 35d2f7705c62a9898e2a96e6103c9189fe69e369..da37ed3d29a2b97402ce88bde941c2fec055ad81 100644 --- a/python_module/megengine/core/function.py +++ b/python_module/megengine/core/function.py @@ -6,6 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import copy from abc import ABCMeta, abstractmethod from typing import Iterable, Tuple, Union @@ -142,6 +143,20 @@ class Function(metaclass=ABCMeta): """ self.saved_tensors = tensors + def __deepcopy__(self, memo): + """ + Defines how the operator is deeply copied + """ + cls = self.__class__ + result = cls.__new__(cls) + tmp = self.saved_tensors + self.saved_tensors = None + memo[id(self)] = result + for k, v in self.__dict__.items(): + setattr(result, k, copy.deepcopy(v, memo)) + self.saved_tensors = tmp + return result + def __call__(self, *inputs): assert ( not self._has_saved_state diff --git a/python_module/megengine/core/tensor.py b/python_module/megengine/core/tensor.py index b4aa4dda981c0499fbd2376f8351e0f368a61898..575ed7d5307fbfea25cf6e13df30b663ab930125 100644 --- a/python_module/megengine/core/tensor.py +++ b/python_module/megengine/core/tensor.py @@ -495,7 +495,12 @@ class Tensor: ) def __getstate__(self): - assert (self.__val is not None) and (self.__sym is None) + r""" __getstate__ will be called for pickle serialization or deep copy + """ + + assert (self.__val is not None) and ( + self.__sym is None + ), "Only SharedND initialized Tensor can be serialized or deep copied" metadata = {"requires_grad": self.requires_grad} state = { "data": self.numpy(), diff --git a/python_module/test/unit/core/test_function.py b/python_module/test/unit/core/test_function.py index 8766388fabe429ba784352270525415159d021da..8333c5fe5fca7b6b66efaf9eaa0f103346dfd0f5 100644 --- a/python_module/test/unit/core/test_function.py +++ b/python_module/test/unit/core/test_function.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -6,10 +5,13 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import copy + import numpy as np import megengine.functional as F from megengine.core import Function, tensor +from megengine.jit import trace from megengine.test import assertTensorClose @@ -76,6 +78,27 @@ def test_ste(): ) +def test_deepcopy(): + class Sigmoid(Function): + def __init__(self, param): + super().__init__() + self.param = param + + def forward(self, x): + y = 1 / (1 + F.exp(-x)) + self.save_for_backward(y) + return y + + def backward(self, grad_y): + (y,) = self.saved_tensors + return grad_y * y * (1 - y) + + origin = Sigmoid(0) + new = copy.deepcopy(Sigmoid(0)) + assert new.param == origin.param + assert new.saved_tensors == None + + def test_save_context(): class Sigmoid(Function): def forward(self, x): @@ -87,14 +110,26 @@ def test_save_context(): (y,) = self.saved_tensors return grad_y * y * (1 - y) - a = tensor(np.array([1926.0817], dtype=np.float32)) - s = Sigmoid()(a) - s2 = F.sigmoid(a) - assertTensorClose(s.numpy(), s2.numpy()) - assertTensorClose( - F.grad(s, a, use_virtual_grad=False).numpy(), - F.grad(s2, a, use_virtual_grad=False).numpy(), - ) + def run_saved_context(a, net=None): + return net(a) + + def run(use_trace, symbolic): + a = tensor(np.array([1926.0817], dtype=np.float32)) + net = Sigmoid() + func_run = run_saved_context + if use_trace: + func_run = trace(run_saved_context, symbolic=symbolic) + s = func_run(a, net=net) + s2 = F.sigmoid(a) + assertTensorClose(s.numpy(), s2.numpy()) + assertTensorClose( + F.grad(s, a, use_virtual_grad=False).numpy(), + F.grad(s2, a, use_virtual_grad=False).numpy(), + ) + + run(False, False) + run(True, False) + run(True, True) def test_none_in_out_grad():