提交 28061f53 编写于 作者: Z zhangbo9674

refine optimizer init logice

上级 b54ee044
...@@ -21,6 +21,7 @@ from ppcls.utils import profiler ...@@ -21,6 +21,7 @@ from ppcls.utils import profiler
def train_epoch(engine, epoch_id, print_batch_step): def train_epoch(engine, epoch_id, print_batch_step):
tic = time.time() tic = time.time()
v_current = [int(i) for i in paddle.__version__.split(".")]
for iter_id, batch in enumerate(engine.train_dataloader): for iter_id, batch in enumerate(engine.train_dataloader):
if iter_id >= engine.max_iter: if iter_id >= engine.max_iter:
break break
...@@ -59,7 +60,7 @@ def train_epoch(engine, epoch_id, print_batch_step): ...@@ -59,7 +60,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
else: else:
loss_dict["loss"].backward() loss_dict["loss"].backward()
engine.optimizer.step() engine.optimizer.step()
engine.optimizer.clear_grad(set_to_zero=True) engine.optimizer.clear_grad()
engine.lr_sch.step() engine.lr_sch.step()
# below code just for logging # below code just for logging
......
...@@ -17,6 +17,7 @@ from __future__ import division ...@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle import optimizer as optim from paddle import optimizer as optim
import paddle
from ppcls.utils import logger from ppcls.utils import logger
...@@ -36,15 +37,13 @@ class Momentum(object): ...@@ -36,15 +37,13 @@ class Momentum(object):
momentum, momentum,
weight_decay=None, weight_decay=None,
grad_clip=None, grad_clip=None,
multi_precision=True, multi_precision=True):
use_multi_tensor=True):
super().__init__() super().__init__()
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.momentum = momentum self.momentum = momentum
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.grad_clip = grad_clip self.grad_clip = grad_clip
self.multi_precision = multi_precision self.multi_precision = multi_precision
self.use_multi_tensor = use_multi_tensor
def __call__(self, model_list): def __call__(self, model_list):
# model_list is None in static graph # model_list is None in static graph
...@@ -56,8 +55,16 @@ class Momentum(object): ...@@ -56,8 +55,16 @@ class Momentum(object):
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
multi_precision=self.multi_precision, multi_precision=self.multi_precision,
use_multi_tensor=self.use_multi_tensor,
parameters=parameters) parameters=parameters)
if hasattr(opt, '_use_multi_tensor'):
opt = optim.Momentum(
learning_rate=self.learning_rate,
momentum=self.momentum,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
multi_precision=self.multi_precision,
parameters=parameters,
use_multi_tensor=False)
return opt return opt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册