未验证 提交 a4b9daf9 编写于 作者: L Leo Chen 提交者: GitHub

fix optimizer dtype (#29917)

上级 9602a182
...@@ -108,8 +108,12 @@ class Optimizer(object): ...@@ -108,8 +108,12 @@ class Optimizer(object):
self.regularization = regularization self.regularization = regularization
self._grad_clip = grad_clip self._grad_clip = grad_clip
self._learning_rate = learning_rate self._learning_rate = learning_rate
# the learning rate type should be inferenced from loss
self._dtype = None self._dtype = None
# Infer the dtype form parameter
if self._parameter_list:
self._dtype = self._parameter_list[0].dtype
# each program should have a independent learning rate # each program should have a independent learning rate
# program -> Variable(learning_rate) # program -> Variable(learning_rate)
self._learning_rate_map = dict() self._learning_rate_map = dict()
...@@ -768,7 +772,10 @@ class Optimizer(object): ...@@ -768,7 +772,10 @@ class Optimizer(object):
else: else:
act_no_grad_set = self._get_no_grad_set(loss, no_grad_set) act_no_grad_set = self._get_no_grad_set(loss, no_grad_set)
self._dtype = loss.dtype # Infer dtype by loss if None
if self._dtype is None:
self._dtype = loss.dtype
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
parameter_list = parameter_list if parameter_list \ parameter_list = parameter_list if parameter_list \
else self._parameter_list else self._parameter_list
......
...@@ -23,7 +23,8 @@ import paddle.fluid.core as core ...@@ -23,7 +23,8 @@ import paddle.fluid.core as core
import paddle.compat as cpt import paddle.compat as cpt
import numpy as np import numpy as np
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard, convert_np_dtype_to_dtype_
import paddle
class TestOptimizer(unittest.TestCase): class TestOptimizer(unittest.TestCase):
...@@ -1042,5 +1043,37 @@ class TestGradientMergeOptimizer(unittest.TestCase): ...@@ -1042,5 +1043,37 @@ class TestGradientMergeOptimizer(unittest.TestCase):
['sgd', 'sgd']) ['sgd', 'sgd'])
class TestOptimizerDtype(unittest.TestCase):
'''
The dtype of optimizer should be inferred by parameters, and the learning rate
is cteated with the same dtype.
'''
def check_with_dtype(self, dtype):
class MyLayer(paddle.nn.Layer):
def __init__(self, dtype):
super(MyLayer, self).__init__()
self._w = self.create_parameter([2, 3], dtype=dtype)
self._b = self.create_parameter([2, 3], dtype=dtype)
def forward(self, x):
return x * self._w + self._b
with paddle.fluid.dygraph.guard():
model = MyLayer(dtype)
x = paddle.rand([10, 2, 3], dtype=dtype)
loss = model(x)
adam = paddle.optimizer.Adam(parameters=model.parameters())
loss.backward()
adam.step()
self.assertEqual(adam._dtype, convert_np_dtype_to_dtype_(dtype))
def test_float64(self):
self.check_with_dtype('float64')
def test_float32(self):
self.check_with_dtype('float32')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -270,7 +270,6 @@ class Adam(Optimizer): ...@@ -270,7 +270,6 @@ class Adam(Optimizer):
adam.step() adam.step()
adam.clear_grad() adam.clear_grad()
""" """
self._dtype = None
params_grads = [] params_grads = []
for param in self._parameter_list: for param in self._parameter_list:
if not param.trainable: if not param.trainable:
......
...@@ -210,7 +210,6 @@ class AdamW(Adam): ...@@ -210,7 +210,6 @@ class AdamW(Adam):
@framework.dygraph_only @framework.dygraph_only
@imperative_base.no_grad @imperative_base.no_grad
def step(self): def step(self):
self._dtype = None
params_grads = [] params_grads = []
for param in self._parameter_list: for param in self._parameter_list:
if not param.trainable: if not param.trainable:
......
...@@ -132,8 +132,12 @@ class Optimizer(object): ...@@ -132,8 +132,12 @@ class Optimizer(object):
self.regularization = weight_decay self.regularization = weight_decay
self._grad_clip = grad_clip self._grad_clip = grad_clip
self._learning_rate = learning_rate self._learning_rate = learning_rate
# the learning rate type should be inferenced from loss
self._dtype = None self._dtype = None
# Infer the dtype form parameter
if self._parameter_list:
self._dtype = self._parameter_list[0].dtype
# each program should have a independent learning rate # each program should have a independent learning rate
# program -> tensor(learning_rate) # program -> tensor(learning_rate)
self._learning_rate_map = dict() self._learning_rate_map = dict()
...@@ -675,7 +679,10 @@ class Optimizer(object): ...@@ -675,7 +679,10 @@ class Optimizer(object):
else: else:
act_no_grad_set = self._get_no_grad_set(loss, no_grad_set) act_no_grad_set = self._get_no_grad_set(loss, no_grad_set)
self._dtype = loss.dtype # Infer dtype by loss if None
if self._dtype is None:
self._dtype = loss.dtype
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
parameter_list = parameters if parameters \ parameter_list = parameters if parameters \
else self._parameter_list else self._parameter_list
...@@ -885,6 +892,7 @@ class Optimizer(object): ...@@ -885,6 +892,7 @@ class Optimizer(object):
return optimize_ops, params_grads return optimize_ops, params_grads
@imperative_base.no_grad
@framework.dygraph_only @framework.dygraph_only
def step(self): def step(self):
""" """
...@@ -910,7 +918,6 @@ class Optimizer(object): ...@@ -910,7 +918,6 @@ class Optimizer(object):
adam.step() adam.step()
adam.clear_grad() adam.clear_grad()
""" """
self._dtype = None
params_grads = [] params_grads = []
for param in self._parameter_list: for param in self._parameter_list:
if not param.trainable: if not param.trainable:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册