From e283663a0287ac9b6efacebf72f0c9c492931505 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Sep 2020 19:59:53 +0800 Subject: [PATCH] fix(mge/imperative): update tests to new optimizer api GitOrigin-RevId: 3d06e3db3c6e057505cfc6df1f2b03c01b6d1470 --- .../python/megengine/autodiff/grad_manager.py | 1 + .../python/megengine/core/autodiff/grad.py | 7 ++++ .../megengine/distributed/functional.py | 2 +- .../python/megengine/optimizer/adadelta.py | 15 ++------ .../python/megengine/optimizer/adagrad.py | 14 ++------ imperative/python/megengine/optimizer/adam.py | 15 ++------ imperative/python/megengine/optimizer/sgd.py | 4 +-- .../test/integration/test_advance_indexing.py | 15 ++++---- imperative/python/test/integration/test_ai.py | 8 +++-- imperative/python/test/integration/test_bn.py | 22 +++++++----- .../python/test/integration/test_converge.py | 8 +++-- .../test/integration/test_correctness.py | 17 ++++++---- .../python/test/integration/test_detach.py | 8 +++-- .../test/integration/test_dp_correctness.py | 31 ++++++++++------- .../test/integration/test_hello_world.py | 8 +++-- .../python/test/integration/test_optimizer.py | 8 +++-- .../python/test/integration/test_save_load.py | 12 ++++--- .../test/integration/test_sgd_momentum.py | 14 ++++---- imperative/python/test/unit/test_function.py | 34 +++++++++++-------- 19 files changed, 127 insertions(+), 116 deletions(-) diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 07d716c1..e34cec46 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -29,6 +29,7 @@ class GradManager: def register_after_backward_callback(self, callback): self._after_backward_callback.append(callback) + return self def backward(self, ys, dys=None): global backwarding_grad_manager diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index c30e4113..d2120937 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -177,6 +177,13 @@ class Grad: dys = aslist(dys) assert len(ys) == len(dys) + ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()] + if len(ids) == 0: + return + + ys = [y for i, y in enumerate(ys) if i in ids] + dys = [dy for i, dy in enumerate(dys) if i in ids] + # ys is changed to a list of VariableNode which contains more information # such as OpNode, callback, etc. ys = [i._extra_data[self].node for i in ys] diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index c6162c53..097ecb49 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -20,8 +20,8 @@ from ..core.autodiff.grad import ( from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend from ..core.tensor.core import apply from ..core.tensor.tensor import Tensor, tensor_apply -from ..tensor import tensor from ..device import get_default_device +from ..tensor import tensor from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank __all__ = [ diff --git a/imperative/python/megengine/optimizer/adadelta.py b/imperative/python/megengine/optimizer/adadelta.py index 2eae5184..0491cff9 100644 --- a/imperative/python/megengine/optimizer/adadelta.py +++ b/imperative/python/megengine/optimizer/adadelta.py @@ -11,7 +11,7 @@ from typing import Iterable, Union import numpy as np from ..functional import sqrt -from ..tensor_nn import Buffer, Parameter +from ..tensor_nn import Parameter from .optimizer import Optimizer @@ -63,16 +63,7 @@ class Adadelta(Optimizer): for param in param_group["params"]: - if param.__wrapped__ in self._grad_skip: - self._grad_skip.remove(param.__wrapped__) - continue - - if not isinstance(param.grad, Buffer): - raise TypeError( - "grad must be a Buffer, maybe you forget to call backward()?" - ) - - if not param.requires_grad: + if not param.requires_grad or "grad" not in param.__dict__: continue states = self._state[param] @@ -91,5 +82,3 @@ class Adadelta(Optimizer): acc_delta = rho * acc_delta + (1 - rho) * delta ** 2 states["square_avg"]._reset(square_avg) states["acc_delta"]._reset(acc_delta) - - assert len(self._grad_skip) == 0 diff --git a/imperative/python/megengine/optimizer/adagrad.py b/imperative/python/megengine/optimizer/adagrad.py index da1ad46b..d0fe8728 100644 --- a/imperative/python/megengine/optimizer/adagrad.py +++ b/imperative/python/megengine/optimizer/adagrad.py @@ -11,7 +11,7 @@ from typing import Iterable, Union import numpy as np from ..functional import sqrt -from ..tensor_nn import Buffer, Parameter +from ..tensor_nn import Parameter from .optimizer import Optimizer @@ -62,16 +62,7 @@ class Adagrad(Optimizer): for param in param_group["params"]: - if param.__wrapped__ in self._grad_skip: - self._grad_skip.remove(param.__wrapped__) - continue - - if not isinstance(param.grad, Buffer): - raise TypeError( - "grad must be a Buffer, maybe you forget to call backward()?" - ) - - if not param.requires_grad: + if not param.requires_grad or "grad" not in param.__dict__: continue states = self._state[param] @@ -87,4 +78,3 @@ class Adagrad(Optimizer): clr = lr / (1 + (step - 1) * lr_decay) param -= clr * delta - assert len(self._grad_skip) == 0 diff --git a/imperative/python/megengine/optimizer/adam.py b/imperative/python/megengine/optimizer/adam.py index 7901adb7..d411945e 100644 --- a/imperative/python/megengine/optimizer/adam.py +++ b/imperative/python/megengine/optimizer/adam.py @@ -8,7 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from typing import Iterable, Tuple, Union -from ..tensor_nn import Buffer, Parameter +from ..tensor_nn import Parameter from .optimizer import Optimizer @@ -59,18 +59,9 @@ class Adam(Optimizer): for param in param_group["params"]: - if param.__wrapped__ in self._grad_skip: - self._grad_skip.remove(param.__wrapped__) + if not param.requires_grad or "grad" not in param.__dict__: continue - if not param.requires_grad: - continue - - if not isinstance(param.grad, Buffer): - raise TypeError( - "grad must be a Buffer, maybe you forget to call backward()?" - ) - grad = param.grad if weight_decay != 0.0: grad += param * weight_decay @@ -91,5 +82,3 @@ class Adam(Optimizer): # not inplace change, need to update underlying tensor handler in state states["exp_avg"]._reset(exp_avg) states["exp_avg_sq"]._reset(exp_avg_sq) - - assert len(self._grad_skip) == 0 diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index 4e4dafb8..9215ef48 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -8,7 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from typing import Iterable, Union -from ..tensor_nn import Buffer, Parameter +from ..tensor_nn import Parameter from .optimizer import Optimizer @@ -52,7 +52,7 @@ class SGD(Optimizer): momentum = param_group["momentum"] for param in param_group["params"]: - if not param.requires_grad: + if not param.requires_grad or "grad" not in param.__dict__: continue grad = param.grad diff --git a/imperative/python/test/integration/test_advance_indexing.py b/imperative/python/test/integration/test_advance_indexing.py index 261f6daf..6267a3d0 100644 --- a/imperative/python/test/integration/test_advance_indexing.py +++ b/imperative/python/test/integration/test_advance_indexing.py @@ -9,6 +9,7 @@ import numpy as np import megengine +import megengine.autodiff as ad import megengine.optimizer as optimizer from megengine import Parameter, tensor from megengine.module import Module @@ -37,8 +38,9 @@ class Simple2(Module): def test_advance_indexing(): net = Simple() + gm = ad.GradManager().register(net.parameters()) optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() + optim.clear_grad() dshape = (10, 10) raw_data = np.arange(100).reshape(dshape).astype(np.float32) @@ -46,9 +48,9 @@ def test_advance_indexing(): data = tensor(raw_data) mask = tensor(raw_mask) answer = 1.0 - raw_data[raw_mask].sum() - with optim.record(): + with gm.record(): loss = net(data, mask).sum() - optim.backward(loss) + gm.backward(loss) optim.step() np.testing.assert_almost_equal(net.a.numpy(), np.array([answer]).astype(np.float32)) @@ -56,15 +58,16 @@ def test_advance_indexing(): def test_advance_indexing_with_subtensor(): net = Simple2() + gm = ad.GradManager().register(net.parameters()) optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() + optim.clear_grad() dshape = (2, 3, 4, 3, 4, 2) raw_data = np.arange(576).reshape(dshape).astype(np.float32) data = tensor(raw_data) answer = 1.0 - raw_data[1, ..., :, 0:4:2, 0:2].sum() - with optim.record(): + with gm.record(): loss = net(data).sum() - optim.backward(loss) + gm.backward(loss) optim.step() np.testing.assert_almost_equal(net.a.numpy(), np.array([answer]).astype(np.float32)) diff --git a/imperative/python/test/integration/test_ai.py b/imperative/python/test/integration/test_ai.py index 3e40bac9..fdf54fa9 100644 --- a/imperative/python/test/integration/test_ai.py +++ b/imperative/python/test/integration/test_ai.py @@ -9,6 +9,7 @@ import numpy as np import megengine +import megengine.autodiff as ad import megengine.optimizer as optimizer from megengine import Parameter, tensor from megengine.module import Module @@ -27,14 +28,15 @@ class Simple(Module): def test_ai(): net = Simple() + gm = ad.GradManager().register(net.parameters()) optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() + optim.clear_grad() dshape = (10, 10) data = tensor(np.ones(dshape).astype(np.float32)) - with optim.record(): + with gm.record(): loss = net(data).sum() - optim.backward(loss) + gm.backward(loss) optim.step() np.testing.assert_almost_equal( net.a.numpy(), np.array([1.0 - dshape[0]]).astype(np.float32) diff --git a/imperative/python/test/integration/test_bn.py b/imperative/python/test/integration/test_bn.py index 779b2ef9..27955232 100644 --- a/imperative/python/test/integration/test_bn.py +++ b/imperative/python/test/integration/test_bn.py @@ -10,6 +10,7 @@ import numpy as np import pytest import megengine +import megengine.autodiff as ad import megengine.optimizer as optimizer from megengine import Parameter, tensor from megengine.module import BatchNorm2d @@ -24,13 +25,14 @@ def test_frozen_bn(): saved_wt = m.weight.numpy() saved_bias = m.bias.numpy() + gm = ad.GradManager().register(m.parameters()) optim = optimizer.SGD(m.parameters(), lr=1.0) - optim.zero_grad() + optim.clear_grad() data = np.random.random((6, nchannel, 2, 2)).astype("float32") - with optim.record(): + with gm.record(): loss = m(data).mean() - optim.backward(loss) + gm.backward(loss) optim.step() np.testing.assert_equal(m.running_var.numpy(), saved_var) @@ -44,13 +46,14 @@ def test_bn_no_track_stat(): nchannel = 3 m = BatchNorm2d(nchannel, track_running_stats=False) + gm = ad.GradManager().register(m.parameters()) optim = optimizer.SGD(m.parameters(), lr=1.0) - optim.zero_grad() + optim.clear_grad() data = np.random.random((6, nchannel, 2, 2)).astype("float32") - with optim.record(): + with gm.record(): loss = m(data).sum() - optim.backward(loss) + gm.backward(loss) optim.step() @@ -65,13 +68,14 @@ def test_bn_no_track_stat2(): saved_mean = m.running_mean.numpy() assert saved_mean is not None + gm = ad.GradManager().register(m.parameters()) optim = optimizer.SGD(m.parameters(), lr=1.0) - optim.zero_grad() + optim.clear_grad() data = np.random.random((6, nchannel, 2, 2)).astype("float32") - with optim.record(): + with gm.record(): loss = m(data).sum() - optim.backward(loss) + gm.backward(loss) optim.step() np.testing.assert_equal(m.running_var.numpy(), saved_var) diff --git a/imperative/python/test/integration/test_converge.py b/imperative/python/test/integration/test_converge.py index 7778c6a9..42267e2c 100644 --- a/imperative/python/test/integration/test_converge.py +++ b/imperative/python/test/integration/test_converge.py @@ -12,6 +12,7 @@ import numpy as np import pytest import megengine as mge +import megengine.autodiff as ad import megengine.functional as F from megengine import Tensor from megengine.module import Linear, Module @@ -76,12 +77,13 @@ def test_training_converge(): opt = SGD( net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 ) + gm = ad.GradManager().register(net.parameters()) def train(data, label): - with opt.record(): + with gm.record(): pred = net(data) loss = F.cross_entropy_with_softmax(pred, label) - opt.backward(loss) + gm.backward(loss) return loss def infer(data): @@ -93,7 +95,7 @@ def test_training_converge(): for data, label in itertools.islice(train_dataset, 2000): data = Tensor(data, dtype=np.float32) label = Tensor(label, dtype=np.int32) - opt.zero_grad() + opt.clear_grad() loss = train(data, label) opt.step() losses.append(loss.numpy()) diff --git a/imperative/python/test/integration/test_correctness.py b/imperative/python/test/integration/test_correctness.py index 31cfecbf..8a672607 100644 --- a/imperative/python/test/integration/test_correctness.py +++ b/imperative/python/test/integration/test_correctness.py @@ -15,6 +15,7 @@ import numpy as np import pytest import megengine as mge +import megengine.autodiff as ad import megengine.functional as F from megengine import jit from megengine.core._trace_option import set_tensor_shape @@ -89,11 +90,11 @@ class MnistNet(Module): return x -def train(data, label, net, opt): - with opt.record(): +def train(data, label, net, opt, gm): + with gm.record(): pred = net(data) loss = F.cross_entropy_with_softmax(pred, label) - opt.backward(loss) + gm.backward(loss) return loss @@ -116,12 +117,13 @@ def update_model(model_path): net.load_state_dict(checkpoint["net_init"]) lr = checkpoint["sgd_lr"] opt = SGD(net.parameters(), lr=lr) + gm = ad.GradManager().register(net.parameters()) data = Tensor(checkpoint["data"], dtype=np.float32) label = Tensor(checkpoint["label"], dtype=np.int32) - opt.zero_grad() - loss = train(data, label, net=net, opt=opt) + opt.clear_grad() + loss = train(data, label, net, opt, gm) opt.step() xpu_name = get_xpu_name() @@ -150,6 +152,7 @@ def run_train( net.load_state_dict(checkpoint["net_init"]) lr = checkpoint["sgd_lr"] opt = SGD(net.parameters(), lr=lr) + gm = ad.GradManager().register(net.parameters()) data = Tensor(checkpoint["data"], dtype=np.float32) label = Tensor(checkpoint["label"], dtype=np.int32) @@ -165,8 +168,8 @@ def run_train( sublinear_memory_config=sublinear_memory_config, ) - opt.zero_grad() - loss = train_func(data, label, net=net, opt=opt) + opt.clear_grad() + loss = train_func(data, label, net, opt, gm) opt.step() assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err) diff --git a/imperative/python/test/integration/test_detach.py b/imperative/python/test/integration/test_detach.py index 0d0b3d5c..6a2f7037 100644 --- a/imperative/python/test/integration/test_detach.py +++ b/imperative/python/test/integration/test_detach.py @@ -9,6 +9,7 @@ import numpy as np import megengine +import megengine.autodiff as ad import megengine.optimizer as optimizer from megengine import Parameter, tensor from megengine.module import Module @@ -30,13 +31,14 @@ def test_detach(): net = Simple() optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() + optim.clear_grad() + gm = ad.GradManager().register(net.parameters()) dshape = (10, 10) data = tensor(np.ones(dshape).astype(np.float32)) - with optim.record(): + with gm.record(): loss = net(data).sum() - optim.backward(loss) + gm.backward(loss) optim.step() np.testing.assert_equal(net.a.numpy(), np.array([1.0]).astype(np.float32)) np.testing.assert_equal( diff --git a/imperative/python/test/integration/test_dp_correctness.py b/imperative/python/test/integration/test_dp_correctness.py index 1494cbd7..df94ea4e 100644 --- a/imperative/python/test/integration/test_dp_correctness.py +++ b/imperative/python/test/integration/test_dp_correctness.py @@ -18,6 +18,7 @@ import numpy as np import pytest import megengine as mge +import megengine.autodiff as ad import megengine.distributed as dist import megengine.functional as F from megengine.device import get_default_device, set_default_device @@ -94,11 +95,13 @@ class MnistNet(Module): return x -def train(data, label, net, opt): - with opt.record(): +def train(data, label, net, opt, gm): + opt.clear_grad() + with gm.record(): pred = net(data) loss = F.cross_entropy_with_softmax(pred, label) - opt.backward(loss) + gm.backward(loss) + opt.step() return loss @@ -111,7 +114,7 @@ def update_model(model_path): .. code-block:: python - from test_correctness import update_model + from test_dp_correctness import update_model update_model('mnist_model_with_test.mge') # for gpu update_model('mnist_model_with_test_cpu.mge') # for cpu @@ -122,6 +125,11 @@ def update_model(model_path): lr = checkpoint["sgd_lr"] opt = SGD(net.parameters(), lr=lr) + gm = ad.GradManager() + gm.register( + net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)] + ) + data = Tensor(checkpoint["data"], dtype=np.float32) label = Tensor(checkpoint["label"], dtype=np.int32) @@ -158,24 +166,23 @@ def run_test( def worker(rank, max_err): dist.init_process_group("localhost", port, p_num, rank, rank) - set_default_device(device="gpu{}".format(dist.get_rank())) net = MnistNet(has_bn=True) net.load_state_dict(checkpoint["net_init"]) lr = checkpoint["sgd_lr"] - opt = SGD(net.parameters(), reduce_method="mean", lr=lr) + opt = SGD(net.parameters(), lr=lr) + + gm = ad.GradManager() + gm.register( + net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)] + ) # use same data and label for all gpu's # such that the result does not depend on number of gpu data_train = Tensor(data) label_train = Tensor(label) - train_func = train - - opt.zero_grad() - loss = train_func(data_train, label_train, net=net, opt=opt) - opt.step() + loss = train(data_train, label_train, net, opt, gm) - print("{} loss {}".format(get_default_device(), loss.numpy()[0])) assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err) if dist.get_rank(): diff --git a/imperative/python/test/integration/test_hello_world.py b/imperative/python/test/integration/test_hello_world.py index 033d2854..af01c3f1 100644 --- a/imperative/python/test/integration/test_hello_world.py +++ b/imperative/python/test/integration/test_hello_world.py @@ -12,6 +12,7 @@ import numpy as np import pytest import megengine +import megengine.autodiff as ad import megengine.optimizer as optimizer from megengine import Parameter, tensor from megengine.module import Module @@ -31,12 +32,13 @@ def test_hello_world(): net = Simple() optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() + optim.clear_grad() + gm = ad.GradManager().register(net.parameters()) data = tensor([2.34]) - with optim.record(): + with gm.record(): loss = net(data) - optim.backward(loss) + gm.backward(loss) optim.step() np.testing.assert_almost_equal( net.a.numpy(), np.array([1.23 - 2.34]).astype(np.float32) diff --git a/imperative/python/test/integration/test_optimizer.py b/imperative/python/test/integration/test_optimizer.py index 36ea5e95..bd7b1798 100644 --- a/imperative/python/test/integration/test_optimizer.py +++ b/imperative/python/test/integration/test_optimizer.py @@ -8,6 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np +import megengine.autodiff as ad import megengine.functional as F from megengine import Parameter, optimizer from megengine.jit import trace @@ -43,6 +44,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): net = Simple() opt = getattr(optimizer, opt_str)(net.parameters(), **test_case) check_func = check_class(net, **test_case) + gm = ad.GradManager().register(net.parameters()) step = 0 data_shape = (2, 28) @@ -54,11 +56,11 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): check_func.lr += 0.01 data = tensor(np.random.random(data_shape).astype(np.float32)) - opt.zero_grad() - with opt.record(): + opt.clear_grad() + with gm.record(): pred = net(data) loss = pred.sum() - opt.backward(loss) + gm.backward(loss) ori_params = TensorDict() for param in net.parameters(): diff --git a/imperative/python/test/integration/test_save_load.py b/imperative/python/test/integration/test_save_load.py index 0664e416..2008a211 100644 --- a/imperative/python/test/integration/test_save_load.py +++ b/imperative/python/test/integration/test_save_load.py @@ -1,6 +1,7 @@ import numpy as np import megengine as mge +import megengine.autodiff as ad import megengine.optimizer as optimizer from megengine import Parameter, tensor from megengine.core.tensor.raw_tensor import RawTensor @@ -21,13 +22,14 @@ def test_save_load(): net = Simple() optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) - optim.zero_grad() + optim.clear_grad() + gm = ad.GradManager().register(net.parameters()) data = tensor([2.34]) - with optim.record(): + with gm.record(): loss = net(data) - optim.backward(loss) + gm.backward(loss) optim.step() @@ -53,9 +55,9 @@ def test_save_load(): optim.load_state_dict(checkpoint["opt_state"]) print("load done") - with optim.record(): + with gm.record(): loss = net([1.23]) - optim.backward(loss) + gm.backward(loss) optim.step() # Restore device diff --git a/imperative/python/test/integration/test_sgd_momentum.py b/imperative/python/test/integration/test_sgd_momentum.py index da60e003..f1d75a79 100644 --- a/imperative/python/test/integration/test_sgd_momentum.py +++ b/imperative/python/test/integration/test_sgd_momentum.py @@ -9,6 +9,7 @@ import numpy as np import megengine +import megengine.autodiff as ad import megengine.optimizer as optimizer from megengine import Parameter, tensor from megengine.jit import trace @@ -29,14 +30,15 @@ def test_sgd_momentum(): net = Simple() optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) - optim.zero_grad() + optim.clear_grad() + gm = ad.GradManager().register(net.parameters()) data = tensor([2.34]) # do a step of train - with optim.record(): + with gm.record(): loss = net(data) - optim.backward(loss) + gm.backward(loss) optim.step() np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34) @@ -48,10 +50,10 @@ def test_sgd_momentum(): np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34) # do a step of train - optim.zero_grad() - with optim.record(): + optim.clear_grad() + with gm.record(): loss = net(data) - optim.backward(loss) + gm.backward(loss) optim.step() np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) diff --git a/imperative/python/test/unit/test_function.py b/imperative/python/test/unit/test_function.py index 8a690ea6..990ced26 100644 --- a/imperative/python/test/unit/test_function.py +++ b/imperative/python/test/unit/test_function.py @@ -9,6 +9,7 @@ import copy import numpy as np +import megengine.autodiff as ad import megengine.functional as F import megengine.optimizer as optimizer from megengine import Parameter @@ -41,13 +42,14 @@ def test_single_input(): return x net = Simple(av) - optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() + gm = ad.GradManager().register(net.parameters()) + opt = optimizer.SGD(net.parameters(), lr=1.0) - with optim.record(): + opt.clear_grad() + with gm.record(): loss = net() - optim.backward(loss.sum()) - optim.step() + gm.backward(loss.sum()) + opt.step() np.testing.assert_almost_equal(loss.numpy(), (av * 10)) np.testing.assert_almost_equal(net.a.numpy(), (av - 10)) @@ -79,13 +81,14 @@ def test_multi_input(): return x net = Simple(av, bv) - optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() + gm = ad.GradManager().register(net.parameters()) + opt = optimizer.SGD(net.parameters(), lr=1.0) - with optim.record(): + opt.clear_grad() + with gm.record(): loss = net() - optim.backward(loss.sum()) - optim.step() + gm.backward(loss.sum()) + opt.step() np.testing.assert_almost_equal(loss.numpy(), (av * bv)) np.testing.assert_almost_equal(net.a.numpy(), (av - 2 * bv)) @@ -118,13 +121,14 @@ def test_multi_output(): return x + y net = Simple(av, bv) - optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() + gm = ad.GradManager().register(net.parameters()) + opt = optimizer.SGD(net.parameters(), lr=1.0) - with optim.record(): + opt.clear_grad() + with gm.record(): loss = net() - optim.backward(loss.sum()) - optim.step() + gm.backward(loss.sum()) + opt.step() np.testing.assert_almost_equal(loss.numpy(), (av * bv + av + bv), decimal=6) np.testing.assert_almost_equal(net.a.numpy(), (av - bv - 1), decimal=6) -- GitLab