diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index f23db10ac774e6cf53bd5a795a565239492a58b8..e98d749a6cb6c9641b60ac4a341497b83c6fc30c 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -34,170 +34,6 @@ namespace experimental { ////////////////// Forward api impls ////////////////////// -std::tuple adamw_impl( - const Tensor& param, - const Tensor& grad, - const Tensor& learning_rate, - const Tensor& moment1, - const Tensor& moment2, - const Tensor& beta1_pow, - const Tensor& beta2_pow, - const paddle::optional& master_param, - const paddle::optional& skip_update, - const Scalar& beta1, - const Scalar& beta2, - const Scalar& epsilon, - float lr_ratio, - float coeff, - bool with_decay, - bool lazy_mode, - int64_t min_row_size_to_use_multithread, - bool multi_precision, - bool use_global_beta_pow) { - Backend kernel_backend = Backend::UNDEFINED; - DataLayout kernel_layout = DataLayout::UNDEFINED; - DataType kernel_data_type = DataType::UNDEFINED; - if (kernel_backend == Backend::UNDEFINED || - kernel_layout == DataLayout::UNDEFINED || - kernel_data_type == DataType::UNDEFINED) { - auto kernel_key_set = ParseKernelKeyByInputArgs(param); - auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - if (kernel_backend == Backend::UNDEFINED) { - kernel_backend = kernel_key.backend(); - } - if (kernel_layout == DataLayout::UNDEFINED) { - kernel_layout = kernel_key.layout(); - } - if (kernel_data_type == DataType::UNDEFINED) { - kernel_data_type = kernel_key.dtype(); - } - } - std::string kernel_name = "adamw"; - auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( - kernel_name, {kernel_backend, kernel_layout, kernel_data_type}); - const auto& kernel = kernel_result.kernel; - VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", " - << kernel_layout << ", " << kernel_data_type << "]"; - VLOG(6) << kernel_name << " API kernel: " << kernel; - - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); - - auto input_param = PrepareData(param, kernel.InputAt(0), {}); - auto input_grad = PrepareData(grad, kernel.InputAt(1), {}); - auto input_lr = PrepareData(learning_rate, kernel.InputAt(2), {}); - auto input_moment1 = PrepareData(moment1, kernel.InputAt(3), {}); - auto input_moment2 = PrepareData(moment2, kernel.InputAt(4), {}); - auto input_beta1_pow = PrepareData(beta1_pow, kernel.InputAt(5), {}); - auto input_beta2_pow = PrepareData(beta2_pow, kernel.InputAt(6), {}); - auto input_master_param = PrepareData(master_param, kernel.InputAt(7), {}); - auto input_skip_update = PrepareData(skip_update, kernel.InputAt(8), {}); - - std::tuple api_output; - auto kernel_out_0 = input_param.get(); - auto kernel_out_1 = input_moment1.get(); - auto kernel_out_2 = input_moment2.get(); - auto kernel_out_3 = input_beta1_pow.get(); - auto kernel_out_4 = input_beta2_pow.get(); - phi::DenseTensor* kernel_out_5 = nullptr; - if (input_master_param) { - kernel_out_5 = input_master_param.get_ptr(); - } - - auto input_meta_ref_master_param = MakeMetaTensor(input_master_param); - - auto input_meta_ref_skip_update = MakeMetaTensor(input_skip_update); - - phi::MetaTensor meta_out_0(kernel_out_0); - phi::MetaTensor meta_out_1(kernel_out_1); - phi::MetaTensor meta_out_2(kernel_out_2); - phi::MetaTensor meta_out_3(kernel_out_3); - phi::MetaTensor meta_out_4(kernel_out_4); - phi::MetaTensor meta_out_5(kernel_out_5); - - phi::AdamwInferMeta(MakeMetaTensor(*input_param), - MakeMetaTensor(*input_grad), - MakeMetaTensor(*input_lr), - MakeMetaTensor(*input_moment1), - MakeMetaTensor(*input_moment2), - MakeMetaTensor(*input_beta1_pow), - MakeMetaTensor(*input_beta2_pow), - input_meta_ref_master_param, - input_meta_ref_skip_update, - beta1, - beta2, - epsilon, - lr_ratio, - coeff, - with_decay, - lazy_mode, - min_row_size_to_use_multithread, - multi_precision, - use_global_beta_pow, - &meta_out_0, - &meta_out_1, - &meta_out_2, - &meta_out_3, - &meta_out_4, - &meta_out_5); - - using kernel_signature = void (*)(const platform::DeviceContext&, - const phi::DenseTensor&, - const phi::DenseTensor&, - const phi::DenseTensor&, - const phi::DenseTensor&, - const phi::DenseTensor&, - const phi::DenseTensor&, - const phi::DenseTensor&, - const paddle::optional&, - const paddle::optional&, - const Scalar&, - const Scalar&, - const Scalar&, - float, - float, - bool, - bool, - int64_t, - bool, - bool, - phi::DenseTensor*, - phi::DenseTensor*, - phi::DenseTensor*, - phi::DenseTensor*, - phi::DenseTensor*, - phi::DenseTensor*); - auto* kernel_fn = kernel.GetVariadicKernelFn(); - - (*kernel_fn)(*dev_ctx, - *input_param, - *input_grad, - *input_lr, - *input_moment1, - *input_moment2, - *input_beta1_pow, - *input_beta2_pow, - input_master_param, - input_skip_update, - beta1, - beta2, - epsilon, - lr_ratio, - coeff, - with_decay, - lazy_mode, - min_row_size_to_use_multithread, - multi_precision, - use_global_beta_pow, - kernel_out_0, - kernel_out_1, - kernel_out_2, - kernel_out_3, - kernel_out_4, - kernel_out_5); - - return api_output; -} - Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) { Tensor out; copy(x, place, blocking, &out); diff --git a/paddle/phi/api/lib/api_custom_impl.h b/paddle/phi/api/lib/api_custom_impl.h index b4b09e0e29eed1ae6716042bd5a363490ddbfd10..907a053d4f10b7e5f899c964cf533c8083d29700 100644 --- a/paddle/phi/api/lib/api_custom_impl.h +++ b/paddle/phi/api/lib/api_custom_impl.h @@ -31,27 +31,6 @@ namespace experimental { ////////////////// Forward api impls ////////////////////// -std::tuple adamw_impl( - const Tensor& param, - const Tensor& grad, - const Tensor& learning_rate, - const Tensor& moment1, - const Tensor& moment2, - const Tensor& beta1_pow, - const Tensor& beta2_pow, - const paddle::optional& master_param, - const paddle::optional& skip_update, - const Scalar& beta1, - const Scalar& beta2, - const Scalar& epsilon, - float lr_ratio, - float coeff, - bool with_decay, - bool lazy_mode, - int64_t min_row_size_to_use_multithread, - bool multi_precision, - bool use_global_beta_pow); - std::tuple batch_norm_impl( const Tensor& x, const Tensor& scale, diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index b50b3096895aa23224c94bf38f11daaabf817531..1812e400b9d171d8fdcdc3cbc8218b5ff2843814 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -79,11 +79,16 @@ kernel : func : adamax -- api : adamw +- api : adamw_ args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, float lr_ratio, float coeff, bool with_decay, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow) output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs) + infer_meta : + func : AdamwInferMeta + kernel : + func : adamw + data_type : param optional : master_param, skip_update - invoke : adamw_impl(param, grad, learning_rate, moment1, moment2, beta1_pow, beta2_pow, master_param, skip_update, beta1, beta2, epsilon, lr_ratio, coeff, with_decay, lazy_mode, min_row_size_to_use_multithread, multi_precision, use_global_beta_pow) + inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_outs) - api : add args : (Tensor x, Tensor y) diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index 25f4006327d7549986ab5606dc4436d94af9d487..0c9e2645ef3e2d2b7130b2f13c04e7bffcfb1fcd 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -443,7 +443,7 @@ class AdamW(Optimizer): if framework.in_dygraph_mode(): found_inf = self._get_auxiliary_var('found_inf') - _, _, _, _, _, _ = _C_ops.final_state_adamw( + _, _, _, _, _, _ = _C_ops.final_state_adamw_( param_and_grad[0], param_and_grad[1], lr, moment1, moment2, beta1_pow_acc, beta2_pow_acc, master_weight, found_inf, _beta1, _beta2, self._epsilon, lr_ratio_,