From e262125d19cb327a9068b6ca08998e34adbb4c5c Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Sun, 26 Sep 2021 13:04:05 +0800 Subject: [PATCH] [cherry pick]split minimize and add unscale_ for GradScaler (#35927) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1、Split function GradScaler::minimize() to GradScaler::step() + GradScaler::update() 2、Add GradScaler::unscale_(optimizer) --- python/paddle/amp/grad_scaler.py | 99 +++++++++++++-- .../fleet/meta_parallel/pipeline_parallel.py | 1 + .../paddle/fluid/dygraph/amp/loss_scaler.py | 114 ++++++++++++------ .../tests/unittests/hybrid_parallel_mp_amp.py | 1 + .../test_imperative_auto_mixed_precision.py | 90 +++++++++++++- 5 files changed, 260 insertions(+), 45 deletions(-) diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index 5c3b575f2f..83f57fc74e 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -13,18 +13,28 @@ # limitations under the License. from paddle.fluid.dygraph.amp import AmpScaler +from paddle.fluid.dygraph.amp import OptimizerState +from collections import defaultdict __all__ = [] +def _refresh_optimizer_state(): + return {"state": OptimizerState.INIT} + + class GradScaler(AmpScaler): """ GradScaler is used for Auto-Mixed-Precision training in dynamic graph mode. It controls the scaling of loss, helps avoiding numerical overflow. - The object of this class has two methods `scale()`, `minimize()`. + The object of this class has nineteen methods `scale()`, `unscale_()`, `minimize()`, `step()`, `update()` and `get`/`set` api of parameters. `scale()` is used to multiply the loss by a scale ratio. - `minimize()` is similar as `optimizer.minimize()`, performs parameters updating. + `unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio) + `minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling, it equal to `step()` + `update()`. + `step()` is similar as `optimizer.step()`, which performs parameters updating. + `update` is used to update the loss_scaling. + Commonly, it is used together with `paddle.amp.auto_cast` to achieve Auto-Mixed-Precision in dynamic graph mode. @@ -115,7 +125,7 @@ class GradScaler(AmpScaler): This function is similar as `optimizer.minimize()`, which performs parameters updating. If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped. - Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters. + Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters. Finally, the loss scaling ratio is updated. @@ -151,16 +161,18 @@ class GradScaler(AmpScaler): This function is similar as `optimizer.step()`, which performs parameters updating. If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped. - Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters. + Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters. Args: optimizer(Optimizer): The optimizer used to update parameters. Examples: + .. code-block:: python # required: gpu import paddle + model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) scaler = paddle.amp.GradScaler(init_loss_scaling=1024) @@ -170,14 +182,21 @@ class GradScaler(AmpScaler): loss = paddle.mean(conv) scaled = scaler.scale(loss) # scale the loss scaled.backward() # do backward - scaler.step(optimizer) + scaler.step(optimizer) # update parameters + scaler.update() # update the loss scaling ratio optimizer.clear_grad() """ if not self._enable: return optimizer.step() + optimizer_state = self._optimizer_states[id(optimizer)] + if optimizer_state["state"] is OptimizerState.STEPPED: + raise RuntimeError( + "step() has already been called since the last update().") + # unscale the grad - self._unscale(optimizer) + if optimizer_state["state"] is OptimizerState.INIT: + self._unscale(optimizer) if self._found_inf: self._cache_founf_inf = True @@ -185,9 +204,75 @@ class GradScaler(AmpScaler): optimizer.step() self._cache_founf_inf = False + optimizer_state["state"] = OptimizerState.STEPPED + + if not self._use_dynamic_loss_scaling: + self._optimizer_states = defaultdict(_refresh_optimizer_state) + + def update(self): + """ + Updates the loss_scaling. + + Examples: + + .. code-block:: python + + # required: gpu + import paddle + + model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) + optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + data = paddle.rand([10, 3, 32, 32]) + with paddle.amp.auto_cast(): + conv = model(data) + loss = paddle.mean(conv) + scaled = scaler.scale(loss) # scale the loss + scaled.backward() # do backward + scaler.step(optimizer) # update parameters + scaler.update() # update the loss scaling ratio + optimizer.clear_grad() + """ + if not self._enable: + return if self._use_dynamic_loss_scaling: - # uopdate the scale self._update() + self._optimizer_states = defaultdict(_refresh_optimizer_state) + return + + def unscale_(self, optimizer): + """ + Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio). + If this instance of :class:`GradScaler` is not enabled, output are returned unmodified. + + Args: + optimizer(Optimizer): The optimizer used to update parameters. + + Returns: + The unscaled parameters or original parameters. + + Examples: + + .. code-block:: python + + # required: gpu + import paddle + + model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) + optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + data = paddle.rand([10, 3, 32, 32]) + with paddle.amp.auto_cast(): + conv = model(data) + loss = paddle.mean(conv) + scaled = scaler.scale(loss) # scale the loss + scaled.backward() # do backward + scaler.unscale_(optimizer) # unscale the parameter + scaler.step(optimizer) + scaler.update() + optimizer.clear_grad() + """ + return super(GradScaler, self)._unscale(optimizer) def is_enable(self): """ diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 8fad0686dd..431bc6d7bc 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -329,6 +329,7 @@ class PipelineParallel(MetaParallelBase): def _optimizer_step(self): if self.scaler: self.scaler.step(self.optimizer) + self.scaler.update() else: self.optimizer.step() diff --git a/python/paddle/fluid/dygraph/amp/loss_scaler.py b/python/paddle/fluid/dygraph/amp/loss_scaler.py index 38881e43c0..432b178ea6 100644 --- a/python/paddle/fluid/dygraph/amp/loss_scaler.py +++ b/python/paddle/fluid/dygraph/amp/loss_scaler.py @@ -21,8 +21,20 @@ from ...wrapped_decorator import signature_safe_contextmanager, wrap_decorator import warnings import numpy as np from paddle import _C_ops +from collections import defaultdict +from enum import Enum -__all__ = ['AmpScaler'] +__all__ = ['AmpScaler', 'OptimizerState'] + + +class OptimizerState(Enum): + INIT = 0 + UNSCALED = 1 + STEPPED = 2 + + +def _refresh_optimizer_state(): + return {"state": OptimizerState.INIT} class AmpScaler(object): @@ -31,10 +43,11 @@ class AmpScaler(object): AmpScaler is used for Auto-Mixed-Precision training/inferring in imperative mode. It controls the scaling of loss, helps avoiding numerical overflow. - The object of this class has two methods `scale()`, `minimize()`. + The object of this class has seventeen methods `scale()`, `unscale_()`, `minimize()` and `get`/`set` api of parameters. `scale()` is used to multiply the loss by a scale ratio. - `minimize()` is similar as `Optimizer.minimize()`, performs parameters updating. + `unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio) + `minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling. Commonly, it is used together with `amp_guard` to achieve Auto-Mixed-Precision in imperative mode. @@ -117,6 +130,7 @@ class AmpScaler(object): self._scale = to_variable( np.array([self._init_loss_scaling]).astype(np.float32)) self._cache_founf_inf = None + self._optimizer_states = defaultdict(_refresh_optimizer_state) def scale(self, var): """ @@ -129,24 +143,25 @@ class AmpScaler(object): The scaled variable or original variable. Examples: + .. code-block:: python - import numpy as np - import paddle.fluid as fluid - - data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') - with fluid.dygraph.guard(): - model = fluid.dygraph.Conv2D(3, 2, 3) - optimizer = fluid.optimizer.SGDOptimizer( - learning_rate=0.01, parameter_list=model.parameters()) - scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024) - data = fluid.dygraph.to_variable(data) - with fluid.dygraph.amp_guard(): - conv = model(data) - loss = fluid.layers.reduce_mean(conv) - scaled = scaler.scale(loss) - scaled.backward() - scaler.minimize(optimizer, scaled) + import numpy as np + import paddle.fluid as fluid + + data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') + with fluid.dygraph.guard(): + model = fluid.dygraph.Conv2D(3, 2, 3) + optimizer = fluid.optimizer.SGDOptimizer( + learning_rate=0.01, parameter_list=model.parameters()) + scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024) + data = fluid.dygraph.to_variable(data) + with fluid.dygraph.amp_guard(): + conv = model(data) + loss = fluid.layers.reduce_mean(conv) + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(optimizer, scaled) """ check_type(var, "var", core.VarBase, 'AmpScaler.scale()') @@ -160,7 +175,7 @@ class AmpScaler(object): This function is similar as `Optimizer.minimize()`, which performs parameters updating. If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped. - Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters. + Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters. Finally, the loss scaling ratio is updated. @@ -170,30 +185,34 @@ class AmpScaler(object): kwargs: Keyword arguments, which will be forward to `Optimizer.minimize()`. Examples: + .. code-block:: python - import numpy as np - import paddle.fluid as fluid - - data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') - with fluid.dygraph.guard(): - model = fluid.dygraph.Conv2D(3, 2, 3) - optimizer = fluid.optimizer.SGDOptimizer( - learning_rate=0.01, parameter_list=model.parameters()) - scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024) - data = fluid.dygraph.to_variable(data) - with fluid.dygraph.amp_guard(): - conv = model(data) - loss = fluid.layers.reduce_mean(conv) - scaled = scaler.scale(loss) - scaled.backward() - scaler.minimize(optimizer, scaled) + import numpy as np + import paddle.fluid as fluid + + data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') + with fluid.dygraph.guard(): + model = fluid.dygraph.Conv2D(3, 2, 3) + optimizer = fluid.optimizer.SGDOptimizer( + learning_rate=0.01, parameter_list=model.parameters()) + scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024) + data = fluid.dygraph.to_variable(data) + with fluid.dygraph.amp_guard(): + conv = model(data) + loss = fluid.layers.reduce_mean(conv) + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(optimizer, scaled) """ if not self._enable: return optimizer.minimize(*args, **kwargs) + optimizer_state = self._optimizer_states[id(optimizer)] + # unscale the grad - self._unscale(optimizer) + if optimizer_state["state"] is OptimizerState.INIT: + self._unscale(optimizer) optimize_ops, params_grads = (None, None) @@ -207,12 +226,31 @@ class AmpScaler(object): # uopdate the scale self._update() + self._optimizer_states = defaultdict(_refresh_optimizer_state) + return optimize_ops, params_grads def _unscale(self, optimizer): + """ + Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio). + If this instance of :class:`GradScaler` is not enabled, output are returned unmodified. + Args: + optimizer(Optimizer): The optimizer used to update parameters. + Returns: + The unscaled parameters or original parameters. + """ if not self._enable: return + optimizer_state = self._optimizer_states[id(optimizer)] + + if optimizer_state["state"] is OptimizerState.UNSCALED: + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update()." + ) + elif optimizer_state["state"] is OptimizerState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + if getattr(optimizer, '_param_groups', None) and isinstance( optimizer._param_groups[0], dict): param_grads = [] @@ -256,6 +294,8 @@ class AmpScaler(object): temp_found_inf_fp32) self._found_inf = temp_found_inf_fp16 or temp_found_inf_fp32 + optimizer_state["state"] = OptimizerState.UNSCALED + def _update(self): """ Updates the loss_scaling. diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py index 083ad31930..4c966585d5 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py @@ -48,6 +48,7 @@ class TestMPClipGrad(TestDistMPTraning): scaled.backward() # do backward scaler.step(optimizer) # update parameters + scaler.update() optimizer.clear_grad() return scaled diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index ed98195363..5f1f4a4641 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -222,6 +222,47 @@ class TestAmpScaler(unittest.TestCase): np.allclose(outs_with_scaler[1][i][0].numpy(), outs_no_scaler[1][i][0].numpy()), True) + def test_step(self): + inp_np = np.random.random(size=[1, 3, 128, 128]).astype(np.float32) + + def run_simple_conv(inp_np, use_scaler=True): + paddle.seed(10) + paddle.framework.random._manual_program_seed(10) + with fluid.dygraph.guard(): + model = SimpleConv( + num_channels=3, + num_filters=64, + filter_size=7, + stride=2, + act='relu') + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + parameters=model.parameters()) + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + data = fluid.dygraph.to_variable(inp_np) + + out = model(data) + loss = fluid.layers.mean(out) + if use_scaler: + print('use scaler') + scaled_loss = scaler.scale(loss) + scaled_loss.backward() + scaler.step(optimizer) + scaler.update() + else: + print('use no scaler') + loss.backward() + optimizer.step() + return optimizer._parameter_list + + outs_with_scaler = run_simple_conv(inp_np, use_scaler=True) + outs_no_scaler = run_simple_conv(inp_np, use_scaler=False) + + for i in range(len(outs_with_scaler)): + # check each parameter + self.assertEqual( + np.allclose(outs_with_scaler[i].numpy(), + outs_no_scaler[i].numpy()), True) + def test_nan_inf(self): inp_np = np.random.random(size=[1, 3, 128, 128]).astype(np.float32) inp_np[0][1][2][3] = np.nan @@ -252,6 +293,52 @@ class TestAmpScaler(unittest.TestCase): self.assertTrue( np.array_equal(param.numpy(), params_init[param.name])) + def test_step_update_exception(self): + def func1(): + model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + parameters=model.parameters()) + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + data = paddle.rand([10, 3, 32, 32]) + conv = model(data) + loss = paddle.mean(conv) + scaled = scaler.scale(loss) + scaled.backward() + scaler.unscale_(optimizer) + scaler.unscale_(optimizer) + + self.assertRaises(RuntimeError, func1) + + def func2(): + model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + parameters=model.parameters()) + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + data = paddle.rand([10, 3, 32, 32]) + conv = model(data) + loss = paddle.mean(conv) + scaled = scaler.scale(loss) + scaled.backward() + scaler.step(optimizer) + scaler.unscale_(optimizer) + + self.assertRaises(RuntimeError, func2) + + def func3(): + model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + parameters=model.parameters()) + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + data = paddle.rand([10, 3, 32, 32]) + conv = model(data) + loss = paddle.mean(conv) + scaled = scaler.scale(loss) + scaled.backward() + scaler.step(optimizer) + scaler.step(optimizer) + + self.assertRaises(RuntimeError, func3) + def test_get_and_set(self): with fluid.dygraph.guard(): scaler = paddle.amp.GradScaler( @@ -838,8 +925,9 @@ class TestResnet2(unittest.TestCase): scaled_loss = scaler.scale(avg_loss) scaled_loss.backward() - + scaler.unscale_(optimizer) scaler.step(optimizer) + scaler.update() dy_grad_value = {} for param in resnet.parameters(): -- GitLab