From 9451a96147f15e5caeeed547d5d58820a507a1c9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 23 Apr 2021 17:11:13 +0800 Subject: [PATCH] test(mge/optimizer): update optimizer test to make sure grad not change GitOrigin-RevId: e207672116dcd53dbfefa89ab9b1dcf7301abbea --- .../python/test/integration/test_optimizer.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/imperative/python/test/integration/test_optimizer.py b/imperative/python/test/integration/test_optimizer.py index da71816b6..d5ca99a9b 100644 --- a/imperative/python/test/integration/test_optimizer.py +++ b/imperative/python/test/integration/test_optimizer.py @@ -66,10 +66,17 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): gm.backward(loss) ori_params = {} + ori_grads = {} for param in net.parameters(): assert param._tuple_shape is () ori_params[param] = np.copy(param.numpy()) + ori_grads[param] = np.copy(param.grad.numpy()) opt.step() + # check grad not change + for param in net.parameters(): + assert np.equal( + ori_grads[param], param.grad.numpy() + ), "step should not change param.grad" step += 1 check_func(ori_params, net.parameters(), step) @@ -135,6 +142,8 @@ def test_sgd(monkeypatch, case, update_lr, inplace_mode): def __call__(self, ori_params, new_params, step): for param in new_params: grad = param.grad.numpy() + if hasattr(self, "weight_decay") and self.weight_decay != 0.0: + grad = grad + ori_params[param] * self.weight_decay if hasattr(self, "momentum"): self.slots[param] = grad + self.slots[param] * self.momentum delta = -self.lr * self.slots[param] @@ -177,6 +186,8 @@ def test_adam(monkeypatch, case, update_lr, inplace_mode): def __call__(self, ori_params, new_params, step): for param in new_params: grad = param.grad.numpy() + if hasattr(self, "weight_decay") and self.weight_decay != 0.0: + grad = grad + ori_params[param] * self.weight_decay m = self.m_slots[param] v = self.v_slots[param] m *= self.betas[0] @@ -222,6 +233,8 @@ def test_adagrad(monkeypatch, case, update_lr, inplace_mode): def __call__(self, ori_params, new_params, step): for param in new_params: grad = param.grad.numpy() + if hasattr(self, "weight_decay") and self.weight_decay != 0.0: + grad = grad + ori_params[param] * self.weight_decay self.s_slots[param] += grad ** 2 delta = grad / (self.s_slots[param] + self.eps) ** 0.5 delta *= -(self.lr / (1 + (step - 1) * self.lr_decay)) @@ -257,6 +270,8 @@ def test_adadelta(monkeypatch, case, update_lr, inplace_mode): def __call__(self, ori_params, new_params, step): for param in new_params: grad = param.grad.numpy() + if hasattr(self, "weight_decay") and self.weight_decay != 0.0: + grad = grad + ori_params[param] * self.weight_decay self.s_slots[param] = self.s_slots[param] * self.rho + grad ** 2 * ( 1 - self.rho ) -- GitLab