From 9f7c66b43987ffcc13fbe42b096d9148572ce165 Mon Sep 17 00:00:00 2001 From: huangxu96 <46740794+huangxu96@users.noreply.github.com> Date: Fri, 8 Jan 2021 17:58:54 +0800 Subject: [PATCH] [Cherry-pick] amp related PR cherry pick into Release/2.0 (#30212) * Optimizer trans momentum (#29597) * merge amp related function in Momentum from paddle.fluid.contrib.optimizer into paddle.optimizer. * Add unittest for 2.0 Momentum API. * fix some bugs in weight_decay. * add alias for fluid.contrib.mixed_precision (#29562) * add alias for fluid.contrib.mixed_precision * add static.amp into setup.pu.in (#29621) * add static.amp into setup.pu.in * add unittest for api * fix a bug in multi_precision_fp16 unittest. (#29756) --- .../fluid/contrib/mixed_precision/__init__.py | 7 +- .../contrib/mixed_precision/fp16_utils.py | 3 + .../tests/test_image_classification_fp16.py | 5 +- .../tests/test_multi_precision_fp16_train.py | 13 ++- python/paddle/optimizer/momentum.py | 105 ++++++++++++++++-- python/paddle/static/__init__.py | 1 + python/paddle/static/amp/__init__.py | 18 +++ python/setup.py.in | 1 + 8 files changed, 137 insertions(+), 16 deletions(-) create mode 100644 python/paddle/static/amp/__init__.py diff --git a/python/paddle/fluid/contrib/mixed_precision/__init__.py b/python/paddle/fluid/contrib/mixed_precision/__init__.py index c6296bcac93..a580ae5574c 100644 --- a/python/paddle/fluid/contrib/mixed_precision/__init__.py +++ b/python/paddle/fluid/contrib/mixed_precision/__init__.py @@ -13,9 +13,14 @@ # limitations under the License. from __future__ import print_function + from . import decorator from .decorator import * -from .fp16_lists import AutoMixedPrecisionLists +from . import fp16_lists +from .fp16_lists import * +from . import fp16_utils +from .fp16_utils import * __all__ = decorator.__all__ __all__ += fp16_lists.__all__ +__all__ += fp16_utils.__all__ diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 2f2f476a875..c9a070a03a4 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -20,6 +20,9 @@ from ... import global_scope from ...log_helper import get_logger import logging import numpy as np + +__all__ = ["cast_model_to_fp16", "cast_parameters_to_fp16"] + _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') diff --git a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py index b29cd265bd6..0280dfcf67b 100644 --- a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py +++ b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py @@ -24,6 +24,7 @@ import unittest import os import copy import numpy as np +from paddle.static.amp import decorate paddle.enable_static() @@ -138,7 +139,7 @@ def train(net_type, use_cuda, save_dirname, is_local): amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( custom_black_varnames={"loss", "conv2d_0.w_0"}) - mp_optimizer = fluid.contrib.mixed_precision.decorate( + mp_optimizer = decorate( optimizer=optimizer, amp_lists=amp_lists, init_loss_scaling=8.0, @@ -442,7 +443,7 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase): optimizer = fluid.optimizer.Lamb(learning_rate=0.001) amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( custom_black_varnames={"loss", "conv2d_0.w_0"}) - mp_optimizer = fluid.contrib.mixed_precision.decorate( + mp_optimizer = decorate( optimizer=optimizer, amp_lists=amp_lists, init_loss_scaling=8.0, diff --git a/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py b/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py index 64ef2e26bbd..3526a3d761c 100644 --- a/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py +++ b/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py @@ -19,8 +19,8 @@ import paddle.fluid as fluid import contextlib import unittest import numpy as np -from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16 -from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16 +from paddle.static.amp import cast_model_to_fp16 +from paddle.static.amp import cast_parameters_to_fp16 paddle.enable_static() @@ -122,11 +122,11 @@ def train(use_pure_fp16=True, use_nesterov=False): # Test program test_program = train_program.clone(for_test=True) - optimizer = fluid.contrib.optimizer.Momentum( + optimizer = paddle.optimizer.Momentum( learning_rate=0.001, momentum=0.9, use_nesterov=use_nesterov, - regularization=fluid.regularizer.L2Decay(1e-4), + weight_decay=fluid.regularizer.L2Decay(1e-4), multi_precision=use_pure_fp16, rescale_grad=1.0 / BATCH_SIZE) @@ -155,9 +155,10 @@ def train(use_pure_fp16=True, use_nesterov=False): loss, = exe.run(compiled_program, feed=feeder.feed(data), fetch_list=[sum_cost]) + loss_v = loss[0] if isinstance(loss, np.ndarray) else loss print('PassID {0:1}, Train Batch ID {1:04}, train loss {2:2.4}'. - format(pass_id, batch_id + 1, float(loss))) - train_loss_list.append(float(loss)) + format(pass_id, batch_id + 1, float(loss_v))) + train_loss_list.append(float(loss_v)) if batch_id >= 4: # For speeding up CI test_loss_list = [] diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index 601fdce7a34..9f367f1c9fc 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -17,8 +17,10 @@ from ..fluid import core from ..fluid import framework from ..fluid.framework import Variable, name_scope from ..fluid.layer_helper import LayerHelper +from ..fluid import unique_name +from ..fluid import layers import paddle.fluid as fluid - +from paddle.fluid.regularizer import L2DecayRegularizer __all__ = ["Momentum"] @@ -62,6 +64,9 @@ class Momentum(Optimizer): some derived class of ``GradientClipBase`` . There are three cliping strategies ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. + multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false. + rescale_grad (float, optional): Multiply the gradient with `rescale_grad` before updating. \ + Often choose to be ``1.0/batch_size``. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . @@ -93,20 +98,33 @@ class Momentum(Optimizer): use_nesterov=False, weight_decay=None, grad_clip=None, + multi_precision=False, + rescale_grad=1.0, name=None): if learning_rate is None: raise ValueError("learning_rate is not set") if momentum is None: raise ValueError("momentum is not set") + predicate = lambda regular: isinstance(regular, L2DecayRegularizer) + py_regular = None if predicate(weight_decay) else weight_decay super(Momentum, self).__init__( learning_rate=learning_rate, parameters=parameters, - weight_decay=weight_decay, + weight_decay=py_regular, grad_clip=grad_clip, name=name) self.type = "momentum" self._momentum = momentum self._use_nesterov = bool(use_nesterov) + self._regularization_method = "" + self._regularization_coeff = 0 + if (isinstance(weight_decay, L2DecayRegularizer)): + self._regularization_method = "l2_decay" + self._regularization_coeff = weight_decay._regularization_coeff + self._multi_precision = multi_precision + self._rescale_grad = rescale_grad + self._master_weights = {} + if framework.in_dygraph_mode(): self.helper = LayerHelper(self.__class__.__name__) for p in parameters: @@ -116,8 +134,62 @@ class Momentum(Optimizer): ).all_parameters() self.helper = LayerHelper(self.__class__.__name__) for p in all_parameters: + if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + master_p = self._create_master_weight(p) + self._add_accumulator(self._velocity_acc_str, master_p) + continue + if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision: + warnings.warn( + "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Consider using multi_precision=True option of the Momentum optimizer." + ) self._add_accumulator(self._velocity_acc_str, p) + def _create_master_weight(self, param): + assert isinstance(self.helper, LayerHelper) + + var_name = param.name + "_fp32_master" + var_name = unique_name.generate(var_name) + var = layers.create_global_var( + name=var_name, + shape=param.shape, + value=0, + dtype='float32', + persistable=True) + block = self.helper.startup_program.global_block() + block.append_op( + type="cast", + inputs={"X": [param]}, + outputs={"Out": [var]}, + attrs={ + "in_dtype": param.dtype, + "out_dtype": core.VarDesc.VarType.FP32 + }) + self._master_weights[param.name] = var + return var + + def _get_accumulator(self, name, param): + """Utility function to fetch an accumulator for a parameter + + Args: + name: name of the accumulator + param: parameter variable for which accumulator is to be fetched + + Returns: + accumulator variable for the parameter + """ + if self._name is not None: + name = self._name + "_" + name + find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 + target_param = self._master_weights[ + param.name] if find_master else param + target_name = target_param.name + if (name not in self._accumulators or + target_name not in self._accumulators[name]): + raise Exception("Accumulator {} does not exist for parameter {}". + format(name, target_name)) + return self._accumulators[name][target_name] + def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) # create accumulator in init func, so no implementation here @@ -127,16 +199,30 @@ class Momentum(Optimizer): velocity_acc = self._get_accumulator(self._velocity_acc_str, param_and_grad[0]) + find_master = self._multi_precision and param_and_grad[ + 0].dtype == core.VarDesc.VarType.FP16 + master_weight = (self._master_weights[param_and_grad[0].name] + if find_master else None) lr = self._create_param_lr(param_and_grad) if framework.in_dygraph_mode(): - _, _ = core.ops.momentum(param_and_grad[0], param_and_grad[1], - velocity_acc, lr, param_and_grad[0], - velocity_acc, 'mu', self._momentum, - 'use_nesterov', self._use_nesterov) + _, _ = core.ops.momentum( + param_and_grad[0], param_and_grad[1], velocity_acc, lr, + param_and_grad[0], velocity_acc, 'mu', self._momentum, + 'use_nesterov', self._use_nesterov, 'regularization_method', + self._regularization_method, 'regularization_coeff', + self._regularization_coeff) return None - attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov} + attrs = { + "mu": self._momentum, + "use_nesterov": self._use_nesterov, + "regularization_method": self._regularization_method, + "regularization_coeff": self._regularization_coeff, + "multi_precision": find_master, + "rescale_grad": self._rescale_grad + } + inputs = { "Param": [param_and_grad[0]], "Grad": [param_and_grad[1]], @@ -148,6 +234,11 @@ class Momentum(Optimizer): "ParamOut": [param_and_grad[0]], "VelocityOut": [velocity_acc] } + + if find_master: + inputs["MasterParam"] = master_weight + outputs["MasterParamOut"] = master_weight + # create the momentum optimize op momentum_op = block.append_op( type=self.type, diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index e67676708bc..7a6a064787b 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -24,6 +24,7 @@ __all__ = [ ] from . import nn +from . import amp from .io import save_inference_model #DEFINE_ALIAS from .io import load_inference_model #DEFINE_ALIAS from .io import deserialize_persistables #DEFINE_ALIAS diff --git a/python/paddle/static/amp/__init__.py b/python/paddle/static/amp/__init__.py new file mode 100644 index 00000000000..604c7c3d2b4 --- /dev/null +++ b/python/paddle/static/amp/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...fluid.contrib import mixed_precision +from ...fluid.contrib.mixed_precision import * + +__all__ = mixed_precision.__all__ diff --git a/python/setup.py.in b/python/setup.py.in index 428b0a057bb..e3517adc194 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -211,6 +211,7 @@ packages=['paddle', 'paddle.metric', 'paddle.static', 'paddle.static.nn', + 'paddle.static.amp', 'paddle.tensor', 'paddle.onnx', ] -- GitLab