From ac56ff7ed3682db13035063dd4b8e00055e3a976 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 15 Apr 2021 18:57:56 +0800 Subject: [PATCH] feat(imperative/autodiff): add grad clip GitOrigin-RevId: f344f4a2330ca2f560f45241d5aebcbab0f959b6 --- .../python/megengine/optimizer/__init__.py | 1 + .../python/megengine/optimizer/clip_grad.py | 72 +++++++++++ .../test_converge_with_gradient_clip.py | 120 ++++++++++++++++++ .../test/unit/optimizer/test_clip_grad.py | 80 ++++++++++++ .../unit/optimizer/test_clip_grad_torch.py | 58 +++++++++ 5 files changed, 331 insertions(+) create mode 100644 imperative/python/megengine/optimizer/clip_grad.py create mode 100644 imperative/python/test/integration/test_converge_with_gradient_clip.py create mode 100644 imperative/python/test/unit/optimizer/test_clip_grad.py create mode 100644 imperative/python/test/unit/optimizer/test_clip_grad_torch.py diff --git a/imperative/python/megengine/optimizer/__init__.py b/imperative/python/megengine/optimizer/__init__.py index f121ff9a0..50740e73b 100644 --- a/imperative/python/megengine/optimizer/__init__.py +++ b/imperative/python/megengine/optimizer/__init__.py @@ -10,6 +10,7 @@ from .adadelta import Adadelta from .adagrad import Adagrad from .adam import Adam from .adamw import AdamW +from .clip_grad import * from .lr_scheduler import LRScheduler from .multi_step_lr import MultiStepLR from .optimizer import Optimizer diff --git a/imperative/python/megengine/optimizer/clip_grad.py b/imperative/python/megengine/optimizer/clip_grad.py new file mode 100644 index 000000000..84492ad67 --- /dev/null +++ b/imperative/python/megengine/optimizer/clip_grad.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# pylint: disable=redefined-builtin +from typing import Iterable, Union + +from ..core._imperative_rt.core2 import pop_scope, push_scope +from ..functional import clip, concat, minimum, norm +from ..tensor import Tensor + +__all__ = ["clip_grad_norm", "clip_grad_value"] + + +def clip_grad_norm( + tensors: Union[Tensor, Iterable[Tensor]], max_norm: float, ord: float = 2.0, +): + r"""Clips gradient norm of an iterable of parameters. + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + :param tensors: an iterable of Tensors or a single Tensor. + :param max_norm: max norm of the gradients. + :param ord: type of the used p-norm. Can be ``'inf'`` for infinity norm. + :return: total norm of the parameters (viewed as a single vector). + """ + push_scope("clip_grad_norm") + if isinstance(tensors, Tensor): + tensors = [tensors] + tensors = [t for t in tensors if t.grad is not None] + if len(tensors) == 0: + pop_scope("clip_grad_norm") + return Tensor(0.0) + norm_ = [norm(t.grad.flatten(), ord=ord) for t in tensors] + if len(norm_) > 1: + norm_ = norm(concat(norm_), ord=ord) + else: + norm_ = norm_[0] + scale = max_norm / (norm_ + 1e-6) + scale = minimum(scale, 1) + for tensor in tensors: + tensor.grad._reset(tensor.grad * scale) + pop_scope("clip_grad_norm") + return norm_ + + +def clip_grad_value( + tensors: Union[Tensor, Iterable[Tensor]], lower: float, upper: float +): + r"""Clips gradient of an iterable of parameters to a specified lower and + upper. Gradients are modified in-place. + + The gradients are clipped in the range: + + .. math:: \left[\text{lower}, \text{upper}\right] + + :param tensors: an iterable of Tensors or a single Tensor. + :param lower: minimum allowed value of the gradients. + :param upper: maximum allowed value of the gradients. + """ + push_scope("clip_grad_value") + if isinstance(tensors, Tensor): + tensors = [tensors] + for tensor in tensors: + if tensor.grad is None: + continue + tensor.grad._reset(clip(tensor.grad, lower, upper)) + pop_scope("clip_grad_value") diff --git a/imperative/python/test/integration/test_converge_with_gradient_clip.py b/imperative/python/test/integration/test_converge_with_gradient_clip.py new file mode 100644 index 000000000..1c0f3ac05 --- /dev/null +++ b/imperative/python/test/integration/test_converge_with_gradient_clip.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import itertools + +import numpy as np +import pytest + +import megengine as mge +import megengine.autodiff as ad +import megengine.functional as F +import megengine.optimizer as optim +from megengine import Tensor +from megengine.jit import trace +from megengine.module import Linear, Module +from megengine.optimizer import SGD + +batch_size = 64 +data_shape = (batch_size, 2) +label_shape = (batch_size,) + + +def minibatch_generator(): + while True: + inp_data = np.zeros((batch_size, 2)) + label = np.zeros(batch_size, dtype=np.int32) + for i in range(batch_size): + # [x0, x1], sampled from U[-1, 1] + inp_data[i, :] = np.random.rand(2) * 2 - 1 + label[i] = 0 if np.prod(inp_data[i]) < 0 else 1 + yield inp_data.astype(np.float32), label.astype(np.int32) + + +def calculate_precision(data: np.ndarray, pred: np.ndarray) -> float: + """ Calculate precision for given data and prediction. + + :type data: [[x, y], ...] + :param data: Input data + :type pred: [[x_pred, y_pred], ...] + :param pred: Network output data + """ + correct = 0 + assert len(data) == len(pred) + for inp_data, pred_output in zip(data, pred): + label = 0 if np.prod(inp_data) < 0 else 1 + pred_label = np.argmax(pred_output) + if pred_label == label: + correct += 1 + return float(correct) / len(data) + + +class XORNet(Module): + def __init__(self): + self.mid_layers = 14 + self.num_class = 2 + super().__init__() + + self.fc0 = Linear(self.num_class, self.mid_layers, bias=True) + self.fc1 = Linear(self.mid_layers, self.mid_layers, bias=True) + + self.fc2 = Linear(self.mid_layers, self.num_class, bias=True) + + def forward(self, x): + x = self.fc0(x) + x = F.tanh(x) + x = self.fc1(x) + x = F.tanh(x) + x = self.fc2(x) + return x + + +def test_training_converge(): + net = XORNet() + opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) + gm = ad.GradManager().attach(net.parameters()) + + @trace(symbolic=False) + def train(data, label): + with gm: + pred = net(data) + loss = F.nn.cross_entropy(pred, label) + gm.backward(loss) + optim.clip_grad_norm(net.parameters(), max_norm=0.2, ord=2.0) + return loss + + def infer(data): + return net(data) + + train_dataset = minibatch_generator() + losses = [] + + for data, label in itertools.islice(train_dataset, 2000): + data = Tensor(data, dtype=np.float32) + label = Tensor(label, dtype=np.int32) + opt.clear_grad() + loss = train(data, label) + optim.clip_grad_value(net.parameters(), lower=-0.1, upper=0.1) + opt.step() + losses.append(loss.numpy()) + print(np.mean(losses[-100:])) + assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" + + ngrid = 10 + x = np.linspace(-1.0, 1.0, ngrid) + xx, yy = np.meshgrid(x, x) + xx = xx.reshape((ngrid * ngrid, 1)) + yy = yy.reshape((ngrid * ngrid, 1)) + data = np.concatenate((xx, yy), axis=1).astype(np.float32) + + pred = infer(data).numpy() + precision = calculate_precision(data, pred) + print("precision=", precision) + assert precision == 1.0, "Test precision must be high enough, get {}".format( + precision + ) diff --git a/imperative/python/test/unit/optimizer/test_clip_grad.py b/imperative/python/test/unit/optimizer/test_clip_grad.py new file mode 100644 index 000000000..63dca8323 --- /dev/null +++ b/imperative/python/test/unit/optimizer/test_clip_grad.py @@ -0,0 +1,80 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import platform +import weakref + +import numpy as np +import pytest + +import megengine as mge +import megengine.autodiff as ad +import megengine.functional as F +import megengine.module as M +import megengine.optimizer as optim + + +class Net(M.Module): + def __init__(self): + super().__init__() + self.conv1 = M.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = M.BatchNorm2d(64) + self.avgpool = M.AvgPool2d(kernel_size=5, stride=5, padding=0) + self.fc = M.Linear(64, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.avgpool(x) + x = F.avg_pool2d(x, 22) + x = F.flatten(x, 1) + x = self.fc(x) + return x + + +def save_grad_value(net): + for param in net.parameters(): + param.grad_backup = param.grad.numpy().copy() + + +def test_clip_grad_norm(): + net = Net() + x = mge.tensor(np.random.randn(10, 3, 224, 224)) + gm = ad.GradManager().attach(net.parameters()) + opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9) + with gm: + loss = net(x).sum() + gm.backward(loss) + save_grad_value(net) + max_norm = 1.0 + original_norm = optim.clip_grad_norm(net.parameters(), max_norm=max_norm, ord=2) + scale = max_norm / original_norm + for param in net.parameters(): + np.testing.assert_almost_equal(param.grad.numpy(), param.grad_backup * scale) + opt.step().clear_grad() + + +def test_clip_grad_value(): + net = Net() + x = np.random.randn(10, 3, 224, 224).astype("float32") + gm = ad.GradManager().attach(net.parameters()) + opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9) + with gm: + y = net(x) + y = y.mean() + gm.backward(y) + save_grad_value(net) + max_val = 5 + min_val = -2 + optim.clip_grad_value(net.parameters(), lower=min_val, upper=max_val) + for param in net.parameters(): + np.testing.assert_almost_equal( + param.grad.numpy(), + np.maximum(np.minimum(param.grad_backup, max_val), min_val), + ) + opt.step().clear_grad() diff --git a/imperative/python/test/unit/optimizer/test_clip_grad_torch.py b/imperative/python/test/unit/optimizer/test_clip_grad_torch.py new file mode 100644 index 000000000..95a3db18f --- /dev/null +++ b/imperative/python/test/unit/optimizer/test_clip_grad_torch.py @@ -0,0 +1,58 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import platform +import weakref + +import numpy as np +import pytest +import torch + +import megengine as mge +import megengine.functional as F +import megengine.module as M +import megengine.optimizer as optim + + +def make_fake_params(): + shapes = [(1,), (3, 3), (5, 5, 5), (6, 7, 8, 9)] + params = [np.random.randn(*shape).astype("float32") for shape in shapes] + params_mge = [] + params_torch = [] + for param in params: + t = torch.ones(param.shape) + t.grad = torch.Tensor(param.copy()) + params_torch.append(t) + + t = mge.functional.ones(param.shape) + t.grad = mge.tensor(param.copy()) + params_mge.append(t) + return params_mge, params_torch + + +def test_clip_grad_norm_torch(): + max_norm = 1.0 + params_mge, params_torch = make_fake_params() + norm_torch = torch.nn.utils.clip_grad_norm_(params_torch, max_norm, norm_type=2.0) + norm_mge = optim.clip_grad_norm(params_mge, max_norm=max_norm, ord=2.0) + np.testing.assert_allclose(norm_mge.numpy(), norm_torch.numpy(), atol=1e-4) + for i in range(len(params_mge)): + np.testing.assert_allclose( + params_mge[i].grad.numpy(), params_torch[i].grad.numpy(), atol=1e-7 + ) + + +def test_clip_grad_value_torch(): + max_val = 0.5 + min_val = -0.5 + params_mge, params_torch = make_fake_params() + torch.nn.utils.clip_grad_value_(params_torch, clip_value=max_val) + optim.clip_grad_value(params_mge, lower=min_val, upper=max_val) + for i in range(len(params_mge)): + np.testing.assert_allclose( + params_mge[i].grad.numpy(), params_torch[i].grad.numpy(), atol=1e-7 + ) -- GitLab