diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index bfe1429342f29e3bc624360f49eb65da8c3d2d8e..937d30cd74aac6a6b14375cdf9cbd7f65fdf6ddb 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -1104,14 +1104,22 @@ class TestMultiTensorAdam(unittest.TestCase): multi_precision=use_amp, ) else: + parameters = list(model.parameters()) + param_num = len(parameters) optimizer = paddle.optimizer.Adam( parameters=[ { - 'params': model.parameters(), + 'params': parameters[: int(param_num / 2)], 'weight_decay': 0.001, 'beta1': 0.1, 'beta2': 0.99, - } + }, + { + 'params': parameters[int(param_num / 2) :], + 'weight_decay': 0.001, + 'beta1': 0.1, + 'beta2': 0.99, + }, ], use_multi_tensor=use_multi_tensor, multi_precision=use_amp, diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index 991291e36fb7cc19f5c70b226e0135a3c22fbe6d..017b001e259d634679e180465d1c73d5abf0621a 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -889,14 +889,22 @@ class TestMultiTensorMomentumDygraph(unittest.TestCase): multi_precision=use_amp, ) else: + parameters = list(model.parameters()) + n = len(parameters) optimizer = paddle.optimizer.Momentum( parameters=[ { - 'params': model.parameters(), + 'params': parameters[: int(n / 2)], 'weight_decay': 0.001, 'learning_rate': 0.1, 'momentum': 0.99, - } + }, + { + 'params': parameters[int(n / 2) :], + 'weight_decay': 0.001, + 'learning_rate': 0.1, + 'momentum': 0.99, + }, ], use_multi_tensor=use_multi_tensor, multi_precision=use_amp, diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index c6e03ccf64cc5ab5edc65045a1ac889bd221f405..37225f62c313572c35df9d491729fa4a30517bd6 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -217,21 +217,13 @@ class Adam(Optimizer): self._use_multi_tensor = use_multi_tensor if self._use_multi_tensor: - self._param_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} - self._moment1_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} - self._moment2_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} - self._beta1_pow_acc_dict = { - 'FP32_LODTensor': [], - 'FP16_LODTensor': [], - } - self._beta2_pow_acc_dict = { - 'FP32_LODTensor': [], - 'FP16_LODTensor': [], - } - self._master_weight_dict = { - 'FP32_LODTensor': None, - 'FP16_LODTensor': [], - } + self._param_dict = self._create_multi_tensor_dict() + self._moment1_dict = self._create_multi_tensor_dict() + self._moment2_dict = self._create_multi_tensor_dict() + self._beta1_pow_acc_dict = self._create_multi_tensor_dict() + self._beta2_pow_acc_dict = self._create_multi_tensor_dict() + self._master_weight_dict = self._create_multi_tensor_dict() + self._master_weight_dict['FP32_LODTensor'] = None def _create_master_weight(self, param): if param.name in self._master_weights: @@ -550,11 +542,14 @@ class Adam(Optimizer): params_grads.append((param, grad_var)) optimize_ops = self._apply_optimize( - loss=None, startup_program=None, params_grads=params_grads + loss=None, + startup_program=None, + params_grads=params_grads, + param_group_idx=0, ) else: # optimize parameters in groups - for param_group in self._param_groups: + for idx, param_group in enumerate(self._param_groups): params_grads = defaultdict(lambda: list()) for param in param_group['params']: if param.stop_gradient: @@ -566,10 +561,13 @@ class Adam(Optimizer): {k: v for k, v in param_group.items() if k != 'params'} ) self._apply_optimize( - loss=None, startup_program=None, params_grads=params_grads + loss=None, + startup_program=None, + params_grads=params_grads, + param_group_idx=idx, ) - def _multi_tensor_init(self, target_block, parameters): + def _multi_tensor_init(self, target_block, parameters, param_group_idx): """ All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32). This function will be overridden in the corresponding optimizer file. @@ -589,21 +587,41 @@ class Adam(Optimizer): ) if param.dtype == paddle.float32: - self._param_dict['FP32_LODTensor'].append(param) - self._moment1_dict['FP32_LODTensor'].append(moment1) - self._moment2_dict['FP32_LODTensor'].append(moment2) - self._beta1_pow_acc_dict['FP32_LODTensor'].append(beta1_pow_acc) - self._beta2_pow_acc_dict['FP32_LODTensor'].append(beta2_pow_acc) + self._param_dict['FP32_LODTensor'][param_group_idx].append( + param + ) + self._moment1_dict['FP32_LODTensor'][param_group_idx].append( + moment1 + ) + self._moment2_dict['FP32_LODTensor'][param_group_idx].append( + moment2 + ) + self._beta1_pow_acc_dict['FP32_LODTensor'][ + param_group_idx + ].append(beta1_pow_acc) + self._beta2_pow_acc_dict['FP32_LODTensor'][ + param_group_idx + ].append(beta2_pow_acc) elif param.dtype == paddle.float16: - self._param_dict['FP16_LODTensor'].append(param) - self._moment1_dict['FP16_LODTensor'].append(moment1) - self._moment2_dict['FP16_LODTensor'].append(moment2) - self._beta1_pow_acc_dict['FP16_LODTensor'].append(beta1_pow_acc) - self._beta2_pow_acc_dict['FP16_LODTensor'].append(beta2_pow_acc) + self._param_dict['FP16_LODTensor'][param_group_idx].append( + param + ) + self._moment1_dict['FP16_LODTensor'][param_group_idx].append( + moment1 + ) + self._moment2_dict['FP16_LODTensor'][param_group_idx].append( + moment2 + ) + self._beta1_pow_acc_dict['FP16_LODTensor'][ + param_group_idx + ].append(beta1_pow_acc) + self._beta2_pow_acc_dict['FP16_LODTensor'][ + param_group_idx + ].append(beta2_pow_acc) if self._multi_precision: - self._master_weight_dict['FP16_LODTensor'].append( - self._master_weights[param.name] - ) + self._master_weight_dict['FP16_LODTensor'][ + param_group_idx + ].append(self._master_weights[param.name]) else: self._master_weight_dict['FP16_LODTensor'] = None else: @@ -612,7 +630,10 @@ class Adam(Optimizer): ) def _append_optimize_multi_tensor_op( - self, target_block, parameters_and_grads + self, + target_block, + parameters_and_grads, + param_group_idx, ): """ For Multi Tensor, append optimize merged_operator to block. @@ -677,7 +698,7 @@ class Adam(Optimizer): multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor'] for key in multi_tensor_list: - if len(self._param_dict[key]) > 0: + if len(self._param_dict[key][param_group_idx]) > 0: find_master = self._multi_precision and key == 'FP16_LODTensor' _beta1 = ( @@ -692,16 +713,23 @@ class Adam(Optimizer): ) if framework._non_static_mode(): + master_weight = self._master_weight_dict[key] + master_weight = ( + master_weight[param_group_idx] + if master_weight is not None + else None + ) if in_dygraph_mode(): + _, _, _, _, _, _ = _C_ops.merged_adam_( - self._param_dict[key], + self._param_dict[key][param_group_idx], grad_dict[key], lr_dict[key], - self._moment1_dict[key], - self._moment2_dict[key], - self._beta1_pow_acc_dict[key], - self._beta2_pow_acc_dict[key], - self._master_weight_dict[key], + self._moment1_dict[key][param_group_idx], + self._moment2_dict[key][param_group_idx], + self._beta1_pow_acc_dict[key][param_group_idx], + self._beta2_pow_acc_dict[key][param_group_idx], + master_weight, _beta1, _beta2, self._epsilon, @@ -710,20 +738,20 @@ class Adam(Optimizer): ) else: _, _, _, _, _, _ = _legacy_C_ops.merged_adam( - self._param_dict[key], + self._param_dict[key][param_group_idx], grad_dict[key], lr_dict[key], - self._moment1_dict[key], - self._moment2_dict[key], - self._beta1_pow_acc_dict[key], - self._beta2_pow_acc_dict[key], - self._master_weight_dict[key], - self._param_dict[key], - self._moment1_dict[key], - self._moment2_dict[key], - self._beta1_pow_acc_dict[key], - self._beta2_pow_acc_dict[key], - self._master_weight_dict[key], + self._moment1_dict[key][param_group_idx], + self._moment2_dict[key][param_group_idx], + self._beta1_pow_acc_dict[key][param_group_idx], + self._beta2_pow_acc_dict[key][param_group_idx], + master_weight, + self._param_dict[key][param_group_idx], + self._moment1_dict[key][param_group_idx], + self._moment2_dict[key][param_group_idx], + self._beta1_pow_acc_dict[key][param_group_idx], + self._beta2_pow_acc_dict[key][param_group_idx], + master_weight, 'epsilon', self._epsilon, 'beta1', @@ -735,20 +763,28 @@ class Adam(Optimizer): ) else: inputs = { - "Param": self._param_dict[key], + "Param": self._param_dict[key][param_group_idx], "Grad": grad_dict[key], "LearningRate": lr_dict[key], - "Moment1": self._moment1_dict[key], - "Moment2": self._moment2_dict[key], - "Beta1Pow": self._beta1_pow_acc_dict[key], - "Beta2Pow": self._beta2_pow_acc_dict[key], + "Moment1": self._moment1_dict[key][param_group_idx], + "Moment2": self._moment2_dict[key][param_group_idx], + "Beta1Pow": self._beta1_pow_acc_dict[key][ + param_group_idx + ], + "Beta2Pow": self._beta2_pow_acc_dict[key][ + param_group_idx + ], } outputs = { - "ParamOut": self._param_dict[key], - "Moment1Out": self._moment1_dict[key], - "Moment2Out": self._moment2_dict[key], - "Beta1PowOut": self._beta1_pow_acc_dict[key], - "Beta2PowOut": self._beta2_pow_acc_dict[key], + "ParamOut": self._param_dict[key][param_group_idx], + "Moment1Out": self._moment1_dict[key][param_group_idx], + "Moment2Out": self._moment2_dict[key][param_group_idx], + "Beta1PowOut": self._beta1_pow_acc_dict[key][ + param_group_idx + ], + "Beta2PowOut": self._beta2_pow_acc_dict[key][ + param_group_idx + ], } attrs = { "epsilon": self._epsilon, @@ -756,10 +792,12 @@ class Adam(Optimizer): "beta2": _beta2, } if find_master: - inputs["MasterParam"] = self._master_weight_dict[key] + inputs["MasterParam"] = self._master_weight_dict[key][ + param_group_idx + ] outputs["MasterParamOut"] = self._master_weight_dict[ key - ] + ][param_group_idx] attrs["multi_precision"] = find_master target_block.append_op( type="merged_adam", diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index 8cca050625a425c54c24336bc3336f626e17aae9..320f102cd927a8938e49a1ccd5a390c772241c25 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -184,20 +184,12 @@ class Momentum(Optimizer): } self._use_multi_tensor = use_multi_tensor if self._use_multi_tensor: - self._param_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} - self._velocity_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} - self._master_weight_dict = { - 'FP32_LODTensor': None, - 'FP16_LODTensor': [], - } - self._regularization_method_dict = { - 'FP32_LODTensor': [], - 'FP16_LODTensor': [], - } - self._regularization_coeff_dict = { - 'FP32_LODTensor': [], - 'FP16_LODTensor': [], - } + self._param_dict = self._create_multi_tensor_dict() + self._velocity_dict = self._create_multi_tensor_dict() + self._master_weight_dict = self._create_multi_tensor_dict() + self._master_weight_dict['FP32_LODTensor'] = None + self._regularization_method_dict = self._create_multi_tensor_dict() + self._regularization_coeff_dict = self._create_multi_tensor_dict() def _update_regularization(self, weight_decay): reg_method = "" @@ -420,7 +412,7 @@ class Momentum(Optimizer): return momentum_op - def _multi_tensor_init(self, target_block, parameters): + def _multi_tensor_init(self, target_block, parameters, param_group_idx): """ All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32). This function will be overridden in the corresponding optimizer file. @@ -445,37 +437,50 @@ class Momentum(Optimizer): regularization_method = "" regularization_coeff = 0.0 if param.dtype == paddle.float32: - self._param_dict['FP32_LODTensor'].append(param) - self._velocity_dict['FP32_LODTensor'].append(velocity_acc) - # fp32 no master weight - self._regularization_method_dict['FP32_LODTensor'].append( - regularization_method + self._param_dict['FP32_LODTensor'][param_group_idx].append( + param ) - self._regularization_coeff_dict['FP32_LODTensor'].append( - regularization_coeff + self._velocity_dict['FP32_LODTensor'][param_group_idx].append( + velocity_acc ) + # fp32 no master weight + self._regularization_method_dict['FP32_LODTensor'][ + param_group_idx + ].append(regularization_method) + self._regularization_coeff_dict['FP32_LODTensor'][ + param_group_idx + ].append(regularization_coeff) elif param.dtype == paddle.float16: - self._param_dict['FP16_LODTensor'].append(param) - self._velocity_dict['FP16_LODTensor'].append(velocity_acc) - if self._multi_precision: - self._master_weight_dict['FP16_LODTensor'].append( - self._master_weights[param.name] - ) - else: - self._master_weight_dict['FP16_LODTensor'] = None - self._regularization_method_dict['FP16_LODTensor'].append( - regularization_method + self._param_dict['FP16_LODTensor'][param_group_idx].append( + param ) - self._regularization_coeff_dict['FP16_LODTensor'].append( - regularization_coeff + self._velocity_dict['FP16_LODTensor'][param_group_idx].append( + velocity_acc ) + if self._multi_precision: + self._master_weight_dict['FP16_LODTensor'][ + param_group_idx + ].append(self._master_weights[param.name]) + else: + self._master_weight_dict['FP16_LODTensor'][ + param_group_idx + ] = None + self._regularization_method_dict['FP16_LODTensor'][ + param_group_idx + ].append(regularization_method) + self._regularization_coeff_dict['FP16_LODTensor'][ + param_group_idx + ].append(regularization_coeff) else: raise ValueError( "Now multi_tensor_momentum only support fp32 and fp16 parameters and grad is LOD_TENSOR." ) def _append_optimize_multi_tensor_op( - self, target_block, parameters_and_grads + self, + target_block, + parameters_and_grads, + param_group_idx, ): """ For Multi Tensor, append optimize merged_operator to block. @@ -540,71 +545,92 @@ class Momentum(Optimizer): multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor'] for key in multi_tensor_list: - if len(self._param_dict[key]) > 0: + if len(self._param_dict[key][param_group_idx]) > 0: find_master = self._multi_precision and key == 'FP16_LODTensor' + master_weight = self._master_weight_dict[key] + master_weight = ( + master_weight[param_group_idx] + if master_weight is not None + else None + ) + if framework._non_static_mode(): if in_dygraph_mode(): _, _, _ = _C_ops.merged_momentum_( - self._param_dict[key], + self._param_dict[key][param_group_idx], grad_dict[key], - self._velocity_dict[key], + self._velocity_dict[key][param_group_idx], lr_dict[key], - self._master_weight_dict[key], + master_weight, self._momentum, self._use_nesterov, - self._regularization_method_dict[key], - self._regularization_coeff_dict[key], + self._regularization_method_dict[key][ + param_group_idx + ], + self._regularization_coeff_dict[key][ + param_group_idx + ], find_master, self._rescale_grad, ) else: _, _, _ = _legacy_C_ops.merged_momentum( - self._param_dict[key], + self._param_dict[key][param_group_idx], grad_dict[key], - self._velocity_dict[key], + self._velocity_dict[key][param_group_idx], lr_dict[key], - self._master_weight_dict[key], - self._param_dict[key], - self._velocity_dict[key], - self._master_weight_dict[key], + master_weight, + self._param_dict[key][param_group_idx], + self._velocity_dict[key][param_group_idx], + master_weight, 'mu', self._momentum, 'use_nesterov', self._use_nesterov, 'regularization_method', - self._regularization_method_dict[key], + self._regularization_method_dict[key][ + param_group_idx + ], 'regularization_coeff', - self._regularization_coeff_dict[key], + self._regularization_coeff_dict[key][ + param_group_idx + ], 'multi_precision', find_master, ) else: inputs = { - "Param": self._param_dict[key], + "Param": self._param_dict[key][param_group_idx], "Grad": grad_dict[key], - "Velocity": self._velocity_dict[key], + "Velocity": self._velocity_dict[key][param_group_idx], "LearningRate": lr_dict[key], } outputs = { - "ParamOut": self._param_dict[key], - "VelocityOut": self._velocity_dict[key], + "ParamOut": self._param_dict[key][param_group_idx], + "VelocityOut": self._velocity_dict[key][ + param_group_idx + ], } attrs = { "mu": self._momentum, "use_nesterov": self._use_nesterov, "regularization_method": self._regularization_method_dict[ key + ][ + param_group_idx ], "regularization_coeff": self._regularization_coeff_dict[ key - ], + ][param_group_idx], } if find_master: - inputs["MasterParam"] = self._master_weight_dict[key] + inputs["MasterParam"] = self._master_weight_dict[key][ + param_group_idx + ] outputs["MasterParamOut"] = self._master_weight_dict[ key - ] + ][param_group_idx] attrs["multi_precision"] = find_master target_block.append_op( type="merged_momentum", diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 54ac0db5536cda4e466d4a8f3ae7b3c54ccc5b13..a7b672383a9ec61ff834b9b5bd0acf2de6c824e1 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -282,13 +282,20 @@ class Optimizer(object): # NOTE: Multi Tensor: Pass in all parameters and gradients to the op kernel of the Optimizer at one time for updating for dygraph mode. # Optimizer support list: [ paddle.optimizer.Momentum, paddle.optimizer.Adam]. self._use_multi_tensor = None - self._param_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} + self._param_dict = self._create_multi_tensor_dict() self._auxiliary_vars = {} def _set_auxiliary_var(self, key, val): self._auxiliary_vars[key] = val + def _create_multi_tensor_dict(self): + n = len(self._param_groups) if self._param_groups is not None else 1 + return { + 'FP32_LODTensor': [[] for _ in range(n)], + 'FP16_LODTensor': [[] for _ in range(n)], + } + def _get_auxiliary_var(self, key): return self._auxiliary_vars.get(key, None) @@ -779,7 +786,9 @@ class Optimizer(object): device = self._param_device_map[param_name] return device - def _create_optimization_pass(self, parameters_and_grads): + def _create_optimization_pass( + self, parameters_and_grads, param_group_idx=0 + ): """Add optimization operators to update gradients to tensors. Args: @@ -825,10 +834,12 @@ class Optimizer(object): 'Adam', ]: if ( - len(self._param_dict['FP32_LODTensor']) == 0 - and len(self._param_dict['FP16_LODTensor']) == 0 + len(self._param_dict['FP32_LODTensor'][param_group_idx]) == 0 + and len(self._param_dict['FP16_LODTensor'][param_group_idx]) + == 0 ): if isinstance(parameters_and_grads, list): + assert param_group_idx == 0 self._multi_tensor_init( target_block, [ @@ -836,6 +847,7 @@ class Optimizer(object): for p in parameters_and_grads if not p[0].stop_gradient ], + param_group_idx, ) else: self._update_param_group(parameters_and_grads) @@ -846,10 +858,13 @@ class Optimizer(object): for p in parameters_and_grads['params'] if not p[0].stop_gradient ], + param_group_idx, ) if framework._non_static_mode(): self._append_optimize_multi_tensor_op( - target_block, parameters_and_grads + target_block, + parameters_and_grads, + param_group_idx=param_group_idx, ) else: self._update_param_device_map( @@ -871,7 +886,9 @@ class Optimizer(object): device = self._get_device_for_param(param_grad_list[0].name) with device_guard(device): self._append_optimize_multi_tensor_op( - target_block, parameters_and_grads + target_block, + parameters_and_grads, + param_group_idx=param_group_idx, ) else: if not framework._non_static_mode(): @@ -1095,7 +1112,9 @@ class Optimizer(object): optimize_ops = self._create_optimization_pass(params_grads) return optimize_ops - def _apply_optimize(self, loss, startup_program, params_grads): + def _apply_optimize( + self, loss, startup_program, params_grads, param_group_idx=0 + ): """ Second part of `minimize`, appending optimization operators for given `params_grads` pairs. @@ -1128,8 +1147,11 @@ class Optimizer(object): params_grads['params'] = self.append_regularization_ops( params_grads['params'], self.regularization ) - optimize_ops = self._create_optimization_pass(params_grads) + optimize_ops = self._create_optimization_pass( + params_grads, param_group_idx=param_group_idx + ) else: + assert param_group_idx == 0 program = loss.block.program with program_guard(program, startup_program): optimize_ops = self.apply_gradients(params_grads) @@ -1398,12 +1420,15 @@ class Optimizer(object): params_grads.append((param, grad_var)) self._apply_optimize( - loss=None, startup_program=None, params_grads=params_grads + loss=None, + startup_program=None, + params_grads=params_grads, + param_group_idx=0, ) else: # optimize parameters in groups - for param_group in self._param_groups: + for idx, param_group in enumerate(self._param_groups): params_grads = defaultdict(lambda: list()) for param in param_group['params']: if param.stop_gradient: @@ -1415,7 +1440,10 @@ class Optimizer(object): {k: v for k, v in param_group.items() if k != 'params'} ) self._apply_optimize( - loss=None, startup_program=None, params_grads=params_grads + loss=None, + startup_program=None, + params_grads=params_grads, + param_group_idx=idx, ) def _add_param_group(self, param_group): @@ -1475,7 +1503,7 @@ class Optimizer(object): pass @framework.dygraph_only - def _multi_tensor_init(self, target_block, parameters): + def _multi_tensor_init(self, target_block, parameters, param_group_idx): """ All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32). This function will be overridden in the corresponding optimizer file. @@ -1488,7 +1516,7 @@ class Optimizer(object): @framework.dygraph_only def _append_optimize_multi_tensor_op( - self, target_block, parameters_and_grads + self, target_block, parameters_and_grads, param_group_idx ): """ For Multi Tensor, append optimize merged_operator to block.