未验证 提交 244e7546 编写于 作者: W wanghuancoder 提交者: GitHub

refine optimizer create accumulators (#50188)

* refine optimizer create accumulators

* refine
上级 eb8353a4
......@@ -145,8 +145,11 @@ class Adadelta(Optimizer):
parameters = parameters.get('params')
for p in parameters:
if p.name in self._already_create_accumulater:
continue
self._add_accumulator(self._avg_squared_grad_acc_str, p)
self._add_accumulator(self._avg_squared_update_acc_str, p)
self._already_create_accumulater.add(p.name)
def _append_optimize_op(self, block, param_and_grad):
if isinstance(param_and_grad, dict):
......
......@@ -139,11 +139,14 @@ class Adagrad(Optimizer):
parameters = self._update_param_group(parameters)
for p in parameters:
if p.name in self._already_create_accumulater:
continue
self._add_accumulator(
self._moment_acc_str,
p,
fill_value=self.initial_accumulator_value,
)
self._already_create_accumulater.add(p.name)
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
......
......@@ -317,9 +317,12 @@ class Adam(Optimizer):
# Create accumulator tensors for first and second moments
for p in parameters:
if p.name in self._already_create_accumulater:
continue
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p)
self._add_moments_pows(master_p)
self._already_create_accumulater.add(p.name)
continue
if (
self._is_dtype_fp16_or_bf16(p.dtype)
......@@ -330,6 +333,7 @@ class Adam(Optimizer):
"Consider using multi_precision=True option of the Adam optimizer."
)
self._add_moments_pows(p)
self._already_create_accumulater.add(p.name)
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
......
......@@ -176,6 +176,8 @@ class Adamax(Optimizer):
# Create accumulator tensors for first moment and infinity norm
for p in parameters:
if p.name in self._already_create_accumulater:
continue
self._add_accumulator(self._moment_acc_str, p)
self._add_accumulator(self._inf_norm_acc_str, p)
self._add_accumulator(
......@@ -184,6 +186,7 @@ class Adamax(Optimizer):
fill_value=self._beta1,
shape=[1],
)
self._already_create_accumulater.add(p.name)
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
......
......@@ -281,6 +281,7 @@ class AdamW(Optimizer):
self._use_multi_tensor = None
self.regularization = None
self._auxiliary_vars = {}
self._already_create_accumulater = set()
def _set_auxiliary_var(self, key, val):
self._auxiliary_vars[key] = val
......@@ -422,9 +423,12 @@ class AdamW(Optimizer):
# Create accumulator tensors for first and second moments
for p in parameters:
if p.name in self._already_create_accumulater:
continue
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p)
self._add_moments_pows(master_p)
self._already_create_accumulater.add(p.name)
continue
if (
self._is_dtype_fp16_or_bf16(p.dtype)
......@@ -435,6 +439,7 @@ class AdamW(Optimizer):
"Consider using multi_precision=True option of the Adam optimizer."
)
self._add_moments_pows(p)
self._already_create_accumulater.add(p.name)
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
......
......@@ -190,11 +190,15 @@ class Lamb(Optimizer):
# Create accumulator tensors for first and second moments
for p in parameters:
if p.name in self._already_create_accumulater:
continue
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._add_moments_pows(master_p)
self._already_create_accumulater.add(p.name)
else:
self._add_moments_pows(p)
self._already_create_accumulater.add(p.name)
def _get_accumulator(self, name, param):
"""Utility function to fetch an accumulator for a parameter
......
......@@ -270,9 +270,12 @@ class Momentum(Optimizer):
parameters = self._update_param_group(parameters)
for p in parameters:
if p.name in self._already_create_accumulater:
continue
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._add_accumulator(self._velocity_acc_str, master_p)
self._already_create_accumulater.add(p.name)
continue
if (
p.dtype == core.VarDesc.VarType.FP16
......@@ -283,6 +286,7 @@ class Momentum(Optimizer):
"Consider using multi_precision=True option of the Momentum optimizer."
)
self._add_accumulator(self._velocity_acc_str, p)
self._already_create_accumulater.add(p.name)
def _create_regularization_of_grad(self, param, grad, regularization=None):
"""Create and add backward regularization Operators
......
......@@ -275,6 +275,7 @@ class Optimizer:
self._param_dict = self._create_multi_tensor_dict()
self._auxiliary_vars = {}
self._already_create_accumulater = set()
def _set_auxiliary_var(self, key, val):
self._auxiliary_vars[key] = val
......
......@@ -199,9 +199,12 @@ class RMSProp(Optimizer):
parameters = parameters.get('params')
for p in parameters:
if p.name in self._already_create_accumulater:
continue
self._add_accumulator(self._momentum_acc_str, p)
self._add_accumulator(self._mean_square_acc_str, p)
self._add_accumulator(self._mean_grad_acc_str, p)
self._already_create_accumulater.add(p.name)
def _append_optimize_op(self, block, param_and_grad):
if not isinstance(block, framework.Block):
......
......@@ -129,8 +129,11 @@ class SGD(Optimizer):
# Create accumulator tensors for first and second moments
for p in parameters:
if p.name in self._already_create_accumulater:
continue
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._already_create_accumulater.add(p.name)
continue
if (
p.dtype == core.VarDesc.VarType.FP16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册