From af2c31a65badb456b478e0edcd9ded77d4faf2f5 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Fri, 10 Mar 2023 11:19:14 +0800 Subject: [PATCH] Delete duplicate code in optimizer.py and support master_param for bf16 in optimzer (#51367) --- python/paddle/fluid/optimizer.py | 468 +++++++-------------------- python/paddle/optimizer/adadelta.py | 75 +---- python/paddle/optimizer/adagrad.py | 76 +---- python/paddle/optimizer/adam.py | 75 +---- python/paddle/optimizer/adamax.py | 87 +---- python/paddle/optimizer/adamw.py | 67 +--- python/paddle/optimizer/lamb.py | 78 +---- python/paddle/optimizer/momentum.py | 88 +---- python/paddle/optimizer/optimizer.py | 56 ++++ python/paddle/optimizer/rmsprop.py | 75 +---- python/paddle/optimizer/sgd.py | 43 +-- 11 files changed, 252 insertions(+), 936 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 9b15bd1f455..08717825f79 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -638,6 +638,48 @@ class Optimizer: ), framework.name_scope('scale_with_param_lr'): return self._global_learning_rate() * param_lr + def _is_dtype_fp16_or_bf16(self, dtype): + """ + check the dtype is fp16 or the dtype is bf16 + :param dtype: instance of core.VarDesc.VarType + :return: True if dtype is one of fp16 or bf16, False otherwise + """ + assert isinstance( + dtype, core.VarDesc.VarType + ), "The dtype should be an instance of core.VarDesc.VarType." + return ( + dtype == core.VarDesc.VarType.FP16 + or dtype == core.VarDesc.VarType.BF16 + ) + + def _create_master_weight(self, param): + if param.name in self._master_weights: + var = self._master_weights[param.name] + else: + assert isinstance(self.helper, LayerHelper) + + var_name = param.name + "_fp32_master" + var_name = unique_name.generate(var_name) + var = paddle.static.create_global_var( + name=var_name, + shape=param.shape, + value=0, + dtype='float32', + persistable=True, + ) + block = self.helper.startup_program.global_block() + block.append_op( + type="cast", + inputs={"X": [param]}, + outputs={"Out": [var]}, + attrs={ + "in_dtype": param.dtype, + "out_dtype": core.VarDesc.VarType.FP32, + }, + ) + self._master_weights[param.name] = var + return var + def _create_accumulators(self, block, parameters): """Create all accumulators needed by the parameters @@ -819,6 +861,34 @@ class Optimizer: ) return self._accumulators[name][param.name] + def _get_accumulator_master(self, name, param): + """Utility function to fetch an accumulator for a parameter + Args: + name: name of the accumulator + param: parameter variable for which accumulator is to be fetched + Returns: + accumulator variable for the parameter + """ + if self._name is not None: + name = self._name + "_" + name + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param.dtype + ) + target_param = ( + self._master_weights[param.name] if find_master else param + ) + target_name = target_param.name + if ( + name not in self._accumulators + or target_name not in self._accumulators[name] + ): + raise Exception( + "Accumulator {} does not exist for parameter {}".format( + name, target_name + ) + ) + return self._accumulators[name][target_name] + def _get_global_accumulator(self, name): """Utility function to fetch a global accumulator @@ -1486,34 +1556,6 @@ class SGDOptimizer(Optimizer): self._multi_precision = multi_precision self._master_weights = {} - def _create_master_weight(self, param): - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + "_fp32_master" - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) if isinstance(parameters, dict): @@ -1521,24 +1563,23 @@ class SGDOptimizer(Optimizer): # Create accumulator tensors for first and second moments for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) continue if ( - p.dtype == core.VarDesc.VarType.FP16 + self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision ): warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence." "Consider using multi_precision=True option of the Adam optimizer." ) @no_grad def _append_optimize_op(self, block, param_and_grad): - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) master_weight = ( self._master_weights[param_and_grad[0].name] @@ -1837,76 +1878,20 @@ class LarsMomentumOptimizer(Optimizer): self._rescale_grad = float(rescale_grad) self._master_weights = {} - def _create_master_weight(self, param): - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + '_fp32_master' - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - - def _get_accumulator(self, name, param): - """Utility function to fetch an accumulator for a parameter - Args: - name: name of the accumulator - param: parameter variable for which accumulator is to be fetched - Returns: - accumulator variable for the parameter - """ - if self._name is not None: - name = self._name + "_" + name - find_master = ( - self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 - ) - target_param = ( - self._master_weights[param.name] if find_master else param - ) - target_name = target_param.name - if ( - name not in self._accumulators - or target_name not in self._accumulators[name] - ): - raise Exception( - "Accumulator {} does not exist for parameter {}".format( - name, target_name - ) - ) - return self._accumulators[name][target_name] - def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_accumulator(self._velocity_acc_str, master_p) continue if ( - p.dtype == core.VarDesc.VarType.FP16 + self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision ): warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence." "Consider using multi_precision=True option of the Lars optimizer." ) self._add_accumulator(self._velocity_acc_str, p) @@ -1921,14 +1906,13 @@ class LarsMomentumOptimizer(Optimizer): _lars_weight_decay = 0.0 break - velocity_acc = self._get_accumulator( + velocity_acc = self._get_accumulator_master( self._velocity_acc_str, param_and_grad[0] ) lr = self._create_param_lr(param_and_grad) - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) master_weight = ( self._master_weights[param_and_grad[0].name] @@ -2084,76 +2068,20 @@ class AdagradOptimizer(Optimizer): self.initial_accumulator_value = initial_accumulator_value self._master_weights = {} - def _create_master_weight(self, param): - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + '_fp32_master' - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - - def _get_accumulator(self, name, param): - """Utility function to fetch an accumulator for a parameter - Args: - name: name of the accumulator - param: parameter variable for which accumulator is to be fetched - Returns: - accumulator variable for the parameter - """ - if self._name is not None: - name = self._name + "_" + name - find_master = ( - self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 - ) - target_param = ( - self._master_weights[param.name] if find_master else param - ) - target_name = target_param.name - if ( - name not in self._accumulators - or target_name not in self._accumulators[name] - ): - raise Exception( - "Accumulator {} does not exist for parameter {}".format( - name, target_name - ) - ) - return self._accumulators[name][target_name] - def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_accumulator(self._moment_acc_str, master_p) continue if ( - p.dtype == core.VarDesc.VarType.FP16 + self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision ): warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence." "Consider using multi_precision=True option of the Lars optimizer." ) self._add_accumulator( @@ -2165,13 +2093,12 @@ class AdagradOptimizer(Optimizer): def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) - moment_acc = self._get_accumulator( + moment_acc = self._get_accumulator_master( self._moment_acc_str, param_and_grad[0] ) - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) master_weight = ( self._master_weights[param_and_grad[0].name] @@ -2764,38 +2691,10 @@ class AdamaxOptimizer(Optimizer): self._multi_precision = False self._master_weights = {} - def _create_master_weight(self, param): - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + '_fp32_master' - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - def _create_accumulators(self, block, parameters): # Create accumulator tensors for first moment and infinity norm for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_accumulator(self._moment_acc_str, master_p) self._add_accumulator(self._inf_norm_acc_str, master_p) @@ -2807,11 +2706,11 @@ class AdamaxOptimizer(Optimizer): ) continue if ( - p.dtype == core.VarDesc.VarType.FP16 + self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision ): warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence." "Consider using multi_precision=True option of the Lars optimizer." ) self._add_accumulator(self._moment_acc_str, p) @@ -2823,48 +2722,21 @@ class AdamaxOptimizer(Optimizer): shape=[1], ) - def _get_accumulator(self, name, param): - """Utility function to fetch an accumulator for a parameter - Args: - name: name of the accumulator - param: parameter variable for which accumulator is to be fetched - Returns: - accumulator variable for the parameter - """ - if self._name is not None: - name = self._name + "_" + name - find_master = ( - self._multi_precision and core.VarDesc.VarType.FP16 == param.dtype - ) - target_param = ( - self._master_weights[param.name] if find_master else param - ) - target_name = target_param.name - if ( - name not in self._accumulators - or target_name not in self._accumulators[name] - ): - raise Exception( - "Accumulator {} does not exist for parameter {}".format( - name, target_name - ) - ) - return self._accumulators[name][target_name] - def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) - moment = self._get_accumulator(self._moment_acc_str, param_and_grad[0]) - inf_norm = self._get_accumulator( + moment = self._get_accumulator_master( + self._moment_acc_str, param_and_grad[0] + ) + inf_norm = self._get_accumulator_master( self._inf_norm_acc_str, param_and_grad[0] ) - beta1_pow_acc = self._get_accumulator( + beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param_and_grad[0] ) - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) master_weight = ( self._master_weights[param_and_grad[0].name] @@ -3283,68 +3155,12 @@ class AdadeltaOptimizer(Optimizer): self._epsilon = epsilon self._rho = rho - def _create_master_weight(self, param): - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + '_fp32_master' - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - - def _get_accumulator(self, name, param): - """Utility function to fetch an accumulator for a parameter - Args: - name: name of the accumulator - param: parameter variable for which accumulator is to be fetched - Returns: - accumulator variable for the parameter - """ - if self._name is not None: - name = self._name + "_" + name - find_master = ( - self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 - ) - target_param = ( - self._master_weights[param.name] if find_master else param - ) - target_name = target_param.name - if ( - name not in self._accumulators - or target_name not in self._accumulators[name] - ): - raise Exception( - "Accumulator {} does not exist for parameter {}".format( - name, target_name - ) - ) - return self._accumulators[name][target_name] - def _create_accumulators(self, block, parameters): if not isinstance(block, framework.Block): raise TypeError("block is not instance of framework.Block.") for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_accumulator(self._avg_squared_grad_acc_str, master_p) self._add_accumulator( @@ -3352,11 +3168,11 @@ class AdadeltaOptimizer(Optimizer): ) continue if ( - p.dtype == core.VarDesc.VarType.FP16 + self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision ): warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence." "Consider using multi_precision=True option of the Lars optimizer." ) self._add_accumulator(self._avg_squared_grad_acc_str, p) @@ -3366,15 +3182,15 @@ class AdadeltaOptimizer(Optimizer): if not isinstance(block, framework.Block): raise TypeError("block is not instance of framework.Block.") - avg_squared_grad_acc = self._get_accumulator( + avg_squared_grad_acc = self._get_accumulator_master( self._avg_squared_grad_acc_str, param_and_grad[0] ) - avg_squared_update_acc = self._get_accumulator( + avg_squared_update_acc = self._get_accumulator_master( self._avg_squared_update_acc_str, param_and_grad[0] ) - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) master_weight = ( self._master_weights[param_and_grad[0].name] @@ -3574,79 +3390,23 @@ class RMSPropOptimizer(Optimizer): self._multi_precision = False self._master_weights = {} - def _create_master_weight(self, param): - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + '_fp32_master' - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - - def _get_accumulator(self, name, param): - """Utility function to fetch an accumulator for a parameter - Args: - name: name of the accumulator - param: parameter variable for which accumulator is to be fetched - Returns: - accumulator variable for the parameter - """ - if self._name is not None: - name = self._name + "_" + name - find_master = ( - self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 - ) - target_param = ( - self._master_weights[param.name] if find_master else param - ) - target_name = target_param.name - if ( - name not in self._accumulators - or target_name not in self._accumulators[name] - ): - raise Exception( - "Accumulator {} does not exist for parameter {}".format( - name, target_name - ) - ) - return self._accumulators[name][target_name] - def _create_accumulators(self, block, parameters): if not isinstance(block, framework.Block): raise TypeError("block is not instance of framework.Block.") for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_accumulator(self._momentum_acc_str, master_p) self._add_accumulator(self._mean_square_acc_str, master_p) self._add_accumulator(self._mean_grad_acc_str, master_p) continue if ( - p.dtype == core.VarDesc.VarType.FP16 + self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision ): warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence." "Consider using multi_precision=True option of the Lars optimizer." ) self._add_accumulator(self._momentum_acc_str, p) @@ -3657,18 +3417,17 @@ class RMSPropOptimizer(Optimizer): if not isinstance(block, framework.Block): raise TypeError("block is not instance of framework.Block.") - momentum_acc = self._get_accumulator( + momentum_acc = self._get_accumulator_master( self._momentum_acc_str, param_and_grad[0] ) - mean_square_acc = self._get_accumulator( + mean_square_acc = self._get_accumulator_master( self._mean_square_acc_str, param_and_grad[0] ) - mean_grad_acc = self._get_accumulator( + mean_grad_acc = self._get_accumulator_master( self._mean_grad_acc_str, param_and_grad[0] ) - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) master_weight = ( self._master_weights[param_and_grad[0].name] @@ -6185,6 +5944,7 @@ class PipelineOptimizer: def _get_var_size(self, var): dtype_to_size = { core.VarDesc.VarType.FP16: 2, + core.VarDesc.VarType.BF16: 2, core.VarDesc.VarType.FP32: 4, core.VarDesc.VarType.FP64: 8, core.VarDesc.VarType.INT16: 2, diff --git a/python/paddle/optimizer/adadelta.py b/python/paddle/optimizer/adadelta.py index 5afc3aef03d..d0d8d917e70 100644 --- a/python/paddle/optimizer/adadelta.py +++ b/python/paddle/optimizer/adadelta.py @@ -14,12 +14,10 @@ import warnings -import paddle from paddle import _C_ops -from ..fluid import core, framework, unique_name +from ..fluid import framework from ..fluid.dygraph import no_grad -from ..fluid.layer_helper import LayerHelper from ..framework import in_dygraph_mode from .optimizer import Optimizer @@ -144,62 +142,6 @@ class Adadelta(Optimizer): 'rho': rho, } - def _create_master_weight(self, param): - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + "_fp32_master" - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - - def _get_accumulator(self, name, param): - """Utility function to fetch an accumulator for a parameter - Args: - name: name of the accumulator - param: parameter variable for which accumulator is to be fetched - Returns: - accumulator variable for the parameter - """ - if self._name is not None: - name = self._name + "_" + name - find_master = ( - self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 - ) - target_param = ( - self._master_weights[param.name] if find_master else param - ) - target_name = target_param.name - if ( - name not in self._accumulators - or target_name not in self._accumulators[name] - ): - raise Exception( - "Accumulator {} does not exist for parameter {}".format( - name, target_name - ) - ) - return self._accumulators[name][target_name] - def _create_accumulators(self, block, parameters): if not isinstance(block, framework.Block): raise TypeError("block is not instance of framework.Block.") @@ -207,7 +149,7 @@ class Adadelta(Optimizer): parameters = parameters.get('params') for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_accumulator(self._avg_squared_grad_acc_str, master_p) self._add_accumulator( @@ -215,11 +157,11 @@ class Adadelta(Optimizer): ) continue if ( - p.dtype == core.VarDesc.VarType.FP16 + self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision ): warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence." "Consider using multi_precision=True option of the Lars optimizer." ) self._add_accumulator(self._avg_squared_grad_acc_str, p) @@ -229,15 +171,14 @@ class Adadelta(Optimizer): if isinstance(param_and_grad, dict): param_and_grad = self._update_param_group(param_and_grad) - avg_squared_grad_acc = self._get_accumulator( + avg_squared_grad_acc = self._get_accumulator_master( self._avg_squared_grad_acc_str, param_and_grad[0] ) - avg_squared_update_acc = self._get_accumulator( + avg_squared_update_acc = self._get_accumulator_master( self._avg_squared_update_acc_str, param_and_grad[0] ) - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) master_weight = ( self._master_weights[param_and_grad[0].name] diff --git a/python/paddle/optimizer/adagrad.py b/python/paddle/optimizer/adagrad.py index a2688a0544e..1052eefb22c 100644 --- a/python/paddle/optimizer/adagrad.py +++ b/python/paddle/optimizer/adagrad.py @@ -13,10 +13,7 @@ # limitations under the License. import warnings -import paddle - -from ..fluid import core, framework, unique_name -from ..fluid.layer_helper import LayerHelper +from ..fluid import framework from .optimizer import Optimizer __all__ = [] @@ -138,64 +135,6 @@ class Adagrad(Optimizer): 'initial_accumulator_value': initial_accumulator_value, } - def _create_master_weight(self, param): - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + "_fp32_master" - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - - def _get_accumulator(self, name, param): - """Utility function to fetch an accumulator for a parameter - - Args: - name: name of the accumulator - param: parameter variable for which accumulator is to be fetched - - Returns: - accumulator variable for the parameter - """ - if self._name is not None: - name = self._name + "_" + name - find_master = ( - self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 - ) - target_param = ( - self._master_weights[param.name] if find_master else param - ) - target_name = target_param.name - if ( - name not in self._accumulators - or target_name not in self._accumulators[name] - ): - raise Exception( - "Accumulator {} does not exist for parameter {}".format( - name, target_name - ) - ) - return self._accumulators[name][target_name] - def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) @@ -203,16 +142,16 @@ class Adagrad(Optimizer): parameters = self._update_param_group(parameters) for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_accumulator(self._moment_acc_str, master_p) continue if ( - p.dtype == core.VarDesc.VarType.FP16 + self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision ): warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence." "Consider using multi_precision=True option of the Momentum optimizer." ) self._add_accumulator( @@ -227,13 +166,12 @@ class Adagrad(Optimizer): if isinstance(param_and_grad, dict): param_and_grad = self._update_param_group(param_and_grad) - moment_acc = self._get_accumulator( + moment_acc = self._get_accumulator_master( self._moment_acc_str, param_and_grad[0] ) - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) master_weight = ( diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 070efdff2d1..ff1ff74e398 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -18,10 +18,9 @@ from collections import defaultdict import paddle from paddle import _C_ops -from ..fluid import core, framework, unique_name +from ..fluid import core, framework from ..fluid.dygraph import base as imperative_base from ..fluid.framework import Variable, in_dygraph_mode -from ..fluid.layer_helper import LayerHelper from .optimizer import Optimizer __all__ = [] @@ -225,62 +224,6 @@ class Adam(Optimizer): self._master_weight_dict = self._create_multi_tensor_dict() self._master_weight_dict['FP32_LODTensor'] = None - def _create_master_weight(self, param): - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + "_fp32_master" - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - - def _get_accumulator(self, name, param): - """Utility function to fetch an accumulator for a parameter - Args: - name: name of the accumulator - param: parameter variable for which accumulator is to be fetched - Returns: - accumulator variable for the parameter - """ - if self._name is not None: - name = self._name + "_" + name - find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( - param.dtype - ) - target_param = ( - self._master_weights[param.name] if find_master else param - ) - target_name = target_param.name - if ( - name not in self._accumulators - or target_name not in self._accumulators[name] - ): - raise Exception( - "Accumulator {} does not exist for parameter {}".format( - name, target_name - ) - ) - return self._accumulators[name][target_name] - def _add_moments_pows(self, p): acc_dtype = p.dtype if self._is_dtype_fp16_or_bf16(acc_dtype): @@ -336,16 +279,16 @@ class Adam(Optimizer): if isinstance(param_and_grad, dict): param_and_grad = self._update_param_group(param_and_grad) - moment1 = self._get_accumulator( + moment1 = self._get_accumulator_master( self._moment1_acc_str, param_and_grad[0] ) - moment2 = self._get_accumulator( + moment2 = self._get_accumulator_master( self._moment2_acc_str, param_and_grad[0] ) - beta1_pow_acc = self._get_accumulator( + beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param_and_grad[0] ) - beta2_pow_acc = self._get_accumulator( + beta2_pow_acc = self._get_accumulator_master( self._beta2_pow_acc_str, param_and_grad[0] ) find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( @@ -530,12 +473,12 @@ class Adam(Optimizer): """ self._create_accumulators(target_block, parameters) for param in parameters: - moment1 = self._get_accumulator(self._moment1_acc_str, param) - moment2 = self._get_accumulator(self._moment2_acc_str, param) - beta1_pow_acc = self._get_accumulator( + moment1 = self._get_accumulator_master(self._moment1_acc_str, param) + moment2 = self._get_accumulator_master(self._moment2_acc_str, param) + beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param ) - beta2_pow_acc = self._get_accumulator( + beta2_pow_acc = self._get_accumulator_master( self._beta2_pow_acc_str, param ) diff --git a/python/paddle/optimizer/adamax.py b/python/paddle/optimizer/adamax.py index 18d3aefec0f..e7a7c6d0d2e 100644 --- a/python/paddle/optimizer/adamax.py +++ b/python/paddle/optimizer/adamax.py @@ -14,13 +14,11 @@ import warnings -import paddle from paddle import _C_ops -from ..fluid import core, framework, unique_name +from ..fluid import core, framework from ..fluid.dygraph import no_grad from ..fluid.framework import name_scope -from ..fluid.layer_helper import LayerHelper from .optimizer import Optimizer __all__ = [] @@ -191,95 +189,40 @@ class Adamax(Optimizer): shape=[1], ) - def _create_master_weight(self, param): - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + "_fp32_master" - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - def _create_accumulators(self, block, parameters): if isinstance(parameters, dict): parameters = self._update_param_group(parameters) # Create accumulator tensors for first moment and infinity norm for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_moments_pows(master_p) continue if ( - p.dtype == core.VarDesc.VarType.FP16 + self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision ): warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence." "Consider using multi_precision=True option of the Adam optimizer." ) self._add_moments_pows(p) - def _get_accumulator(self, name, param): - """Utility function to fetch an accumulator for a parameter - Args: - name: name of the accumulator - param: parameter variable for which accumulator is to be fetched - Returns: - accumulator variable for the parameter - """ - if self._name is not None: - name = self._name + "_" + name - find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( - param.dtype - ) - target_param = ( - self._master_weights[param.name] if find_master else param - ) - target_name = target_param.name - if ( - name not in self._accumulators - or target_name not in self._accumulators[name] - ): - raise Exception( - "Accumulator {} does not exist for parameter {}".format( - name, target_name - ) - ) - return self._accumulators[name][target_name] - def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) if isinstance(param_and_grad, dict): param_and_grad = self._update_param_group(param_and_grad) - moment = self._get_accumulator(self._moment_acc_str, param_and_grad[0]) - inf_norm = self._get_accumulator( + moment = self._get_accumulator_master( + self._moment_acc_str, param_and_grad[0] + ) + inf_norm = self._get_accumulator_master( self._inf_norm_acc_str, param_and_grad[0] ) - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) master_weight = ( self._master_weights[param_and_grad[0].name] @@ -287,7 +230,7 @@ class Adamax(Optimizer): else None ) - beta1_pow_acc = self._get_accumulator( + beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param_and_grad[0] ) if framework.in_dygraph_mode(): @@ -347,7 +290,7 @@ class Adamax(Optimizer): if grad is None or param.stop_gradient is True: continue if framework.in_dygraph_mode(): - beta1_pow_acc = self._get_accumulator( + beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param ) with no_grad(): @@ -359,7 +302,7 @@ class Adamax(Optimizer): with param.block.program._optimized_guard( [param, grad] ), name_scope('adamax'): - beta1_pow_acc = self._get_accumulator( + beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param ) block.append_op( @@ -374,7 +317,7 @@ class Adamax(Optimizer): if grad is None or param.stop_gradient is True: continue if framework.in_dygraph_mode(): - beta1_pow_acc = self._get_accumulator( + beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param ) self._beta1 = parameters_and_grads.get( @@ -389,7 +332,7 @@ class Adamax(Optimizer): with param.block.program._optimized_guard( [param, grad] ), name_scope('adamax'): - beta1_pow_acc = self._get_accumulator( + beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param ) self._beta1 = parameters_and_grads.get( diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index a5cb7798353..8177f43ee58 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -19,10 +19,9 @@ from collections.abc import Callable import paddle from .. import _C_ops -from ..fluid import core, framework, unique_name +from ..fluid import core, framework from ..fluid.dygraph import base as imperative_base from ..fluid.framework import Parameter, Variable -from ..fluid.layer_helper import LayerHelper from ..nn.clip import GradientClipBase from .lr import LRScheduler from .optimizer import Optimizer @@ -333,62 +332,6 @@ class AdamW(Optimizer): self._param_groups.append(param_group) - def _create_master_weight(self, param): - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + "_fp32_master" - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - - def _get_accumulator(self, name, param): - """Utility function to fetch an accumulator for a parameter - Args: - name: name of the accumulator - param: parameter variable for which accumulator is to be fetched - Returns: - accumulator variable for the parameter - """ - if self._name is not None: - name = self._name + "_" + name - find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( - param.dtype - ) - target_param = ( - self._master_weights[param.name] if find_master else param - ) - target_name = target_param.name - if ( - name not in self._accumulators - or target_name not in self._accumulators[name] - ): - raise Exception( - "Accumulator {} does not exist for parameter {}".format( - name, target_name - ) - ) - return self._accumulators[name][target_name] - def _add_moments_pows(self, p): acc_dtype = p.dtype if self._is_dtype_fp16_or_bf16(acc_dtype): @@ -453,16 +396,16 @@ class AdamW(Optimizer): ): with_decay = False - moment1 = self._get_accumulator( + moment1 = self._get_accumulator_master( self._moment1_acc_str, param_and_grad[0] ) - moment2 = self._get_accumulator( + moment2 = self._get_accumulator_master( self._moment2_acc_str, param_and_grad[0] ) - beta1_pow_acc = self._get_accumulator( + beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param_and_grad[0] ) - beta2_pow_acc = self._get_accumulator( + beta2_pow_acc = self._get_accumulator_master( self._beta2_pow_acc_str, param_and_grad[0] ) find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( diff --git a/python/paddle/optimizer/lamb.py b/python/paddle/optimizer/lamb.py index e531e785e31..e7aeede370d 100644 --- a/python/paddle/optimizer/lamb.py +++ b/python/paddle/optimizer/lamb.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle from paddle import _C_ops from paddle.fluid.executor import global_scope -from ..fluid import core, framework, unique_name +from ..fluid import core, framework from ..fluid.framework import Variable -from ..fluid.layer_helper import LayerHelper from .optimizer import Optimizer __all__ = [] @@ -154,35 +152,6 @@ class Lamb(Optimizer): master_p_t = None return p_t, master_p_t - def _create_master_weight(self, param): - assert self._multi_precision - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + "_fp32_master" - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) if isinstance(parameters, dict): @@ -190,43 +159,15 @@ class Lamb(Optimizer): # Create accumulator tensors for first and second moments for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_moments_pows(master_p) else: self._add_moments_pows(p) - def _get_accumulator(self, name, param): - """Utility function to fetch an accumulator for a parameter - Args: - name: name of the accumulator - param: parameter variable for which accumulator is to be fetched - Returns: - accumulator variable for the parameter - """ - if self._name is not None: - name = self._name + "_" + name - find_master = ( - self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 - ) - target_param = ( - self._master_weights[param.name] if find_master else param - ) - target_name = target_param.name - if ( - name not in self._accumulators - or target_name not in self._accumulators[name] - ): - raise Exception( - "Accumulator {} does not exist for parameter {}".format( - name, target_name - ) - ) - return self._accumulators[name][target_name] - def _add_moments_pows(self, p): acc_dtype = p.dtype - if acc_dtype == core.VarDesc.VarType.FP16: + if self._is_dtype_fp16_or_bf16(acc_dtype): acc_dtype = core.VarDesc.VarType.FP32 self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) @@ -261,16 +202,16 @@ class Lamb(Optimizer): block.program._use_lamb = True - moment1 = self._get_accumulator( + moment1 = self._get_accumulator_master( self._moment1_acc_str, param_and_grad[0] ) - moment2 = self._get_accumulator( + moment2 = self._get_accumulator_master( self._moment2_acc_str, param_and_grad[0] ) - beta1_pow_acc = self._get_accumulator( + beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param_and_grad[0] ) - beta2_pow_acc = self._get_accumulator( + beta2_pow_acc = self._get_accumulator_master( self._beta2_pow_acc_str, param_and_grad[0] ) @@ -283,9 +224,8 @@ class Lamb(Optimizer): weight_decay = self._lamb_weight_decay lr = self._create_param_lr(param_and_grad) - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) p_name = param_and_grad[0].name if find_master: diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index 3b20777599f..cf14efb8525 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -19,8 +19,7 @@ from paddle import _C_ops from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.regularizer import L2DecayRegularizer -from ..fluid import core, framework, unique_name -from ..fluid.layer_helper import LayerHelper +from ..fluid import core, framework from .optimizer import Optimizer __all__ = [] @@ -201,64 +200,6 @@ class Momentum(Optimizer): reg_coeff = weight_decay return reg_method, reg_coeff - def _create_master_weight(self, param): - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + "_fp32_master" - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - - def _get_accumulator(self, name, param): - """Utility function to fetch an accumulator for a parameter - - Args: - name: name of the accumulator - param: parameter variable for which accumulator is to be fetched - - Returns: - accumulator variable for the parameter - """ - if self._name is not None: - name = self._name + "_" + name - find_master = ( - self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 - ) - target_param = ( - self._master_weights[param.name] if find_master else param - ) - target_name = target_param.name - if ( - name not in self._accumulators - or target_name not in self._accumulators[name] - ): - raise Exception( - "Accumulator {} does not exist for parameter {}".format( - name, target_name - ) - ) - return self._accumulators[name][target_name] - def _create_accumulators(self, block, parameters): ''' if framework._non_static_mode(): @@ -270,16 +211,16 @@ class Momentum(Optimizer): parameters = self._update_param_group(parameters) for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_accumulator(self._velocity_acc_str, master_p) continue if ( - p.dtype == core.VarDesc.VarType.FP16 + self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision ): warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence." "Consider using multi_precision=True option of the Momentum optimizer." ) self._add_accumulator(self._velocity_acc_str, p) @@ -304,7 +245,7 @@ class Momentum(Optimizer): if isinstance(param_and_grad, dict): param_and_grad = self._update_param_group(param_and_grad) - velocity_acc = self._get_accumulator( + velocity_acc = self._get_accumulator_master( self._velocity_acc_str, param_and_grad[0] ) lr = self._create_param_lr(param_and_grad) @@ -323,9 +264,8 @@ class Momentum(Optimizer): regularization_method = "" regularization_coeff = 0.0 - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) master_weight = ( self._master_weights[param_and_grad[0].name] @@ -388,7 +328,7 @@ class Momentum(Optimizer): def _multi_tensor_init(self, target_block, parameters, param_group_idx): """ - All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32). + All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, bf16, float32). This function will be overridden in the corresponding optimizer file. Args: @@ -397,7 +337,9 @@ class Momentum(Optimizer): """ self._create_accumulators(target_block, parameters) for param in parameters: - velocity_acc = self._get_accumulator(self._velocity_acc_str, param) + velocity_acc = self._get_accumulator_master( + self._velocity_acc_str, param + ) regularization_method = self._regularization_method regularization_coeff = self._regularization_coeff if hasattr(param, 'regularizer'): @@ -424,7 +366,7 @@ class Momentum(Optimizer): self._regularization_coeff_dict['FP32_LODTensor'][ param_group_idx ].append(regularization_coeff) - elif param.dtype == paddle.float16: + elif self._is_dtype_fp16_or_bf16(param.dtype): self._param_dict['FP16_LODTensor'][param_group_idx].append( param ) @@ -447,7 +389,7 @@ class Momentum(Optimizer): ].append(regularization_coeff) else: raise ValueError( - "Now multi_tensor_momentum only support fp32 and fp16 parameters and grad is LOD_TENSOR." + "Now multi_tensor_momentum only support fp32, fp16 or bf16 parameters and grad is LOD_TENSOR." ) def _append_optimize_multi_tensor_op( @@ -478,7 +420,7 @@ class Momentum(Optimizer): lr = self._create_param_lr(param_and_grad) lr_dict['FP32_LODTensor'].append(lr) elif ( - param_and_grad[0].dtype == paddle.float16 + self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype) and param_and_grad[1].type == core.VarDesc.VarType.LOD_TENSOR ): @@ -509,7 +451,7 @@ class Momentum(Optimizer): lr = self._create_param_lr(param_and_grad) lr_dict['FP32_LODTensor'].append(lr) elif ( - param_and_grad[0].dtype == paddle.float16 + self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype) and param_and_grad[1].type == core.VarDesc.VarType.LOD_TENSOR ): diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 71ca201ff9f..43f15b55e3a 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -636,6 +636,34 @@ class Optimizer: else: return self._global_learning_rate() + def _create_master_weight(self, param): + if param.name in self._master_weights: + var = self._master_weights[param.name] + else: + assert isinstance(self.helper, LayerHelper) + + var_name = param.name + "_fp32_master" + var_name = unique_name.generate(var_name) + var = paddle.static.create_global_var( + name=var_name, + shape=param.shape, + value=0, + dtype='float32', + persistable=True, + ) + block = self.helper.startup_program.global_block() + block.append_op( + type="cast", + inputs={"X": [param]}, + outputs={"Out": [var]}, + attrs={ + "in_dtype": param.dtype, + "out_dtype": core.VarDesc.VarType.FP32, + }, + ) + self._master_weights[param.name] = var + return var + def _create_accumulators(self, block, parameters): """Create all accumulators needed by the parameters @@ -767,6 +795,34 @@ class Optimizer: ) return self._accumulators[name][param.name] + def _get_accumulator_master(self, name, param): + """Utility function to fetch an accumulator for a parameter + Args: + name: name of the accumulator + param: parameter variable for which accumulator is to be fetched + Returns: + accumulator variable for the parameter + """ + if self._name is not None: + name = self._name + "_" + name + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param.dtype + ) + target_param = ( + self._master_weights[param.name] if find_master else param + ) + target_name = target_param.name + if ( + name not in self._accumulators + or target_name not in self._accumulators[name] + ): + raise Exception( + "Accumulator {} does not exist for parameter {}".format( + name, target_name + ) + ) + return self._accumulators[name][target_name] + def _update_param_device_map(self, parameters_and_grads, target_block): for param_and_grad in parameters_and_grads: if param_and_grad[0].stop_gradient is False: diff --git a/python/paddle/optimizer/rmsprop.py b/python/paddle/optimizer/rmsprop.py index 65a827631d4..266e771647d 100644 --- a/python/paddle/optimizer/rmsprop.py +++ b/python/paddle/optimizer/rmsprop.py @@ -14,12 +14,10 @@ import warnings -import paddle from paddle import _C_ops -from ..fluid import core, framework, unique_name +from ..fluid import framework from ..fluid.framework import in_dygraph_mode -from ..fluid.layer_helper import LayerHelper from .optimizer import Optimizer __all__ = [] @@ -197,62 +195,6 @@ class RMSProp(Optimizer): 'centered': centered, } - def _create_master_weight(self, param): - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + "_fp32_master" - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - - def _get_accumulator(self, name, param): - """Utility function to fetch an accumulator for a parameter - Args: - name: name of the accumulator - param: parameter variable for which accumulator is to be fetched - Returns: - accumulator variable for the parameter - """ - if self._name is not None: - name = self._name + "_" + name - find_master = ( - self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 - ) - target_param = ( - self._master_weights[param.name] if find_master else param - ) - target_name = target_param.name - if ( - name not in self._accumulators - or target_name not in self._accumulators[name] - ): - raise Exception( - "Accumulator {} does not exist for parameter {}".format( - name, target_name - ) - ) - return self._accumulators[name][target_name] - def _create_accumulators(self, block, parameters): if not isinstance(block, framework.Block): raise TypeError("block is not instance of framework.Block.") @@ -261,14 +203,14 @@ class RMSProp(Optimizer): parameters = parameters.get('params') for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) self._add_accumulator(self._momentum_acc_str, master_p) self._add_accumulator(self._mean_square_acc_str, master_p) self._add_accumulator(self._mean_grad_acc_str, master_p) continue if ( - p.dtype == core.VarDesc.VarType.FP16 + self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision ): warnings.warn( @@ -286,18 +228,17 @@ class RMSProp(Optimizer): if isinstance(param_and_grad, dict): param_and_grad = self._update_param_group(param_and_grad) - momentum_acc = self._get_accumulator( + momentum_acc = self._get_accumulator_master( self._momentum_acc_str, param_and_grad[0] ) - mean_square_acc = self._get_accumulator( + mean_square_acc = self._get_accumulator_master( self._mean_square_acc_str, param_and_grad[0] ) - mean_grad_acc = self._get_accumulator( + mean_grad_acc = self._get_accumulator_master( self._mean_grad_acc_str, param_and_grad[0] ) - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) master_weight = ( self._master_weights[param_and_grad[0].name] diff --git a/python/paddle/optimizer/sgd.py b/python/paddle/optimizer/sgd.py index c188cd15a8c..ffb091ae5a5 100644 --- a/python/paddle/optimizer/sgd.py +++ b/python/paddle/optimizer/sgd.py @@ -14,13 +14,11 @@ import warnings -import paddle from paddle import _C_ops -from ..fluid import core, framework, unique_name +from ..fluid import framework from ..fluid.dygraph import no_grad from ..fluid.framework import in_dygraph_mode -from ..fluid.layer_helper import LayerHelper from .optimizer import Optimizer __all__ = [] @@ -94,34 +92,6 @@ class SGD(Optimizer): self._multi_precision = multi_precision self._master_weights = {} - def _create_master_weight(self, param): - if param.name in self._master_weights: - var = self._master_weights[param.name] - else: - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + "_fp32_master" - var_name = unique_name.generate(var_name) - var = paddle.static.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True, - ) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32, - }, - ) - self._master_weights[param.name] = var - return var - def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) if isinstance(parameters, dict): @@ -129,15 +99,15 @@ class SGD(Optimizer): # Create accumulator tensors for first and second moments for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) continue if ( - p.dtype == core.VarDesc.VarType.FP16 + self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision ): warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence." "Consider using multi_precision=True option of the Adam optimizer." ) @@ -146,9 +116,8 @@ class SGD(Optimizer): if isinstance(param_and_grad, dict): param_and_grad = self._update_param_group(param_and_grad) - find_master = ( - self._multi_precision - and param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype ) master_weight = ( self._master_weights[param_and_grad[0].name] -- GitLab