Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
dda74715
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dda74715
编写于
2月 28, 2023
作者:
T
taixiurong
提交者:
GitHub
2月 28, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
xpu-paddlepaddle-57 [任务] adamw lr_radio支持 (#50979)
上级
a8fff38f
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
380 addition
and
41 deletion
+380
-41
paddle/phi/kernels/xpu/adamw_kernel.cc
paddle/phi/kernels/xpu/adamw_kernel.cc
+35
-39
python/paddle/fluid/tests/unittests/xpu/test_adamw_op_xpu.py
python/paddle/fluid/tests/unittests/xpu/test_adamw_op_xpu.py
+340
-0
python/paddle/optimizer/adamw.py
python/paddle/optimizer/adamw.py
+5
-2
未找到文件。
paddle/phi/kernels/xpu/adamw_kernel.cc
浏览文件 @
dda74715
...
...
@@ -87,47 +87,43 @@ void AdamwDenseKernel(const Context& dev_ctx,
beta1_pow_ptr
=
xpu_beta1_pow
.
template
data
<
float
>();
beta2_pow_ptr
=
xpu_beta2_pow
.
template
data
<
float
>();
}
if
(
with_decay
)
{
int
r
=
xpu
::
adamw
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
grad
.
template
data
<
T
>()),
moment1
.
template
data
<
float
>(),
moment2
.
template
data
<
float
>(),
reinterpret_cast
<
const
XPUType
*>
(
param
.
template
data
<
T
>()),
beta1_pow_ptr
,
beta2_pow_ptr
,
learning_rate
.
template
data
<
float
>(),
dev_ctx
.
template
Alloc
<
float
>(
moment1_out
),
dev_ctx
.
template
Alloc
<
float
>(
moment2_out
),
reinterpret_cast
<
XPUType
*>
(
dev_ctx
.
template
Alloc
<
T
>(
param_out
)),
beta1_
,
beta2_
,
epsilon_
,
coeff
,
param
.
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"adamw"
);
}
else
{
int
r
=
xpu
::
adam
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
grad
.
template
data
<
T
>()),
moment1
.
template
data
<
float
>(),
moment2
.
template
data
<
float
>(),
reinterpret_cast
<
const
XPUType
*>
(
param
.
template
data
<
T
>()),
beta1_pow_ptr
,
beta2_pow_ptr
,
learning_rate
.
template
data
<
float
>(),
dev_ctx
.
template
Alloc
<
float
>(
moment1_out
),
dev_ctx
.
template
Alloc
<
float
>(
moment2_out
),
reinterpret_cast
<
XPUType
*>
(
dev_ctx
.
template
Alloc
<
T
>(
param_out
)),
beta1_
,
beta2_
,
epsilon_
,
param
.
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"adam"
);
if
(
!
with_decay
)
{
coeff
=
static_cast
<
float
>
(
0.0
);
}
xpu
::
ctx_guard
RAII_GUARD
(
dev_ctx
.
x_context
());
float
*
new_lr
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
learning_rate
.
numel
());
PADDLE_ENFORCE_XDNN_NOT_NULL
(
new_lr
);
int
r
=
0
;
r
=
xpu
::
scale
(
dev_ctx
.
x_context
(),
learning_rate
.
template
data
<
float
>(),
new_lr
,
learning_rate
.
numel
(),
false
,
lr_ratio
,
0.0
f
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"scale"
);
r
=
xpu
::
adamw
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
grad
.
template
data
<
T
>()),
moment1
.
template
data
<
float
>(),
moment2
.
template
data
<
float
>(),
reinterpret_cast
<
const
XPUType
*>
(
param
.
template
data
<
T
>()),
beta1_pow_ptr
,
beta2_pow_ptr
,
new_lr
,
dev_ctx
.
template
Alloc
<
float
>(
moment1_out
),
dev_ctx
.
template
Alloc
<
float
>(
moment2_out
),
reinterpret_cast
<
XPUType
*>
(
dev_ctx
.
template
Alloc
<
T
>(
param_out
)),
beta1_
,
beta2_
,
epsilon_
,
coeff
,
param
.
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"adamw"
);
if
(
!
use_global_beta_pow
)
{
// update in cpu
and then copy to xpu
// update in cpu
if
(
beta1_pow
.
place
()
==
CPUPlace
()
&&
beta2_pow
.
place
()
==
CPUPlace
())
{
const
float
*
beta1_pow_p
=
beta1_pow
.
template
data
<
float
>();
dev_ctx
.
template
HostAlloc
<
float
>(
beta1_pow_out
)[
0
]
=
...
...
@@ -136,7 +132,7 @@ void AdamwDenseKernel(const Context& dev_ctx,
dev_ctx
.
template
HostAlloc
<
float
>(
beta2_pow_out
)[
0
]
=
beta2_
*
beta2_pow_p
[
0
];
xpu_wait
(
dev_ctx
.
x_context
()
->
xpu_stream
);
}
else
{
}
else
{
// update in xpu
float
*
beta1_pow_out_p
=
dev_ctx
.
template
Alloc
<
float
>(
beta1_pow_out
);
float
*
beta2_pow_out_p
=
dev_ctx
.
template
Alloc
<
float
>(
beta2_pow_out
);
int
r
=
xpu
::
scale
(
dev_ctx
.
x_context
(),
...
...
python/paddle/fluid/tests/unittests/xpu/test_adamw_op_xpu.py
浏览文件 @
dda74715
...
...
@@ -17,6 +17,7 @@ import sys
sys
.
path
.
append
(
".."
)
import
unittest
from
functools
import
partial
import
numpy
as
np
from
op_test_xpu
import
XPUOpTest
...
...
@@ -301,6 +302,345 @@ class XPUTestAdamwOp2(XPUOpTestWrapper):
adam
.
step
()
adam
.
clear_gradients
()
class
TestAdamWOpLayerwiseLR
(
TestAdamWOp
):
def
setUp
(
self
):
np
.
random
.
seed
(
2022
)
paddle
.
seed
(
2022
)
def
test_adamw_op_dygraph
(
self
):
paddle
.
disable_static
()
linear1
=
paddle
.
nn
.
Linear
(
13
,
8
,
bias_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1.0
)
)
linear2
=
paddle
.
nn
.
Linear
(
8
,
5
,
bias_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1.0
)
)
# fix the linear name, simple_lr_setting function will use the name
linear1
.
weight
.
name
=
"linear_1.w_0"
linear1
.
bias
.
name
=
"linear_1.b_0"
linear2
.
weight
.
name
=
"linear_2.w_0"
linear2
.
bias
.
name
=
"linear_2.b_0"
fc1_w
=
np
.
array
(
linear1
.
weight
)
fc1_w_mon1
=
np
.
zeros_like
(
fc1_w
)
fc1_w_mon2
=
np
.
zeros_like
(
fc1_w
)
fc1_b
=
np
.
array
(
linear1
.
bias
)
fc1_b_mon1
=
np
.
zeros_like
(
fc1_b
)
fc1_b_mon2
=
np
.
zeros_like
(
fc1_b
)
fc2_w
=
np
.
array
(
linear2
.
weight
)
fc2_w_mon1
=
np
.
zeros_like
(
fc2_w
)
fc2_w_mon2
=
np
.
zeros_like
(
fc2_w
)
fc2_b
=
np
.
array
(
linear2
.
bias
)
fc2_b_mon1
=
np
.
zeros_like
(
fc2_b
)
fc2_b_mon2
=
np
.
zeros_like
(
fc2_b
)
simple_lr_fun
=
partial
(
simple_lr_setting
,
decay_rate
=
0.8
,
n_layers
=
2
)
learning_rate
=
0.001
weight_decay
=
0.01
beta1
=
0.9
beta2
=
0.999
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
learning_rate
,
parameters
=
[
{
'params'
:
linear1
.
parameters
()},
{
'params'
:
linear2
.
parameters
(),
},
],
apply_decay_param_fun
=
lambda
name
:
True
,
weight_decay
=
weight_decay
,
lr_ratio
=
simple_lr_fun
,
)
def
get_numpy_output
(
param
,
grad
,
moment1
,
moment2
,
lr_ratio
,
t
):
np_inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'Moment1'
:
moment1
,
'Moment2'
:
moment2
,
'LearningRate'
:
np
.
array
([
learning_rate
]).
astype
(
"float32"
),
'Beta1Pow'
:
np
.
array
([
beta1
**
t
]).
astype
(
"float32"
),
'Beta2Pow'
:
np
.
array
([
beta2
**
t
]).
astype
(
"float32"
),
}
np_attrs
=
{
'epsilon'
:
1e-8
,
'beta1'
:
beta1
,
'beta2'
:
beta2
,
"lr_ratio"
:
lr_ratio
,
"coeff"
:
weight_decay
,
"with_decay"
:
True
,
}
param_out
,
moment1_out
,
moment2_out
=
adamw_step
(
np_inputs
,
np_attrs
)
return
param_out
,
moment1_out
,
moment2_out
for
i
in
range
(
5
):
a
=
paddle
.
to_tensor
(
np
.
random
.
uniform
(
-
1
,
1
,
(
2
,
13
)).
astype
(
"float32"
)
)
a1
=
linear1
(
a
)
out
=
linear2
(
a1
)
out
=
paddle
.
mean
(
out
)
out
.
backward
()
fc1_w
,
fc1_w_mon1
,
fc1_w_mon2
=
get_numpy_output
(
fc1_w
,
np
.
array
(
linear1
.
weight
.
grad
),
fc1_w_mon1
,
fc1_w_mon2
,
simple_lr_fun
(
linear1
.
weight
),
i
+
1
,
)
fc1_b
,
fc1_b_mon1
,
fc1_b_mon2
=
get_numpy_output
(
fc1_b
,
np
.
array
(
linear1
.
bias
.
grad
),
fc1_b_mon1
,
fc1_b_mon2
,
simple_lr_fun
(
linear1
.
bias
),
i
+
1
,
)
fc2_w
,
fc2_w_mon1
,
fc2_w_mon2
=
get_numpy_output
(
fc2_w
,
np
.
array
(
linear2
.
weight
.
grad
),
fc2_w_mon1
,
fc2_w_mon2
,
simple_lr_fun
(
linear2
.
weight
),
i
+
1
,
)
fc2_b
,
fc2_b_mon1
,
fc2_b_mon2
=
get_numpy_output
(
fc2_b
,
np
.
array
(
linear2
.
bias
.
grad
),
fc2_b_mon1
,
fc2_b_mon2
,
simple_lr_fun
(
linear2
.
bias
),
i
+
1
,
)
opt
.
step
()
opt
.
clear_gradients
()
np
.
testing
.
assert_allclose
(
linear1
.
weight
.
numpy
(),
fc1_w
,
rtol
=
1e-5
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
linear1
.
bias
.
numpy
(),
fc1_b
,
rtol
=
1e-5
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
linear2
.
weight
.
numpy
(),
fc2_w
,
rtol
=
1e-5
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
linear2
.
bias
.
numpy
(),
fc2_b
,
rtol
=
1e-5
,
atol
=
1e-5
)
def
test_adamw_op
(
self
):
paddle
.
enable_static
()
place
=
fluid
.
XPUPlace
(
0
)
learning_rate
=
0.0001
beta1
=
0.85
beta2
=
0.95
weight_decay
=
0.01
epsilon
=
1e-8
train_prog
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
train_prog
,
startup
):
with
fluid
.
unique_name
.
guard
():
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
[
None
,
10
],
dtype
=
'float32'
)
y
=
fluid
.
data
(
name
=
'y'
,
shape
=
[
None
,
1
],
dtype
=
'float32'
)
weight_attr1
=
paddle
.
framework
.
ParamAttr
(
name
=
"linear_0.w_0"
)
bias_attr1
=
paddle
.
framework
.
ParamAttr
(
name
=
"linear_0.b_0"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1.0
),
)
weight_attr2
=
paddle
.
framework
.
ParamAttr
(
name
=
"linear_1.w_0"
)
bias_attr2
=
paddle
.
framework
.
ParamAttr
(
name
=
"linear_1.b_0"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1.0
),
)
linear1
=
paddle
.
nn
.
Linear
(
10
,
32
,
weight_attr
=
weight_attr1
,
bias_attr
=
bias_attr1
)
linear2
=
paddle
.
nn
.
Linear
(
32
,
1
,
weight_attr
=
weight_attr2
,
bias_attr
=
bias_attr2
)
out
=
linear1
(
x
)
out
=
linear2
(
out
)
fc1_w_mon1
=
np
.
zeros
((
linear1
.
weight
.
shape
)).
astype
(
"float32"
)
fc1_w_mon2
=
np
.
zeros
((
linear1
.
weight
.
shape
)).
astype
(
"float32"
)
fc1_b_mon1
=
np
.
zeros
((
linear1
.
bias
.
shape
)).
astype
(
"float32"
)
fc1_b_mon2
=
np
.
zeros
((
linear1
.
bias
.
shape
)).
astype
(
"float32"
)
fc2_w_mon1
=
np
.
zeros
((
linear2
.
weight
.
shape
)).
astype
(
"float32"
)
fc2_w_mon2
=
np
.
zeros
((
linear2
.
weight
.
shape
)).
astype
(
"float32"
)
fc2_b_mon1
=
np
.
zeros
((
linear2
.
bias
.
shape
)).
astype
(
"float32"
)
fc2_b_mon2
=
np
.
zeros
((
linear2
.
bias
.
shape
)).
astype
(
"float32"
)
cost
=
paddle
.
nn
.
functional
.
square_error_cost
(
input
=
out
,
label
=
y
)
avg_cost
=
paddle
.
mean
(
cost
)
simple_lr_fun
=
partial
(
simple_lr_setting
,
decay_rate
=
0.8
,
n_layers
=
2
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
learning_rate
,
beta1
=
beta1
,
beta2
=
beta2
,
weight_decay
=
weight_decay
,
epsilon
=
epsilon
,
lr_ratio
=
simple_lr_fun
,
)
opt
.
minimize
(
avg_cost
)
def
get_numpy_output
(
param
,
grad
,
moment1
,
moment2
,
lr_ratio
,
t
):
np_inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'Moment1'
:
moment1
,
'Moment2'
:
moment2
,
'LearningRate'
:
np
.
array
([
learning_rate
]).
astype
(
"float32"
),
'Beta1Pow'
:
np
.
array
([
beta1
**
t
]).
astype
(
"float32"
),
'Beta2Pow'
:
np
.
array
([
beta2
**
t
]).
astype
(
"float32"
),
}
np_attrs
=
{
'epsilon'
:
epsilon
,
'beta1'
:
beta1
,
'beta2'
:
beta2
,
"lr_ratio"
:
lr_ratio
,
"coeff"
:
weight_decay
,
"with_decay"
:
True
,
}
param_out
,
moment1_out
,
moment2_out
=
adamw_step
(
np_inputs
,
np_attrs
)
return
param_out
,
moment1_out
,
moment2_out
fetch_list1
=
[
"linear_0.w_0"
,
"linear_0.b_0"
,
"linear_1.w_0"
,
"linear_1.b_0"
,
]
fetch_list2
=
[
"linear_0.w_0"
,
"linear_0.w_0@GRAD"
,
"linear_0.b_0"
,
"linear_0.b_0@GRAD"
,
"linear_1.w_0"
,
"linear_1.w_0@GRAD"
,
"linear_1.b_0"
,
"linear_1.b_0@GRAD"
,
]
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup
)
test_prog
=
train_prog
.
clone
(
for_test
=
True
)
for
i
in
range
(
5
):
inputs
=
np
.
random
.
random
(
size
=
[
8
,
10
]).
astype
(
'float32'
)
outputs
=
np
.
random
.
random
(
size
=
[
8
,
1
]).
astype
(
'float32'
)
param
=
exe
.
run
(
test_prog
,
feed
=
{
"x"
:
inputs
,
"y"
:
outputs
},
fetch_list
=
fetch_list1
,
)
params_and_gras
=
exe
.
run
(
train_prog
,
feed
=
{
"x"
:
inputs
,
"y"
:
outputs
},
fetch_list
=
fetch_list2
,
)
fc1_w
=
param
[
0
]
fc1_w_grad
=
params_and_gras
[
1
]
fc1_b
=
param
[
1
]
fc1_b_grad
=
params_and_gras
[
3
]
fc2_w
=
param
[
2
]
fc2_w_grad
=
params_and_gras
[
5
]
fc2_b
=
param
[
3
]
fc2_b_grad
=
params_and_gras
[
7
]
fc1_w
,
fc1_w_mon1
,
fc1_w_mon2
=
get_numpy_output
(
fc1_w
,
fc1_w_grad
,
fc1_w_mon1
,
fc1_w_mon2
,
simple_lr_fun
(
linear1
.
weight
),
i
+
1
,
)
fc1_b
,
fc1_b_mon1
,
fc1_b_mon2
=
get_numpy_output
(
fc1_b
,
fc1_b_grad
,
fc1_b_mon1
,
fc1_b_mon2
,
simple_lr_fun
(
linear1
.
bias
),
i
+
1
,
)
fc2_w
,
fc2_w_mon1
,
fc2_w_mon2
=
get_numpy_output
(
fc2_w
,
fc2_w_grad
,
fc2_w_mon1
,
fc2_w_mon2
,
simple_lr_fun
(
linear2
.
weight
),
i
+
1
,
)
fc2_b
,
fc2_b_mon1
,
fc2_b_mon2
=
get_numpy_output
(
fc2_b
,
fc2_b_grad
,
fc2_b_mon1
,
fc2_b_mon2
,
simple_lr_fun
(
linear2
.
bias
),
i
+
1
,
)
np
.
testing
.
assert_allclose
(
params_and_gras
[
0
],
fc1_w
,
rtol
=
1e-5
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
params_and_gras
[
2
],
fc1_b
,
rtol
=
1e-5
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
params_and_gras
[
4
],
fc2_w
,
rtol
=
1e-5
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
params_and_gras
[
6
],
fc2_b
,
rtol
=
1e-5
,
atol
=
1e-5
)
paddle
.
disable_static
()
support_types
=
get_xpu_op_support_types
(
'adamw'
)
for
stype
in
support_types
:
...
...
python/paddle/optimizer/adamw.py
浏览文件 @
dda74715
...
...
@@ -178,9 +178,12 @@ class AdamW(Optimizer):
raise
TypeError
(
"weight_decay should be float or Tensor."
)
if
lr_ratio
is
not
None
:
assert
isinstance
(
lr_ratio
,
Callable
)
if
not
core
.
is_compiled_with_cuda
():
if
(
not
core
.
is_compiled_with_cuda
()
and
not
core
.
is_compiled_with_xpu
()
):
raise
NotImplementedError
(
"'lr_ratio' is unimplemented in CPU,
XPU
and NPU"
"'lr_ratio' is unimplemented in CPU, and NPU"
)
if
parameters
is
not
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录