From b4f74eed17fc114ebd53af35bbc933212f5506c4 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Mon, 29 Aug 2022 11:09:13 +0800 Subject: [PATCH] [phi] Transfer merged_adam yaml to phi (#45367) * add legacy_api.yaml * set merged_momentum inplace only * support inplace optional> * add dygraph_mode api * add merged_adam yaml * add merged_adam python api * change testcase of merged_adam and adam * fix import of test_merged_adam_op --- paddle/phi/api/yaml/legacy_api.yaml | 11 ++++++ .../fluid/tests/unittests/test_adam_op.py | 4 +++ .../tests/unittests/test_merged_adam_op.py | 19 +++++++---- python/paddle/optimizer/adam.py | 34 ++++++++++++------- 4 files changed, 50 insertions(+), 18 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index f71674ec91b..840738d4bb6 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1730,6 +1730,17 @@ func : mean_all backward : mean_all_grad +- api : merged_adam_ + args : (Tensor[] param, Tensor[] grad, Tensor[] learning_rate, Tensor[] moment1, Tensor[] moment2, Tensor[] beta1_pow, Tensor[] beta2_pow, Tensor[] master_param, Scalar beta1, Scalar beta2, Scalar epsilon, bool multi_precision, bool use_global_beta_pow) + output : Tensor[](param_out){param.size()}, Tensor[](moment1_out){param.size()}, Tensor[](moment2_out){param.size()}, Tensor[](beta1_pow_out){param.size()}, Tensor[](beta2_pow_out){param.size()}, Tensor[](master_param_out){param.size()} + infer_meta : + func : MergedAdamInferMeta + optional: master_param + kernel : + func : merged_adam + data_type : param + inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_out) + - api : merged_momentum_ args : (Tensor[] param, Tensor[] grad, Tensor[] velocity, Tensor[] learning_rate, Tensor[] master_param, float mu, bool use_nesterov = false, str[] regularization_method = {}, float[] regularization_coeff = {}, bool multi_precision = false, float rescale_grad = 1.0f) output : Tensor[](param_out){param.size()}, Tensor[](velocity_out){param.size()}, Tensor[](master_param_out){param.size()} diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index 1396d073f7c..eb2ea8c56ac 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -1230,6 +1230,10 @@ class TestMultiTensorAdam(unittest.TestCase): self._check_with_param_arrt(place, use_amp) self._check_with_param_group(place, use_amp) + def test_api_eager_dygraph(self): + with _test_eager_guard(): + self.test_main() + if __name__ == "__main__": paddle.enable_static() diff --git a/python/paddle/fluid/tests/unittests/test_merged_adam_op.py b/python/paddle/fluid/tests/unittests/test_merged_adam_op.py index 21749a92f31..c6aadeda5fc 100644 --- a/python/paddle/fluid/tests/unittests/test_merged_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_merged_adam_op.py @@ -16,6 +16,7 @@ import unittest import paddle import numpy as np from paddle import _C_ops, _legacy_C_ops +from paddle.fluid.framework import in_dygraph_mode def run_adam_op(params, @@ -63,12 +64,18 @@ def run_adam_op(params, master_param_vars[i], 'epsilon', epsilon, 'beta1', beta1, 'beta2', beta2, 'multi_precision', multi_precision) else: - _, _, _, _, _, _ = _legacy_C_ops.merged_adam( - param_vars, grad_vars, lr_vars, moment1_vars, moment2_vars, - beta1_pow_vars, beta2_pow_vars, master_param_vars, param_vars, - moment1_vars, moment2_vars, beta1_pow_vars, beta2_pow_vars, - master_param_vars, 'epsilon', epsilon, 'beta1', beta1, 'beta2', - beta2, 'multi_precision', multi_precision) + if in_dygraph_mode(): + _, _, _, _, _, _ = _C_ops.merged_adam_( + param_vars, grad_vars, lr_vars, moment1_vars, moment2_vars, + beta1_pow_vars, beta2_pow_vars, master_param_vars, beta1, beta2, + epsilon, multi_precision, False) + else: + _, _, _, _, _, _ = _legacy_C_ops.merged_adam( + param_vars, grad_vars, lr_vars, moment1_vars, moment2_vars, + beta1_pow_vars, beta2_pow_vars, master_param_vars, param_vars, + moment1_vars, moment2_vars, beta1_pow_vars, beta2_pow_vars, + master_param_vars, 'epsilon', epsilon, 'beta1', beta1, 'beta2', + beta2, 'multi_precision', multi_precision) outputs = { 'ParamOut': param_vars, diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 96ff625f1f9..1140516cdc5 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -583,18 +583,28 @@ class Adam(Optimizer): self._beta2, Variable) else self._beta2.numpy().item(0) if framework._non_static_mode(): - _, _, _, _, _, _ = _legacy_C_ops.merged_adam( - self._param_dict[key], grad_dict[key], lr_dict[key], - self._moment1_dict[key], self._moment2_dict[key], - self._beta1_pow_acc_dict[key], - self._beta2_pow_acc_dict[key], - self._master_weight_dict[key], self._param_dict[key], - self._moment1_dict[key], self._moment2_dict[key], - self._beta1_pow_acc_dict[key], - self._beta2_pow_acc_dict[key], - self._master_weight_dict[key], 'epsilon', self._epsilon, - 'beta1', _beta1, 'beta2', _beta2, 'multi_precision', - find_master) + if in_dygraph_mode(): + _, _, _, _, _, _ = _C_ops.merged_adam_( + self._param_dict[key], grad_dict[key], lr_dict[key], + self._moment1_dict[key], self._moment2_dict[key], + self._beta1_pow_acc_dict[key], + self._beta2_pow_acc_dict[key], + self._master_weight_dict[key], _beta1, _beta2, + self._epsilon, find_master, False) + else: + _, _, _, _, _, _ = _legacy_C_ops.merged_adam( + self._param_dict[key], grad_dict[key], lr_dict[key], + self._moment1_dict[key], self._moment2_dict[key], + self._beta1_pow_acc_dict[key], + self._beta2_pow_acc_dict[key], + self._master_weight_dict[key], + self._param_dict[key], self._moment1_dict[key], + self._moment2_dict[key], + self._beta1_pow_acc_dict[key], + self._beta2_pow_acc_dict[key], + self._master_weight_dict[key], 'epsilon', + self._epsilon, 'beta1', _beta1, 'beta2', _beta2, + 'multi_precision', find_master) else: inputs = { "Param": self._param_dict[key], -- GitLab