未验证 提交 1353761a 编写于 作者: C Charles-hit 提交者: GitHub

support adamw generation (#45149)

上级 8636d2a2
...@@ -34,170 +34,6 @@ namespace experimental { ...@@ -34,170 +34,6 @@ namespace experimental {
////////////////// Forward api impls ////////////////////// ////////////////// Forward api impls //////////////////////
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adamw_impl(
const Tensor& param,
const Tensor& grad,
const Tensor& learning_rate,
const Tensor& moment1,
const Tensor& moment2,
const Tensor& beta1_pow,
const Tensor& beta2_pow,
const paddle::optional<Tensor>& master_param,
const paddle::optional<Tensor>& skip_update,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
float lr_ratio,
float coeff,
bool with_decay,
bool lazy_mode,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(param);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
std::string kernel_name = "adamw";
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});
const auto& kernel = kernel_result.kernel;
VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << kernel_name << " API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_param = PrepareData(param, kernel.InputAt(0), {});
auto input_grad = PrepareData(grad, kernel.InputAt(1), {});
auto input_lr = PrepareData(learning_rate, kernel.InputAt(2), {});
auto input_moment1 = PrepareData(moment1, kernel.InputAt(3), {});
auto input_moment2 = PrepareData(moment2, kernel.InputAt(4), {});
auto input_beta1_pow = PrepareData(beta1_pow, kernel.InputAt(5), {});
auto input_beta2_pow = PrepareData(beta2_pow, kernel.InputAt(6), {});
auto input_master_param = PrepareData(master_param, kernel.InputAt(7), {});
auto input_skip_update = PrepareData(skip_update, kernel.InputAt(8), {});
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> api_output;
auto kernel_out_0 = input_param.get();
auto kernel_out_1 = input_moment1.get();
auto kernel_out_2 = input_moment2.get();
auto kernel_out_3 = input_beta1_pow.get();
auto kernel_out_4 = input_beta2_pow.get();
phi::DenseTensor* kernel_out_5 = nullptr;
if (input_master_param) {
kernel_out_5 = input_master_param.get_ptr();
}
auto input_meta_ref_master_param = MakeMetaTensor(input_master_param);
auto input_meta_ref_skip_update = MakeMetaTensor(input_skip_update);
phi::MetaTensor meta_out_0(kernel_out_0);
phi::MetaTensor meta_out_1(kernel_out_1);
phi::MetaTensor meta_out_2(kernel_out_2);
phi::MetaTensor meta_out_3(kernel_out_3);
phi::MetaTensor meta_out_4(kernel_out_4);
phi::MetaTensor meta_out_5(kernel_out_5);
phi::AdamwInferMeta(MakeMetaTensor(*input_param),
MakeMetaTensor(*input_grad),
MakeMetaTensor(*input_lr),
MakeMetaTensor(*input_moment1),
MakeMetaTensor(*input_moment2),
MakeMetaTensor(*input_beta1_pow),
MakeMetaTensor(*input_beta2_pow),
input_meta_ref_master_param,
input_meta_ref_skip_update,
beta1,
beta2,
epsilon,
lr_ratio,
coeff,
with_decay,
lazy_mode,
min_row_size_to_use_multithread,
multi_precision,
use_global_beta_pow,
&meta_out_0,
&meta_out_1,
&meta_out_2,
&meta_out_3,
&meta_out_4,
&meta_out_5);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const paddle::optional<phi::DenseTensor>&,
const paddle::optional<phi::DenseTensor>&,
const Scalar&,
const Scalar&,
const Scalar&,
float,
float,
bool,
bool,
int64_t,
bool,
bool,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*input_param,
*input_grad,
*input_lr,
*input_moment1,
*input_moment2,
*input_beta1_pow,
*input_beta2_pow,
input_master_param,
input_skip_update,
beta1,
beta2,
epsilon,
lr_ratio,
coeff,
with_decay,
lazy_mode,
min_row_size_to_use_multithread,
multi_precision,
use_global_beta_pow,
kernel_out_0,
kernel_out_1,
kernel_out_2,
kernel_out_3,
kernel_out_4,
kernel_out_5);
return api_output;
}
Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) { Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) {
Tensor out; Tensor out;
copy(x, place, blocking, &out); copy(x, place, blocking, &out);
......
...@@ -31,27 +31,6 @@ namespace experimental { ...@@ -31,27 +31,6 @@ namespace experimental {
////////////////// Forward api impls ////////////////////// ////////////////// Forward api impls //////////////////////
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adamw_impl(
const Tensor& param,
const Tensor& grad,
const Tensor& learning_rate,
const Tensor& moment1,
const Tensor& moment2,
const Tensor& beta1_pow,
const Tensor& beta2_pow,
const paddle::optional<Tensor>& master_param,
const paddle::optional<Tensor>& skip_update,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
float lr_ratio,
float coeff,
bool with_decay,
bool lazy_mode,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow);
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl( std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
const Tensor& x, const Tensor& x,
const Tensor& scale, const Tensor& scale,
......
...@@ -79,11 +79,16 @@ ...@@ -79,11 +79,16 @@
kernel : kernel :
func : adamax func : adamax
- api : adamw - api : adamw_
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, float lr_ratio, float coeff, bool with_decay, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow) args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, float lr_ratio, float coeff, bool with_decay, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow)
output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs) output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs)
infer_meta :
func : AdamwInferMeta
kernel :
func : adamw
data_type : param
optional : master_param, skip_update optional : master_param, skip_update
invoke : adamw_impl(param, grad, learning_rate, moment1, moment2, beta1_pow, beta2_pow, master_param, skip_update, beta1, beta2, epsilon, lr_ratio, coeff, with_decay, lazy_mode, min_row_size_to_use_multithread, multi_precision, use_global_beta_pow) inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_outs)
- api : add - api : add
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
......
...@@ -443,7 +443,7 @@ class AdamW(Optimizer): ...@@ -443,7 +443,7 @@ class AdamW(Optimizer):
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
found_inf = self._get_auxiliary_var('found_inf') found_inf = self._get_auxiliary_var('found_inf')
_, _, _, _, _, _ = _C_ops.final_state_adamw( _, _, _, _, _, _ = _C_ops.final_state_adamw_(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2, param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, master_weight, found_inf, beta1_pow_acc, beta2_pow_acc, master_weight, found_inf,
_beta1, _beta2, self._epsilon, lr_ratio_, _beta1, _beta2, self._epsilon, lr_ratio_,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册