未验证 提交 88e6dc4a 编写于 作者: W wanghuancoder 提交者: GitHub

optimize momentum to speedup dygraph, a little, test=develop (#30099)

上级 254ad619
......@@ -198,10 +198,6 @@ class Momentum(Optimizer):
velocity_acc = self._get_accumulator(self._velocity_acc_str,
param_and_grad[0])
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)
lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode():
......@@ -213,6 +209,11 @@ class Momentum(Optimizer):
self._regularization_coeff)
return None
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)
attrs = {
"mu": self._momentum,
"use_nesterov": self._use_nesterov,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册