From 0cc5e22c194445382e3bb89e90c6f2af76c7257a Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Mon, 20 Dec 2021 11:52:35 +0800 Subject: [PATCH] Add multi_tensor for momentum optimizer and clear_grads (#37564) * add multi_tensor for momentum and clear_grads for optimizer * fix bug for dygraph * add unittest * refine comment * add param_group * refine regularizaiton logic * del clear_grads * add clear_grads * add dispensable check of None * refine clear_grad * fix build bug * refine code by comment * refine code * add multi tensor check * refine param_group update * add multi tensor for static mode * refine comments * delete useless comma for momentum * refine comment for momentum * refine code by commment --- paddle/fluid/pybind/op_function_common.cc | 2 +- paddle/fluid/pybind/op_function_generator.h | 4 + paddle/fluid/pybind/pybind.cc | 8 + .../fluid/tests/unittests/test_momentum_op.py | 185 ++++++++++++++++++ python/paddle/optimizer/momentum.py | 180 +++++++++++++++-- python/paddle/optimizer/optimizer.py | 165 +++++++++++----- 6 files changed, 485 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 1f761ae29c2..3ad4994a590 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -688,7 +688,7 @@ std::vector> GetVarBaseListFromArgs( ssize_t arg_idx, bool dispensable) { PyObject* list = PyTuple_GET_ITEM(args, arg_idx); - if (list == nullptr) { + if (list == nullptr || list == Py_None) { if (!dispensable) { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument '%s' (position %d) must be list of Tensor, but got " diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 6fd5f659a99..c29228f2a5d 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -58,6 +58,8 @@ std::map> op_ins_map = { {"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}}, {"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}}, {"momentum", {"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}}, + {"merged_momentum", + {"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}}, {"sparse_momentum", {"Param", "Grad", "Velocity", "Index", "LearningRate"}}, {"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}}, {"run_program", {"X", "Params"}}, @@ -113,6 +115,7 @@ std::map> op_outs_map = { {"multiclass_nms3", {"Out", "NmsRoisNum"}}, {"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}}, {"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, + {"merged_momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, {"sparse_momentum", {"ParamOut", "VelocityOut"}}, {"rnn", {"DropoutState", "Reserve", "Out", "State"}}, {"run_program", {"DOut"}}, @@ -153,6 +156,7 @@ std::map> op_passing_outs_map = { {"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates", "out_old_num_accumulates", "out_num_updates"}}, {"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, + {"merged_momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, {"sparse_momentum", {"ParamOut", "VelocityOut"}}, {"batch_norm", {"MeanOut", "VarianceOut"}}, {"sync_batch_norm", {"MeanOut", "VarianceOut"}}, diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f998c30dd15..a15e26b848e 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -579,6 +579,14 @@ PYBIND11_MODULE(core_noavx, m) { m.def("disable_signal_handler", &DisableSignalHandler); + m.def("clear_gradients", + [](std::vector> param_list, + bool set_to_zero) { + for (auto param : param_list) { + param->ClearGradient(set_to_zero); + } + }); + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) m.def("cudnn_version", &platform::DnnVersion); m.def("gpu_memory_available", []() { diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index 34e057a5a8a..a59b355b4a7 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -21,6 +21,7 @@ from paddle.fluid.op import Operator from op_test import OpTest import paddle import paddle.fluid as fluid +import numpy def calculate_momentum_by_numpy(param, @@ -805,5 +806,189 @@ class TestMomentumV2Group(TestMomentumV2): adam.clear_gradients() +class TestMultiTensorMomentumDygraph(unittest.TestCase): + def _momentum_optimize_dygraph(self, + place, + use_param_attr=False, + use_param_group=False, + use_amp=False, + use_multi_tensor=False): + paddle.disable_static() + paddle.seed(10) + paddle.set_device(place) + input = paddle.randn((5, 5)) + weight_attr = paddle.ParamAttr( + learning_rate=0.5, + regularizer=paddle.regularizer.L2Decay(1.0), + trainable=True) + if use_param_attr: + model = paddle.nn.Linear(5, 5, weight_attr) + else: + model = paddle.nn.Linear(5, 5) + if not use_param_group: + optimizer = paddle.optimizer.Momentum( + parameters=model.parameters(), + use_multi_tensor=use_multi_tensor, + multi_precision=use_amp) + else: + optimizer = paddle.optimizer.Momentum( + parameters=[{ + 'params': model.parameters(), + 'weight_decay': 0.001, + 'learning_rate': 0.1, + 'momentum': 0.99 + }], + use_multi_tensor=use_multi_tensor, + multi_precision=use_amp) + for idx in range(5): + if place == 'gpu' and use_amp == True: + model = paddle.amp.decorate(models=model, level='O2') + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + if place == 'gpu' and use_amp == True: + with paddle.amp.auto_cast(level='O2'): + output = model(input) + loss = paddle.mean(output) + scaled = scaler.scale(loss) + scaled.backward() + scaler.step(optimizer) + optimizer.clear_grad(set_to_zero=False) + else: + output = model(input) + loss = paddle.mean(output) + # This can be any optimizer supported by dygraph. + loss.backward() + optimizer.step() + optimizer.clear_grad(set_to_zero=False) + return output, model.parameters() + + def _get_places(self): + places = ['cpu'] + if paddle.is_compiled_with_cuda(): + places.append('gpu') + return places + + def _check_with_place_amp(self, place, use_amp): + output1, params1 = self._momentum_optimize_dygraph( + place=place, use_amp=use_amp, use_multi_tensor=True) + output2, params2 = self._momentum_optimize_dygraph( + place=place, use_amp=use_amp, use_multi_tensor=False) + self.assertEqual(np.allclose(output1, output2, rtol=1e-05), True) + for idx in range(len(params1)): + self.assertEqual( + np.allclose( + params1[idx], params2[idx], rtol=1e-05), True) + + def _check_with_param_arrt(self, place, use_amp): + output1, params1 = self._momentum_optimize_dygraph( + place=place, + use_amp=use_amp, + use_param_attr=True, + use_multi_tensor=True) + output2, params2 = self._momentum_optimize_dygraph( + place=place, + use_amp=use_amp, + use_param_attr=True, + use_multi_tensor=False) + self.assertEqual(np.allclose(output1, output2, rtol=1e-05), True) + for idx in range(len(params1)): + self.assertEqual( + np.allclose( + params1[idx], params2[idx], rtol=1e-05), True) + + def _check_with_param_group(self, place, use_amp): + output1, params1 = self._momentum_optimize_dygraph( + place=place, + use_amp=use_amp, + use_param_group=True, + use_multi_tensor=True) + output2, params2 = self._momentum_optimize_dygraph( + place=place, + use_amp=use_amp, + use_param_group=True, + use_multi_tensor=False) + self.assertEqual(np.allclose(output1, output2, rtol=1e-05), True) + for idx in range(len(params1)): + self.assertEqual( + np.allclose( + params1[idx], params2[idx], rtol=1e-05), True) + + def test_main(self): + for place in self._get_places(): + use_amp_list = [True, False] + for use_amp in use_amp_list: + self._check_with_place_amp(place, use_amp) + self._check_with_param_arrt(place, use_amp) + self._check_with_param_group(place, use_amp) + + +class TestMultiTensorMomentumStatic(unittest.TestCase): + def _momentum_optimize_static(self, + place, + use_amp=False, + use_multi_tensor=False): + paddle.enable_static() + paddle.seed(10) + np.random.seed(10) + if place == 'cpu': + use_amp = False + exe = paddle.static.Executor(place=place) + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + optimizer = paddle.optimizer.Momentum( + multi_precision=use_amp, use_multi_tensor=use_multi_tensor) + if use_amp: + optimizer = paddle.static.amp.decorate( + optimizer, + init_loss_scaling=128.0, + use_dynamic_loss_scaling=True, + use_pure_fp16=True, + use_fp16_guard=False) + with paddle.static.program_guard(train_program, startup_program): + if use_amp: + data = paddle.static.data( + shape=[2, 2], name='X', dtype='float16') + else: + data = paddle.static.data( + shape=[2, 2], name='X', dtype='float32') + hidden = paddle.static.nn.fc(x=data, size=10) + loss = paddle.fluid.layers.mean(hidden) + optimizer.minimize(loss) + exe.run(startup_program) + if use_amp: + optimizer.amp_init(place=place, scope=paddle.static.global_scope()) + x = numpy.random.random(size=(2, 2)).astype('float16') + else: + x = numpy.random.random(size=(2, 2)).astype('float32') + out = [] + for idx in range(5): + loss_data, = exe.run(train_program, + feed={"X": x}, + fetch_list=[loss.name]) + out.append(loss_data) + return out + + def _get_places(self): + places = ['cpu'] + if paddle.is_compiled_with_cuda(): + places.append('gpu') + return places + + def _check_with_place_amp(self, place, use_amp): + output1 = self._momentum_optimize_static( + place=place, use_amp=use_amp, use_multi_tensor=True) + output2 = self._momentum_optimize_static( + place=place, use_amp=use_amp, use_multi_tensor=False) + for idx in range(len(output1)): + self.assertEqual( + np.allclose( + output1[idx], output2[idx], rtol=1e-05), True) + + def test_main(self): + for place in self._get_places(): + use_amp_list = [True, False] + for use_amp in use_amp_list: + self._check_with_place_amp(place, use_amp) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index fde3b286073..65425df72af 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -24,6 +24,7 @@ from ..fluid import layers import paddle.fluid as fluid from paddle.fluid.regularizer import L2DecayRegularizer from paddle import _C_ops +import paddle __all__ = [] @@ -74,6 +75,7 @@ class Momentum(Optimizer): multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false. rescale_grad (float, optional): Multiply the gradient with `rescale_grad` before updating. \ Often choose to be ``1.0/batch_size``. + use_multi_tensor (bool, optional): Whether to use multi-tensor strategy to update all parameters at once . Default is false. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . @@ -129,6 +131,7 @@ class Momentum(Optimizer): grad_clip=None, multi_precision=False, rescale_grad=1.0, + use_multi_tensor=False, name=None): if learning_rate is None: raise ValueError("learning_rate is not set") @@ -170,17 +173,22 @@ class Momentum(Optimizer): 'regularization_method': self._regularization_method, 'regularization_coeff': self._regularization_coeff, } - ''' - if framework.in_dygraph_mode(): - self.helper = LayerHelper(self.__class__.__name__) - if isinstance(self._parameter_list[0], dict): - for parameters in self._param_groups: - for p in parameters['params']: - self._add_accumulator(self._velocity_acc_str, p) - else: - for p in parameters: - self._add_accumulator(self._velocity_acc_str, p) - ''' + self._use_multi_tensor = use_multi_tensor + if self._use_multi_tensor: + self._param_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} + self._velocity_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} + self._master_weight_dict = { + 'FP32_LODTensor': None, + 'FP16_LODTensor': [] + } + self._regularization_method_dict = { + 'FP32_LODTensor': [], + 'FP16_LODTensor': [] + } + self._regularization_coeff_dict = { + 'FP32_LODTensor': [], + 'FP16_LODTensor': [] + } def _update_regularization(self, weight_decay): reg_method = "" @@ -353,6 +361,156 @@ class Momentum(Optimizer): return momentum_op + def _multi_tensor_init(self, target_block, parameters): + """ + All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32). + This function will be overridden in the corresponding optimizer file. + + Args: + target_block: the block in which the loss tensor is present + parameters: list of parameter tensors for the optimizer + """ + self._create_accumulators(target_block, parameters) + for param in parameters: + velocity_acc = self._get_accumulator(self._velocity_acc_str, param) + regularization_method = self._regularization_method + regularization_coeff = self._regularization_coeff + if hasattr(param, 'regularizer'): + # we skip param's l2decay before, so fuse it with momentum here. + if isinstance(param.regularizer, L2DecayRegularizer): + regularization_method = "l2_decay" + regularization_coeff = param.regularizer._regularization_coeff + else: + regularization_method = "" + regularization_coeff = 0.0 + if param.dtype == paddle.float32: + self._param_dict['FP32_LODTensor'].append(param) + self._velocity_dict['FP32_LODTensor'].append(velocity_acc) + # fp32 no master weight + self._regularization_method_dict['FP32_LODTensor'].append( + regularization_method) + self._regularization_coeff_dict['FP32_LODTensor'].append( + regularization_coeff) + elif param.dtype == paddle.float16: + self._param_dict['FP16_LODTensor'].append(param) + self._velocity_dict['FP16_LODTensor'].append(velocity_acc) + if self._multi_precision: + self._master_weight_dict['FP16_LODTensor'].append( + self._master_weights[param.name]) + else: + self._master_weight_dict['FP16_LODTensor'] = None + self._regularization_method_dict['FP16_LODTensor'].append( + regularization_method) + self._regularization_coeff_dict['FP16_LODTensor'].append( + regularization_coeff) + else: + raise ValueError( + "Now multi_tensor_momentum only support fp32 and fp16 parameters and grad is LOD_TENSOR." + ) + + def _append_optimize_multi_tensor_op(self, target_block, + parameters_and_grads): + """ + For Multi Tensor, append optimize merged_operator to block. + """ + assert isinstance(target_block, framework.Block) + + grad_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} + lr_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} + + if isinstance(parameters_and_grads, list): + for param_and_grad in parameters_and_grads: + if param_and_grad[1] is None: + continue + if param_and_grad[0].stop_gradient is False: + if param_and_grad[ + 0].dtype == paddle.float32 and param_and_grad[ + 1].type == core.VarDesc.VarType.LOD_TENSOR: + grad_dict['FP32_LODTensor'].append(param_and_grad[1]) + lr = self._create_param_lr(param_and_grad) + lr_dict['FP32_LODTensor'].append(lr) + elif param_and_grad[ + 0].dtype == paddle.float16 and param_and_grad[ + 1].type == core.VarDesc.VarType.LOD_TENSOR: + grad_dict['FP16_LODTensor'].append(param_and_grad[1]) + lr = self._create_param_lr(param_and_grad) + lr_dict['FP16_LODTensor'].append(lr) + else: + for param_and_grad in parameters_and_grads['params']: + if param_and_grad[1] is None: + continue + if param_and_grad[0].stop_gradient is False: + param_grad_dict = dict() + param_grad_dict['params'] = param_and_grad + param_grad_dict.update({ + k: v + for k, v in parameters_and_grads.items() + if k != 'params' + }) + param_and_grad = self._update_param_group(param_grad_dict) + if param_and_grad[ + 0].dtype == paddle.float32 and param_and_grad[ + 1].type == core.VarDesc.VarType.LOD_TENSOR: + grad_dict['FP32_LODTensor'].append(param_and_grad[1]) + lr = self._create_param_lr(param_and_grad) + lr_dict['FP32_LODTensor'].append(lr) + elif param_and_grad[ + 0].dtype == paddle.float16 and param_and_grad[ + 1].type == core.VarDesc.VarType.LOD_TENSOR: + grad_dict['FP16_LODTensor'].append(param_and_grad[1]) + lr = self._create_param_lr(param_and_grad) + lr_dict['FP16_LODTensor'].append(lr) + + multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor'] + for key in multi_tensor_list: + if len(self._param_dict[key]) > 0: + if key == 'FP32_LODTensor': + self._multi_precision = False + + if framework.in_dygraph_mode(): + _, _, _ = _C_ops.merged_momentum( + self._param_dict[key], grad_dict[key], + self._velocity_dict[key], lr_dict[key], + self._master_weight_dict[key], self._param_dict[key], + self._velocity_dict[key], self._master_weight_dict[key], + 'mu', self._momentum, 'use_nesterov', + self._use_nesterov, 'regularization_method', + self._regularization_method_dict[key], + 'regularization_coeff', + self._regularization_coeff_dict[key], 'multi_precision', + self._multi_precision) + else: + inputs = { + "Param": self._param_dict[key], + "Grad": grad_dict[key], + "Velocity": self._velocity_dict[key], + "LearningRate": lr_dict[key], + } + outputs = { + "ParamOut": self._param_dict[key], + "VelocityOut": self._velocity_dict[key], + } + attrs = { + "mu": self._momentum, + "use_nesterov": self._use_nesterov, + "regularization_method": + self._regularization_method_dict[key], + "regularization_coeff": + self._regularization_coeff_dict[key], + } + if self._multi_precision: + inputs["MasterParam"] = self._master_weight_dict[key] + outputs["MasterParamOut"] = self._master_weight_dict[ + key] + attrs["multi_precision"] = self._multi_precision + target_block.append_op( + type="merged_momentum", + inputs=inputs, + outputs=outputs, + attrs=attrs, + stop_gradient=True) + return None + def _update_param_group(self, parameters): self._momentum = parameters.get('momentum', self._default_dict['momentum']) diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index abfaf489822..a711d98df6f 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -217,6 +217,11 @@ class Optimizer(object): else: self._param_groups = self._parameter_list + # NOTE: Multi Tensor: Pass in all parameters and gradients to the op kernel of the Optimizer at one time for updating for dygraph mode. + # Optimizer support list: [ paddle.optimizer.Momentum ]. + self._use_multi_tensor = None + self._param_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} + self._auxiliary_vars = {} def _set_auxiliary_var(self, key, val): @@ -676,57 +681,96 @@ class Optimizer(object): start = len(target_block.ops) self.helper = LayerHelper(self.__class__.__name__) - params_grads_device_map = parameters_and_grads['params'] if isinstance( - parameters_and_grads, dict) else parameters_and_grads - self._update_param_device_map(params_grads_device_map, target_block) - if isinstance(parameters_and_grads, list): - self._create_accumulators( - target_block, - [p[0] for p in parameters_and_grads if not p[0].stop_gradient]) - - else: - params_acc_dict = parameters_and_grads.copy() - params_acc_dict['params'] = [ - p[0] for p in params_acc_dict['params'] - if not p[0].stop_gradient - ] - self._create_accumulators(target_block, params_acc_dict) self._create_global_learning_rate() - if framework.in_dygraph_mode(): - - if isinstance(parameters_and_grads, list): + # NOTE: Multi Tensor support [ Momentum ] for dygraph mode + if self._use_multi_tensor and self.__class__.__name__ in ['Momentum']: + if len(self._param_dict['FP32_LODTensor']) == 0 and len( + self._param_dict['FP16_LODTensor']) == 0: + if isinstance(parameters_and_grads, list): + self._multi_tensor_init(target_block, [ + p[0] for p in parameters_and_grads + if not p[0].stop_gradient + ]) + else: + self._update_param_group(parameters_and_grads) + self._multi_tensor_init(target_block, [ + p[0] for p in parameters_and_grads['params'] + if not p[0].stop_gradient + ]) + if framework.in_dygraph_mode(): + self._append_optimize_multi_tensor_op(target_block, + parameters_and_grads) + else: + self._update_param_device_map(parameters_and_grads, + target_block) + # NOTE: Multi Tensor requires all parameters to be in the same device and program. + # param_grad_list = [p_0,g_0,p_1,g_1,....] + param_grad_list = [] for param_and_grad in parameters_and_grads: - if param_and_grad[1] is None: - continue - if param_and_grad[0].stop_gradient is False: - self._append_optimize_op(target_block, param_and_grad) + if not param_and_grad[0].stop_gradient and param_and_grad[ + 1] is not None: + param_grad_list.append(param_and_grad[0]) + param_grad_list.append(param_and_grad[1]) + with param_grad_list[0].block.program._optimized_guard( + param_grad_list), name_scope("optimizer"): + device = self._get_device_for_param(param_grad_list[0].name) + with device_guard(device): + self._append_optimize_multi_tensor_op( + target_block, parameters_and_grads) + else: + if isinstance(parameters_and_grads, list): + self._create_accumulators(target_block, [ + p[0] for p in parameters_and_grads if not p[0].stop_gradient + ]) else: - for param_and_grad in parameters_and_grads['params']: + params_acc_dict = parameters_and_grads.copy() + params_acc_dict['params'] = [ + p[0] for p in params_acc_dict['params'] + if not p[0].stop_gradient + ] + self._create_accumulators(target_block, params_acc_dict) + + if framework.in_dygraph_mode(): + if isinstance(parameters_and_grads, list): + for param_and_grad in parameters_and_grads: + if param_and_grad[1] is None: + continue + if param_and_grad[0].stop_gradient is False: + self._append_optimize_op(target_block, + param_and_grad) + else: + for param_and_grad in parameters_and_grads['params']: + if param_and_grad[1] is None: + continue + if param_and_grad[0].stop_gradient is False: + param_grad_dict = dict() + param_grad_dict['params'] = param_and_grad + param_grad_dict.update({ + k: v + for k, v in parameters_and_grads.items() + if k != 'params' + }) + self._append_optimize_op(target_block, + param_grad_dict) + else: + params_grads_device_map = parameters_and_grads[ + 'params'] if isinstance(parameters_and_grads, + dict) else parameters_and_grads + self._update_param_device_map(params_grads_device_map, + target_block) + for param_and_grad in parameters_and_grads: if param_and_grad[1] is None: continue - if param_and_grad[0].stop_gradient is False: - param_grad_dict = dict() - param_grad_dict['params'] = param_and_grad - param_grad_dict.update({ - k: v - for k, v in parameters_and_grads.items() - if k != 'params' - }) - self._append_optimize_op(target_block, param_grad_dict) - else: - for param_and_grad in parameters_and_grads: - if param_and_grad[1] is None: - continue - with param_and_grad[0].block.program._optimized_guard( - param_and_grad), name_scope("optimizer"): - if param_and_grad[0].stop_gradient is False: - device = self._get_device_for_param(param_and_grad[0] - .name) - with device_guard(device): - optimize_op = self._append_optimize_op( - target_block, param_and_grad) + with param_and_grad[0].block.program._optimized_guard( + param_and_grad), name_scope("optimizer"): + if param_and_grad[0].stop_gradient is False: + device = self._get_device_for_param(param_and_grad[ + 0].name) + with device_guard(device): + optimize_op = self._append_optimize_op( + target_block, param_and_grad) # Get custom finish ops for subclasses # FIXME: Need to fix this once we figure out how to handle dependencies @@ -1002,11 +1046,16 @@ class Optimizer(object): return no_grad_set @framework.dygraph_only - def clear_grad(self): + def clear_grad(self, set_to_zero=True): """ Clear the gradients of all optimized parameters for model. If not, new gradient will accumulat on previous gradient. + + There are two method to clear grad: set_to_zero or delete grad. + + Args: + set_to_zero (bool, optional): If set grads to zero or not, default is True. Returns: None @@ -1029,16 +1078,18 @@ class Optimizer(object): adam.clear_grad() """ + param_list = [] if self._parameter_list is None or not isinstance( self._parameter_list[0], dict): for p in self._parameter_list: if not p.stop_gradient: - p.clear_gradient() + param_list.append(p) else: for param_group in self._param_groups: for p in param_group['params']: if not p.stop_gradient: - p.clear_gradient() + param_list.append(p) + core.clear_gradients(param_list, set_to_zero) @imperative_base.no_grad def minimize(self, @@ -1210,3 +1261,23 @@ class Optimizer(object): different optimization options. Only used in child class. """ pass + + @framework.dygraph_only + def _multi_tensor_init(self, target_block, parameters): + """ + All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32). + This function will be overridden in the corresponding optimizer file. + + Args: + target_block: the block in which the loss tensor is present + parameters: list of parameter tensors for the optimizer + """ + pass + + @framework.dygraph_only + def _append_optimize_multi_tensor_op(self, target_block, + parameters_and_grads): + """ + For Multi Tensor, append optimize merged_operator to block. + """ + pass -- GitLab