未验证 提交 c47d6729 编写于 作者: S sneaxiy 提交者: GitHub

Add _get_parameter method to Lamb optimizer (#39416)

* add _get_parameter func to lamb

* remove duplicate code
上级 32d79bb9
...@@ -195,32 +195,57 @@ class TestLambOpMultiPrecision(unittest.TestCase): ...@@ -195,32 +195,57 @@ class TestLambOpMultiPrecision(unittest.TestCase):
hidden = linear(x) hidden = linear(x)
loss = paddle.mean(hidden) loss = paddle.mean(hidden)
optimizer = paddle.optimizer.Lamb(learning_rate=1e-3) original_optimizer = paddle.optimizer.Lamb(learning_rate=1e-3)
optimizer._multi_precision = multi_precision original_optimizer._multi_precision = multi_precision
if multi_precision: if multi_precision:
optimizer = paddle.static.amp.decorate( optimizer = paddle.static.amp.decorate(
optimizer, use_pure_fp16=True, use_fp16_guard=True) original_optimizer, use_pure_fp16=True, use_fp16_guard=True)
else:
optimizer = original_optimizer
optimizer.minimize(loss) optimizer.minimize(loss)
weight, bias = linear.weight, linear.bias weight, bias = linear.weight, linear.bias
scope = paddle.static.Scope()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
scope = paddle.static.Scope() scope = paddle.static.Scope()
x = main_prog.global_block().var(x.name) x = main_prog.global_block().var(x.name)
if x.dtype == core.VarDesc.VarType.FP16: if x.dtype == core.VarDesc.VarType.FP16:
x_np = x_np.astype(np.float16) x_np = x_np.astype(np.float16)
def get_parameter(var):
name = var if isinstance(var, (str, bytes)) else var.name
params = original_optimizer._get_parameter(name, scope)
assert isinstance(params, (list, tuple))
params = list(params)
assert len(params) == 2
if multi_precision:
params[0] = np.array(params[0])
params[1] = np.array(params[1])
self.assertTrue(
np.array_equal(params[0], params[1].astype(np.float16)))
return params[0].astype(np.float32)
else:
self.assertTrue(params[0] is not None)
self.assertTrue(params[1] is None)
params[0] = np.array(params[0])
return params[0]
with paddle.static.scope_guard(scope): with paddle.static.scope_guard(scope):
exe.run(startup_prog) exe.run(startup_prog)
if multi_precision: if multi_precision:
optimizer.amp_init(place) optimizer.amp_init(place)
weight_np, bias_np = None, None weight_np, bias_np = None, None
for i in range(n): for i in range(n):
feed_dict = {x.name: x_np} feed_dict = {x.name: x_np}
weight_np, bias_np = exe.run(main_prog, weight_np, bias_np = exe.run(main_prog,
feed=feed_dict, feed=feed_dict,
fetch_list=[weight, bias]) fetch_list=[weight, bias])
return weight_np.astype('float32'), bias_np.astype('float32') weight_np = weight_np.astype('float32')
bias_np = bias_np.astype('float32')
self.assertTrue(
np.array_equal(weight_np, get_parameter(weight)))
self.assertTrue(np.array_equal(bias_np, get_parameter(bias)))
return weight_np, bias_np
@switch_to_static_graph @switch_to_static_graph
def test_main(self): def test_main(self):
......
...@@ -20,6 +20,7 @@ from ..fluid import layers ...@@ -20,6 +20,7 @@ from ..fluid import layers
from ..fluid import unique_name from ..fluid import unique_name
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid.executor import global_scope
__all__ = [] __all__ = []
...@@ -131,9 +132,25 @@ class Lamb(Optimizer): ...@@ -131,9 +132,25 @@ class Lamb(Optimizer):
'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn, 'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn,
} }
self._master_weights = {} self._master_weights = {}
self._used_master_weights = {}
# TODO(zengjinle): expose API as soon as possible # TODO(zengjinle): expose API as soon as possible
self._multi_precision = False self._multi_precision = False
def _get_parameter(self, name, scope=None):
if scope is None:
scope = global_scope()
p_t = scope.find_var(name).get_tensor()
master_name = self._used_master_weights.get(name)
if master_name is not None:
master_p_t = scope.find_var(master_name).get_tensor()
assert master_p_t._dtype() != p_t._dtype()
assert master_p_t.shape() == p_t.shape()
else:
master_p_t = None
return p_t, master_p_t
def _create_master_weight(self, param): def _create_master_weight(self, param):
assert self._multi_precision assert self._multi_precision
if param.name in self._master_weights: if param.name in self._master_weights:
...@@ -243,8 +260,12 @@ class Lamb(Optimizer): ...@@ -243,8 +260,12 @@ class Lamb(Optimizer):
find_master = self._multi_precision and param_and_grad[ find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16 0].dtype == core.VarDesc.VarType.FP16
master_weight = self._master_weights[param_and_grad[0] p_name = param_and_grad[0].name
.name] if find_master else None if find_master:
master_weight = self._master_weights[p_name]
self._used_master_weights[p_name] = master_weight.name
else:
master_weight = None
found_inf = self._get_auxiliary_var('found_inf') found_inf = self._get_auxiliary_var('found_inf')
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册