Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
77a8a394
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
77a8a394
编写于
8月 23, 2021
作者:
Z
zhaoyingli
提交者:
GitHub
8月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add adamw cuda kernel (#35020)
* adamw support cuda * adamw support cuda
上级
cf99c0d5
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
551 addition
and
102 deletion
+551
-102
paddle/fluid/operators/optimizers/adamw_op.cu
paddle/fluid/operators/optimizers/adamw_op.cu
+438
-0
paddle/fluid/operators/optimizers/adamw_op.h
paddle/fluid/operators/optimizers/adamw_op.h
+103
-1
paddle/fluid/pybind/op_function_generator.cc
paddle/fluid/pybind/op_function_generator.cc
+2
-0
python/paddle/fluid/tests/unittests/test_adamw_op.py
python/paddle/fluid/tests/unittests/test_adamw_op.py
+0
-27
python/paddle/optimizer/adamw.py
python/paddle/optimizer/adamw.py
+8
-74
未找到文件。
paddle/fluid/operators/optimizers/adamw_op.cu
0 → 100644
浏览文件 @
77a8a394
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/adamw_op.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
typename
MT
>
__global__
void
AdamWKernelREG
(
MT
beta1
,
MT
beta2
,
MT
epsilon
,
MT
coeff
,
MT
beta1_pow_
,
MT
beta2_pow_
,
const
MT
*
moment1
,
MT
*
moment1_out
,
const
MT
*
moment2
,
MT
*
moment2_out
,
const
MT
*
lr_
,
const
T
*
grad
,
const
T
*
param
,
T
*
param_out
,
const
MT
*
master_param
,
MT
*
master_param_out
,
int
ndim
)
{
MT
lr
=
*
lr_
;
MT
beta1_pow
=
beta1_pow_
;
MT
beta2_pow
=
beta2_pow_
;
MT
wd
=
static_cast
<
MT
>
(
1.0
)
-
coeff
*
lr
;
lr
*=
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
)
/
(
static_cast
<
MT
>
(
1.0
)
-
beta1_pow
);
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(;
id
<
ndim
;
id
+=
gridDim
.
x
*
blockDim
.
x
)
{
MT
p
=
master_param
?
master_param
[
id
]
:
static_cast
<
MT
>
(
param
[
id
]);
MT
g
=
static_cast
<
MT
>
(
grad
[
id
]);
MT
mom1
=
moment1
[
id
];
MT
mom2
=
moment2
[
id
];
mom1
=
beta1
*
mom1
+
(
static_cast
<
MT
>
(
1.0
)
-
beta1
)
*
g
;
mom2
=
beta2
*
mom2
+
(
static_cast
<
MT
>
(
1.0
)
-
beta2
)
*
g
*
g
;
p
=
wd
*
p
-
lr
*
(
mom1
/
(
sqrt
(
mom2
)
+
epsilon
*
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
)));
moment1_out
[
id
]
=
mom1
;
moment2_out
[
id
]
=
mom2
;
param_out
[
id
]
=
static_cast
<
T
>
(
p
);
if
(
master_param_out
)
{
master_param_out
[
id
]
=
p
;
}
}
}
template
<
typename
T
,
typename
MT
>
__global__
void
AdamWKernelMEM
(
MT
beta1
,
MT
beta2
,
MT
epsilon
,
MT
coeff
,
const
MT
*
beta1_pow_
,
const
MT
*
beta2_pow_
,
const
MT
*
moment1
,
MT
*
moment1_out
,
const
MT
*
moment2
,
MT
*
moment2_out
,
const
MT
*
lr_
,
const
T
*
grad
,
const
T
*
param
,
T
*
param_out
,
const
MT
*
master_param
,
MT
*
master_param_out
,
int
ndim
)
{
MT
lr
=
*
lr_
;
MT
beta1_pow
=
*
beta1_pow_
;
MT
beta2_pow
=
*
beta2_pow_
;
MT
wd
=
static_cast
<
MT
>
(
1.0
)
-
coeff
*
lr
;
lr
*=
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
)
/
(
static_cast
<
MT
>
(
1.0
)
-
beta1_pow
);
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(;
id
<
ndim
;
id
+=
gridDim
.
x
*
blockDim
.
x
)
{
MT
p
=
master_param
?
master_param
[
id
]
:
static_cast
<
MT
>
(
param
[
id
]);
MT
g
=
static_cast
<
MT
>
(
grad
[
id
]);
MT
mom1
=
static_cast
<
MT
>
(
moment1
[
id
]);
MT
mom2
=
static_cast
<
MT
>
(
moment2
[
id
]);
mom1
=
beta1
*
mom1
+
(
static_cast
<
MT
>
(
1.0
)
-
beta1
)
*
g
;
mom2
=
beta2
*
mom2
+
(
static_cast
<
MT
>
(
1.0
)
-
beta2
)
*
g
*
g
;
p
=
wd
*
p
-
lr
*
(
mom1
/
(
sqrt
(
mom2
)
+
epsilon
*
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
)));
moment1_out
[
id
]
=
mom1
;
moment2_out
[
id
]
=
mom2
;
param_out
[
id
]
=
static_cast
<
T
>
(
p
);
if
(
master_param_out
)
{
master_param_out
[
id
]
=
p
;
}
}
}
template
<
typename
T
>
__global__
void
UpdateAdamWBetaPow
(
T
beta1
,
T
beta2
,
const
T
*
beta1_pow_
,
const
T
*
beta2_pow_
,
T
*
beta1_pow_out
,
T
*
beta2_pow_out
)
{
*
beta1_pow_out
=
beta1
*
beta1_pow_
[
0
];
*
beta2_pow_out
=
beta2
*
beta2_pow_
[
0
];
}
template
<
typename
T
,
typename
MT
>
__global__
void
SparseAdamWCUDAKernelREG
(
MT
beta1
,
MT
beta2
,
MT
epsilon
,
MT
coeff
,
const
MT
beta1_pow
,
const
MT
beta2_pow
,
const
MT
*
mom1_
,
MT
*
mom1_out_
,
const
MT
*
mom2_
,
MT
*
mom2_out_
,
const
MT
*
lr_
,
const
T
*
grad_
,
const
T
*
param_
,
T
*
param_out_
,
const
MT
*
master_param
,
MT
*
master_param_out
,
const
int64_t
*
rows_
,
int64_t
row_numel
,
int64_t
row_count
,
bool
lazy_mode
,
int
ndim
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
MT
lr
=
*
lr_
;
MT
wd
=
static_cast
<
MT
>
(
1.0
)
-
coeff
*
lr
;
lr
*=
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
)
/
(
static_cast
<
MT
>
(
1.0
)
-
beta1_pow
);
for
(;
id
<
ndim
;
id
+=
blockDim
.
x
*
gridDim
.
x
)
{
auto
row_idx
=
math
::
BinarySearch
<
int64_t
>
(
rows_
,
row_count
,
id
/
row_numel
);
if
(
lazy_mode
&&
row_idx
<
0
)
{
return
;
}
else
{
MT
mom1
=
mom1_
[
id
];
MT
mom2
=
mom2_
[
id
];
MT
p
=
master_param
?
master_param
[
id
]
:
static_cast
<
MT
>
(
param_
[
id
]);
MT
g
=
row_idx
>=
0
?
static_cast
<
MT
>
(
grad_
[
row_idx
*
row_numel
+
id
%
row_numel
])
:
static_cast
<
MT
>
(
0
);
mom1
=
beta1
*
mom1
+
(
static_cast
<
MT
>
(
1.0
)
-
beta1
)
*
g
;
mom2
=
beta2
*
mom2
+
(
static_cast
<
MT
>
(
1.0
)
-
beta2
)
*
g
*
g
;
p
=
wd
*
p
-
lr
*
(
mom1
/
(
sqrt
(
mom2
)
+
epsilon
*
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
)));
// Write back to global memory
mom1_out_
[
id
]
=
mom1
;
mom2_out_
[
id
]
=
mom2
;
param_out_
[
id
]
=
static_cast
<
T
>
(
p
);
if
(
master_param_out
)
{
master_param_out
[
id
]
=
p
;
}
}
}
}
template
<
typename
T
>
class
AdamWOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE_EQ
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
ctx
.
InputNames
(
"Param"
).
front
(),
framework
::
ToTypeName
(
param_var
->
Type
())));
using
paddle
::
framework
::
LoDTensor
;
using
MPDType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
int64_t
min_row_size_to_use_multithread
=
ctx
.
Attr
<
int64_t
>
(
"min_row_size_to_use_multithread"
);
bool
lazy_mode
=
ctx
.
Attr
<
bool
>
(
"lazy_mode"
);
bool
use_global_beta_pow
=
ctx
.
Attr
<
bool
>
(
"use_global_beta_pow"
);
VLOG
(
4
)
<<
"use_global_beta_pow:"
<<
use_global_beta_pow
;
float
coeff
=
ctx
.
Attr
<
float
>
(
"coeff"
);
auto
*
param
=
ctx
.
Input
<
LoDTensor
>
(
"Param"
);
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
auto
*
mom1
=
ctx
.
Input
<
LoDTensor
>
(
"Moment1"
);
auto
*
mom2
=
ctx
.
Input
<
LoDTensor
>
(
"Moment2"
);
auto
*
lr
=
ctx
.
Input
<
LoDTensor
>
(
"LearningRate"
);
auto
*
beta1_pow
=
ctx
.
Input
<
LoDTensor
>
(
"Beta1Pow"
);
auto
*
beta2_pow
=
ctx
.
Input
<
LoDTensor
>
(
"Beta2Pow"
);
auto
*
param_out
=
ctx
.
Output
<
LoDTensor
>
(
"ParamOut"
);
auto
*
mom1_out
=
ctx
.
Output
<
LoDTensor
>
(
"Moment1Out"
);
auto
*
mom2_out
=
ctx
.
Output
<
LoDTensor
>
(
"Moment2Out"
);
auto
*
beta1_pow_out
=
ctx
.
Output
<
LoDTensor
>
(
"Beta1PowOut"
);
auto
*
beta2_pow_out
=
ctx
.
Output
<
LoDTensor
>
(
"Beta2PowOut"
);
bool
skip_update
=
false
;
if
(
ctx
.
HasInput
(
"SkipUpdate"
))
{
auto
*
skip_update_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"SkipUpdate"
);
PADDLE_ENFORCE_EQ
(
skip_update_tensor
->
numel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Input(SkipUpdate) size must be 1, but get %d"
,
skip_update_tensor
->
numel
()));
std
::
vector
<
bool
>
skip_update_vec
;
TensorToVector
(
*
skip_update_tensor
,
ctx
.
device_context
(),
&
skip_update_vec
);
skip_update
=
skip_update_vec
[
0
];
}
// skip_update=true, just copy input to output, and TensorCopy will call
// mutable_data
if
(
skip_update
)
{
VLOG
(
4
)
<<
"Adamw skip update"
;
framework
::
TensorCopy
(
*
param
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
param_out
);
framework
::
TensorCopy
(
*
mom1
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
mom1_out
);
framework
::
TensorCopy
(
*
mom2
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
mom2_out
);
framework
::
TensorCopy
(
*
beta1_pow
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
beta1_pow_out
);
framework
::
TensorCopy
(
*
beta2_pow
,
ctx
.
GetPlace
(),
ctx
.
template
device_context
<
platform
::
DeviceContext
>(),
beta2_pow_out
);
return
;
}
// if with_decay = false, coeff = 0
bool
with_decay
=
ctx
.
Attr
<
bool
>
(
"with_decay"
);
if
(
!
with_decay
)
{
coeff
=
static_cast
<
float
>
(
0.0
);
}
MPDType
beta1
=
static_cast
<
MPDType
>
(
ctx
.
Attr
<
float
>
(
"beta1"
));
if
(
ctx
.
HasInput
(
"Beta1Tensor"
))
{
auto
*
beta1_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Beta1Tensor"
);
PADDLE_ENFORCE_EQ
(
beta1_tensor
->
numel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Input(Beta1Tensor) size must be 1, but get %d"
,
beta1_tensor
->
numel
()));
beta1
=
static_cast
<
MPDType
>
(
GetAttrFromTensor
(
beta1_tensor
));
}
MPDType
beta2
=
static_cast
<
MPDType
>
(
ctx
.
Attr
<
float
>
(
"beta2"
));
if
(
ctx
.
HasInput
(
"Beta2Tensor"
))
{
auto
*
beta2_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Beta2Tensor"
);
PADDLE_ENFORCE_EQ
(
beta2_tensor
->
numel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Input(Beta2Tensor) size must be 1, but get %d"
,
beta2_tensor
->
numel
()));
beta2
=
static_cast
<
MPDType
>
(
GetAttrFromTensor
(
beta2_tensor
));
}
MPDType
epsilon
=
static_cast
<
MPDType
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
if
(
ctx
.
HasInput
(
"EpsilonTensor"
))
{
auto
*
epsilon_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"EpsilonTensor"
);
PADDLE_ENFORCE_EQ
(
epsilon_tensor
->
numel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Input(EpsilonTensor) size must be 1, but get %d"
,
epsilon_tensor
->
numel
()));
epsilon
=
static_cast
<
MPDType
>
(
GetAttrFromTensor
(
epsilon_tensor
));
}
VLOG
(
3
)
<<
"beta1_pow.numel() : "
<<
beta1_pow
->
numel
()
<<
"beta2_pow.numel() : "
<<
beta2_pow
->
numel
();
VLOG
(
3
)
<<
"param.numel(): "
<<
param
->
numel
();
PADDLE_ENFORCE_EQ
(
beta1_pow_out
->
numel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"beta1 pow output size should be 1, but received "
"value is:%d."
,
beta1_pow_out
->
numel
()));
PADDLE_ENFORCE_EQ
(
beta2_pow_out
->
numel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"beta2 pow output size should be 1, but received "
"value is:%d."
,
beta2_pow_out
->
numel
()));
const
bool
multi_precision
=
ctx
.
Attr
<
bool
>
(
"multi_precision"
);
const
LoDTensor
*
master_param
=
nullptr
;
LoDTensor
*
master_param_out
=
nullptr
;
if
(
multi_precision
)
{
bool
has_master
=
ctx
.
HasInput
(
"MasterParam"
)
&&
ctx
.
HasOutput
(
"MasterParamOut"
);
PADDLE_ENFORCE_EQ
(
has_master
,
true
,
platform
::
errors
::
InvalidArgument
(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"
));
master_param
=
ctx
.
Input
<
LoDTensor
>
(
"MasterParam"
);
master_param_out
=
ctx
.
Output
<
LoDTensor
>
(
"MasterParamOut"
);
}
const
MPDType
*
master_in_data
=
multi_precision
?
master_param
->
data
<
MPDType
>
()
:
nullptr
;
MPDType
*
master_out_data
=
multi_precision
?
master_param_out
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
())
:
nullptr
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
*
grad
=
ctx
.
Input
<
LoDTensor
>
(
"Grad"
);
// update param and moment
int
threads
=
512
;
int
blocks
=
(
param
->
numel
()
+
threads
-
1
)
/
threads
;
if
(
beta1_pow
->
place
()
==
platform
::
CPUPlace
()
&&
beta2_pow
->
place
()
==
platform
::
CPUPlace
())
{
// Compute with betapow in REG
AdamWKernelREG
<
T
,
MPDType
><<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
beta1
,
beta2
,
epsilon
,
coeff
,
*
beta1_pow
->
data
<
MPDType
>
(),
*
beta2_pow
->
data
<
MPDType
>
(),
mom1
->
data
<
MPDType
>
(),
mom1_out
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
mom2
->
data
<
MPDType
>
(),
mom2_out
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
lr
->
data
<
MPDType
>
(),
grad
->
data
<
T
>
(),
param
->
data
<
T
>
(),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
master_in_data
,
master_out_data
,
param
->
numel
());
if
(
!
use_global_beta_pow
)
{
// Cpu update
beta1_pow_out
->
mutable_data
<
MPDType
>
(
platform
::
CPUPlace
())[
0
]
=
beta1
*
beta1_pow
->
data
<
MPDType
>
()[
0
];
beta2_pow_out
->
mutable_data
<
MPDType
>
(
platform
::
CPUPlace
())[
0
]
=
beta2
*
beta2_pow
->
data
<
MPDType
>
()[
0
];
}
}
else
{
AdamWKernelMEM
<
T
,
MPDType
><<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
beta1
,
beta2
,
epsilon
,
coeff
,
beta1_pow
->
data
<
MPDType
>
(),
beta2_pow
->
data
<
MPDType
>
(),
mom1
->
data
<
MPDType
>
(),
mom1_out
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
mom2
->
data
<
MPDType
>
(),
mom2_out
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
lr
->
data
<
MPDType
>
(),
grad
->
data
<
T
>
(),
param
->
data
<
T
>
(),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
master_in_data
,
master_out_data
,
param
->
numel
());
if
(
!
use_global_beta_pow
)
{
// Update with gpu
UpdateAdamWBetaPow
<
MPDType
><<<
1
,
32
,
0
,
dev_ctx
.
stream
()
>>>
(
beta1
,
beta2
,
beta1_pow
->
data
<
MPDType
>
(),
beta2_pow
->
data
<
MPDType
>
(),
beta1_pow_out
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
beta2_pow_out
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()));
}
}
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
grad
=
ctx
.
Input
<
framework
::
SelectedRows
>
(
"Grad"
);
if
(
grad
->
rows
().
size
()
==
0
)
{
VLOG
(
3
)
<<
"grad row size is 0!!"
;
return
;
}
std
::
vector
<
int64_t
>
cpu_rows
(
grad
->
rows
().
begin
(),
grad
->
rows
().
end
());
bool
is_strict_sorted
=
true
;
for
(
size_t
i
=
1
;
i
<
cpu_rows
.
size
();
++
i
)
{
if
(
cpu_rows
[
i
-
1
]
>=
cpu_rows
[
i
])
{
is_strict_sorted
=
false
;
break
;
}
}
framework
::
SelectedRows
tmp_grad_merge
;
const
framework
::
SelectedRows
*
grad_merge_ptr
;
if
(
is_strict_sorted
)
{
grad_merge_ptr
=
grad
;
}
else
{
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter
::
MergeAdd
<
platform
::
CUDADeviceContext
,
T
>
merge_func
;
merge_func
(
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>(),
*
grad
,
&
tmp_grad_merge
,
true
);
grad_merge_ptr
=
&
tmp_grad_merge
;
}
auto
&
grad_merge
=
*
grad_merge_ptr
;
auto
&
grad_tensor
=
grad_merge
.
value
();
const
T
*
grad_data
=
grad_tensor
.
template
data
<
T
>();
const
int64_t
*
rows
=
grad_merge
.
rows
().
Data
(
ctx
.
GetPlace
());
auto
row_numel
=
grad_tensor
.
numel
()
/
grad_merge
.
rows
().
size
();
if
(
beta1_pow
->
place
()
==
platform
::
CPUPlace
()
&&
beta2_pow
->
place
()
==
platform
::
CPUPlace
())
{
int
threads
=
512
;
int
ndim
=
param
->
numel
();
int
blocks
=
(
ndim
+
threads
-
1
)
/
threads
;
SparseAdamWCUDAKernelREG
<
T
,
MPDType
><<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
beta1
,
beta2
,
epsilon
,
coeff
,
*
beta1_pow
->
data
<
MPDType
>
(),
*
beta2_pow
->
data
<
MPDType
>
(),
mom1
->
data
<
MPDType
>
(),
mom1_out
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
mom2
->
data
<
MPDType
>
(),
mom2_out
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
lr
->
data
<
MPDType
>
(),
grad_data
,
param
->
data
<
T
>
(),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
master_in_data
,
master_out_data
,
rows
,
row_numel
,
grad_merge
.
rows
().
size
(),
lazy_mode
,
ndim
);
if
(
!
use_global_beta_pow
)
{
// Update with cpu
beta1_pow_out
->
mutable_data
<
MPDType
>
(
platform
::
CPUPlace
())[
0
]
=
beta1
*
beta1_pow
->
data
<
MPDType
>
()[
0
];
beta2_pow_out
->
mutable_data
<
MPDType
>
(
platform
::
CPUPlace
())[
0
]
=
beta2
*
beta2_pow
->
data
<
MPDType
>
()[
0
];
}
}
else
{
SparseAdamWFunctor
<
T
,
GPUAdamW
,
MPDType
>
functor
(
beta1
,
beta2
,
epsilon
,
coeff
,
beta1_pow
->
data
<
MPDType
>
(),
beta2_pow
->
data
<
MPDType
>
(),
mom1
->
data
<
MPDType
>
(),
mom1_out
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
mom2
->
data
<
MPDType
>
(),
mom2_out
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
lr
->
data
<
MPDType
>
(),
grad_data
,
param
->
data
<
T
>
(),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
master_in_data
,
master_out_data
,
rows
,
row_numel
,
grad_merge
.
rows
().
size
(),
lazy_mode
);
// FIXME(minqiyang): remove BinarySearch in GPU later
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
static_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
()),
param
->
numel
());
for_range
(
functor
);
if
(
!
use_global_beta_pow
)
{
// update beta1 and beta2
UpdateAdamWBetaPow
<
MPDType
><<<
1
,
32
,
0
,
dev_ctx
.
stream
()
>>>
(
beta1
,
beta2
,
beta1_pow
->
data
<
MPDType
>
(),
beta2_pow
->
data
<
MPDType
>
(),
beta1_pow_out
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
beta2_pow_out
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()));
}
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Variable type not supported by adamw_op"
));
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
adamw
,
ops
::
AdamWOpCUDAKernel
<
float
>
,
ops
::
AdamWOpCUDAKernel
<
double
>
,
ops
::
AdamWOpCUDAKernel
<
plat
::
float16
>
);
paddle/fluid/operators/optimizers/adamw_op.h
浏览文件 @
77a8a394
/* Copyright (c) 20
16
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 20
21
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
...
...
@@ -22,6 +22,7 @@ class AdamWOp : public AdamOp {
using
AdamOp
::
AdamOp
;
};
struct
GPUAdamW
;
struct
CPUAdamW
;
template
<
typename
T
,
typename
Flavour
>
...
...
@@ -46,6 +47,107 @@ class AdamWFunctor<T, CPUAdamW> {
}
};
template
<
typename
T
,
typename
Flavour
,
typename
MT
=
T
>
class
SparseAdamWFunctor
;
template
<
typename
T
,
typename
MT
>
class
SparseAdamWFunctor
<
T
,
GPUAdamW
,
MT
>
{
private:
MT
beta1_
;
MT
beta2_
;
MT
epsilon_
;
MT
coeff_
;
const
MT
*
beta1_pow_
;
const
MT
*
beta2_pow_
;
const
MT
*
moment1_
;
MT
*
moment1_out_
;
const
MT
*
moment2_
;
MT
*
moment2_out_
;
const
MT
*
lr_
;
const
T
*
grad_
;
const
T
*
param_
;
T
*
param_out_
;
const
MT
*
master_param_
;
MT
*
master_param_out_
;
const
int64_t
*
rows_
;
int64_t
row_numel_
;
int64_t
row_count_
;
bool
lazy_mode_
;
public:
SparseAdamWFunctor
(
MT
beta1
,
MT
beta2
,
MT
epsilon
,
MT
coeff
,
const
MT
*
beta1_pow
,
const
MT
*
beta2_pow
,
const
MT
*
mom1
,
MT
*
mom1_out
,
const
MT
*
mom2
,
MT
*
mom2_out
,
const
MT
*
lr
,
const
T
*
grad
,
const
T
*
param
,
T
*
param_out
,
const
MT
*
master_param
,
MT
*
master_param_out
,
const
int64_t
*
rows
,
int64_t
row_numel
,
int64_t
row_count
,
bool
lazy_mode
)
:
beta1_
(
beta1
),
beta2_
(
beta2
),
epsilon_
(
epsilon
),
coeff_
(
coeff
),
beta1_pow_
(
beta1_pow
),
beta2_pow_
(
beta2_pow
),
moment1_
(
mom1
),
moment1_out_
(
mom1_out
),
moment2_
(
mom2
),
moment2_out_
(
mom2_out
),
lr_
(
lr
),
grad_
(
grad
),
param_
(
param
),
param_out_
(
param_out
),
master_param_
(
master_param
),
master_param_out_
(
master_param_out
),
rows_
(
rows
),
row_numel_
(
row_numel
),
row_count_
(
row_count
),
lazy_mode_
(
lazy_mode
)
{}
inline
HOSTDEVICE
void
adamw_update
(
size_t
i
,
MT
g
)
const
{
// The following code is the same as dense
MT
mom1
=
moment1_
[
i
];
MT
mom2
=
moment2_
[
i
];
MT
lr
=
*
lr_
;
MT
beta1_pow
=
*
beta1_pow_
;
MT
beta2_pow
=
*
beta2_pow_
;
MT
p
=
master_param_
?
master_param_
[
i
]
:
static_cast
<
MT
>
(
param_
[
i
]);
// Calculation
MT
wd
=
static_cast
<
MT
>
(
1.0
)
-
coeff_
*
lr
;
lr
*=
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
)
/
(
static_cast
<
MT
>
(
1.0
)
-
beta1_pow
);
mom1
=
beta1_
*
mom1
+
(
static_cast
<
MT
>
(
1.0
)
-
beta1_
)
*
g
;
mom2
=
beta2_
*
mom2
+
(
static_cast
<
MT
>
(
1.0
)
-
beta2_
)
*
g
*
g
;
p
=
wd
*
p
-
lr
*
(
mom1
/
(
sqrt
(
mom2
)
+
epsilon_
*
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
)));
// Write back to global memory
moment1_out_
[
i
]
=
mom1
;
moment2_out_
[
i
]
=
mom2
;
param_out_
[
i
]
=
static_cast
<
T
>
(
p
);
if
(
master_param_out_
)
{
master_param_out_
[
i
]
=
p
;
}
}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
auto
row_idx
=
math
::
BinarySearch
<
int64_t
>
(
rows_
,
row_count_
,
i
/
row_numel_
);
if
(
lazy_mode_
&&
row_idx
<
0
)
{
return
;
}
else
{
MT
g
=
row_idx
>=
0
?
static_cast
<
MT
>
(
grad_
[
row_idx
*
row_numel_
+
i
%
row_numel_
])
:
static_cast
<
MT
>
(
0
);
adamw_update
(
i
,
g
);
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
AdamWOpKernel
:
public
AdamOpKernel
<
DeviceContext
,
T
>
{
public:
...
...
paddle/fluid/pybind/op_function_generator.cc
浏览文件 @
77a8a394
...
...
@@ -118,6 +118,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{
"sgd"
,
{
"ParamOut"
}},
{
"adam"
,
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
}},
{
"adamw"
,
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
}},
{
"average_accumulates"
,
{
"out_sum_1"
,
"out_sum_2"
,
"out_sum_3"
,
"out_num_accumulates"
,
"out_old_num_accumulates"
,
"out_num_updates"
}},
...
...
python/paddle/fluid/tests/unittests/test_adamw_op.py
浏览文件 @
77a8a394
...
...
@@ -93,33 +93,6 @@ class TestAdamWOp(unittest.TestCase):
adam
=
paddle
.
optimizer
.
AdamW
(
0.1
,
epsilon
=-
1
,
parameters
=
linear
.
parameters
())
def
test_adamw_lr_decay
(
self
):
paddle
.
disable_static
()
value
=
np
.
arange
(
26
).
reshape
(
2
,
13
).
astype
(
"float32"
)
a
=
paddle
.
to_tensor
(
value
)
linear
=
paddle
.
nn
.
Linear
(
13
,
5
)
lr
=
paddle
.
optimizer
.
lr
.
NoamDecay
(
d_model
=
0.01
,
warmup_steps
=
10
)
wd
=
0.1
adam
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
lr
,
parameters
=
linear
.
parameters
(),
apply_decay_param_fun
=
lambda
name
:
True
,
weight_decay
=
wd
)
for
_
in
range
(
2
):
out
=
linear
(
a
)
out
.
backward
()
lr_to_coeff
=
adam
.
_lr_to_coeff
adam
.
step
()
for
i
,
value
in
enumerate
(
lr_to_coeff
.
values
()):
self
.
assertAlmostEqual
(
value
.
numpy
()[
0
],
1.0
-
lr
()
*
wd
)
self
.
assertEqual
(
len
(
adam
.
_lr_to_coeff
),
0
)
lr
.
step
()
adam
.
clear_gradients
()
class
TestAdamWOpGroup
(
TestAdamWOp
):
def
test_adamw_op_dygraph
(
self
):
...
...
python/paddle/optimizer/adamw.py
浏览文件 @
77a8a394
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -48,8 +48,8 @@ class AdamW(Adam):
Args:
learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``.
It can be a float value or a LRScheduler. The default value is 0.001.
parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``. \
This parameter is required in dygraph mode. And you can specify different options for \
parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``. \
This parameter is required in dygraph mode. And you can specify different options for \
different parameter groups such as the learning rate, weight decay, etc, \
then the parameters are list of dict. Note that the learning_rate in paramter groups \
represents the scale of base learning_rate. \
...
...
@@ -162,7 +162,6 @@ class AdamW(Adam):
self
.
_params_name
=
set
()
self
.
_apply_decay_param_fun
=
apply_decay_param_fun
self
.
_coeff
=
coeff
self
.
_lr_to_coeff
=
dict
()
super
(
AdamW
,
self
).
__init__
(
learning_rate
=
learning_rate
,
...
...
@@ -178,9 +177,6 @@ class AdamW(Adam):
self
.
type
=
"adamw"
# now the adamw op doesn't support cuda
if
core
.
is_compiled_with_cuda
():
self
.
type
=
"adam"
# Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that.
self
.
_auxiliary_vars
=
dict
()
...
...
@@ -193,64 +189,7 @@ class AdamW(Adam):
else
:
return
None
def
_append_decoupled_weight_decay
(
self
,
block
,
param_and_grad
):
"""
Add decoupled weight decay op.
parameter = parameter - parameter * coeff * lr
Args:
block: block in which variable is to be created
param_and_grad: (parameters, gradients) pairs,
the parameters need to decay.
Raises:
Exception: The type of coeff and parameter is not consistent.
"""
if
isinstance
(
param_and_grad
,
dict
):
param_and_grad
=
self
.
_update_param_group
(
param_and_grad
)
param
,
grad
=
param_and_grad
if
self
.
_apply_decay_param_fun
is
not
None
\
and
not
self
.
_apply_decay_param_fun
(
param
.
name
):
return
if
isinstance
(
self
.
_learning_rate
,
float
):
learning_rate
=
self
.
_learning_rate
else
:
# NOTE. We add this function to the _append_optimize_op(),
# for we must make sure _create_param_lr() be called after
# optimizer._create_global_learning_rate().
learning_rate
=
self
.
_create_param_lr
(
param_and_grad
)
with
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
framework
.
name_scope
(
'weight decay'
):
self
.
_params_name
.
add
(
param
.
name
)
# If it has been calculated, the result will be reused.
# NOTE(wangxi): In dygraph mode, apply_gradient will be executed
# every step, so need clear _lr_to_coeff every step,
# we do this in _create_optimization_pass
decay_coeff
=
self
.
_lr_to_coeff
.
get
(
learning_rate
,
None
)
if
decay_coeff
is
None
:
# NOTE(wangxi): for pipeline to set device:all
with
paddle
.
static
.
device_guard
(
None
):
decay_coeff
=
1.0
-
learning_rate
*
self
.
_coeff
self
.
_lr_to_coeff
[
learning_rate
]
=
decay_coeff
find_master
=
(
self
.
_multi_precision
and
param
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
)
if
find_master
:
master_weight
=
self
.
_master_weights
[
param
.
name
]
scaled_param
=
master_weight
*
decay_coeff
paddle
.
fluid
.
layers
.
assign
(
input
=
scaled_param
,
output
=
master_weight
)
else
:
scaled_param
=
param
*
decay_coeff
paddle
.
fluid
.
layers
.
assign
(
input
=
scaled_param
,
output
=
param
)
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
if
not
core
.
is_compiled_with_npu
():
self
.
_append_decoupled_weight_decay
(
block
,
param_and_grad
)
return
super
(
AdamW
,
self
).
_append_optimize_op
(
block
,
param_and_grad
)
assert
isinstance
(
block
,
framework
.
Block
)
if
isinstance
(
param_and_grad
,
dict
):
...
...
@@ -262,6 +201,8 @@ class AdamW(Adam):
if
self
.
_apply_decay_param_fun
is
not
None
\
and
not
self
.
_apply_decay_param_fun
(
param
.
name
):
with_decay
=
False
else
:
self
.
_params_name
.
add
(
param
.
name
)
moment1
=
self
.
_get_accumulator
(
self
.
_moment1_acc_str
,
param_and_grad
[
0
])
...
...
@@ -277,19 +218,19 @@ class AdamW(Adam):
if
find_master
else
None
)
lr
=
self
.
_create_param_lr
(
param_and_grad
)
# create the adam optimize op
# create the adam
w
optimize op
if
framework
.
in_dygraph_mode
():
_beta1
=
self
.
_beta1
if
not
isinstance
(
self
.
_beta1
,
Variable
)
else
self
.
_beta1
.
numpy
().
item
(
0
)
_beta2
=
self
.
_beta2
if
not
isinstance
(
self
.
_beta2
,
Variable
)
else
self
.
_beta2
.
numpy
().
item
(
0
)
_
,
_
,
_
,
_
,
_
=
_C_ops
.
adam
(
_
,
_
,
_
,
_
,
_
=
_C_ops
.
adam
w
(
param_and_grad
[
0
],
param_and_grad
[
1
],
lr
,
moment1
,
moment2
,
beta1_pow_acc
,
beta2_pow_acc
,
param_and_grad
[
0
],
moment1
,
moment2
,
beta1_pow_acc
,
beta2_pow_acc
,
'epsilon'
,
self
.
_epsilon
,
'lazy_mode'
,
self
.
_lazy_mode
,
'min_row_size_to_use_multithread'
,
1000
,
'beta1'
,
_beta1
,
'beta2'
,
_beta2
)
1000
,
'beta1'
,
_beta1
,
'beta2'
,
_beta2
,
'coeff'
,
self
.
_coeff
)
return
None
...
...
@@ -350,13 +291,6 @@ class AdamW(Adam):
return
adamw_op
def
_create_optimization_pass
(
self
,
parameters_and_grads
):
optimize_ops
=
super
(
AdamW
,
self
).
_create_optimization_pass
(
parameters_and_grads
)
# In dygraph mode, clear _lr_to_coeff after applied gradient
self
.
_lr_to_coeff
=
dict
()
return
optimize_ops
def
__str__
(
self
):
return
" "
.
join
([
"Weight Decay, params:"
,
","
.
join
(
self
.
_params_name
)])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录