未验证 提交 c47d6729 编写于 作者: S sneaxiy 提交者: GitHub

Add _get_parameter method to Lamb optimizer (#39416)

* add _get_parameter func to lamb

* remove duplicate code
上级 32d79bb9
......@@ -195,32 +195,57 @@ class TestLambOpMultiPrecision(unittest.TestCase):
hidden = linear(x)
loss = paddle.mean(hidden)
optimizer = paddle.optimizer.Lamb(learning_rate=1e-3)
optimizer._multi_precision = multi_precision
original_optimizer = paddle.optimizer.Lamb(learning_rate=1e-3)
original_optimizer._multi_precision = multi_precision
if multi_precision:
optimizer = paddle.static.amp.decorate(
optimizer, use_pure_fp16=True, use_fp16_guard=True)
original_optimizer, use_pure_fp16=True, use_fp16_guard=True)
else:
optimizer = original_optimizer
optimizer.minimize(loss)
weight, bias = linear.weight, linear.bias
scope = paddle.static.Scope()
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
x = main_prog.global_block().var(x.name)
if x.dtype == core.VarDesc.VarType.FP16:
x_np = x_np.astype(np.float16)
def get_parameter(var):
name = var if isinstance(var, (str, bytes)) else var.name
params = original_optimizer._get_parameter(name, scope)
assert isinstance(params, (list, tuple))
params = list(params)
assert len(params) == 2
if multi_precision:
params[0] = np.array(params[0])
params[1] = np.array(params[1])
self.assertTrue(
np.array_equal(params[0], params[1].astype(np.float16)))
return params[0].astype(np.float32)
else:
self.assertTrue(params[0] is not None)
self.assertTrue(params[1] is None)
params[0] = np.array(params[0])
return params[0]
with paddle.static.scope_guard(scope):
exe.run(startup_prog)
if multi_precision:
optimizer.amp_init(place)
weight_np, bias_np = None, None
for i in range(n):
feed_dict = {x.name: x_np}
weight_np, bias_np = exe.run(main_prog,
feed=feed_dict,
fetch_list=[weight, bias])
return weight_np.astype('float32'), bias_np.astype('float32')
weight_np = weight_np.astype('float32')
bias_np = bias_np.astype('float32')
self.assertTrue(
np.array_equal(weight_np, get_parameter(weight)))
self.assertTrue(np.array_equal(bias_np, get_parameter(bias)))
return weight_np, bias_np
@switch_to_static_graph
def test_main(self):
......
......@@ -20,6 +20,7 @@ from ..fluid import layers
from ..fluid import unique_name
from ..fluid.layer_helper import LayerHelper
from paddle import _C_ops
from paddle.fluid.executor import global_scope
__all__ = []
......@@ -131,9 +132,25 @@ class Lamb(Optimizer):
'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn,
}
self._master_weights = {}
self._used_master_weights = {}
# TODO(zengjinle): expose API as soon as possible
self._multi_precision = False
def _get_parameter(self, name, scope=None):
if scope is None:
scope = global_scope()
p_t = scope.find_var(name).get_tensor()
master_name = self._used_master_weights.get(name)
if master_name is not None:
master_p_t = scope.find_var(master_name).get_tensor()
assert master_p_t._dtype() != p_t._dtype()
assert master_p_t.shape() == p_t.shape()
else:
master_p_t = None
return p_t, master_p_t
def _create_master_weight(self, param):
assert self._multi_precision
if param.name in self._master_weights:
......@@ -243,8 +260,12 @@ class Lamb(Optimizer):
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = self._master_weights[param_and_grad[0]
.name] if find_master else None
p_name = param_and_grad[0].name
if find_master:
master_weight = self._master_weights[p_name]
self._used_master_weights[p_name] = master_weight.name
else:
master_weight = None
found_inf = self._get_auxiliary_var('found_inf')
if framework.in_dygraph_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册