未验证 提交 b8a593e7 编写于 作者: Z Zhen Wang 提交者: GitHub

Use correct master weights in AdamW. (#30895) (#31142)

* Use correct master weights in AdamW.

* Just modify the master weight.

* Update for CI Coverage.
上级 37b71828
...@@ -97,7 +97,7 @@ def train(use_pure_fp16=True, use_nesterov=False, use_adam=False): ...@@ -97,7 +97,7 @@ def train(use_pure_fp16=True, use_nesterov=False, use_adam=False):
test_program = train_program.clone(for_test=True) test_program = train_program.clone(for_test=True)
if use_adam: if use_adam:
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.AdamW(
learning_rate=0.001, learning_rate=0.001,
epsilon=1e-8, epsilon=1e-8,
weight_decay=0.0, weight_decay=0.0,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from .optimizer import Optimizer from .optimizer import Optimizer
from .adam import Adam from .adam import Adam
from ..fluid import core
from ..fluid import framework from ..fluid import framework
from ..fluid.dygraph import base as imperative_base from ..fluid.dygraph import base as imperative_base
import paddle import paddle
...@@ -182,8 +183,16 @@ class AdamW(Adam): ...@@ -182,8 +183,16 @@ class AdamW(Adam):
decay_coeff = 1.0 - learning_rate * self._coeff decay_coeff = 1.0 - learning_rate * self._coeff
self._lr_to_coeff[learning_rate] = decay_coeff self._lr_to_coeff[learning_rate] = decay_coeff
scaled_param = param * decay_coeff find_master = (self._multi_precision and
paddle.fluid.layers.assign(input=scaled_param, output=param) param.dtype == core.VarDesc.VarType.FP16)
if find_master:
master_weight = self._master_weights[param.name]
scaled_param = master_weight * decay_coeff
paddle.fluid.layers.assign(
input=scaled_param, output=master_weight)
else:
scaled_param = param * decay_coeff
paddle.fluid.layers.assign(input=scaled_param, output=param)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
self._append_decoupled_weight_decay(block, param_and_grad) self._append_decoupled_weight_decay(block, param_and_grad)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册