diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 033ec569de811c935c7b43eb4feff8e300a9120f..ae248a7bf12803ba996f059c936db7693f5ca796 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -217,6 +217,199 @@ std::tuple adam_impl( ////////////////// Forward api impls ////////////////////// +std::tuple adamw_impl( + const Tensor& param, + const Tensor& grad, + const Tensor& learning_rate, + const Tensor& moment1, + const Tensor& moment2, + const Tensor& beta1_pow, + const Tensor& beta2_pow, + paddle::optional master_param, + paddle::optional skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + float lr_ratio, + float coeff, + bool with_decay, + bool lazy_mode, + int64_t min_row_size_to_use_multithread, + bool multi_precision, + bool use_global_beta_pow) { + Backend kernel_backend = Backend::UNDEFINED; + DataLayout kernel_layout = DataLayout::UNDEFINED; + DataType kernel_data_type = DataType::UNDEFINED; + if (kernel_backend == Backend::UNDEFINED || + kernel_layout == DataLayout::UNDEFINED || + kernel_data_type == DataType::UNDEFINED) { + auto kernel_key_set = ParseKernelKeyByInputArgs(param); + auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); + if (kernel_backend == Backend::UNDEFINED) { + kernel_backend = kernel_key.backend(); + } + if (kernel_layout == DataLayout::UNDEFINED) { + kernel_layout = kernel_key.layout(); + } + if (kernel_data_type == DataType::UNDEFINED) { + kernel_data_type = kernel_key.dtype(); + } + } + std::string kernel_name = "adamw"; + const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + kernel_name, {kernel_backend, kernel_layout, kernel_data_type}); + VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", " + << kernel_layout << ", " << kernel_data_type << "]"; + VLOG(6) << kernel_name << " API kernel: " << kernel; + + auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + + auto input_param = PrepareData(param, kernel.InputAt(0), {}); + auto input_grad = PrepareData(grad, kernel.InputAt(1), {}); + auto input_lr = PrepareData(learning_rate, kernel.InputAt(2), {}); + auto input_moment1 = PrepareData(moment1, kernel.InputAt(3), {}); + auto input_moment2 = PrepareData(moment2, kernel.InputAt(4), {}); + auto input_beta1_pow = PrepareData(beta1_pow, kernel.InputAt(5), {}); + auto input_beta2_pow = PrepareData(beta2_pow, kernel.InputAt(6), {}); + paddle::optional input_master_param(paddle::none); + auto input_master_param_ptr = + PrepareData(master_param, kernel.InputAt(7), {}); + paddle::optional input_skip_update(paddle::none); + auto input_skip_update_ptr = PrepareData(skip_update, kernel.InputAt(8), {}); + + std::tuple api_output; + auto kernel_out_0 = input_param.get(); + auto kernel_out_1 = input_moment1.get(); + auto kernel_out_2 = input_moment2.get(); + auto kernel_out_3 = input_beta1_pow.get(); + auto kernel_out_4 = input_beta2_pow.get(); + phi::DenseTensor* kernel_out_5 = nullptr; + if (input_master_param_ptr) { + input_master_param = + paddle::make_optional(*input_master_param_ptr); + kernel_out_5 = + paddle::make_optional(*input_master_param_ptr) + .get_ptr(); + } + + if (input_skip_update_ptr) { + input_skip_update = + paddle::make_optional(*input_skip_update_ptr); + } + + paddle::optional input_meta_ref_master_param( + paddle::none); + phi::DenseTensor dt; + phi::MetaTensor input_meta_tmp_master_param(dt); + if (input_master_param_ptr) { + input_meta_tmp_master_param.set_dtype(input_master_param_ptr->dtype()); + input_meta_tmp_master_param.set_dims(input_master_param_ptr->dims()); + input_meta_tmp_master_param.set_layout(input_master_param_ptr->layout()); + input_meta_ref_master_param = input_meta_tmp_master_param; + } + + paddle::optional input_meta_ref_skip_update( + paddle::none); + phi::DenseTensor dt1; + phi::MetaTensor input_meta_tmp_skip_update(dt1); + if (input_skip_update_ptr) { + input_meta_tmp_skip_update.set_dtype(input_skip_update_ptr->dtype()); + input_meta_tmp_skip_update.set_dims(input_skip_update_ptr->dims()); + input_meta_tmp_skip_update.set_layout(input_skip_update_ptr->layout()); + input_meta_ref_skip_update = input_meta_tmp_skip_update; + } + + phi::MetaTensor meta_out_0(kernel_out_0); + phi::MetaTensor meta_out_1(kernel_out_1); + phi::MetaTensor meta_out_2(kernel_out_2); + phi::MetaTensor meta_out_3(kernel_out_3); + phi::MetaTensor meta_out_4(kernel_out_4); + phi::MetaTensor meta_out_5(kernel_out_5); + + phi::AdamwInferMeta(MakeMetaTensor(*input_param), + MakeMetaTensor(*input_grad), + MakeMetaTensor(*input_lr), + MakeMetaTensor(*input_moment1), + MakeMetaTensor(*input_moment2), + MakeMetaTensor(*input_beta1_pow), + MakeMetaTensor(*input_beta2_pow), + input_meta_ref_master_param, + input_meta_ref_skip_update, + beta1, + beta2, + epsilon, + lr_ratio, + coeff, + with_decay, + lazy_mode, + min_row_size_to_use_multithread, + multi_precision, + use_global_beta_pow, + &meta_out_0, + &meta_out_1, + &meta_out_2, + &meta_out_3, + &meta_out_4, + &meta_out_5); + + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + paddle::optional, + paddle::optional, + const Scalar&, + const Scalar&, + const Scalar&, + float, + float, + bool, + bool, + int64_t, + bool, + bool, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + + (*kernel_fn)(*dev_ctx, + *input_param, + *input_grad, + *input_lr, + *input_moment1, + *input_moment2, + *input_beta1_pow, + *input_beta2_pow, + input_master_param, + input_skip_update, + beta1, + beta2, + epsilon, + lr_ratio, + coeff, + with_decay, + lazy_mode, + min_row_size_to_use_multithread, + multi_precision, + use_global_beta_pow, + kernel_out_0, + kernel_out_1, + kernel_out_2, + kernel_out_3, + kernel_out_4, + kernel_out_5); + + return api_output; +} + Tensor conv2d_impl(const Tensor& input, const Tensor& filter, const std::vector& strides, diff --git a/paddle/phi/api/lib/api_custom_impl.h b/paddle/phi/api/lib/api_custom_impl.h index 4ddc3e5f4e0d2edda8864960b79dc8eb22de48ff..46abcd90de32a610794892aec9d828f156239dd0 100644 --- a/paddle/phi/api/lib/api_custom_impl.h +++ b/paddle/phi/api/lib/api_custom_impl.h @@ -49,6 +49,27 @@ std::tuple adam_impl( bool multi_precision, bool use_global_beta_pow); +std::tuple adamw_impl( + const Tensor& param, + const Tensor& grad, + const Tensor& learning_rate, + const Tensor& moment1, + const Tensor& moment2, + const Tensor& beta1_pow, + const Tensor& beta2_pow, + paddle::optional master_param, + paddle::optional skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + float lr_ratio, + float coeff, + bool with_decay, + bool lazy_mode, + int64_t min_row_size_to_use_multithread, + bool multi_precision, + bool use_global_beta_pow); + std::tuple batch_norm_impl( const Tensor& x, const Tensor& scale, diff --git a/python/paddle/fluid/tests/unittests/test_adamw_op.py b/python/paddle/fluid/tests/unittests/test_adamw_op.py index d5fa944802a4763c9686811fdb85cb0be1586f59..d2eef785f6e07e8714101b4fb1915625e3bcbb37 100644 --- a/python/paddle/fluid/tests/unittests/test_adamw_op.py +++ b/python/paddle/fluid/tests/unittests/test_adamw_op.py @@ -20,6 +20,7 @@ import paddle.fluid as fluid from op_test import OpTest from functools import partial from paddle.framework import core +from paddle.fluid.framework import _test_eager_guard def adamw_step(inputs, attributes): @@ -238,6 +239,11 @@ class TestAdamWOp(unittest.TestCase): adam = paddle.optimizer.AdamW( 0.1, epsilon=-1, parameters=linear.parameters()) + def test_api_eager_dygraph(self): + with _test_eager_guard(): + self.test_adamw_op_dygraph() + self.test_adamw_op_invalid_input() + class TestAdamWOpGroup(TestAdamWOp): def test_adamw_op_dygraph(self): @@ -319,6 +325,12 @@ class TestAdamWOpLayerwiseLR(TestAdamWOp): linear1 = paddle.nn.Linear(13, 8) linear2 = paddle.nn.Linear(8, 5) + # fix the linear name, simple_lr_setting function will use the name + linear1.weight.name = "linear_1.w_0" + linear1.bias.name = "linear_1.b_0" + linear2.weight.name = "linear_2.w_0" + linear2.bias.name = "linear_2.b_0" + simple_lr_fun = partial(simple_lr_setting, decay_rate=0.8, n_layers=2) adam = paddle.optimizer.AdamW( diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index e69dcf170d93c5a188ddbe56da9f31a8c5270311..0fa49745a95fb8958f45139c030cd01ecdb6ce87 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -290,14 +290,24 @@ class AdamW(Adam): _beta2 = self._beta2 if not isinstance( self._beta2, Variable) else self._beta2.numpy().item(0) - _, _, _, _, _, _ = _C_ops.adamw( - param_and_grad[0], param_and_grad[1], lr, moment1, moment2, - beta1_pow_acc, beta2_pow_acc, master_weight, param_and_grad[0], - moment1, moment2, beta1_pow_acc, beta2_pow_acc, master_weight, - 'epsilon', self._epsilon, 'lazy_mode', self._lazy_mode, - 'min_row_size_to_use_multithread', 1000, 'beta1', _beta1, - 'beta2', _beta2, "with_decay", with_decay, 'coeff', self._coeff, - 'multi_precision', find_master, 'lr_ratio', lr_ratio_) + if framework.in_dygraph_mode(): + found_inf = self._get_auxiliary_var('found_inf') + _, _, _, _, _, _ = _C_ops.final_state_adamw( + param_and_grad[0], param_and_grad[1], lr, moment1, moment2, + beta1_pow_acc, beta2_pow_acc, master_weight, found_inf, + _beta1, _beta2, self._epsilon, lr_ratio_, self._coeff, + with_decay, self._lazy_mode, 1000, find_master, False) + else: + _, _, _, _, _, _ = _C_ops.adamw( + param_and_grad[0], param_and_grad[1], lr, moment1, moment2, + beta1_pow_acc, beta2_pow_acc, master_weight, + param_and_grad[0], moment1, moment2, beta1_pow_acc, + beta2_pow_acc, master_weight, 'epsilon', self._epsilon, + 'lazy_mode', self._lazy_mode, + 'min_row_size_to_use_multithread', 1000, 'beta1', _beta1, + 'beta2', _beta2, "with_decay", with_decay, 'coeff', + self._coeff, 'multi_precision', find_master, 'lr_ratio', + lr_ratio_) return None inputs = { diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index a142225e6578cd95b6837445671c7d2170283c9e..41b5fc26fa9413f67fa898915164f4878e6edbf3 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -58,6 +58,12 @@ func : AdamaxInferMeta kernel : func : adamax + +- api : adamw + args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, float lr_ratio, float coeff, bool with_decay, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow) + output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs) + optional : master_param, skip_update + invoke : adamw_impl(param, grad, learning_rate, moment1, moment2, beta1_pow, beta2_pow, master_param, skip_update, beta1, beta2, epsilon, lr_ratio, coeff, with_decay, lazy_mode, min_row_size_to_use_multithread, multi_precision, use_global_beta_pow) - api : add args : (Tensor x, Tensor y)