diff --git a/python/paddle/fluid/tests/unittests/test_lambv2_op.py b/python/paddle/fluid/tests/unittests/test_lambv2_op.py index 24a22f802ce92f2efffe15169ef36496f82664b4..674cd9a3e9c5bb4b4fdb4eddec5df0b2a23f64b7 100644 --- a/python/paddle/fluid/tests/unittests/test_lambv2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lambv2_op.py @@ -195,32 +195,57 @@ class TestLambOpMultiPrecision(unittest.TestCase): hidden = linear(x) loss = paddle.mean(hidden) - optimizer = paddle.optimizer.Lamb(learning_rate=1e-3) - optimizer._multi_precision = multi_precision + original_optimizer = paddle.optimizer.Lamb(learning_rate=1e-3) + original_optimizer._multi_precision = multi_precision if multi_precision: optimizer = paddle.static.amp.decorate( - optimizer, use_pure_fp16=True, use_fp16_guard=True) + original_optimizer, use_pure_fp16=True, use_fp16_guard=True) + else: + optimizer = original_optimizer optimizer.minimize(loss) weight, bias = linear.weight, linear.bias - scope = paddle.static.Scope() exe = paddle.static.Executor(place) scope = paddle.static.Scope() x = main_prog.global_block().var(x.name) if x.dtype == core.VarDesc.VarType.FP16: x_np = x_np.astype(np.float16) + def get_parameter(var): + name = var if isinstance(var, (str, bytes)) else var.name + params = original_optimizer._get_parameter(name, scope) + assert isinstance(params, (list, tuple)) + params = list(params) + assert len(params) == 2 + if multi_precision: + params[0] = np.array(params[0]) + params[1] = np.array(params[1]) + self.assertTrue( + np.array_equal(params[0], params[1].astype(np.float16))) + return params[0].astype(np.float32) + else: + self.assertTrue(params[0] is not None) + self.assertTrue(params[1] is None) + params[0] = np.array(params[0]) + return params[0] + with paddle.static.scope_guard(scope): exe.run(startup_prog) if multi_precision: optimizer.amp_init(place) + weight_np, bias_np = None, None for i in range(n): feed_dict = {x.name: x_np} weight_np, bias_np = exe.run(main_prog, feed=feed_dict, fetch_list=[weight, bias]) - return weight_np.astype('float32'), bias_np.astype('float32') + weight_np = weight_np.astype('float32') + bias_np = bias_np.astype('float32') + self.assertTrue( + np.array_equal(weight_np, get_parameter(weight))) + self.assertTrue(np.array_equal(bias_np, get_parameter(bias))) + return weight_np, bias_np @switch_to_static_graph def test_main(self): diff --git a/python/paddle/optimizer/lamb.py b/python/paddle/optimizer/lamb.py index 894c829f58830540d7e8b74a9ce1da6e287dded5..b8d7b101b1d66b32a1b7a1b0f98c71ed3a1c1896 100644 --- a/python/paddle/optimizer/lamb.py +++ b/python/paddle/optimizer/lamb.py @@ -20,6 +20,7 @@ from ..fluid import layers from ..fluid import unique_name from ..fluid.layer_helper import LayerHelper from paddle import _C_ops +from paddle.fluid.executor import global_scope __all__ = [] @@ -131,9 +132,25 @@ class Lamb(Optimizer): 'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn, } self._master_weights = {} + self._used_master_weights = {} # TODO(zengjinle): expose API as soon as possible self._multi_precision = False + def _get_parameter(self, name, scope=None): + if scope is None: + scope = global_scope() + + p_t = scope.find_var(name).get_tensor() + + master_name = self._used_master_weights.get(name) + if master_name is not None: + master_p_t = scope.find_var(master_name).get_tensor() + assert master_p_t._dtype() != p_t._dtype() + assert master_p_t.shape() == p_t.shape() + else: + master_p_t = None + return p_t, master_p_t + def _create_master_weight(self, param): assert self._multi_precision if param.name in self._master_weights: @@ -243,8 +260,12 @@ class Lamb(Optimizer): find_master = self._multi_precision and param_and_grad[ 0].dtype == core.VarDesc.VarType.FP16 - master_weight = self._master_weights[param_and_grad[0] - .name] if find_master else None + p_name = param_and_grad[0].name + if find_master: + master_weight = self._master_weights[p_name] + self._used_master_weights[p_name] = master_weight.name + else: + master_weight = None found_inf = self._get_auxiliary_var('found_inf') if framework.in_dygraph_mode():