未验证 提交 642f6df9 编写于 作者: C Charles-hit 提交者: GitHub

support momentum op auto generation (#45163)

上级 59241336
......@@ -178,136 +178,6 @@ std::vector<Tensor> split_impl(const Tensor& x,
return out;
}
std::tuple<Tensor, Tensor, Tensor> momentum_impl(
const Tensor& param,
const Tensor& grad,
const Tensor& velocity,
const Tensor& learning_rate,
const paddle::optional<Tensor>& master_param,
float mu,
bool use_nesterov,
const std::string& regularization_method,
float regularization_coeff,
bool multi_precision,
float rescale_grad) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(param);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
std::string kernel_name = "momentum";
if (grad.is_selected_rows()) {
kernel_name = "momentum_dense_param_sparse_grad";
}
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});
const auto& kernel = kernel_result.kernel;
VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << kernel_name << " API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_param = PrepareData(param, kernel.InputAt(0), {});
auto input_grad = PrepareData(grad, kernel.InputAt(1), {});
auto input_velocity = PrepareData(velocity, kernel.InputAt(2), {});
auto input_learning_rate = PrepareData(learning_rate, kernel.InputAt(3), {});
auto input_master_param = PrepareData(master_param, kernel.InputAt(4), {});
std::tuple<Tensor, Tensor, Tensor> api_output;
auto kernel_out_0 = input_param.get();
auto kernel_out_1 = input_velocity.get();
phi::DenseTensor* kernel_out_2 = nullptr;
if (input_master_param) {
kernel_out_2 = input_master_param.get_ptr();
}
auto input_meta_ref_master_param = MakeMetaTensor(input_master_param);
phi::MetaTensor meta_out_0(kernel_out_0);
phi::MetaTensor meta_out_1(kernel_out_1);
if (kernel_out_2) {
phi::MetaTensor meta_out_2(kernel_out_2);
phi::MomentumInferMeta(MakeMetaTensor(*input_param),
MakeMetaTensor(*input_grad),
MakeMetaTensor(*input_velocity),
MakeMetaTensor(*input_learning_rate),
input_meta_ref_master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
multi_precision,
rescale_grad,
&meta_out_0,
&meta_out_1,
&meta_out_2);
} else {
phi::MomentumInferMeta(MakeMetaTensor(*input_param),
MakeMetaTensor(*input_grad),
MakeMetaTensor(*input_velocity),
MakeMetaTensor(*input_learning_rate),
input_meta_ref_master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
multi_precision,
rescale_grad,
&meta_out_0,
&meta_out_1,
nullptr);
}
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const paddle::optional<phi::DenseTensor>&,
float,
bool,
const std::string&,
float,
bool,
float,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*input_param,
*input_grad,
*input_velocity,
*input_learning_rate,
input_master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
multi_precision,
rescale_grad,
kernel_out_0,
kernel_out_1,
kernel_out_2);
return api_output;
}
////////////////// Backward(grad) api impls //////////////////////
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
......
......@@ -56,19 +56,6 @@ std::vector<Tensor> split_impl(const Tensor& x,
const IntArray& num_or_sections,
const Scalar& axis);
std::tuple<Tensor, Tensor, Tensor> momentum_impl(
const Tensor& param,
const Tensor& grad,
const Tensor& velocity,
const Tensor& learning_rate,
const paddle::optional<Tensor>& master_param,
float mu,
bool use_nesterov,
const std::string& regularization_method,
float regularization_coeff,
bool multi_precision,
float rescale_grad);
////////////////// Backward(grad) api impls //////////////////////
void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad);
......
......@@ -1795,11 +1795,16 @@
func : modulo
backward : modulo_grad
- api : momentum
- api : momentum_
args : (Tensor param, Tensor grad, Tensor velocity, Tensor learning_rate, Tensor master_param, float mu, bool use_nesterov = false, str regularization_method = "", float regularization_coeff = 0.0, bool multi_precision = false, float rescale_grad = 1.0f)
output : Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out)
invoke : momentum_impl(param, grad, velocity, learning_rate, master_param, mu, use_nesterov, regularization_method, regularization_coeff, multi_precision, rescale_grad)
infer_meta:
func : MomentumInferMeta
kernel :
func : momentum
data_type : param
optional : master_param
inplace : (param -> param_out), (velocity -> velocity_out), (master_param -> master_param_out)
- api : multi_dot
args : (Tensor[] x)
......
......@@ -327,7 +327,7 @@ class Momentum(Optimizer):
if in_dygraph_mode():
if isinstance(param_and_grad, dict):
self._update_regularization(param_and_grad['weight_decay'])
return _C_ops.final_state_momentum(
return _C_ops.final_state_momentum_(
param_and_grad[0], param_and_grad[1], velocity_acc, lr,
master_weight, self._momentum, self._use_nesterov,
regularization_method, regularization_coeff, find_master,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册