提交 28061f53 编写于 作者: Z zhangbo9674

refine optimizer init logice

上级 b54ee044
......@@ -21,6 +21,7 @@ from ppcls.utils import profiler
def train_epoch(engine, epoch_id, print_batch_step):
tic = time.time()
v_current = [int(i) for i in paddle.__version__.split(".")]
for iter_id, batch in enumerate(engine.train_dataloader):
if iter_id >= engine.max_iter:
break
......@@ -59,7 +60,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
else:
loss_dict["loss"].backward()
engine.optimizer.step()
engine.optimizer.clear_grad(set_to_zero=True)
engine.optimizer.clear_grad()
engine.lr_sch.step()
# below code just for logging
......
......@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function
from paddle import optimizer as optim
import paddle
from ppcls.utils import logger
......@@ -36,15 +37,13 @@ class Momentum(object):
momentum,
weight_decay=None,
grad_clip=None,
multi_precision=True,
use_multi_tensor=True):
multi_precision=True):
super().__init__()
self.learning_rate = learning_rate
self.momentum = momentum
self.weight_decay = weight_decay
self.grad_clip = grad_clip
self.multi_precision = multi_precision
self.use_multi_tensor = use_multi_tensor
def __call__(self, model_list):
# model_list is None in static graph
......@@ -56,8 +55,16 @@ class Momentum(object):
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
multi_precision=self.multi_precision,
use_multi_tensor=self.use_multi_tensor,
parameters=parameters)
if hasattr(opt, '_use_multi_tensor'):
opt = optim.Momentum(
learning_rate=self.learning_rate,
momentum=self.momentum,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
multi_precision=self.multi_precision,
parameters=parameters,
use_multi_tensor=False)
return opt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册