Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1353761a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1353761a
编写于
8月 15, 2022
作者:
C
Charles-hit
提交者:
GitHub
8月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support adamw generation (#45149)
上级
8636d2a2
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
8 addition
and
188 deletion
+8
-188
paddle/phi/api/lib/api_custom_impl.cc
paddle/phi/api/lib/api_custom_impl.cc
+0
-164
paddle/phi/api/lib/api_custom_impl.h
paddle/phi/api/lib/api_custom_impl.h
+0
-21
paddle/phi/api/yaml/legacy_api.yaml
paddle/phi/api/yaml/legacy_api.yaml
+7
-2
python/paddle/optimizer/adamw.py
python/paddle/optimizer/adamw.py
+1
-1
未找到文件。
paddle/phi/api/lib/api_custom_impl.cc
浏览文件 @
1353761a
...
...
@@ -34,170 +34,6 @@ namespace experimental {
////////////////// Forward api impls //////////////////////
std
::
tuple
<
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
>
adamw_impl
(
const
Tensor
&
param
,
const
Tensor
&
grad
,
const
Tensor
&
learning_rate
,
const
Tensor
&
moment1
,
const
Tensor
&
moment2
,
const
Tensor
&
beta1_pow
,
const
Tensor
&
beta2_pow
,
const
paddle
::
optional
<
Tensor
>&
master_param
,
const
paddle
::
optional
<
Tensor
>&
skip_update
,
const
Scalar
&
beta1
,
const
Scalar
&
beta2
,
const
Scalar
&
epsilon
,
float
lr_ratio
,
float
coeff
,
bool
with_decay
,
bool
lazy_mode
,
int64_t
min_row_size_to_use_multithread
,
bool
multi_precision
,
bool
use_global_beta_pow
)
{
Backend
kernel_backend
=
Backend
::
UNDEFINED
;
DataLayout
kernel_layout
=
DataLayout
::
UNDEFINED
;
DataType
kernel_data_type
=
DataType
::
UNDEFINED
;
if
(
kernel_backend
==
Backend
::
UNDEFINED
||
kernel_layout
==
DataLayout
::
UNDEFINED
||
kernel_data_type
==
DataType
::
UNDEFINED
)
{
auto
kernel_key_set
=
ParseKernelKeyByInputArgs
(
param
);
auto
kernel_key
=
kernel_key_set
.
GetHighestPriorityKernelKey
();
if
(
kernel_backend
==
Backend
::
UNDEFINED
)
{
kernel_backend
=
kernel_key
.
backend
();
}
if
(
kernel_layout
==
DataLayout
::
UNDEFINED
)
{
kernel_layout
=
kernel_key
.
layout
();
}
if
(
kernel_data_type
==
DataType
::
UNDEFINED
)
{
kernel_data_type
=
kernel_key
.
dtype
();
}
}
std
::
string
kernel_name
=
"adamw"
;
auto
kernel_result
=
phi
::
KernelFactory
::
Instance
().
SelectKernelOrThrowError
(
kernel_name
,
{
kernel_backend
,
kernel_layout
,
kernel_data_type
});
const
auto
&
kernel
=
kernel_result
.
kernel
;
VLOG
(
6
)
<<
kernel_name
<<
" API kernel key: ["
<<
kernel_backend
<<
", "
<<
kernel_layout
<<
", "
<<
kernel_data_type
<<
"]"
;
VLOG
(
6
)
<<
kernel_name
<<
" API kernel: "
<<
kernel
;
auto
*
dev_ctx
=
GetDeviceContextByBackend
(
kernel_backend
);
auto
input_param
=
PrepareData
(
param
,
kernel
.
InputAt
(
0
),
{});
auto
input_grad
=
PrepareData
(
grad
,
kernel
.
InputAt
(
1
),
{});
auto
input_lr
=
PrepareData
(
learning_rate
,
kernel
.
InputAt
(
2
),
{});
auto
input_moment1
=
PrepareData
(
moment1
,
kernel
.
InputAt
(
3
),
{});
auto
input_moment2
=
PrepareData
(
moment2
,
kernel
.
InputAt
(
4
),
{});
auto
input_beta1_pow
=
PrepareData
(
beta1_pow
,
kernel
.
InputAt
(
5
),
{});
auto
input_beta2_pow
=
PrepareData
(
beta2_pow
,
kernel
.
InputAt
(
6
),
{});
auto
input_master_param
=
PrepareData
(
master_param
,
kernel
.
InputAt
(
7
),
{});
auto
input_skip_update
=
PrepareData
(
skip_update
,
kernel
.
InputAt
(
8
),
{});
std
::
tuple
<
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
>
api_output
;
auto
kernel_out_0
=
input_param
.
get
();
auto
kernel_out_1
=
input_moment1
.
get
();
auto
kernel_out_2
=
input_moment2
.
get
();
auto
kernel_out_3
=
input_beta1_pow
.
get
();
auto
kernel_out_4
=
input_beta2_pow
.
get
();
phi
::
DenseTensor
*
kernel_out_5
=
nullptr
;
if
(
input_master_param
)
{
kernel_out_5
=
input_master_param
.
get_ptr
();
}
auto
input_meta_ref_master_param
=
MakeMetaTensor
(
input_master_param
);
auto
input_meta_ref_skip_update
=
MakeMetaTensor
(
input_skip_update
);
phi
::
MetaTensor
meta_out_0
(
kernel_out_0
);
phi
::
MetaTensor
meta_out_1
(
kernel_out_1
);
phi
::
MetaTensor
meta_out_2
(
kernel_out_2
);
phi
::
MetaTensor
meta_out_3
(
kernel_out_3
);
phi
::
MetaTensor
meta_out_4
(
kernel_out_4
);
phi
::
MetaTensor
meta_out_5
(
kernel_out_5
);
phi
::
AdamwInferMeta
(
MakeMetaTensor
(
*
input_param
),
MakeMetaTensor
(
*
input_grad
),
MakeMetaTensor
(
*
input_lr
),
MakeMetaTensor
(
*
input_moment1
),
MakeMetaTensor
(
*
input_moment2
),
MakeMetaTensor
(
*
input_beta1_pow
),
MakeMetaTensor
(
*
input_beta2_pow
),
input_meta_ref_master_param
,
input_meta_ref_skip_update
,
beta1
,
beta2
,
epsilon
,
lr_ratio
,
coeff
,
with_decay
,
lazy_mode
,
min_row_size_to_use_multithread
,
multi_precision
,
use_global_beta_pow
,
&
meta_out_0
,
&
meta_out_1
,
&
meta_out_2
,
&
meta_out_3
,
&
meta_out_4
,
&
meta_out_5
);
using
kernel_signature
=
void
(
*
)(
const
platform
::
DeviceContext
&
,
const
phi
::
DenseTensor
&
,
const
phi
::
DenseTensor
&
,
const
phi
::
DenseTensor
&
,
const
phi
::
DenseTensor
&
,
const
phi
::
DenseTensor
&
,
const
phi
::
DenseTensor
&
,
const
phi
::
DenseTensor
&
,
const
paddle
::
optional
<
phi
::
DenseTensor
>&
,
const
paddle
::
optional
<
phi
::
DenseTensor
>&
,
const
Scalar
&
,
const
Scalar
&
,
const
Scalar
&
,
float
,
float
,
bool
,
bool
,
int64_t
,
bool
,
bool
,
phi
::
DenseTensor
*
,
phi
::
DenseTensor
*
,
phi
::
DenseTensor
*
,
phi
::
DenseTensor
*
,
phi
::
DenseTensor
*
,
phi
::
DenseTensor
*
);
auto
*
kernel_fn
=
kernel
.
GetVariadicKernelFn
<
kernel_signature
>
();
(
*
kernel_fn
)(
*
dev_ctx
,
*
input_param
,
*
input_grad
,
*
input_lr
,
*
input_moment1
,
*
input_moment2
,
*
input_beta1_pow
,
*
input_beta2_pow
,
input_master_param
,
input_skip_update
,
beta1
,
beta2
,
epsilon
,
lr_ratio
,
coeff
,
with_decay
,
lazy_mode
,
min_row_size_to_use_multithread
,
multi_precision
,
use_global_beta_pow
,
kernel_out_0
,
kernel_out_1
,
kernel_out_2
,
kernel_out_3
,
kernel_out_4
,
kernel_out_5
);
return
api_output
;
}
Tensor
copy_to_impl
(
const
Tensor
&
x
,
Place
place
,
bool
blocking
)
{
Tensor
out
;
copy
(
x
,
place
,
blocking
,
&
out
);
...
...
paddle/phi/api/lib/api_custom_impl.h
浏览文件 @
1353761a
...
...
@@ -31,27 +31,6 @@ namespace experimental {
////////////////// Forward api impls //////////////////////
std
::
tuple
<
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
>
adamw_impl
(
const
Tensor
&
param
,
const
Tensor
&
grad
,
const
Tensor
&
learning_rate
,
const
Tensor
&
moment1
,
const
Tensor
&
moment2
,
const
Tensor
&
beta1_pow
,
const
Tensor
&
beta2_pow
,
const
paddle
::
optional
<
Tensor
>&
master_param
,
const
paddle
::
optional
<
Tensor
>&
skip_update
,
const
Scalar
&
beta1
,
const
Scalar
&
beta2
,
const
Scalar
&
epsilon
,
float
lr_ratio
,
float
coeff
,
bool
with_decay
,
bool
lazy_mode
,
int64_t
min_row_size_to_use_multithread
,
bool
multi_precision
,
bool
use_global_beta_pow
);
std
::
tuple
<
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
>
batch_norm_impl
(
const
Tensor
&
x
,
const
Tensor
&
scale
,
...
...
paddle/phi/api/yaml/legacy_api.yaml
浏览文件 @
1353761a
...
...
@@ -79,11 +79,16 @@
kernel
:
func
:
adamax
-
api
:
adamw
-
api
:
adamw
_
args
:
(Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, float lr_ratio, float coeff, bool with_decay, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow)
output
:
Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs)
infer_meta
:
func
:
AdamwInferMeta
kernel
:
func
:
adamw
data_type
:
param
optional
:
master_param, skip_update
in
voke
:
adamw_impl(param, grad, learning_rate, moment1, moment2, beta1_pow, beta2_pow, master_param, skip_update, beta1, beta2, epsilon, lr_ratio, coeff, with_decay, lazy_mode, min_row_size_to_use_multithread, multi_precision, use_global_beta_pow
)
in
place
:
(param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_outs
)
-
api
:
add
args
:
(Tensor x, Tensor y)
...
...
python/paddle/optimizer/adamw.py
浏览文件 @
1353761a
...
...
@@ -443,7 +443,7 @@ class AdamW(Optimizer):
if
framework
.
in_dygraph_mode
():
found_inf
=
self
.
_get_auxiliary_var
(
'found_inf'
)
_
,
_
,
_
,
_
,
_
,
_
=
_C_ops
.
final_state_adamw
(
_
,
_
,
_
,
_
,
_
,
_
=
_C_ops
.
final_state_adamw
_
(
param_and_grad
[
0
],
param_and_grad
[
1
],
lr
,
moment1
,
moment2
,
beta1_pow_acc
,
beta2_pow_acc
,
master_weight
,
found_inf
,
_beta1
,
_beta2
,
self
.
_epsilon
,
lr_ratio_
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录