From 5aa19f3d4a7c38d8deff3b1e6432b8b68e686cf1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 26 Aug 2020 15:11:01 +0800 Subject: [PATCH] test(mge/imperative): add more testcases for function GitOrigin-RevId: e0675994f132447081c4cab6171045f62505c6b1 --- .../python/megengine/core/tensor/function.py | 17 +- imperative/python/test/unit/test_function.py | 169 ++++++++++++++++++ 2 files changed, 182 insertions(+), 4 deletions(-) diff --git a/imperative/python/megengine/core/tensor/function.py b/imperative/python/megengine/core/tensor/function.py index 6b51c6675..de1986df9 100644 --- a/imperative/python/megengine/core/tensor/function.py +++ b/imperative/python/megengine/core/tensor/function.py @@ -87,13 +87,21 @@ class Function: def _backward(*output_grads): if type(output_grads) is tuple: - _output_grads = map(TensorWrapper, output_grads) + _output_grads = [ + TensorWrapper(i) if i is not None else i for i in output_grads + ] else: - _output_grads = (TensorWrapper(output_grads),) + _output_grads = ( + TensorWrapper(output_grads) + if output_grads is not None + else output_grads, + ) ret = self.backward(*_output_grads) if type(ret) is not tuple: ret = (ret,) - ret = tuple([i.__wrapped__ for i in ret]) + ret = tuple( + i.__wrapped__ if isinstance(i, TensorWrapper) else i for i in ret + ) return ret return _backward @@ -127,7 +135,8 @@ def _(op: Function, *args: TensorWrapperBase): ) for output in outputs: - output._extra_data = {} + if output not in inputs: + output._extra_data = {} with push_context() as ctx: ctx.inputs = inputs diff --git a/imperative/python/test/unit/test_function.py b/imperative/python/test/unit/test_function.py index 8d46e26e2..8a690ea6f 100644 --- a/imperative/python/test/unit/test_function.py +++ b/imperative/python/test/unit/test_function.py @@ -5,8 +5,11 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import copy + import numpy as np +import megengine.functional as F import megengine.optimizer as optimizer from megengine import Parameter from megengine import Tensor as tensor @@ -126,3 +129,169 @@ def test_multi_output(): np.testing.assert_almost_equal(loss.numpy(), (av * bv + av + bv), decimal=6) np.testing.assert_almost_equal(net.a.numpy(), (av - bv - 1), decimal=6) np.testing.assert_almost_equal(net.b.numpy(), (bv - av - 1), decimal=6) + + +def test_skip_invalid_grad(): + data_shape = (1, 9, 2, 6) + av = np.random.random(data_shape).astype(np.float32) + bv = np.random.random(data_shape).astype(np.float32) + c = np.random.random(data_shape).astype(np.float32) + cookie = tensor(c) + + class EqWithFakeGrad(Function): + def forward(self, a, b): + return a + b + + def backward(self, grad_o): + _ = grad_o + return cookie, cookie + + class Simple(Module): + def __init__(self, a, b): + super().__init__() + self.a = Parameter(a, dtype=np.float32) + self.b = Parameter(b, dtype=np.float32) + self.layer1 = EqWithFakeGrad() + + def forward(self): + x = self.layer1(self.a, self.b) + return x + + net = Simple(av, bv) + optim = optimizer.SGD(net.parameters(), lr=1.0) + optim.zero_grad() + with optim.record(): + loss = net().sum() + optim.backward(loss) + optim.step() + np.testing.assert_almost_equal(net.a.numpy(), av - c) + np.testing.assert_almost_equal(net.b.numpy(), bv - c) + + +def test_ste(): + class STE(Function): + def forward(self, x): + maxv, minv = x.max(), x.min() + scale = F.maximum(maxv, -minv) / 127 + return F.round(x / scale) * scale + + def backward(self, grad_y): + return grad_y + + class Simple(Module): + def __init__(self, a): + super().__init__() + self.a = Parameter(a, dtype=np.float32) + self.layer1 = STE() + + def forward(self): + x = self.layer1(self.a) + x = (x * 2.0).sum() + return x + + data_shape = (1, 9, 2, 6) + av = np.random.random(data_shape).astype(np.float32) + net = Simple(av) + optim = optimizer.SGD(net.parameters(), lr=1.0) + optim.zero_grad() + + with optim.record(): + loss = net() + optim.backward(loss.sum()) + optim.step() + + np.testing.assert_almost_equal( + net.a.numpy(), + av - np.broadcast_to(np.array([2.0], dtype=np.float32), data_shape), + ) + + +def test_deepcopy(): + class Sigmoid(Function): + def __init__(self, param): + super().__init__() + self.param = param + + def forward(self, x): + y = 1 / (1 + F.exp(-x)) + self.save_for_backward(y) + return y + + def backward(self, grad_y): + (y,) = self.saved_tensors + return grad_y * y * (1 - y) + + origin = Sigmoid(0) + new = copy.deepcopy(Sigmoid(0)) + assert new.param == origin.param + + +def test_none_in_out_grad(): + class Test(Function): + def forward(self, a, b): + return a, b + + def backward(self, grad_a, grad_b): + assert grad_b is None + return (grad_a, 0.0) + + class Simple(Module): + def __init__(self, a, b): + super().__init__() + self.a = Parameter(a, dtype=np.float32) + self.b = Parameter(b, dtype=np.float32) + self.layer = Test() + + def forward(self): + aa, bb = self.layer(self.a, self.b) + return aa, bb + + a = tensor(np.array([1.0], dtype=np.float32)) + b = tensor(np.array([2.0], dtype=np.float32)) + net = Simple(a, b) + optim = optimizer.SGD(net.parameters(), lr=1.0) + optim.zero_grad() + with optim.record(): + loss, _ = net() + optim.backward(loss) + optim.step() + + np.testing.assert_almost_equal( + net.a.numpy(), np.array([1.0 - 1.0], dtype=np.float32) + ) + np.testing.assert_almost_equal( + net.b.numpy(), np.array([2.0 - 0.0], dtype=np.float32) + ) + + +def test_zero_grad(): + class StopGradient(Function): + def forward(self, a): + return a + + def backward(self, *_): + return None + + class Simple(Module): + def __init__(self, a): + super().__init__() + self.a = Parameter(a, dtype=np.float32) + self.layer = StopGradient() + + def forward(self): + b = self.a * 3.0 + c = self.a * 4.0 + return self.layer(b) + c + + a = tensor(np.array([1.0], dtype=np.float32)) + net = Simple(a) + optim = optimizer.SGD(net.parameters(), lr=1.0) + optim.zero_grad() + + with optim.record(): + loss = net() + optim.backward(loss.sum()) + optim.step() + np.testing.assert_almost_equal( + net.a.numpy(), np.array([1.0 - 4.0], dtype=np.float32), + ) -- GitLab