未验证 提交 0cc5e22c 编写于 作者: Z zhangbo9674 提交者: GitHub

Add multi_tensor for momentum optimizer and clear_grads (#37564)

* add multi_tensor for momentum and clear_grads for optimizer

* fix bug for dygraph

* add unittest

* refine comment

* add param_group

* refine regularizaiton logic

* del clear_grads

* add clear_grads

* add dispensable check of None

* refine clear_grad

* fix build bug

* refine code by comment

* refine code

* add multi tensor check

* refine param_group update

* add multi tensor for static mode

* refine comments

* delete useless comma for momentum

* refine comment for momentum

* refine code by commment
上级 2f188341
...@@ -688,7 +688,7 @@ std::vector<std::shared_ptr<imperative::VarBase>> GetVarBaseListFromArgs( ...@@ -688,7 +688,7 @@ std::vector<std::shared_ptr<imperative::VarBase>> GetVarBaseListFromArgs(
ssize_t arg_idx, bool dispensable) { ssize_t arg_idx, bool dispensable) {
PyObject* list = PyTuple_GET_ITEM(args, arg_idx); PyObject* list = PyTuple_GET_ITEM(args, arg_idx);
if (list == nullptr) { if (list == nullptr || list == Py_None) {
if (!dispensable) { if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensor, but got " "%s(): argument '%s' (position %d) must be list of Tensor, but got "
......
...@@ -58,6 +58,8 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -58,6 +58,8 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}}, {"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}},
{"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}}, {"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}},
{"momentum", {"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}}, {"momentum", {"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}},
{"merged_momentum",
{"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}},
{"sparse_momentum", {"Param", "Grad", "Velocity", "Index", "LearningRate"}}, {"sparse_momentum", {"Param", "Grad", "Velocity", "Index", "LearningRate"}},
{"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}}, {"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}},
{"run_program", {"X", "Params"}}, {"run_program", {"X", "Params"}},
...@@ -113,6 +115,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = { ...@@ -113,6 +115,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"multiclass_nms3", {"Out", "NmsRoisNum"}}, {"multiclass_nms3", {"Out", "NmsRoisNum"}},
{"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}}, {"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
{"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, {"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"merged_momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"sparse_momentum", {"ParamOut", "VelocityOut"}}, {"sparse_momentum", {"ParamOut", "VelocityOut"}},
{"rnn", {"DropoutState", "Reserve", "Out", "State"}}, {"rnn", {"DropoutState", "Reserve", "Out", "State"}},
{"run_program", {"DOut"}}, {"run_program", {"DOut"}},
...@@ -153,6 +156,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -153,6 +156,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates", {"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates",
"out_old_num_accumulates", "out_num_updates"}}, "out_old_num_accumulates", "out_num_updates"}},
{"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, {"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"merged_momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"sparse_momentum", {"ParamOut", "VelocityOut"}}, {"sparse_momentum", {"ParamOut", "VelocityOut"}},
{"batch_norm", {"MeanOut", "VarianceOut"}}, {"batch_norm", {"MeanOut", "VarianceOut"}},
{"sync_batch_norm", {"MeanOut", "VarianceOut"}}, {"sync_batch_norm", {"MeanOut", "VarianceOut"}},
......
...@@ -579,6 +579,14 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -579,6 +579,14 @@ PYBIND11_MODULE(core_noavx, m) {
m.def("disable_signal_handler", &DisableSignalHandler); m.def("disable_signal_handler", &DisableSignalHandler);
m.def("clear_gradients",
[](std::vector<std::shared_ptr<imperative::VarBase>> param_list,
bool set_to_zero) {
for (auto param : param_list) {
param->ClearGradient(set_to_zero);
}
});
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
m.def("cudnn_version", &platform::DnnVersion); m.def("cudnn_version", &platform::DnnVersion);
m.def("gpu_memory_available", []() { m.def("gpu_memory_available", []() {
......
...@@ -21,6 +21,7 @@ from paddle.fluid.op import Operator ...@@ -21,6 +21,7 @@ from paddle.fluid.op import Operator
from op_test import OpTest from op_test import OpTest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy
def calculate_momentum_by_numpy(param, def calculate_momentum_by_numpy(param,
...@@ -805,5 +806,189 @@ class TestMomentumV2Group(TestMomentumV2): ...@@ -805,5 +806,189 @@ class TestMomentumV2Group(TestMomentumV2):
adam.clear_gradients() adam.clear_gradients()
class TestMultiTensorMomentumDygraph(unittest.TestCase):
def _momentum_optimize_dygraph(self,
place,
use_param_attr=False,
use_param_group=False,
use_amp=False,
use_multi_tensor=False):
paddle.disable_static()
paddle.seed(10)
paddle.set_device(place)
input = paddle.randn((5, 5))
weight_attr = paddle.ParamAttr(
learning_rate=0.5,
regularizer=paddle.regularizer.L2Decay(1.0),
trainable=True)
if use_param_attr:
model = paddle.nn.Linear(5, 5, weight_attr)
else:
model = paddle.nn.Linear(5, 5)
if not use_param_group:
optimizer = paddle.optimizer.Momentum(
parameters=model.parameters(),
use_multi_tensor=use_multi_tensor,
multi_precision=use_amp)
else:
optimizer = paddle.optimizer.Momentum(
parameters=[{
'params': model.parameters(),
'weight_decay': 0.001,
'learning_rate': 0.1,
'momentum': 0.99
}],
use_multi_tensor=use_multi_tensor,
multi_precision=use_amp)
for idx in range(5):
if place == 'gpu' and use_amp == True:
model = paddle.amp.decorate(models=model, level='O2')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
if place == 'gpu' and use_amp == True:
with paddle.amp.auto_cast(level='O2'):
output = model(input)
loss = paddle.mean(output)
scaled = scaler.scale(loss)
scaled.backward()
scaler.step(optimizer)
optimizer.clear_grad(set_to_zero=False)
else:
output = model(input)
loss = paddle.mean(output)
# This can be any optimizer supported by dygraph.
loss.backward()
optimizer.step()
optimizer.clear_grad(set_to_zero=False)
return output, model.parameters()
def _get_places(self):
places = ['cpu']
if paddle.is_compiled_with_cuda():
places.append('gpu')
return places
def _check_with_place_amp(self, place, use_amp):
output1, params1 = self._momentum_optimize_dygraph(
place=place, use_amp=use_amp, use_multi_tensor=True)
output2, params2 = self._momentum_optimize_dygraph(
place=place, use_amp=use_amp, use_multi_tensor=False)
self.assertEqual(np.allclose(output1, output2, rtol=1e-05), True)
for idx in range(len(params1)):
self.assertEqual(
np.allclose(
params1[idx], params2[idx], rtol=1e-05), True)
def _check_with_param_arrt(self, place, use_amp):
output1, params1 = self._momentum_optimize_dygraph(
place=place,
use_amp=use_amp,
use_param_attr=True,
use_multi_tensor=True)
output2, params2 = self._momentum_optimize_dygraph(
place=place,
use_amp=use_amp,
use_param_attr=True,
use_multi_tensor=False)
self.assertEqual(np.allclose(output1, output2, rtol=1e-05), True)
for idx in range(len(params1)):
self.assertEqual(
np.allclose(
params1[idx], params2[idx], rtol=1e-05), True)
def _check_with_param_group(self, place, use_amp):
output1, params1 = self._momentum_optimize_dygraph(
place=place,
use_amp=use_amp,
use_param_group=True,
use_multi_tensor=True)
output2, params2 = self._momentum_optimize_dygraph(
place=place,
use_amp=use_amp,
use_param_group=True,
use_multi_tensor=False)
self.assertEqual(np.allclose(output1, output2, rtol=1e-05), True)
for idx in range(len(params1)):
self.assertEqual(
np.allclose(
params1[idx], params2[idx], rtol=1e-05), True)
def test_main(self):
for place in self._get_places():
use_amp_list = [True, False]
for use_amp in use_amp_list:
self._check_with_place_amp(place, use_amp)
self._check_with_param_arrt(place, use_amp)
self._check_with_param_group(place, use_amp)
class TestMultiTensorMomentumStatic(unittest.TestCase):
def _momentum_optimize_static(self,
place,
use_amp=False,
use_multi_tensor=False):
paddle.enable_static()
paddle.seed(10)
np.random.seed(10)
if place == 'cpu':
use_amp = False
exe = paddle.static.Executor(place=place)
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
optimizer = paddle.optimizer.Momentum(
multi_precision=use_amp, use_multi_tensor=use_multi_tensor)
if use_amp:
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True,
use_fp16_guard=False)
with paddle.static.program_guard(train_program, startup_program):
if use_amp:
data = paddle.static.data(
shape=[2, 2], name='X', dtype='float16')
else:
data = paddle.static.data(
shape=[2, 2], name='X', dtype='float32')
hidden = paddle.static.nn.fc(x=data, size=10)
loss = paddle.fluid.layers.mean(hidden)
optimizer.minimize(loss)
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place=place, scope=paddle.static.global_scope())
x = numpy.random.random(size=(2, 2)).astype('float16')
else:
x = numpy.random.random(size=(2, 2)).astype('float32')
out = []
for idx in range(5):
loss_data, = exe.run(train_program,
feed={"X": x},
fetch_list=[loss.name])
out.append(loss_data)
return out
def _get_places(self):
places = ['cpu']
if paddle.is_compiled_with_cuda():
places.append('gpu')
return places
def _check_with_place_amp(self, place, use_amp):
output1 = self._momentum_optimize_static(
place=place, use_amp=use_amp, use_multi_tensor=True)
output2 = self._momentum_optimize_static(
place=place, use_amp=use_amp, use_multi_tensor=False)
for idx in range(len(output1)):
self.assertEqual(
np.allclose(
output1[idx], output2[idx], rtol=1e-05), True)
def test_main(self):
for place in self._get_places():
use_amp_list = [True, False]
for use_amp in use_amp_list:
self._check_with_place_amp(place, use_amp)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -24,6 +24,7 @@ from ..fluid import layers ...@@ -24,6 +24,7 @@ from ..fluid import layers
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.regularizer import L2DecayRegularizer from paddle.fluid.regularizer import L2DecayRegularizer
from paddle import _C_ops from paddle import _C_ops
import paddle
__all__ = [] __all__ = []
...@@ -74,6 +75,7 @@ class Momentum(Optimizer): ...@@ -74,6 +75,7 @@ class Momentum(Optimizer):
multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false. multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
rescale_grad (float, optional): Multiply the gradient with `rescale_grad` before updating. \ rescale_grad (float, optional): Multiply the gradient with `rescale_grad` before updating. \
Often choose to be ``1.0/batch_size``. Often choose to be ``1.0/batch_size``.
use_multi_tensor (bool, optional): Whether to use multi-tensor strategy to update all parameters at once . Default is false.
name (str, optional): The default value is None. Normally there is no need for user name (str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to to set this property. For more information, please refer to
:ref:`api_guide_Name` . :ref:`api_guide_Name` .
...@@ -129,6 +131,7 @@ class Momentum(Optimizer): ...@@ -129,6 +131,7 @@ class Momentum(Optimizer):
grad_clip=None, grad_clip=None,
multi_precision=False, multi_precision=False,
rescale_grad=1.0, rescale_grad=1.0,
use_multi_tensor=False,
name=None): name=None):
if learning_rate is None: if learning_rate is None:
raise ValueError("learning_rate is not set") raise ValueError("learning_rate is not set")
...@@ -170,17 +173,22 @@ class Momentum(Optimizer): ...@@ -170,17 +173,22 @@ class Momentum(Optimizer):
'regularization_method': self._regularization_method, 'regularization_method': self._regularization_method,
'regularization_coeff': self._regularization_coeff, 'regularization_coeff': self._regularization_coeff,
} }
''' self._use_multi_tensor = use_multi_tensor
if framework.in_dygraph_mode(): if self._use_multi_tensor:
self.helper = LayerHelper(self.__class__.__name__) self._param_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
if isinstance(self._parameter_list[0], dict): self._velocity_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
for parameters in self._param_groups: self._master_weight_dict = {
for p in parameters['params']: 'FP32_LODTensor': None,
self._add_accumulator(self._velocity_acc_str, p) 'FP16_LODTensor': []
else: }
for p in parameters: self._regularization_method_dict = {
self._add_accumulator(self._velocity_acc_str, p) 'FP32_LODTensor': [],
''' 'FP16_LODTensor': []
}
self._regularization_coeff_dict = {
'FP32_LODTensor': [],
'FP16_LODTensor': []
}
def _update_regularization(self, weight_decay): def _update_regularization(self, weight_decay):
reg_method = "" reg_method = ""
...@@ -353,6 +361,156 @@ class Momentum(Optimizer): ...@@ -353,6 +361,156 @@ class Momentum(Optimizer):
return momentum_op return momentum_op
def _multi_tensor_init(self, target_block, parameters):
"""
All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32).
This function will be overridden in the corresponding optimizer file.
Args:
target_block: the block in which the loss tensor is present
parameters: list of parameter tensors for the optimizer
"""
self._create_accumulators(target_block, parameters)
for param in parameters:
velocity_acc = self._get_accumulator(self._velocity_acc_str, param)
regularization_method = self._regularization_method
regularization_coeff = self._regularization_coeff
if hasattr(param, 'regularizer'):
# we skip param's l2decay before, so fuse it with momentum here.
if isinstance(param.regularizer, L2DecayRegularizer):
regularization_method = "l2_decay"
regularization_coeff = param.regularizer._regularization_coeff
else:
regularization_method = ""
regularization_coeff = 0.0
if param.dtype == paddle.float32:
self._param_dict['FP32_LODTensor'].append(param)
self._velocity_dict['FP32_LODTensor'].append(velocity_acc)
# fp32 no master weight
self._regularization_method_dict['FP32_LODTensor'].append(
regularization_method)
self._regularization_coeff_dict['FP32_LODTensor'].append(
regularization_coeff)
elif param.dtype == paddle.float16:
self._param_dict['FP16_LODTensor'].append(param)
self._velocity_dict['FP16_LODTensor'].append(velocity_acc)
if self._multi_precision:
self._master_weight_dict['FP16_LODTensor'].append(
self._master_weights[param.name])
else:
self._master_weight_dict['FP16_LODTensor'] = None
self._regularization_method_dict['FP16_LODTensor'].append(
regularization_method)
self._regularization_coeff_dict['FP16_LODTensor'].append(
regularization_coeff)
else:
raise ValueError(
"Now multi_tensor_momentum only support fp32 and fp16 parameters and grad is LOD_TENSOR."
)
def _append_optimize_multi_tensor_op(self, target_block,
parameters_and_grads):
"""
For Multi Tensor, append optimize merged_operator to block.
"""
assert isinstance(target_block, framework.Block)
grad_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
lr_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
if isinstance(parameters_and_grads, list):
for param_and_grad in parameters_and_grads:
if param_and_grad[1] is None:
continue
if param_and_grad[0].stop_gradient is False:
if param_and_grad[
0].dtype == paddle.float32 and param_and_grad[
1].type == core.VarDesc.VarType.LOD_TENSOR:
grad_dict['FP32_LODTensor'].append(param_and_grad[1])
lr = self._create_param_lr(param_and_grad)
lr_dict['FP32_LODTensor'].append(lr)
elif param_and_grad[
0].dtype == paddle.float16 and param_and_grad[
1].type == core.VarDesc.VarType.LOD_TENSOR:
grad_dict['FP16_LODTensor'].append(param_and_grad[1])
lr = self._create_param_lr(param_and_grad)
lr_dict['FP16_LODTensor'].append(lr)
else:
for param_and_grad in parameters_and_grads['params']:
if param_and_grad[1] is None:
continue
if param_and_grad[0].stop_gradient is False:
param_grad_dict = dict()
param_grad_dict['params'] = param_and_grad
param_grad_dict.update({
k: v
for k, v in parameters_and_grads.items()
if k != 'params'
})
param_and_grad = self._update_param_group(param_grad_dict)
if param_and_grad[
0].dtype == paddle.float32 and param_and_grad[
1].type == core.VarDesc.VarType.LOD_TENSOR:
grad_dict['FP32_LODTensor'].append(param_and_grad[1])
lr = self._create_param_lr(param_and_grad)
lr_dict['FP32_LODTensor'].append(lr)
elif param_and_grad[
0].dtype == paddle.float16 and param_and_grad[
1].type == core.VarDesc.VarType.LOD_TENSOR:
grad_dict['FP16_LODTensor'].append(param_and_grad[1])
lr = self._create_param_lr(param_and_grad)
lr_dict['FP16_LODTensor'].append(lr)
multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
for key in multi_tensor_list:
if len(self._param_dict[key]) > 0:
if key == 'FP32_LODTensor':
self._multi_precision = False
if framework.in_dygraph_mode():
_, _, _ = _C_ops.merged_momentum(
self._param_dict[key], grad_dict[key],
self._velocity_dict[key], lr_dict[key],
self._master_weight_dict[key], self._param_dict[key],
self._velocity_dict[key], self._master_weight_dict[key],
'mu', self._momentum, 'use_nesterov',
self._use_nesterov, 'regularization_method',
self._regularization_method_dict[key],
'regularization_coeff',
self._regularization_coeff_dict[key], 'multi_precision',
self._multi_precision)
else:
inputs = {
"Param": self._param_dict[key],
"Grad": grad_dict[key],
"Velocity": self._velocity_dict[key],
"LearningRate": lr_dict[key],
}
outputs = {
"ParamOut": self._param_dict[key],
"VelocityOut": self._velocity_dict[key],
}
attrs = {
"mu": self._momentum,
"use_nesterov": self._use_nesterov,
"regularization_method":
self._regularization_method_dict[key],
"regularization_coeff":
self._regularization_coeff_dict[key],
}
if self._multi_precision:
inputs["MasterParam"] = self._master_weight_dict[key]
outputs["MasterParamOut"] = self._master_weight_dict[
key]
attrs["multi_precision"] = self._multi_precision
target_block.append_op(
type="merged_momentum",
inputs=inputs,
outputs=outputs,
attrs=attrs,
stop_gradient=True)
return None
def _update_param_group(self, parameters): def _update_param_group(self, parameters):
self._momentum = parameters.get('momentum', self._momentum = parameters.get('momentum',
self._default_dict['momentum']) self._default_dict['momentum'])
......
...@@ -217,6 +217,11 @@ class Optimizer(object): ...@@ -217,6 +217,11 @@ class Optimizer(object):
else: else:
self._param_groups = self._parameter_list self._param_groups = self._parameter_list
# NOTE: Multi Tensor: Pass in all parameters and gradients to the op kernel of the Optimizer at one time for updating for dygraph mode.
# Optimizer support list: [ paddle.optimizer.Momentum ].
self._use_multi_tensor = None
self._param_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
self._auxiliary_vars = {} self._auxiliary_vars = {}
def _set_auxiliary_var(self, key, val): def _set_auxiliary_var(self, key, val):
...@@ -676,14 +681,49 @@ class Optimizer(object): ...@@ -676,14 +681,49 @@ class Optimizer(object):
start = len(target_block.ops) start = len(target_block.ops)
self.helper = LayerHelper(self.__class__.__name__) self.helper = LayerHelper(self.__class__.__name__)
params_grads_device_map = parameters_and_grads['params'] if isinstance(
parameters_and_grads, dict) else parameters_and_grads
self._update_param_device_map(params_grads_device_map, target_block)
if isinstance(parameters_and_grads, list):
self._create_accumulators(
target_block,
[p[0] for p in parameters_and_grads if not p[0].stop_gradient])
self._create_global_learning_rate()
# NOTE: Multi Tensor support [ Momentum ] for dygraph mode
if self._use_multi_tensor and self.__class__.__name__ in ['Momentum']:
if len(self._param_dict['FP32_LODTensor']) == 0 and len(
self._param_dict['FP16_LODTensor']) == 0:
if isinstance(parameters_and_grads, list):
self._multi_tensor_init(target_block, [
p[0] for p in parameters_and_grads
if not p[0].stop_gradient
])
else:
self._update_param_group(parameters_and_grads)
self._multi_tensor_init(target_block, [
p[0] for p in parameters_and_grads['params']
if not p[0].stop_gradient
])
if framework.in_dygraph_mode():
self._append_optimize_multi_tensor_op(target_block,
parameters_and_grads)
else:
self._update_param_device_map(parameters_and_grads,
target_block)
# NOTE: Multi Tensor requires all parameters to be in the same device and program.
# param_grad_list = [p_0,g_0,p_1,g_1,....]
param_grad_list = []
for param_and_grad in parameters_and_grads:
if not param_and_grad[0].stop_gradient and param_and_grad[
1] is not None:
param_grad_list.append(param_and_grad[0])
param_grad_list.append(param_and_grad[1])
with param_grad_list[0].block.program._optimized_guard(
param_grad_list), name_scope("optimizer"):
device = self._get_device_for_param(param_grad_list[0].name)
with device_guard(device):
self._append_optimize_multi_tensor_op(
target_block, parameters_and_grads)
else:
if isinstance(parameters_and_grads, list):
self._create_accumulators(target_block, [
p[0] for p in parameters_and_grads if not p[0].stop_gradient
])
else: else:
params_acc_dict = parameters_and_grads.copy() params_acc_dict = parameters_and_grads.copy()
params_acc_dict['params'] = [ params_acc_dict['params'] = [
...@@ -692,16 +732,14 @@ class Optimizer(object): ...@@ -692,16 +732,14 @@ class Optimizer(object):
] ]
self._create_accumulators(target_block, params_acc_dict) self._create_accumulators(target_block, params_acc_dict)
self._create_global_learning_rate()
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
if isinstance(parameters_and_grads, list): if isinstance(parameters_and_grads, list):
for param_and_grad in parameters_and_grads: for param_and_grad in parameters_and_grads:
if param_and_grad[1] is None: if param_and_grad[1] is None:
continue continue
if param_and_grad[0].stop_gradient is False: if param_and_grad[0].stop_gradient is False:
self._append_optimize_op(target_block, param_and_grad) self._append_optimize_op(target_block,
param_and_grad)
else: else:
for param_and_grad in parameters_and_grads['params']: for param_and_grad in parameters_and_grads['params']:
if param_and_grad[1] is None: if param_and_grad[1] is None:
...@@ -714,16 +752,22 @@ class Optimizer(object): ...@@ -714,16 +752,22 @@ class Optimizer(object):
for k, v in parameters_and_grads.items() for k, v in parameters_and_grads.items()
if k != 'params' if k != 'params'
}) })
self._append_optimize_op(target_block, param_grad_dict) self._append_optimize_op(target_block,
param_grad_dict)
else: else:
params_grads_device_map = parameters_and_grads[
'params'] if isinstance(parameters_and_grads,
dict) else parameters_and_grads
self._update_param_device_map(params_grads_device_map,
target_block)
for param_and_grad in parameters_and_grads: for param_and_grad in parameters_and_grads:
if param_and_grad[1] is None: if param_and_grad[1] is None:
continue continue
with param_and_grad[0].block.program._optimized_guard( with param_and_grad[0].block.program._optimized_guard(
param_and_grad), name_scope("optimizer"): param_and_grad), name_scope("optimizer"):
if param_and_grad[0].stop_gradient is False: if param_and_grad[0].stop_gradient is False:
device = self._get_device_for_param(param_and_grad[0] device = self._get_device_for_param(param_and_grad[
.name) 0].name)
with device_guard(device): with device_guard(device):
optimize_op = self._append_optimize_op( optimize_op = self._append_optimize_op(
target_block, param_and_grad) target_block, param_and_grad)
...@@ -1002,12 +1046,17 @@ class Optimizer(object): ...@@ -1002,12 +1046,17 @@ class Optimizer(object):
return no_grad_set return no_grad_set
@framework.dygraph_only @framework.dygraph_only
def clear_grad(self): def clear_grad(self, set_to_zero=True):
""" """
Clear the gradients of all optimized parameters for model. Clear the gradients of all optimized parameters for model.
If not, new gradient will accumulat on previous gradient. If not, new gradient will accumulat on previous gradient.
There are two method to clear grad: set_to_zero or delete grad.
Args:
set_to_zero (bool, optional): If set grads to zero or not, default is True.
Returns: Returns:
None None
...@@ -1029,16 +1078,18 @@ class Optimizer(object): ...@@ -1029,16 +1078,18 @@ class Optimizer(object):
adam.clear_grad() adam.clear_grad()
""" """
param_list = []
if self._parameter_list is None or not isinstance( if self._parameter_list is None or not isinstance(
self._parameter_list[0], dict): self._parameter_list[0], dict):
for p in self._parameter_list: for p in self._parameter_list:
if not p.stop_gradient: if not p.stop_gradient:
p.clear_gradient() param_list.append(p)
else: else:
for param_group in self._param_groups: for param_group in self._param_groups:
for p in param_group['params']: for p in param_group['params']:
if not p.stop_gradient: if not p.stop_gradient:
p.clear_gradient() param_list.append(p)
core.clear_gradients(param_list, set_to_zero)
@imperative_base.no_grad @imperative_base.no_grad
def minimize(self, def minimize(self,
...@@ -1210,3 +1261,23 @@ class Optimizer(object): ...@@ -1210,3 +1261,23 @@ class Optimizer(object):
different optimization options. Only used in child class. different optimization options. Only used in child class.
""" """
pass pass
@framework.dygraph_only
def _multi_tensor_init(self, target_block, parameters):
"""
All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32).
This function will be overridden in the corresponding optimizer file.
Args:
target_block: the block in which the loss tensor is present
parameters: list of parameter tensors for the optimizer
"""
pass
@framework.dygraph_only
def _append_optimize_multi_tensor_op(self, target_block,
parameters_and_grads):
"""
For Multi Tensor, append optimize merged_operator to block.
"""
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册