diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index e98d749a6cb6c9641b60ac4a341497b83c6fc30c..b37fac194137ce97d4b8c0958683aba03b5641e8 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -178,136 +178,6 @@ std::vector split_impl(const Tensor& x, return out; } -std::tuple momentum_impl( - const Tensor& param, - const Tensor& grad, - const Tensor& velocity, - const Tensor& learning_rate, - const paddle::optional& master_param, - float mu, - bool use_nesterov, - const std::string& regularization_method, - float regularization_coeff, - bool multi_precision, - float rescale_grad) { - Backend kernel_backend = Backend::UNDEFINED; - DataLayout kernel_layout = DataLayout::UNDEFINED; - DataType kernel_data_type = DataType::UNDEFINED; - if (kernel_backend == Backend::UNDEFINED || - kernel_layout == DataLayout::UNDEFINED || - kernel_data_type == DataType::UNDEFINED) { - auto kernel_key_set = ParseKernelKeyByInputArgs(param); - auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - if (kernel_backend == Backend::UNDEFINED) { - kernel_backend = kernel_key.backend(); - } - if (kernel_layout == DataLayout::UNDEFINED) { - kernel_layout = kernel_key.layout(); - } - if (kernel_data_type == DataType::UNDEFINED) { - kernel_data_type = kernel_key.dtype(); - } - } - std::string kernel_name = "momentum"; - if (grad.is_selected_rows()) { - kernel_name = "momentum_dense_param_sparse_grad"; - } - auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( - kernel_name, {kernel_backend, kernel_layout, kernel_data_type}); - const auto& kernel = kernel_result.kernel; - VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", " - << kernel_layout << ", " << kernel_data_type << "]"; - VLOG(6) << kernel_name << " API kernel: " << kernel; - - auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); - - auto input_param = PrepareData(param, kernel.InputAt(0), {}); - auto input_grad = PrepareData(grad, kernel.InputAt(1), {}); - auto input_velocity = PrepareData(velocity, kernel.InputAt(2), {}); - auto input_learning_rate = PrepareData(learning_rate, kernel.InputAt(3), {}); - auto input_master_param = PrepareData(master_param, kernel.InputAt(4), {}); - - std::tuple api_output; - auto kernel_out_0 = input_param.get(); - auto kernel_out_1 = input_velocity.get(); - phi::DenseTensor* kernel_out_2 = nullptr; - if (input_master_param) { - kernel_out_2 = input_master_param.get_ptr(); - } - - auto input_meta_ref_master_param = MakeMetaTensor(input_master_param); - - phi::MetaTensor meta_out_0(kernel_out_0); - phi::MetaTensor meta_out_1(kernel_out_1); - if (kernel_out_2) { - phi::MetaTensor meta_out_2(kernel_out_2); - phi::MomentumInferMeta(MakeMetaTensor(*input_param), - MakeMetaTensor(*input_grad), - MakeMetaTensor(*input_velocity), - MakeMetaTensor(*input_learning_rate), - input_meta_ref_master_param, - mu, - use_nesterov, - regularization_method, - regularization_coeff, - multi_precision, - rescale_grad, - &meta_out_0, - &meta_out_1, - &meta_out_2); - } else { - phi::MomentumInferMeta(MakeMetaTensor(*input_param), - MakeMetaTensor(*input_grad), - MakeMetaTensor(*input_velocity), - MakeMetaTensor(*input_learning_rate), - input_meta_ref_master_param, - mu, - use_nesterov, - regularization_method, - regularization_coeff, - multi_precision, - rescale_grad, - &meta_out_0, - &meta_out_1, - nullptr); - } - - using kernel_signature = void (*)(const platform::DeviceContext&, - const phi::DenseTensor&, - const phi::DenseTensor&, - const phi::DenseTensor&, - const phi::DenseTensor&, - const paddle::optional&, - float, - bool, - const std::string&, - float, - bool, - float, - phi::DenseTensor*, - phi::DenseTensor*, - phi::DenseTensor*); - auto* kernel_fn = kernel.GetVariadicKernelFn(); - - (*kernel_fn)(*dev_ctx, - *input_param, - *input_grad, - *input_velocity, - *input_learning_rate, - input_master_param, - mu, - use_nesterov, - regularization_method, - regularization_coeff, - multi_precision, - rescale_grad, - kernel_out_0, - kernel_out_1, - kernel_out_2); - - return api_output; -} - ////////////////// Backward(grad) api impls ////////////////////// std::tuple batch_norm_impl( diff --git a/paddle/phi/api/lib/api_custom_impl.h b/paddle/phi/api/lib/api_custom_impl.h index 907a053d4f10b7e5f899c964cf533c8083d29700..e7fca7bfbc84d1b8740444d390dc3775b103d582 100644 --- a/paddle/phi/api/lib/api_custom_impl.h +++ b/paddle/phi/api/lib/api_custom_impl.h @@ -56,19 +56,6 @@ std::vector split_impl(const Tensor& x, const IntArray& num_or_sections, const Scalar& axis); -std::tuple momentum_impl( - const Tensor& param, - const Tensor& grad, - const Tensor& velocity, - const Tensor& learning_rate, - const paddle::optional& master_param, - float mu, - bool use_nesterov, - const std::string& regularization_method, - float regularization_coeff, - bool multi_precision, - float rescale_grad); - ////////////////// Backward(grad) api impls ////////////////////// void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad); diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 1812e400b9d171d8fdcdc3cbc8218b5ff2843814..a98009903f2109222b390a9559c35d4068493780 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1795,11 +1795,16 @@ func : modulo backward : modulo_grad -- api : momentum +- api : momentum_ args : (Tensor param, Tensor grad, Tensor velocity, Tensor learning_rate, Tensor master_param, float mu, bool use_nesterov = false, str regularization_method = "", float regularization_coeff = 0.0, bool multi_precision = false, float rescale_grad = 1.0f) output : Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out) - invoke : momentum_impl(param, grad, velocity, learning_rate, master_param, mu, use_nesterov, regularization_method, regularization_coeff, multi_precision, rescale_grad) + infer_meta: + func : MomentumInferMeta + kernel : + func : momentum + data_type : param optional : master_param + inplace : (param -> param_out), (velocity -> velocity_out), (master_param -> master_param_out) - api : multi_dot args : (Tensor[] x) diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index bb7765ac715dd8e0f34eed1fe4087fa756933f12..2839b80b4e5a588a7dcf3c6212cee1d8ed81bfc2 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -327,7 +327,7 @@ class Momentum(Optimizer): if in_dygraph_mode(): if isinstance(param_and_grad, dict): self._update_regularization(param_and_grad['weight_decay']) - return _C_ops.final_state_momentum( + return _C_ops.final_state_momentum_( param_and_grad[0], param_and_grad[1], velocity_acc, lr, master_weight, self._momentum, self._use_nesterov, regularization_method, regularization_coeff, find_master,