From c47d6729b4c7824414fe0ebe5316d52796cef466 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 10 Feb 2022 14:33:01 +0800 Subject: [PATCH] Add _get_parameter method to Lamb optimizer (#39416) * add _get_parameter func to lamb * remove duplicate code --- .../fluid/tests/unittests/test_lambv2_op.py | 35 ++++++++++++++++--- python/paddle/optimizer/lamb.py | 25 +++++++++++-- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lambv2_op.py b/python/paddle/fluid/tests/unittests/test_lambv2_op.py index 24a22f802ce..674cd9a3e9c 100644 --- a/python/paddle/fluid/tests/unittests/test_lambv2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lambv2_op.py @@ -195,32 +195,57 @@ class TestLambOpMultiPrecision(unittest.TestCase): hidden = linear(x) loss = paddle.mean(hidden) - optimizer = paddle.optimizer.Lamb(learning_rate=1e-3) - optimizer._multi_precision = multi_precision + original_optimizer = paddle.optimizer.Lamb(learning_rate=1e-3) + original_optimizer._multi_precision = multi_precision if multi_precision: optimizer = paddle.static.amp.decorate( - optimizer, use_pure_fp16=True, use_fp16_guard=True) + original_optimizer, use_pure_fp16=True, use_fp16_guard=True) + else: + optimizer = original_optimizer optimizer.minimize(loss) weight, bias = linear.weight, linear.bias - scope = paddle.static.Scope() exe = paddle.static.Executor(place) scope = paddle.static.Scope() x = main_prog.global_block().var(x.name) if x.dtype == core.VarDesc.VarType.FP16: x_np = x_np.astype(np.float16) + def get_parameter(var): + name = var if isinstance(var, (str, bytes)) else var.name + params = original_optimizer._get_parameter(name, scope) + assert isinstance(params, (list, tuple)) + params = list(params) + assert len(params) == 2 + if multi_precision: + params[0] = np.array(params[0]) + params[1] = np.array(params[1]) + self.assertTrue( + np.array_equal(params[0], params[1].astype(np.float16))) + return params[0].astype(np.float32) + else: + self.assertTrue(params[0] is not None) + self.assertTrue(params[1] is None) + params[0] = np.array(params[0]) + return params[0] + with paddle.static.scope_guard(scope): exe.run(startup_prog) if multi_precision: optimizer.amp_init(place) + weight_np, bias_np = None, None for i in range(n): feed_dict = {x.name: x_np} weight_np, bias_np = exe.run(main_prog, feed=feed_dict, fetch_list=[weight, bias]) - return weight_np.astype('float32'), bias_np.astype('float32') + weight_np = weight_np.astype('float32') + bias_np = bias_np.astype('float32') + self.assertTrue( + np.array_equal(weight_np, get_parameter(weight))) + self.assertTrue(np.array_equal(bias_np, get_parameter(bias))) + return weight_np, bias_np @switch_to_static_graph def test_main(self): diff --git a/python/paddle/optimizer/lamb.py b/python/paddle/optimizer/lamb.py index 894c829f588..b8d7b101b1d 100644 --- a/python/paddle/optimizer/lamb.py +++ b/python/paddle/optimizer/lamb.py @@ -20,6 +20,7 @@ from ..fluid import layers from ..fluid import unique_name from ..fluid.layer_helper import LayerHelper from paddle import _C_ops +from paddle.fluid.executor import global_scope __all__ = [] @@ -131,9 +132,25 @@ class Lamb(Optimizer): 'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn, } self._master_weights = {} + self._used_master_weights = {} # TODO(zengjinle): expose API as soon as possible self._multi_precision = False + def _get_parameter(self, name, scope=None): + if scope is None: + scope = global_scope() + + p_t = scope.find_var(name).get_tensor() + + master_name = self._used_master_weights.get(name) + if master_name is not None: + master_p_t = scope.find_var(master_name).get_tensor() + assert master_p_t._dtype() != p_t._dtype() + assert master_p_t.shape() == p_t.shape() + else: + master_p_t = None + return p_t, master_p_t + def _create_master_weight(self, param): assert self._multi_precision if param.name in self._master_weights: @@ -243,8 +260,12 @@ class Lamb(Optimizer): find_master = self._multi_precision and param_and_grad[ 0].dtype == core.VarDesc.VarType.FP16 - master_weight = self._master_weights[param_and_grad[0] - .name] if find_master else None + p_name = param_and_grad[0].name + if find_master: + master_weight = self._master_weights[p_name] + self._used_master_weights[p_name] = master_weight.name + else: + master_weight = None found_inf = self._get_auxiliary_var('found_inf') if framework.in_dygraph_mode(): -- GitLab