提交 5cef74a7 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mge/amp): add GradScaler support

GitOrigin-RevId: 0ab4910360757d783f132be7041297c846cf513f
上级 1bf18252
......@@ -5,10 +5,10 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import mprop
from ..core.tensor.amp import *
from .autocast import autocast
from .grad_scaler import GradScaler
mprop.init()
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable, List, Union
import numpy as np
from ..autodiff import GradManager
from ..functional import full_like
from ..functional.math import _has_inf
from ..tensor import Tensor
class GradScaler:
r"""
A helper class that performs grad scaling to prevent from data overflow in
:class:`~.autocast` mode.
:param init_scale: Initial scale factor.
:param growth_factor: Factor that the scale is multiplied by in actual
:meth:`update` stage. If growth_factor is 0, scale_factor will not update.
:param backoff_factor: Factor that the scale is multiplied by when encountering
overflow grad.
:param growth_interval: The interval between two scale update stages.
Example::
gm = GradManager()
opt = ...
scaler = GradScaler()
gm.attach(model.parameters())
@autocast()
def train_step(image, label):
with gm:
logits = model(image)
loss = F.nn.cross_entropy(logits, label)
scaler.backward(gm, loss)
opt.step().clear_grad()
return loss
If need more flexible usage, could split ``scaler.backward`` into three lines:
.. code-block::
@autocast()
def train_step(image, label):
with gm:
logits = model(image)
loss = F.nn.cross_entropy(logits, label)
gm.backward(loss, dy=megengine.tensor(scaler.scale_factor))
scaler.unscale(gm.attached_tensors())
scaler.update()
opt.step().clear_grad()
return loss
This is useful when need to accumulate grads for multi batches.
"""
def __init__(
self,
init_scale: float = 2.0 ** 4,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
):
self.scale_factor = float(init_scale)
self.growth_factor = float(growth_factor)
self.backoff_factor = float(backoff_factor)
self.growth_interval = growth_interval
self._growth_tracker = 0
self._found_inf = False
def backward(
self,
gm: GradManager,
y: Union[Tensor, List[Tensor]] = None,
dy: Union[Tensor, List[Tensor]] = None,
*,
unscale_grad: bool = True,
update_scale: bool = "if_unscale_grad"
):
r"""
A wrapper of GradManager's :meth:`~.GradManager.backward`, used to scale
``y``'s grad and unscale parameters' grads.
:param gm: The to be wrapped GradManager.
:param y: Same as GradManager backward's ``y``.
:param dy: Same as GradManager backward's ``dy``. Will be multiplied
by ``scale_factor``.
:param unscale_grad: Whether do :meth:`unscale` at the same time. Could be
``False`` if needs to accumulate grads.
:param update_scale: Same as :meth:`unscale`'s ``update``. Will be ignored
if ``unscale_grad`` is ``False``.
"""
# These checks should be consistent with GradManager's
if y is None:
ys = []
elif isinstance(y, (tuple, list)):
ys = y
else:
ys = [y]
if dy is None:
dys = [full_like(y, self.scale_factor) for y in ys]
elif isinstance(dy, (tuple, list)):
dys = [dy_ * self.scale_factor for dy_ in dy]
else:
dys = [dy * self.scale_factor]
gm.backward(y=ys, dy=dys)
if unscale_grad:
self.unscale(gm.attached_tensors())
if update_scale:
self.update()
def unscale(self, grad_tensors: Iterable[Tensor]):
r"""
Unscale all ``grad_tensors``'s grad.
:param grad_tensors: Tensors needed to unscale grads. Should be all tensors
that are affected by ``target`` tensor in GradManager's backward.
"""
# use float64 for better precision
inv_scale = Tensor(1.0 / self.scale_factor)
for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None:
continue
# to support tracing, _check_gradients should be applied to every grad.
if self._check_gradients(tensor.grad):
self._found_inf = True
tensor.grad *= inv_scale
if self._found_inf:
for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None:
continue
tensor.grad = None
return self
def _check_gradients(self, grad):
if self.growth_interval == 0:
return False
return _has_inf(grad)
def update(self, new_scale: float = None):
r"""Update the scale factor according to whether encountered overflow grad.
If ``new_scale`` is provided, internal update mechanism will be ignored."""
if self.growth_interval == 0:
return
if new_scale is not None:
self.scale_factor = float(new_scale)
else:
if self._found_inf:
self.scale_factor *= self.backoff_factor
self._growth_tracker = 0
else:
self._growth_tracker += 1
if self._growth_tracker >= self.growth_interval:
self.scale_factor *= self.growth_factor
self._growth_tracker = 0
self._found_inf = False
def state_dict(self):
return {
"scale_factor": self.scale_factor,
"growth_factor": self.growth_factor,
"backoff_factor": self.backoff_factor,
"growth_interval": self.growth_interval,
"_growth_tracker": self._growth_tracker,
}
def load_state_dict(self, state):
self.scale_factor = state["scale_factor"]
self.growth_factor = state["growth_factor"]
self.backoff_factor = state["backoff_factor"]
self.growth_interval = state["growth_interval"]
self._growth_tracker = state["_growth_tracker"]
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import weakref
from typing import Callable, Iterable
from collections import OrderedDict
from typing import Callable, Iterable, List, Union
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core.autodiff.grad import Grad
......@@ -123,6 +131,10 @@ class GradManager:
self._gradients = {}
self._priority = None
def attached_tensors(self):
r"""Return attached tensor list from :meth:`attach`."""
return [spec.tensor() for spec in self._attach_specs.values()]
def attach(self, tensors: Iterable[Tensor], callbacks=None):
r"""
Instruct GradManager to track operations on tensors, so that gradients with respect
......@@ -210,13 +222,18 @@ class GradManager:
spec.callbacks.extend(callbacks)
if new_attach and self._recording:
self._do_record(spec)
return self
def _register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback)
return self
def backward(self, y=None, dy=None):
def backward(
self,
y: Union[Tensor, List[Tensor]] = None,
dy: Union[Tensor, List[Tensor]] = None,
):
r"""
Compute gradients (or vector-Jacobian product) for all attached tensors, accumulate to
corresponding .grad attribute, and release resources along the way.
......@@ -257,6 +274,7 @@ class GradManager:
"call a method that clears the history?"
)
assert self._grad is not None
# These checks should be consistent with GradScaler's
if y is None:
ys = []
elif isinstance(y, (tuple, list)):
......
......@@ -1019,7 +1019,7 @@ def batch_norm(
momentum: float = 0.9,
eps: float = 1e-5,
inplace: bool = True,
compute_mode="default",
compute_mode="default"
):
r"""
Applies batch normalization to the input.
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import megengine as mge
from megengine.amp import GradScaler
from megengine.autodiff import GradManager
def test_grad_scaler():
gm = GradManager()
scaler = GradScaler()
x = mge.tensor(1.0)
for _ in range(3):
with gm:
y = x + 1
gm.attach(y)
loss = y + 1
scaler.backward(gm, loss, unscale_grad=False)
np.testing.assert_equal(y.grad.numpy(), scaler.scale_factor)
scaler.unscale(gm.attached_tensors())
np.testing.assert_equal(y.grad.numpy(), 1)
# test handle None elements
scaler.unscale(gm.attached_tensors())
......@@ -49,6 +49,32 @@ def test_basic():
np.testing.assert_equal(b.grad.numpy(), [1])
def test_dy():
x = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3)
w = mge.tensor([2.0, 4.0, 6.0]).reshape(3, 1)
b = mge.tensor(-1.0)
gm = GradManager().attach([w, b])
def get_grad(grad, dy, idx):
if isinstance(dy, (list, tuple)):
return np.array(grad) * dy[idx]
else:
return np.array(grad) * dy
# dy's shape should be the same as y's
dy = mge.tensor(2.5).reshape(1, 1)
w.grad = None
b.grad = None
with gm:
p = F.matmul(x, w)
y = p + b
gm.backward(y, dy=dy)
np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]] * dy.numpy())
np.testing.assert_equal(b.grad.numpy(), [1] * dy.numpy())
def test_attach_in_with_block():
a = mge.Parameter([1.0])
gm = GradManager()
......@@ -93,6 +119,25 @@ def test_attach_temporary():
# gm.backward(y)
def test_attached_tensors():
w1 = mge.Parameter(2.0)
w2 = mge.Parameter(2.0)
gm = GradManager()
def check(expected):
actual = gm.attached_tensors()
assert len(expected) == len(actual)
for exp, act in zip(expected, actual):
assert exp is act
gm.attach(w1)
check([w1])
gm.attach(w2)
check([w1, w2])
gm.attach(w1)
check([w1, w2])
def test_no_dependency():
x = mge.tensor(3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册