From 6c2bc29cc06b74153d0c5e3af43e7a011a27df71 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Tue, 10 Sep 2019 10:50:07 +0800 Subject: [PATCH] Fix float16 optimizer. (#19682) Fix float16 optimizer --- paddle/fluid/API.spec | 2 +- .../contrib/mixed_precision/decorator.py | 22 ++++++++++--- .../tests/test_image_classification_fp16.py | 3 +- .../fluid/incubate/fleet/base/fleet_base.py | 4 ++- .../incubate/fleet/collective/__init__.py | 5 ++- python/paddle/fluid/optimizer.py | 32 +++++++++++-------- 6 files changed, 47 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index d626c8c678f..f31e8a10642 100755 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -505,7 +505,7 @@ paddle.fluid.contrib.HDFSClient.upload (ArgSpec(args=['self', 'hdfs_path', 'loca paddle.fluid.contrib.multi_download (ArgSpec(args=['client', 'hdfs_path', 'local_path', 'trainer_id', 'trainers', 'multi_processes'], varargs=None, keywords=None, defaults=(5,)), ('document', '100927be598ed8f9eaa1f3ef1b23568a')) paddle.fluid.contrib.multi_upload (ArgSpec(args=['client', 'hdfs_path', 'local_path', 'multi_processes', 'overwrite', 'sync'], varargs=None, keywords=None, defaults=(5, False, True)), ('document', '183f34c83d30dbe16e09e8716c41958a')) paddle.fluid.contrib.extend_with_decoupled_weight_decay (ArgSpec(args=['base_optimizer'], varargs=None, keywords=None, defaults=None), ('document', 'a1095dfd4ec725747f662d69cd7659d4')) -paddle.fluid.contrib.mixed_precision.decorate (ArgSpec(args=['optimizer', 'amp_lists', 'init_loss_scaling', 'incr_every_n_steps', 'decr_every_n_nan_or_inf', 'incr_ratio', 'decr_ratio', 'use_dynamic_loss_scaling'], varargs=None, keywords=None, defaults=(None, 1.0, 1000, 2, 2.0, 0.8, False)), ('document', 'd05e71f5b0bd6d92bb94e70e00b3f9cf')) +paddle.fluid.contrib.mixed_precision.decorate (ArgSpec(args=['optimizer', 'amp_lists', 'init_loss_scaling', 'incr_every_n_steps', 'decr_every_n_nan_or_inf', 'incr_ratio', 'decr_ratio', 'use_dynamic_loss_scaling'], varargs=None, keywords=None, defaults=(None, 1.0, 1000, 2, 2.0, 0.8, False)), ('document', '5f118631fc8632afb981b3a26daae731')) paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists ('paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists', ('document', 'c116ec6bb5d30998792daea8db21ee40')) paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists.__init__ (ArgSpec(args=['self', 'custom_white_list', 'custom_black_list'], varargs=None, keywords=None, defaults=(None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.contrib.fused_elemwise_activation (ArgSpec(args=['x', 'y', 'functor_list', 'axis', 'scale', 'save_intermediate_out'], varargs=None, keywords=None, defaults=(-1, 0.0, True)), ('document', '1c4b247a2858cea8d9d8750693688270')) diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index d7e028a8675..83a75699836 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -172,21 +172,34 @@ class OptimizerWithMixedPrecison(object): return optimize_ops - def minimize(self, loss): + def minimize(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): """ Perform optimization by minimizing the given loss. Args: loss (Variable): The loss Variable. + startup_program (Program): startup_program for initializing parameters + in `parameter_list`. + parameter_list (list): list of Variables to update. + no_grad_set (set|None): set of Variables should be ignored. Returns: The scaled loss by scaling factor, the list of optimize ops, and a list of scaled parameters and gradients. """ - scaled_params_grads, scaled_loss = self.backward(loss) + scaled_params_grads, scaled_loss = self.backward( + loss, + startup_program=startup_program, + parameter_list=parameter_list, + no_grad_set=no_grad_set) + optimize_ops = self.apply_gradients(scaled_params_grads) - return scaled_loss, optimize_ops, scaled_params_grads + return optimize_ops, scaled_params_grads def decorate(optimizer, @@ -228,7 +241,8 @@ def decorate(optimizer, mp_optimizer = fluid.contrib.mixed_precision.decorate( optimizer=optimizer, init_loss_scaling=8.0) - scaled_loss, _, _ = mp_optimizer.minimize(loss) + ops, param_grads = mp_optimizer.minimize(loss) + scaled_loss = mp_optimizer.get_loss_scaling() """ if amp_lists is None: amp_lists = AutoMixedPrecisionLists() diff --git a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py index 47beb9e21cb..982e380c7e7 100644 --- a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py +++ b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py @@ -138,7 +138,8 @@ def train(net_type, use_cuda, save_dirname, is_local): init_loss_scaling=8.0, use_dynamic_loss_scaling=True) - scaled_loss, _, _ = mp_optimizer.minimize(avg_cost) + mp_optimizer.minimize(avg_cost) + scaled_loss = mp_optimizer.get_loss_scaling() BATCH_SIZE = 128 PASS_NUM = 1 diff --git a/python/paddle/fluid/incubate/fleet/base/fleet_base.py b/python/paddle/fluid/incubate/fleet/base/fleet_base.py index 4c98d9e0e6a..8e7cee1fb69 100644 --- a/python/paddle/fluid/incubate/fleet/base/fleet_base.py +++ b/python/paddle/fluid/incubate/fleet/base/fleet_base.py @@ -23,6 +23,7 @@ from paddle.fluid.optimizer import SGD from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker from paddle.fluid.incubate.fleet.base.role_maker import RoleMakerBase from paddle.fluid.incubate.fleet.base.role_maker import UserDefinedRoleMaker +from paddle.fluid.contrib.mixed_precision.decorator import OptimizerWithMixedPrecison class Mode: @@ -257,7 +258,8 @@ class DistributedOptimizer(object): __metaclass__ = abc.ABCMeta def __init__(self, optimizer, strategy=None): - if not isinstance(optimizer, SGD.__bases__): + if not isinstance(optimizer, SGD.__bases__) \ + and not isinstance(optimizer, OptimizerWithMixedPrecison): raise TypeError("optimizer must be an instance of Optimizer") self._optimizer = optimizer diff --git a/python/paddle/fluid/incubate/fleet/collective/__init__.py b/python/paddle/fluid/incubate/fleet/collective/__init__.py index 6a0984240bb..26b2dbeb29c 100644 --- a/python/paddle/fluid/incubate/fleet/collective/__init__.py +++ b/python/paddle/fluid/incubate/fleet/collective/__init__.py @@ -347,7 +347,10 @@ class CollectiveOptimizer(DistributedOptimizer): self._strategy) optimize_ops, param_grads = self._optimizer.minimize( - loss, startup_program, parameter_list, no_grad_set) + loss, + startup_program=startup_program, + parameter_list=parameter_list, + no_grad_set=no_grad_set) fleet._origin_program = main_program fleet.main_program = self._try_to_compile(startup_program, main_program) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 03578604ad6..f7db8ce32b3 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -464,6 +464,8 @@ class Optimizer(object): Examples: See examples in `apply_gradients`. """ + no_grad_set = self._get_no_grad_set(loss, no_grad_set) + self._dtype = loss.dtype if framework.in_dygraph_mode(): if parameter_list is not None: @@ -563,6 +565,23 @@ class Optimizer(object): optimize_ops = self.apply_gradients(params_grads) return optimize_ops + def _get_no_grad_set(self, loss, no_grad_set=None): + if no_grad_set is None: + no_grad_set = set() + elif isinstance(no_grad_set, set) or isinstance( + no_grad_set, list) or isinstance(no_grad_set, tuple): + no_grad_set = set(no_grad_set) + else: + assert "no_grad_set should be a set, but the passed type is {}".format( + type(no_grad_set)) + parameters = loss.block.program.global_block().all_parameters() + param_no_trainable = set( + [param.name for param in parameters if param.trainable is False]) + # If the parameter is no trainable, it should not have a gradient. + no_grad_set.update(param_no_trainable) + + return no_grad_set + @imperative_base.no_grad def minimize(self, loss, @@ -589,19 +608,6 @@ class Optimizer(object): and list of (param, grad) Variables pair for optimization. """ assert isinstance(loss, Variable), "The loss should be an Variable." - if no_grad_set is None: - no_grad_set = set() - elif isinstance(no_grad_set, set) or isinstance( - no_grad_set, list) or isinstance(no_grad_set, tuple): - no_grad_set = set(no_grad_set) - else: - assert "no_grad_set should be a set, but the passed type is {}".format( - type(no_grad_set)) - parameters = loss.block.program.global_block().all_parameters() - param_no_trainable = set( - [param.name for param in parameters if param.trainable is False]) - # If the parameter is no trainable, it should not have a gradient. - no_grad_set.update(param_no_trainable) params_grads = self.backward( loss, startup_program=startup_program, -- GitLab