提交 ac56ff7e 编写于 作者: M Megvii Engine Team

feat(imperative/autodiff): add grad clip

GitOrigin-RevId: f344f4a2330ca2f560f45241d5aebcbab0f959b6
上级 4cc585b7
......@@ -10,6 +10,7 @@ from .adadelta import Adadelta
from .adagrad import Adagrad
from .adam import Adam
from .adamw import AdamW
from .clip_grad import *
from .lr_scheduler import LRScheduler
from .multi_step_lr import MultiStepLR
from .optimizer import Optimizer
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
from typing import Iterable, Union
from ..core._imperative_rt.core2 import pop_scope, push_scope
from ..functional import clip, concat, minimum, norm
from ..tensor import Tensor
__all__ = ["clip_grad_norm", "clip_grad_value"]
def clip_grad_norm(
tensors: Union[Tensor, Iterable[Tensor]], max_norm: float, ord: float = 2.0,
):
r"""Clips gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
:param tensors: an iterable of Tensors or a single Tensor.
:param max_norm: max norm of the gradients.
:param ord: type of the used p-norm. Can be ``'inf'`` for infinity norm.
:return: total norm of the parameters (viewed as a single vector).
"""
push_scope("clip_grad_norm")
if isinstance(tensors, Tensor):
tensors = [tensors]
tensors = [t for t in tensors if t.grad is not None]
if len(tensors) == 0:
pop_scope("clip_grad_norm")
return Tensor(0.0)
norm_ = [norm(t.grad.flatten(), ord=ord) for t in tensors]
if len(norm_) > 1:
norm_ = norm(concat(norm_), ord=ord)
else:
norm_ = norm_[0]
scale = max_norm / (norm_ + 1e-6)
scale = minimum(scale, 1)
for tensor in tensors:
tensor.grad._reset(tensor.grad * scale)
pop_scope("clip_grad_norm")
return norm_
def clip_grad_value(
tensors: Union[Tensor, Iterable[Tensor]], lower: float, upper: float
):
r"""Clips gradient of an iterable of parameters to a specified lower and
upper. Gradients are modified in-place.
The gradients are clipped in the range:
.. math:: \left[\text{lower}, \text{upper}\right]
:param tensors: an iterable of Tensors or a single Tensor.
:param lower: minimum allowed value of the gradients.
:param upper: maximum allowed value of the gradients.
"""
push_scope("clip_grad_value")
if isinstance(tensors, Tensor):
tensors = [tensors]
for tensor in tensors:
if tensor.grad is None:
continue
tensor.grad._reset(clip(tensor.grad, lower, upper))
pop_scope("clip_grad_value")
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import itertools
import numpy as np
import pytest
import megengine as mge
import megengine.autodiff as ad
import megengine.functional as F
import megengine.optimizer as optim
from megengine import Tensor
from megengine.jit import trace
from megengine.module import Linear, Module
from megengine.optimizer import SGD
batch_size = 64
data_shape = (batch_size, 2)
label_shape = (batch_size,)
def minibatch_generator():
while True:
inp_data = np.zeros((batch_size, 2))
label = np.zeros(batch_size, dtype=np.int32)
for i in range(batch_size):
# [x0, x1], sampled from U[-1, 1]
inp_data[i, :] = np.random.rand(2) * 2 - 1
label[i] = 0 if np.prod(inp_data[i]) < 0 else 1
yield inp_data.astype(np.float32), label.astype(np.int32)
def calculate_precision(data: np.ndarray, pred: np.ndarray) -> float:
""" Calculate precision for given data and prediction.
:type data: [[x, y], ...]
:param data: Input data
:type pred: [[x_pred, y_pred], ...]
:param pred: Network output data
"""
correct = 0
assert len(data) == len(pred)
for inp_data, pred_output in zip(data, pred):
label = 0 if np.prod(inp_data) < 0 else 1
pred_label = np.argmax(pred_output)
if pred_label == label:
correct += 1
return float(correct) / len(data)
class XORNet(Module):
def __init__(self):
self.mid_layers = 14
self.num_class = 2
super().__init__()
self.fc0 = Linear(self.num_class, self.mid_layers, bias=True)
self.fc1 = Linear(self.mid_layers, self.mid_layers, bias=True)
self.fc2 = Linear(self.mid_layers, self.num_class, bias=True)
def forward(self, x):
x = self.fc0(x)
x = F.tanh(x)
x = self.fc1(x)
x = F.tanh(x)
x = self.fc2(x)
return x
def test_training_converge():
net = XORNet()
opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
gm = ad.GradManager().attach(net.parameters())
@trace(symbolic=False)
def train(data, label):
with gm:
pred = net(data)
loss = F.nn.cross_entropy(pred, label)
gm.backward(loss)
optim.clip_grad_norm(net.parameters(), max_norm=0.2, ord=2.0)
return loss
def infer(data):
return net(data)
train_dataset = minibatch_generator()
losses = []
for data, label in itertools.islice(train_dataset, 2000):
data = Tensor(data, dtype=np.float32)
label = Tensor(label, dtype=np.int32)
opt.clear_grad()
loss = train(data, label)
optim.clip_grad_value(net.parameters(), lower=-0.1, upper=0.1)
opt.step()
losses.append(loss.numpy())
print(np.mean(losses[-100:]))
assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"
ngrid = 10
x = np.linspace(-1.0, 1.0, ngrid)
xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)
pred = infer(data).numpy()
precision = calculate_precision(data, pred)
print("precision=", precision)
assert precision == 1.0, "Test precision must be high enough, get {}".format(
precision
)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import platform
import weakref
import numpy as np
import pytest
import megengine as mge
import megengine.autodiff as ad
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
class Net(M.Module):
def __init__(self):
super().__init__()
self.conv1 = M.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = M.BatchNorm2d(64)
self.avgpool = M.AvgPool2d(kernel_size=5, stride=5, padding=0)
self.fc = M.Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.avgpool(x)
x = F.avg_pool2d(x, 22)
x = F.flatten(x, 1)
x = self.fc(x)
return x
def save_grad_value(net):
for param in net.parameters():
param.grad_backup = param.grad.numpy().copy()
def test_clip_grad_norm():
net = Net()
x = mge.tensor(np.random.randn(10, 3, 224, 224))
gm = ad.GradManager().attach(net.parameters())
opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
with gm:
loss = net(x).sum()
gm.backward(loss)
save_grad_value(net)
max_norm = 1.0
original_norm = optim.clip_grad_norm(net.parameters(), max_norm=max_norm, ord=2)
scale = max_norm / original_norm
for param in net.parameters():
np.testing.assert_almost_equal(param.grad.numpy(), param.grad_backup * scale)
opt.step().clear_grad()
def test_clip_grad_value():
net = Net()
x = np.random.randn(10, 3, 224, 224).astype("float32")
gm = ad.GradManager().attach(net.parameters())
opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
with gm:
y = net(x)
y = y.mean()
gm.backward(y)
save_grad_value(net)
max_val = 5
min_val = -2
optim.clip_grad_value(net.parameters(), lower=min_val, upper=max_val)
for param in net.parameters():
np.testing.assert_almost_equal(
param.grad.numpy(),
np.maximum(np.minimum(param.grad_backup, max_val), min_val),
)
opt.step().clear_grad()
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import platform
import weakref
import numpy as np
import pytest
import torch
import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
def make_fake_params():
shapes = [(1,), (3, 3), (5, 5, 5), (6, 7, 8, 9)]
params = [np.random.randn(*shape).astype("float32") for shape in shapes]
params_mge = []
params_torch = []
for param in params:
t = torch.ones(param.shape)
t.grad = torch.Tensor(param.copy())
params_torch.append(t)
t = mge.functional.ones(param.shape)
t.grad = mge.tensor(param.copy())
params_mge.append(t)
return params_mge, params_torch
def test_clip_grad_norm_torch():
max_norm = 1.0
params_mge, params_torch = make_fake_params()
norm_torch = torch.nn.utils.clip_grad_norm_(params_torch, max_norm, norm_type=2.0)
norm_mge = optim.clip_grad_norm(params_mge, max_norm=max_norm, ord=2.0)
np.testing.assert_allclose(norm_mge.numpy(), norm_torch.numpy(), atol=1e-4)
for i in range(len(params_mge)):
np.testing.assert_allclose(
params_mge[i].grad.numpy(), params_torch[i].grad.numpy(), atol=1e-7
)
def test_clip_grad_value_torch():
max_val = 0.5
min_val = -0.5
params_mge, params_torch = make_fake_params()
torch.nn.utils.clip_grad_value_(params_torch, clip_value=max_val)
optim.clip_grad_value(params_mge, lower=min_val, upper=max_val)
for i in range(len(params_mge)):
np.testing.assert_allclose(
params_mge[i].grad.numpy(), params_torch[i].grad.numpy(), atol=1e-7
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册