Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
48060b2e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
48060b2e
编写于
3月 01, 2023
作者:
N
niuliling123
提交者:
GitHub
3月 01, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add multiprecision for rms op (#50132)
上级
798b527c
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
879 addition
and
210 deletion
+879
-210
paddle/fluid/operators/optimizers/rmsprop_op.cc
paddle/fluid/operators/optimizers/rmsprop_op.cc
+10
-0
paddle/fluid/pybind/eager_generator.h
paddle/fluid/pybind/eager_generator.h
+20
-1
paddle/phi/api/yaml/legacy_ops.yaml
paddle/phi/api/yaml/legacy_ops.yaml
+7
-6
paddle/phi/infermeta/multiary.cc
paddle/phi/infermeta/multiary.cc
+4
-1
paddle/phi/infermeta/multiary.h
paddle/phi/infermeta/multiary.h
+4
-1
paddle/phi/kernels/cpu/rmsprop_kernel.cc
paddle/phi/kernels/cpu/rmsprop_kernel.cc
+92
-0
paddle/phi/kernels/gpu/rmsprop_kernel.cu
paddle/phi/kernels/gpu/rmsprop_kernel.cu
+90
-3
paddle/phi/kernels/impl/rmsprop_kernel_impl.h
paddle/phi/kernels/impl/rmsprop_kernel_impl.h
+125
-160
paddle/phi/kernels/rmsprop_kernel.h
paddle/phi/kernels/rmsprop_kernel.h
+8
-2
paddle/phi/kernels/xpu/rmsprop_kernel.cc
paddle/phi/kernels/xpu/rmsprop_kernel.cc
+4
-1
paddle/phi/ops/compat/rmsprop_sig.cc
paddle/phi/ops/compat/rmsprop_sig.cc
+26
-6
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+106
-14
python/paddle/fluid/tests/unittests/test_rmsprop_op.py
python/paddle/fluid/tests/unittests/test_rmsprop_op.py
+274
-0
python/paddle/optimizer/rmsprop.py
python/paddle/optimizer/rmsprop.py
+109
-15
未找到文件。
paddle/fluid/operators/optimizers/rmsprop_op.cc
浏览文件 @
48060b2e
...
...
@@ -38,6 +38,7 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor, default Tensor<float>)"
" The moving average of gradient"
)
.
AsDispensable
();
AddInput
(
"LearningRate"
,
"(Tensor, default Tensor<float>) "
"The learning rate should be a tensor of size 1."
);
...
...
@@ -46,12 +47,17 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
"Input gradient of the parameter."
);
AddInput
(
"Moment"
,
"(Tensor, default Tensor<float>) The moment that gets updated."
);
AddInput
(
"MasterParam"
,
"FP32 master weight for AMP."
).
AsDispensable
();
AddOutput
(
"ParamOut"
,
"(Tensor) Output updated parameter value."
);
AddOutput
(
"MomentOut"
,
"(Tensor) Output updated moment."
);
AddOutput
(
"MeanSquareOut"
,
"(Tensor) Output Mean squared updated value."
);
AddOutput
(
"MeanGradOut"
,
"(Tensor) Output moving average of gradient updated value."
);
AddOutput
(
"MasterParamOut"
,
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam)."
)
.
AsDispensable
();
AddAttr
<
float
>
(
"epsilon"
,
"(float, default 1e-10) Constant "
...
...
@@ -65,6 +71,10 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
.
SetDefault
(
0.0
f
);
AddAttr
<
bool
>
(
"centered"
,
"(bool, default false) use centered rmsprop."
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"multi_precision"
,
"(bool, default false) "
"Whether to use multi-precision during weight updating."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
Rmsprop Optimizer.
...
...
paddle/fluid/pybind/eager_generator.h
浏览文件 @
48060b2e
...
...
@@ -148,6 +148,14 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"Ln2Bias"
}},
{
"faster_tokenizer"
,
{
"Text"
,
"Vocab"
,
"TextPair"
}},
{
"matrix_rank"
,
{
"X"
,
"TolTensor"
}},
{
"rmsprop"
,
{
"Param"
,
"MeanSquare"
,
"Grad"
,
"Moment"
,
"LearningRate"
,
"MeanGrad"
,
"MasterParam"
}},
{
"adam"
,
{
"Param"
,
"Grad"
,
...
...
@@ -311,6 +319,12 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{
"MultiFpnRois"
,
"RestoreIndex"
,
"MultiLevelRoIsNum"
}},
{
"moving_average_abs_max_scale"
,
{
"Out"
,
"OutScale"
,
"OutAccum"
,
"OutState"
}},
{
"rmsprop"
,
{
"ParamOut"
,
"MomentOut"
,
"MeanSquareOut"
,
"MeanGradOut"
,
"MasterParamOut"
}},
{
"multiclass_nms3"
,
{
"Out"
,
"NmsRoisNum"
}},
{
"generate_proposals_v2"
,
{
"RpnRois"
,
"RpnRoiProbs"
,
"RpnRoisNum"
}},
{
"momentum"
,
{
"ParamOut"
,
"VelocityOut"
,
"MasterParamOut"
}},
...
...
@@ -377,7 +391,12 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
// For those OPs, we need to manually specify the outs need to pass in this map.
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
op_passing_outs_map
=
{
{
"sgd"
,
{
"ParamOut"
,
"MasterParamOut"
}},
{
"rmsprop"
,
{
"ParamOut"
,
"MomentOut"
,
"MeanSquareOut"
,
"MeanGradOut"
}},
{
"rmsprop"
,
{
"ParamOut"
,
"MomentOut"
,
"MeanSquareOut"
,
"MeanGradOut"
,
"MasterParamOut"
}},
{
"ftrl"
,
{
"ParamOut"
,
"SquaredAccumOut"
,
"LinearAccumOut"
}},
{
"adadelta"
,
{
"ParamOut"
,
"AvgSquaredGradOut"
,
"AvgSquaredUpdateOut"
}},
{
"adagrad"
,
{
"ParamOut"
,
"MomentOut"
}},
...
...
paddle/phi/api/yaml/legacy_ops.yaml
浏览文件 @
48060b2e
...
...
@@ -1459,15 +1459,16 @@
backward
:
reverse_grad
-
op
:
rmsprop_
args
:
(Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad,
float epsilon, float decay, float momentum, bool centered
)
output
:
Tensor(param_out), Tensor(moment_out), Tensor(mean_square_out), Tensor(mean_grad_out)
args
:
(Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad,
Tensor master_param, float epsilon, float decay, float momentum, bool centered, bool multi_precision
)
output
:
Tensor(param_out), Tensor(moment_out), Tensor(mean_square_out), Tensor(mean_grad_out)
, Tensor(master_param_out)
infer_meta
:
func
:
RmspropInferMeta
kernel
:
func
:
rmsprop {dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense}
rmsprop_dense_param_sparse_grad {dense, dense, selected_rows, dense, dense, dense -> dense, dense, dense, dense}
optional
:
mean_grad
inplace
:
(param -> param_out), (moment -> moment_out), (mean_square -> mean_square_out), (mean_grad -> mean_grad_out)
func
:
rmsprop {dense, dense, dense, dense, dense, dense, dense-> dense, dense, dense, dense, dense}
rmsprop_dense_param_sparse_grad {dense, dense, selected_rows, dense, dense, dense, dense-> dense, dense, dense, dense, dense}
data_type
:
param
optional
:
mean_grad, master_param
inplace
:
(param -> param_out), (moment -> moment_out), (mean_square -> mean_square_out), (mean_grad -> mean_grad_out), (master_param->master_param_out)
-
op
:
rnn
args
:
(Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor dropout_state_in, float dropout_prob=0.0, bool is_bidirec=false, int input_size=10, int hidden_size=100, int num_layers=1, str mode="RNN_TANH", int seed=0, bool is_test=false)
...
...
paddle/phi/infermeta/multiary.cc
浏览文件 @
48060b2e
...
...
@@ -2313,14 +2313,17 @@ void RmspropInferMeta(const MetaTensor& param,
const
MetaTensor
&
moment
,
const
MetaTensor
&
learning_rate
,
const
MetaTensor
&
mean_grad
,
const
MetaTensor
&
master_param
,
float
epsilon
,
float
decay
,
float
momentum
,
bool
centered
,
bool
multi_precision
,
MetaTensor
*
param_out
,
MetaTensor
*
moment_out
,
MetaTensor
*
mean_square_out
,
MetaTensor
*
mean_grad_out
)
{
MetaTensor
*
mean_grad_out
,
MetaTensor
*
master_param_outs
)
{
if
(
centered
)
{
PADDLE_ENFORCE_NOT_NULL
(
mean_grad_out
,
...
...
paddle/phi/infermeta/multiary.h
浏览文件 @
48060b2e
...
...
@@ -421,14 +421,17 @@ void RmspropInferMeta(const MetaTensor& param,
const
MetaTensor
&
moment
,
const
MetaTensor
&
learning_rate
,
const
MetaTensor
&
mean_grad
,
const
MetaTensor
&
master_param
,
float
epsilon
,
float
decay
,
float
momentum
,
bool
centered
,
bool
multi_precision
,
MetaTensor
*
param_out
,
MetaTensor
*
moment_out
,
MetaTensor
*
mean_square_out
,
MetaTensor
*
mean_grad_out
);
MetaTensor
*
mean_grad_out
,
MetaTensor
*
master_param_outs
);
void
RnnInferMeta
(
const
MetaTensor
&
x
,
const
std
::
vector
<
const
MetaTensor
*>&
pre_state
,
...
...
paddle/phi/kernels/cpu/rmsprop_kernel.cc
浏览文件 @
48060b2e
...
...
@@ -17,7 +17,99 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/rmsprop_kernel_impl.h"
namespace
phi
{
template
<
typename
T
>
struct
RmsFunctor
<
T
,
phi
::
CPUContext
>
{
RmsFunctor
(
const
phi
::
CPUContext
&
ctx
,
const
DenseTensor
&
param
,
const
DenseTensor
&
mean_square
,
const
DenseTensor
&
grad
,
const
DenseTensor
&
moment
,
const
DenseTensor
&
learning_rate
,
const
paddle
::
optional
<
DenseTensor
>
&
mean_grad_opt
,
const
paddle
::
optional
<
DenseTensor
>
&
master_param
,
float
epsilon_t
,
float
decay_t
,
float
momentum_t
,
bool
centered
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
moment_out
,
DenseTensor
*
mean_square_out
,
DenseTensor
*
mean_grad_out
,
DenseTensor
*
master_param_outs
)
{
auto
epsilon
=
static_cast
<
T
>
(
epsilon_t
);
auto
rho
=
static_cast
<
T
>
(
decay_t
);
auto
momentum
=
static_cast
<
T
>
(
momentum_t
);
auto
&
p_tensor
=
param
;
auto
&
ms_tensor
=
mean_square
;
auto
&
lr_tensor
=
learning_rate
;
auto
&
mom_tensor
=
moment
;
PADDLE_ENFORCE_EQ
(
p_tensor
.
IsSharedBufferWith
(
*
param_out
),
true
,
phi
::
errors
::
InvalidArgument
(
"Param and ParamOut must be the same Tensor"
));
PADDLE_ENFORCE_EQ
(
mom_tensor
.
IsSharedBufferWith
(
*
moment_out
),
true
,
phi
::
errors
::
InvalidArgument
(
"Moment and MomentOut must be the same Tensor"
));
PADDLE_ENFORCE_EQ
(
ms_tensor
.
IsSharedBufferWith
(
*
mean_square_out
),
true
,
phi
::
errors
::
InvalidArgument
(
"MeanSquare and MeanSquareOut must be the same Tensor"
));
auto
&
grad_tensor
=
grad
;
auto
&
place
=
*
ctx
.
eigen_device
();
auto
lr_value
=
lr_tensor
.
data
<
T
>
()[
0
];
auto
p
=
EigenVector
<
T
>::
Flatten
(
p_tensor
);
auto
ms
=
EigenVector
<
T
>::
Flatten
(
ms_tensor
);
auto
g
=
EigenVector
<
T
>::
Flatten
(
grad_tensor
);
auto
mom
=
EigenVector
<
T
>::
Flatten
(
mom_tensor
);
auto
p_out
=
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
mom_out
=
EigenVector
<
T
>::
Flatten
(
*
moment_out
);
auto
ms_out
=
EigenVector
<
T
>::
Flatten
(
*
mean_square_out
);
ms_out
.
device
(
place
)
=
rho
*
ms
+
(
1
-
rho
)
*
g
*
g
;
if
(
centered
)
{
auto
mg_tensor
=
mean_grad_opt
.
get_ptr
();
if
(
mg_tensor
)
{
PADDLE_ENFORCE_EQ
(
mg_tensor
->
Holder
(),
mean_grad_out
->
Holder
(),
phi
::
errors
::
InvalidArgument
(
"MeanGrad and MeanGradOut must be the same Tensor"
));
}
else
{
PADDLE_ENFORCE_EQ
(
mg_tensor
,
mean_grad_out
,
phi
::
errors
::
InvalidArgument
(
"MeanGrad and MeanGradOut must be the same Tensor"
));
}
auto
mg
=
EigenVector
<
T
>::
Flatten
(
*
mg_tensor
);
auto
mg_out
=
EigenVector
<
T
>::
Flatten
(
*
mean_grad_out
);
mg_out
.
device
(
place
)
=
rho
*
mg
+
(
1
-
rho
)
*
g
;
mom_out
.
device
(
place
)
=
momentum
*
mom
+
lr_value
*
g
/
(
ms_out
-
mg_out
.
square
()
+
epsilon
).
sqrt
();
}
else
{
mom_out
.
device
(
place
)
=
momentum
*
mom
+
lr_value
*
g
/
(
ms_out
+
epsilon
).
sqrt
();
}
p_out
.
device
(
place
)
=
p
-
mom_out
;
}
};
template
struct
RmsFunctor
<
phi
::
GPUContext
,
float
>;
template
struct
RmsFunctor
<
phi
::
GPUContext
,
double
>;
template
struct
RmsFunctor
<
phi
::
GPUContext
,
phi
::
dtype
::
float16
>;
}
// namespace phi
PD_REGISTER_KERNEL
(
rmsprop
,
CPU
,
ALL_LAYOUT
,
phi
::
RmspropDenseKernel
,
float
,
double
)
{}
...
...
paddle/phi/kernels/gpu/rmsprop_kernel.cu
浏览文件 @
48060b2e
...
...
@@ -18,12 +18,99 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/rmsprop_kernel_impl.h"
PD_REGISTER_KERNEL
(
rmsprop
,
GPU
,
ALL_LAYOUT
,
phi
::
RmspropDenseKernel
,
float
,
double
)
{}
namespace
phi
{
template
<
typename
T
>
struct
RmsFunctor
<
T
,
phi
::
GPUContext
>
{
RmsFunctor
(
const
phi
::
GPUContext
&
ctx
,
const
DenseTensor
&
param
,
const
DenseTensor
&
mean_square
,
const
DenseTensor
&
grad
,
const
DenseTensor
&
moment
,
const
DenseTensor
&
learning_rate
,
const
paddle
::
optional
<
DenseTensor
>
&
mean_grad_opt
,
const
paddle
::
optional
<
DenseTensor
>
&
master_param
,
float
epsilon_t
,
float
decay_t
,
float
momentum_t
,
bool
centered
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
moment_out
,
DenseTensor
*
mean_square_out
,
DenseTensor
*
mean_grad_out
,
DenseTensor
*
master_param_outs
)
{
auto
&
p_tensor
=
param
;
auto
&
ms_tensor
=
mean_square
;
auto
&
lr_tensor
=
learning_rate
;
auto
&
mom_tensor
=
moment
;
auto
&
grad_tensor
=
grad
;
size_t
limit
=
static_cast
<
size_t
>
(
ms_tensor
.
numel
());
DenseRmspropGradFunctor
<
T
>
grad_func
(
grad_tensor
.
data
<
T
>
());
funcs
::
ForRange
<
phi
::
GPUContext
>
for_range
(
ctx
,
limit
);
using
MPDType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPDType
*
master_out_data
=
multi_precision
?
ctx
.
template
Alloc
<
MPDType
>(
master_param_outs
)
:
nullptr
;
if
(
centered
)
{
auto
mg_tensor
=
mean_grad_opt
.
get_ptr
();
if
(
mg_tensor
)
{
PADDLE_ENFORCE_EQ
(
mg_tensor
->
Holder
(),
mean_grad_out
->
Holder
(),
phi
::
errors
::
InvalidArgument
(
"MeanGrad and MeanGradOut must be the same Tensor"
));
}
else
{
PADDLE_ENFORCE_EQ
(
mg_tensor
,
mean_grad_out
,
phi
::
errors
::
InvalidArgument
(
"MeanGrad and MeanGradOut must be the same Tensor"
));
}
for_range
(
CenteredRmspropFunctor
<
T
,
MPDType
,
DenseRmspropGradFunctor
<
T
>>
(
ctx
.
template
Alloc
<
T
>(
param_out
),
ctx
.
template
Alloc
<
MPDType
>(
mean_square_out
),
ctx
.
template
Alloc
<
MPDType
>(
moment_out
),
ctx
.
template
Alloc
<
MPDType
>(
mean_grad_out
),
lr_tensor
.
data
<
MPDType
>
(),
master_out_data
,
static_cast
<
MPDType
>
(
decay_t
),
static_cast
<
MPDType
>
(
epsilon_t
),
static_cast
<
MPDType
>
(
momentum_t
),
grad_func
));
}
else
{
for_range
(
UncenteredRmspropFunctor
<
T
,
MPDType
,
DenseRmspropGradFunctor
<
T
>>
(
ctx
.
template
Alloc
<
T
>(
param_out
),
ctx
.
template
Alloc
<
MPDType
>(
mean_square_out
),
ctx
.
template
Alloc
<
MPDType
>(
moment_out
),
lr_tensor
.
data
<
MPDType
>
(),
master_out_data
,
static_cast
<
MPDType
>
(
decay_t
),
static_cast
<
MPDType
>
(
epsilon_t
),
static_cast
<
MPDType
>
(
momentum_t
),
grad_func
));
}
}
};
template
struct
RmsFunctor
<
phi
::
GPUContext
,
float
>;
template
struct
RmsFunctor
<
phi
::
GPUContext
,
double
>;
template
struct
RmsFunctor
<
phi
::
GPUContext
,
phi
::
dtype
::
float16
>;
}
// namespace phi
PD_REGISTER_KERNEL
(
rmsprop
,
GPU
,
ALL_LAYOUT
,
phi
::
RmspropDenseKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
PD_REGISTER_KERNEL
(
rmsprop_dense_param_sparse_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
RmspropSparseKernel
,
float
,
double
)
{}
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/impl/rmsprop_kernel_impl.h
浏览文件 @
48060b2e
...
...
@@ -16,14 +16,36 @@
#include <math.h>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
#include "paddle/phi/kernels/rmsprop_kernel.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
struct
RmsFunctor
{
RmsFunctor
(
const
Context
&
ctx
,
const
DenseTensor
&
param
,
const
DenseTensor
&
mean_square
,
const
DenseTensor
&
grad
,
const
DenseTensor
&
moment
,
const
DenseTensor
&
learning_rate
,
const
paddle
::
optional
<
DenseTensor
>
&
mean_grad_opt
,
const
paddle
::
optional
<
DenseTensor
>
&
master_param
,
float
epsilon_t
,
float
decay_t
,
float
momentum_t
,
bool
centered
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
moment_out
,
DenseTensor
*
mean_square_out
,
DenseTensor
*
mean_grad_out
,
DenseTensor
*
master_param_outs
);
};
template
<
typename
T
>
struct
DenseRmspropGradFunctor
{
inline
explicit
DenseRmspropGradFunctor
(
const
T
*
grad
)
:
grad_
(
grad
)
{}
...
...
@@ -47,7 +69,8 @@ struct SparseRmspropGradFunctor {
HOSTDEVICE
inline
T
operator
()(
int64_t
idx
)
const
{
auto
row_idx
=
phi
::
funcs
::
BinarySearch
(
rows_
,
row_count_
,
idx
/
row_numel_
);
return
row_idx
>=
0
?
grad_
[
row_idx
*
row_numel_
+
idx
%
row_numel_
]
:
0
;
return
row_idx
>=
0
?
grad_
[
row_idx
*
row_numel_
+
idx
%
row_numel_
]
:
static_cast
<
T
>
(
0
);
}
const
T
*
grad_
;
...
...
@@ -56,19 +79,21 @@ struct SparseRmspropGradFunctor {
int64_t
row_count_
;
};
template
<
typename
T
,
typename
GradFunctor
>
template
<
typename
T
,
typename
MT
,
typename
GradFunctor
>
struct
UncenteredRmspropFunctor
{
UncenteredRmspropFunctor
(
T
*
param
,
T
*
ms
,
T
*
mom
,
const
T
*
lr
,
T
rho
,
T
epsilon
,
T
momentum
,
MT
*
ms
,
MT
*
mom
,
const
MT
*
lr
,
MT
*
master_p
,
MT
rho
,
MT
epsilon
,
MT
momentum
,
const
GradFunctor
&
grad_functor
)
:
param_
(
param
),
ms_
(
ms
),
mom_
(
mom
),
master_p_
(
master_p
),
lr_
(
lr
),
rho_
(
rho
),
epsilon_
(
epsilon
),
...
...
@@ -76,38 +101,46 @@ struct UncenteredRmspropFunctor {
grad_functor_
(
grad_functor
)
{}
HOSTDEVICE
inline
void
operator
()(
int64_t
idx
)
const
{
T
g
=
grad_functor_
(
idx
);
T
ms_out
=
rho_
*
ms_
[
idx
]
+
(
1
-
rho_
)
*
g
*
g
;
T
mom_out
=
momentum_
*
mom_
[
idx
]
+
lr_
[
0
]
*
g
/
sqrt
(
ms_out
+
epsilon_
);
param_
[
idx
]
-=
mom_out
;
MT
g
=
static_cast
<
MT
>
(
grad_functor_
(
idx
));
MT
l_rho
=
static_cast
<
MT
>
(
1
)
-
rho_
;
MT
ms_out
=
rho_
*
ms_
[
idx
]
+
l_rho
*
g
*
g
;
MT
mom_out
=
momentum_
*
mom_
[
idx
]
+
static_cast
<
MT
>
(
lr_
[
0
])
*
g
/
sqrt
(
ms_out
+
epsilon_
);
MT
p
=
master_p_
?
master_p_
[
idx
]
:
static_cast
<
MT
>
(
param_
[
idx
]);
MT
p_m
=
p
-
mom_out
;
param_
[
idx
]
=
static_cast
<
T
>
(
p_m
);
ms_
[
idx
]
=
ms_out
;
mom_
[
idx
]
=
mom_out
;
if
(
master_p_
)
master_p_
[
idx
]
=
p_m
;
}
T
*
param_
;
T
*
ms_
;
T
*
mom_
;
const
T
*
lr_
;
T
rho_
;
T
epsilon_
;
T
momentum_
;
MT
*
ms_
;
MT
*
mom_
;
MT
*
master_p_
;
const
MT
*
lr_
;
MT
rho_
;
MT
epsilon_
;
MT
momentum_
;
GradFunctor
grad_functor_
;
};
template
<
typename
T
,
typename
GradFunctor
>
template
<
typename
T
,
typename
MT
,
typename
GradFunctor
>
struct
CenteredRmspropFunctor
{
CenteredRmspropFunctor
(
T
*
param
,
T
*
ms
,
T
*
mom
,
T
*
mean_grad
,
const
T
*
lr
,
T
rho
,
T
epsilon
,
T
momentum
,
MT
*
ms
,
MT
*
mom
,
MT
*
mean_grad
,
const
MT
*
lr
,
MT
*
master_param
,
MT
rho
,
MT
epsilon
,
MT
momentum
,
const
GradFunctor
&
grad_functor
)
:
param_
(
param
),
ms_
(
ms
),
mom_
(
mom
),
master_p_
(
master_param
),
mean_grad_
(
mean_grad
),
lr_
(
lr
),
rho_
(
rho
),
...
...
@@ -116,25 +149,32 @@ struct CenteredRmspropFunctor {
grad_functor_
(
grad_functor
)
{}
HOSTDEVICE
inline
void
operator
()(
int64_t
idx
)
const
{
T
g
=
grad_functor_
(
idx
);
T
ms_out
=
rho_
*
ms_
[
idx
]
+
(
1
-
rho_
)
*
g
*
g
;
T
mg_out
=
rho_
*
mean_grad_
[
idx
]
+
(
1
-
rho_
)
*
g
;
T
mom_out
=
momentum_
*
mom_
[
idx
]
+
lr_
[
0
]
*
g
/
sqrt
(
ms_out
-
mg_out
*
mg_out
+
epsilon_
);
param_
[
idx
]
-=
mom_out
;
MT
g
=
static_cast
<
MT
>
(
grad_functor_
(
idx
));
MT
l_rho
=
static_cast
<
MT
>
(
1
)
-
rho_
;
MT
ms_out
=
rho_
*
ms_
[
idx
]
+
l_rho
*
g
*
g
;
MT
mg_out
=
rho_
*
mean_grad_
[
idx
]
+
l_rho
*
g
;
MT
mom_out
=
momentum_
*
mom_
[
idx
]
+
static_cast
<
MT
>
(
lr_
[
0
])
*
g
/
sqrt
(
ms_out
-
mg_out
*
mg_out
+
epsilon_
);
MT
p
=
master_p_
?
master_p_
[
idx
]
:
static_cast
<
MT
>
(
param_
[
idx
]);
MT
p_m
=
p
-
mom_out
;
param_
[
idx
]
=
static_cast
<
T
>
(
p_m
);
ms_
[
idx
]
=
ms_out
;
mom_
[
idx
]
=
mom_out
;
mean_grad_
[
idx
]
=
mg_out
;
if
(
master_p_
)
master_p_
[
idx
]
=
p_m
;
}
T
*
param_
;
T
*
ms_
;
T
*
mom_
;
T
*
mean_grad_
;
const
T
*
lr_
;
T
rho_
;
T
epsilon_
;
T
momentum_
;
MT
*
ms_
;
MT
*
mom_
;
MT
*
master_p_
;
MT
*
mean_grad_
;
const
MT
*
lr_
;
MT
rho_
;
MT
epsilon_
;
MT
momentum_
;
GradFunctor
grad_functor_
;
};
...
...
@@ -146,120 +186,35 @@ void RmspropDenseKernel(const Context &ctx,
const
DenseTensor
&
moment
,
const
DenseTensor
&
learning_rate
,
const
paddle
::
optional
<
DenseTensor
>
&
mean_grad_opt
,
const
paddle
::
optional
<
DenseTensor
>
&
master_param
,
float
epsilon_t
,
float
decay_t
,
float
momentum_t
,
bool
centered
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
moment_out
,
DenseTensor
*
mean_square_out
,
DenseTensor
*
mean_grad_out
)
{
auto
epsilon
=
static_cast
<
T
>
(
epsilon_t
);
auto
rho
=
static_cast
<
T
>
(
decay_t
);
auto
momentum
=
static_cast
<
T
>
(
momentum_t
);
auto
&
p_tensor
=
param
;
auto
&
ms_tensor
=
mean_square
;
auto
&
lr_tensor
=
learning_rate
;
auto
&
mom_tensor
=
moment
;
PADDLE_ENFORCE_EQ
(
p_tensor
.
IsSharedBufferWith
(
*
param_out
),
true
,
phi
::
errors
::
InvalidArgument
(
"Param and ParamOut must be the same Tensor"
));
PADDLE_ENFORCE_EQ
(
mom_tensor
.
IsSharedBufferWith
(
*
moment_out
),
true
,
phi
::
errors
::
InvalidArgument
(
"Moment and MomentOut must be the same Tensor"
));
PADDLE_ENFORCE_EQ
(
ms_tensor
.
IsSharedBufferWith
(
*
mean_square_out
),
true
,
phi
::
errors
::
InvalidArgument
(
"MeanSquare and MeanSquareOut must be the same Tensor"
));
size_t
limit
=
static_cast
<
size_t
>
(
ms_tensor
.
numel
());
auto
&
grad_tensor
=
grad
;
if
(
paddle
::
platform
::
is_cpu_place
(
ctx
.
GetPlace
()))
{
auto
&
place
=
*
ctx
.
eigen_device
();
auto
lr_value
=
lr_tensor
.
data
<
T
>
()[
0
];
auto
p
=
EigenVector
<
T
>::
Flatten
(
p_tensor
);
auto
ms
=
EigenVector
<
T
>::
Flatten
(
ms_tensor
);
auto
g
=
EigenVector
<
T
>::
Flatten
(
grad_tensor
);
auto
mom
=
EigenVector
<
T
>::
Flatten
(
mom_tensor
);
auto
p_out
=
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
mom_out
=
EigenVector
<
T
>::
Flatten
(
*
moment_out
);
auto
ms_out
=
EigenVector
<
T
>::
Flatten
(
*
mean_square_out
);
ms_out
.
device
(
place
)
=
rho
*
ms
+
(
1
-
rho
)
*
g
*
g
;
if
(
centered
)
{
auto
mg_tensor
=
mean_grad_opt
.
get_ptr
();
auto
mg
=
EigenVector
<
T
>::
Flatten
(
*
mg_tensor
);
if
(
mg_tensor
)
{
PADDLE_ENFORCE_EQ
(
mg_tensor
->
Holder
(),
mean_grad_out
->
Holder
(),
phi
::
errors
::
InvalidArgument
(
"MeanGrad and MeanGradOut must be the same Tensor"
));
}
else
{
PADDLE_ENFORCE_EQ
(
mg_tensor
,
mean_grad_out
,
phi
::
errors
::
InvalidArgument
(
"MeanGrad and MeanGradOut must be the same Tensor"
));
}
auto
mg_out
=
EigenVector
<
T
>::
Flatten
(
*
mean_grad_out
);
mg_out
.
device
(
place
)
=
rho
*
mg
+
(
1
-
rho
)
*
g
;
mom_out
.
device
(
place
)
=
momentum
*
mom
+
lr_value
*
g
/
(
ms_out
-
mg_out
.
square
()
+
epsilon
).
sqrt
();
}
else
{
mom_out
.
device
(
place
)
=
momentum
*
mom
+
lr_value
*
g
/
(
ms_out
+
epsilon
).
sqrt
();
}
p_out
.
device
(
place
)
=
p
-
mom_out
;
}
else
{
DenseRmspropGradFunctor
<
T
>
grad_func
(
grad_tensor
.
data
<
T
>
());
funcs
::
ForRange
<
Context
>
for_range
(
ctx
,
limit
);
if
(
centered
)
{
auto
mg_tensor
=
mean_grad_opt
.
get_ptr
();
if
(
mg_tensor
)
{
PADDLE_ENFORCE_EQ
(
mg_tensor
->
Holder
(),
mean_grad_out
->
Holder
(),
phi
::
errors
::
InvalidArgument
(
"MeanGrad and MeanGradOut must be the same Tensor"
));
}
else
{
PADDLE_ENFORCE_EQ
(
mg_tensor
,
mean_grad_out
,
phi
::
errors
::
InvalidArgument
(
"MeanGrad and MeanGradOut must be the same Tensor"
));
}
for_range
(
CenteredRmspropFunctor
<
T
,
DenseRmspropGradFunctor
<
T
>>
(
ctx
.
template
Alloc
<
T
>(
param_out
),
ctx
.
template
Alloc
<
T
>(
mean_square_out
),
ctx
.
template
Alloc
<
T
>(
moment_out
),
ctx
.
template
Alloc
<
T
>(
mean_grad_out
),
lr_tensor
.
data
<
T
>
(),
rho
,
epsilon
,
momentum
,
grad_func
));
}
else
{
for_range
(
UncenteredRmspropFunctor
<
T
,
DenseRmspropGradFunctor
<
T
>>
(
ctx
.
template
Alloc
<
T
>(
param_out
),
ctx
.
template
Alloc
<
T
>(
mean_square_out
),
ctx
.
template
Alloc
<
T
>(
moment_out
),
lr_tensor
.
data
<
T
>
(),
rho
,
epsilon
,
momentum
,
grad_func
));
}
}
DenseTensor
*
mean_grad_out
,
DenseTensor
*
master_param_outs
)
{
RmsFunctor
<
T
,
Context
>
functor
(
ctx
,
param
,
mean_square
,
grad
,
moment
,
learning_rate
,
mean_grad_opt
,
master_param
,
epsilon_t
,
decay_t
,
momentum_t
,
centered
,
multi_precision
,
param_out
,
moment_out
,
mean_square_out
,
mean_grad_out
,
master_param_outs
);
}
template
<
typename
T
,
typename
Context
>
...
...
@@ -270,17 +225,21 @@ void RmspropSparseKernel(const Context &ctx,
const
DenseTensor
&
moment
,
const
DenseTensor
&
learning_rate
,
const
paddle
::
optional
<
DenseTensor
>
&
mean_grad_opt
,
const
paddle
::
optional
<
DenseTensor
>
&
master_param
,
float
epsilon_t
,
float
decay_t
,
float
momentum_t
,
bool
centered
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
moment_out
,
DenseTensor
*
mean_square_out
,
DenseTensor
*
mean_grad_out
)
{
auto
epsilon
=
static_cast
<
T
>
(
epsilon_t
);
auto
rho
=
static_cast
<
T
>
(
decay_t
);
auto
momentum
=
static_cast
<
T
>
(
momentum_t
);
DenseTensor
*
mean_grad_out
,
DenseTensor
*
master_param_outs
)
{
using
MPDType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
auto
epsilon
=
static_cast
<
MPDType
>
(
epsilon_t
);
auto
rho
=
static_cast
<
MPDType
>
(
decay_t
);
auto
momentum
=
static_cast
<
MPDType
>
(
momentum_t
);
auto
&
p_tensor
=
param
;
auto
&
ms_tensor
=
mean_square
;
...
...
@@ -318,6 +277,10 @@ void RmspropSparseKernel(const Context &ctx,
SparseRmspropGradFunctor
<
T
>
grad_func
(
merged_tensor
.
data
<
T
>
(),
rows
,
row_numel
,
row_count
);
MPDType
*
master_out_data
=
multi_precision
?
ctx
.
template
Alloc
<
MPDType
>(
master_param_outs
)
:
nullptr
;
if
(
centered
)
{
auto
mg_tensor
=
mean_grad_opt
.
get_ptr
();
if
(
mg_tensor
)
{
...
...
@@ -334,22 +297,24 @@ void RmspropSparseKernel(const Context &ctx,
"MeanGrad and MeanGradOut must be the same Tensor"
));
}
for_range
(
CenteredRmspropFunctor
<
T
,
SparseRmspropGradFunctor
<
T
>>
(
for_range
(
CenteredRmspropFunctor
<
T
,
MPDType
,
SparseRmspropGradFunctor
<
T
>>
(
ctx
.
template
Alloc
<
T
>(
param_out
),
ctx
.
template
Alloc
<
T
>(
mean_square_out
),
ctx
.
template
Alloc
<
T
>(
moment_out
),
ctx
.
template
Alloc
<
T
>(
mean_grad_out
),
lr_tensor
.
data
<
T
>
(),
ctx
.
template
Alloc
<
MPDType
>(
mean_square_out
),
ctx
.
template
Alloc
<
MPDType
>(
moment_out
),
ctx
.
template
Alloc
<
MPDType
>(
mean_grad_out
),
lr_tensor
.
data
<
MPDType
>
(),
master_out_data
,
rho
,
epsilon
,
momentum
,
grad_func
));
}
else
{
for_range
(
UncenteredRmspropFunctor
<
T
,
SparseRmspropGradFunctor
<
T
>>
(
for_range
(
UncenteredRmspropFunctor
<
T
,
MPDType
,
SparseRmspropGradFunctor
<
T
>>
(
ctx
.
template
Alloc
<
T
>(
param_out
),
ctx
.
template
Alloc
<
T
>(
mean_square_out
),
ctx
.
template
Alloc
<
T
>(
moment_out
),
lr_tensor
.
data
<
T
>
(),
ctx
.
template
Alloc
<
MPDType
>(
mean_square_out
),
ctx
.
template
Alloc
<
MPDType
>(
moment_out
),
lr_tensor
.
data
<
MPDType
>
(),
master_out_data
,
rho
,
epsilon
,
momentum
,
...
...
paddle/phi/kernels/rmsprop_kernel.h
浏览文件 @
48060b2e
...
...
@@ -27,14 +27,17 @@ void RmspropDenseKernel(const Context& dev_ctx,
const
DenseTensor
&
moment
,
const
DenseTensor
&
learning_rate
,
const
paddle
::
optional
<
DenseTensor
>&
mean_grad
,
const
paddle
::
optional
<
DenseTensor
>&
master_param
,
float
epsilon
,
float
decay
,
float
momentum
,
bool
centered
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
moment_out
,
DenseTensor
*
mean_square_out
,
DenseTensor
*
mean_grad_out
);
DenseTensor
*
mean_grad_out
,
DenseTensor
*
master_param_outs
);
template
<
typename
T
,
typename
Context
>
void
RmspropSparseKernel
(
const
Context
&
dev_ctx
,
...
...
@@ -44,13 +47,16 @@ void RmspropSparseKernel(const Context& dev_ctx,
const
DenseTensor
&
moment
,
const
DenseTensor
&
learning_rate
,
const
paddle
::
optional
<
DenseTensor
>&
mean_grad
,
const
paddle
::
optional
<
DenseTensor
>&
master_param
,
float
epsilon
,
float
decay
,
float
momentum
,
bool
centered
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
moment_out
,
DenseTensor
*
mean_square_out
,
DenseTensor
*
mean_grad_out
);
DenseTensor
*
mean_grad_out
,
DenseTensor
*
master_param_outs
);
}
// namespace phi
paddle/phi/kernels/xpu/rmsprop_kernel.cc
浏览文件 @
48060b2e
...
...
@@ -29,14 +29,17 @@ void RmspropDenseKernel(const Context& dev_ctx,
const
DenseTensor
&
moment
,
const
DenseTensor
&
learning_rate
,
const
paddle
::
optional
<
DenseTensor
>&
mean_grad
,
const
paddle
::
optional
<
DenseTensor
>&
master_param
,
float
epsilon
,
float
decay
,
float
momentum
,
bool
centered
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
moment_out
,
DenseTensor
*
mean_square_out
,
DenseTensor
*
mean_grad_out
)
{
DenseTensor
*
mean_grad_out
,
DenseTensor
*
master_param_outs
)
{
// copy learning_rate to cpu
PADDLE_ENFORCE_EQ
(
learning_rate
.
dims
().
size
(),
...
...
paddle/phi/ops/compat/rmsprop_sig.cc
浏览文件 @
48060b2e
...
...
@@ -20,15 +20,35 @@ KernelSignature RmspropOpArgumentMapping(const ArgumentMappingContext& ctx) {
if
(
ctx
.
IsDenseTensorInput
(
"Grad"
))
{
return
KernelSignature
(
"rmsprop"
,
{
"Param"
,
"MeanSquare"
,
"Grad"
,
"Moment"
,
"LearningRate"
,
"MeanGrad"
},
{
"epsilon"
,
"decay"
,
"momentum"
,
"centered"
},
{
"ParamOut"
,
"MomentOut"
,
"MeanSquareOut"
,
"MeanGradOut"
});
{
"Param"
,
"MeanSquare"
,
"Grad"
,
"Moment"
,
"LearningRate"
,
"MeanGrad"
,
"MasterParam"
},
{
"epsilon"
,
"decay"
,
"momentum"
,
"centered"
,
"multi_precision"
},
{
"ParamOut"
,
"MomentOut"
,
"MeanSquareOut"
,
"MeanGradOut"
,
"MasterParamOut"
});
}
else
if
(
ctx
.
IsSelectedRowsInput
(
"Grad"
))
{
return
KernelSignature
(
"rmsprop_dense_param_sparse_grad"
,
{
"Param"
,
"MeanSquare"
,
"Grad"
,
"Moment"
,
"LearningRate"
,
"MeanGrad"
},
{
"epsilon"
,
"decay"
,
"momentum"
,
"centered"
},
{
"ParamOut"
,
"MomentOut"
,
"MeanSquareOut"
,
"MeanGradOut"
});
{
"Param"
,
"MeanSquare"
,
"Grad"
,
"Moment"
,
"LearningRate"
,
"MeanGrad"
,
"MasterParam"
},
{
"epsilon"
,
"decay"
,
"momentum"
,
"centered"
,
"multi_precision"
},
{
"ParamOut"
,
"MomentOut"
,
"MeanSquareOut"
,
"MeanGradOut"
,
"MasterParamOut"
});
}
return
KernelSignature
(
"unregistered"
,
{},
{},
{});
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
48060b2e
...
...
@@ -3287,12 +3287,84 @@ class RMSPropOptimizer(Optimizer):
self
.
_epsilon
=
epsilon
self
.
_momentum
=
momentum
self
.
_centered
=
centered
self
.
_multi_precision
=
False
self
.
_master_weights
=
{}
def
_create_master_weight
(
self
,
param
):
if
param
.
name
in
self
.
_master_weights
:
var
=
self
.
_master_weights
[
param
.
name
]
else
:
assert
isinstance
(
self
.
helper
,
LayerHelper
)
var_name
=
param
.
name
+
'_fp32_master'
var_name
=
unique_name
.
generate
(
var_name
)
var
=
paddle
.
static
.
create_global_var
(
name
=
var_name
,
shape
=
param
.
shape
,
value
=
0
,
dtype
=
'float32'
,
persistable
=
True
,
)
block
=
self
.
helper
.
startup_program
.
global_block
()
block
.
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
[
param
]},
outputs
=
{
"Out"
:
[
var
]},
attrs
=
{
"in_dtype"
:
param
.
dtype
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
,
},
)
self
.
_master_weights
[
param
.
name
]
=
var
return
var
def
_get_accumulator
(
self
,
name
,
param
):
"""Utility function to fetch an accumulator for a parameter
Args:
name: name of the accumulator
param: parameter variable for which accumulator is to be fetched
Returns:
accumulator variable for the parameter
"""
if
self
.
_name
is
not
None
:
name
=
self
.
_name
+
"_"
+
name
find_master
=
(
self
.
_multi_precision
and
param
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
)
target_param
=
(
self
.
_master_weights
[
param
.
name
]
if
find_master
else
param
)
target_name
=
target_param
.
name
if
(
name
not
in
self
.
_accumulators
or
target_name
not
in
self
.
_accumulators
[
name
]
):
raise
Exception
(
"Accumulator {} does not exist for parameter {}"
.
format
(
name
,
target_name
)
)
return
self
.
_accumulators
[
name
][
target_name
]
def
_create_accumulators
(
self
,
block
,
parameters
):
if
not
isinstance
(
block
,
framework
.
Block
):
raise
TypeError
(
"block is not instance of framework.Block."
)
for
p
in
parameters
:
if
self
.
_multi_precision
and
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
master_p
=
self
.
_create_master_weight
(
p
)
self
.
_add_accumulator
(
self
.
_momentum_acc_str
,
master_p
)
self
.
_add_accumulator
(
self
.
_mean_square_acc_str
,
master_p
)
self
.
_add_accumulator
(
self
.
_mean_grad_acc_str
,
master_p
)
continue
if
(
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
and
not
self
.
_multi_precision
):
warnings
.
warn
(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Lars optimizer."
)
self
.
_add_accumulator
(
self
.
_momentum_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_mean_square_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_mean_grad_acc_str
,
p
)
...
...
@@ -3310,6 +3382,15 @@ class RMSPropOptimizer(Optimizer):
mean_grad_acc
=
self
.
_get_accumulator
(
self
.
_mean_grad_acc_str
,
param_and_grad
[
0
]
)
find_master
=
(
self
.
_multi_precision
and
param_and_grad
[
0
].
dtype
==
core
.
VarDesc
.
VarType
.
FP16
)
master_weight
=
(
self
.
_master_weights
[
param_and_grad
[
0
].
name
]
if
find_master
else
None
)
if
in_dygraph_mode
():
_C_ops
.
rmsprop_
(
param_and_grad
[
0
],
...
...
@@ -3318,34 +3399,45 @@ class RMSPropOptimizer(Optimizer):
momentum_acc
,
self
.
_create_param_lr
(
param_and_grad
),
mean_grad_acc
,
master_weight
,
self
.
_epsilon
,
self
.
_rho
,
self
.
_momentum
,
self
.
_centered
,
find_master
,
)
return
None
else
:
inputs
=
{
"Param"
:
param_and_grad
[
0
],
"Grad"
:
param_and_grad
[
1
],
"Moment"
:
momentum_acc
,
"MeanSquare"
:
mean_square_acc
,
"MeanGrad"
:
mean_grad_acc
,
"LearningRate"
:
self
.
_create_param_lr
(
param_and_grad
),
}
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
],
"MomentOut"
:
momentum_acc
,
"MeanSquareOut"
:
mean_square_acc
,
"MeanGradOut"
:
mean_grad_acc
,
}
if
find_master
:
inputs
[
"MasterParam"
]
=
master_weight
outputs
[
"MasterParamOut"
]
=
master_weight
rmsprop_op
=
block
.
append_op
(
type
=
self
.
type
,
inputs
=
{
"Param"
:
param_and_grad
[
0
],
"Grad"
:
param_and_grad
[
1
],
"Moment"
:
momentum_acc
,
"MeanSquare"
:
mean_square_acc
,
"MeanGrad"
:
mean_grad_acc
,
"LearningRate"
:
self
.
_create_param_lr
(
param_and_grad
),
},
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
],
"MomentOut"
:
momentum_acc
,
"MeanSquareOut"
:
mean_square_acc
,
"MeanGradOut"
:
mean_grad_acc
,
},
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
{
"epsilon"
:
self
.
_epsilon
,
"decay"
:
self
.
_rho
,
"momentum"
:
self
.
_momentum
,
"centered"
:
self
.
_centered
,
"multi_precision"
:
find_master
,
},
stop_gradient
=
True
,
)
...
...
python/paddle/fluid/tests/unittests/test_rmsprop_op.py
浏览文件 @
48060b2e
...
...
@@ -356,6 +356,280 @@ class TestRMSPropV2Group(TestRMSPropV2):
adam
.
clear_gradients
()
class
TestRMSOpMultiPrecison
(
unittest
.
TestCase
):
def
_test_rms_op_dygraph_place_amp
(
self
,
place
,
use_amp
=
False
):
import
paddle
paddle
.
disable_static
()
paddle
.
seed
(
10
)
paddle
.
set_device
(
place
)
input
=
paddle
.
randn
((
5
,
5
))
model
=
paddle
.
nn
.
Linear
(
5
,
5
)
optimizer
=
paddle
.
optimizer
.
RMSProp
(
learning_rate
=
0.01
,
parameters
=
model
.
parameters
(),
weight_decay
=
0.01
,
)
optimizer
.
_multi_precision
=
use_amp
for
idx
in
range
(
2
):
if
place
==
'gpu'
and
use_amp
:
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
1024
)
if
place
==
'gpu'
and
use_amp
:
with
paddle
.
amp
.
auto_cast
(
level
=
'O2'
):
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
scaled
=
scaler
.
scale
(
loss
)
scaled
.
backward
()
scaler
.
step
(
optimizer
)
optimizer
.
clear_grad
()
else
:
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
loss
.
backward
()
optimizer
.
step
()
optimizer
.
clear_grad
()
paddle
.
enable_static
()
def
_get_places
(
self
):
import
paddle
places
=
[
'cpu'
]
if
paddle
.
is_compiled_with_cuda
():
places
.
append
(
'gpu'
)
return
places
def
test_main
(
self
):
for
place
in
self
.
_get_places
():
use_amp_list
=
[
True
,
False
]
for
use_amp
in
use_amp_list
:
self
.
_test_rms_op_dygraph_place_amp
(
place
,
use_amp
)
class
TestRMSPropMultiPrecision2_0
(
unittest
.
TestCase
):
def
dygraph_rmsprop_mp
(
self
,
mp
,
use_amp
):
paddle
.
disable_static
()
paddle
.
seed
(
100
)
paddle
.
set_device
(
'gpu'
)
input
=
paddle
.
randn
((
2
,
2
))
model
=
paddle
.
nn
.
Linear
(
2
,
2
)
optimizer
=
paddle
.
optimizer
.
RMSProp
(
0.5
,
parameters
=
model
.
parameters
())
optimizer
.
_multi_precision
=
mp
if
use_amp
:
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
1024
)
for
idx
in
range
(
5
):
if
use_amp
:
with
paddle
.
amp
.
auto_cast
(
level
=
'O2'
):
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
scaled
=
scaler
.
scale
(
loss
)
scaled
.
backward
()
scaler
.
minimize
(
optimizer
,
scaled
)
optimizer
.
clear_grad
()
else
:
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
loss
.
backward
()
optimizer
.
step
()
optimizer
.
clear_grad
()
return
output
,
model
.
parameters
()
def
static_rmsprop_mp
(
self
,
mp
,
use_amp
):
paddle
.
enable_static
()
paddle
.
seed
(
100
)
np
.
random
.
seed
(
100
)
exe
=
paddle
.
static
.
Executor
(
'gpu'
)
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
optimizer
=
paddle
.
optimizer
.
RMSProp
(
0.1
)
optimizer
.
_multi_precision
=
mp
if
use_amp
:
optimizer
=
paddle
.
static
.
amp
.
decorate
(
optimizer
,
init_loss_scaling
=
128.0
,
use_dynamic_loss_scaling
=
True
,
use_pure_fp16
=
True
,
use_fp16_guard
=
False
,
)
with
paddle
.
static
.
program_guard
(
train_program
,
startup_program
):
if
use_amp
:
data
=
paddle
.
static
.
data
(
shape
=
[
2
,
2
],
name
=
'X'
,
dtype
=
'float16'
)
else
:
data
=
paddle
.
static
.
data
(
shape
=
[
2
,
2
],
name
=
'X'
,
dtype
=
'float32'
)
hidden
=
paddle
.
static
.
nn
.
fc
(
x
=
data
,
size
=
10
)
loss
=
paddle
.
mean
(
hidden
)
optimizer
.
minimize
(
loss
)
exe
.
run
(
startup_program
)
if
use_amp
:
optimizer
.
amp_init
(
place
=
'gpu'
,
scope
=
paddle
.
static
.
global_scope
())
x
=
np
.
random
.
random
(
size
=
(
2
,
2
)).
astype
(
'float16'
)
else
:
x
=
np
.
random
.
random
(
size
=
(
2
,
2
)).
astype
(
'float32'
)
out
=
[]
for
idx
in
range
(
5
):
(
loss_data
,)
=
exe
.
run
(
train_program
,
feed
=
{
"X"
:
x
},
fetch_list
=
[
loss
.
name
]
)
out
.
append
(
loss_data
)
return
out
def
test_main
(
self
):
if
not
paddle
.
is_compiled_with_cuda
():
return
"Test dygraph mode"
output1_dy
,
params1_dy
=
self
.
dygraph_rmsprop_mp
(
use_amp
=
True
,
mp
=
True
)
output2_dy
,
params2_dy
=
self
.
dygraph_rmsprop_mp
(
use_amp
=
False
,
mp
=
False
)
np
.
testing
.
assert_allclose
(
output1_dy
.
astype
(
'float32'
).
numpy
(),
output2_dy
.
astype
(
'float32'
).
numpy
(),
rtol
=
1e-05
,
atol
=
0.1
,
)
for
idx
in
range
(
len
(
params1_dy
)):
np
.
testing
.
assert_allclose
(
params1_dy
[
idx
].
astype
(
'float32'
).
numpy
(),
params2_dy
[
idx
].
astype
(
'float32'
).
numpy
(),
rtol
=
1e-05
,
atol
=
0.1
,
)
"Test static mode"
output1_st
=
self
.
static_rmsprop_mp
(
use_amp
=
True
,
mp
=
True
)
output2_st
=
self
.
static_rmsprop_mp
(
use_amp
=
False
,
mp
=
False
)
for
idx
in
range
(
len
(
output1_st
)):
np
.
testing
.
assert_allclose
(
output1_st
[
idx
].
astype
(
'float32'
),
output2_st
[
idx
].
astype
(
'float32'
),
rtol
=
1e-05
,
atol
=
0.1
,
)
class
TestRMSPropMultiPrecision1_0
(
unittest
.
TestCase
):
def
dygraph_rmsprop_mp
(
self
,
use_amp
,
mp
):
paddle
.
disable_static
()
paddle
.
seed
(
10
)
paddle
.
set_device
(
'gpu'
)
input
=
paddle
.
randn
((
2
,
2
))
model
=
paddle
.
nn
.
Linear
(
2
,
2
)
optimizer
=
paddle
.
fluid
.
optimizer
.
RMSProp
(
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
(),
)
optimizer
.
_multi_precision
=
mp
if
use_amp
:
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
1024
)
for
idx
in
range
(
5
):
if
use_amp
:
with
paddle
.
amp
.
auto_cast
(
level
=
'O2'
):
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
scaled
=
scaler
.
scale
(
loss
)
scaled
.
backward
()
scaler
.
minimize
(
optimizer
,
scaled
)
optimizer
.
clear_gradients
()
else
:
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
optimizer
.
minimize
(
loss
)
optimizer
.
clear_gradients
()
return
output
,
model
.
parameters
()
def
static_rmsprop_mp
(
self
,
use_amp
,
mp
):
paddle
.
enable_static
()
paddle
.
seed
(
100
)
np
.
random
.
seed
(
100
)
exe
=
paddle
.
static
.
Executor
(
'gpu'
)
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
optimizer
=
paddle
.
fluid
.
optimizer
.
RMSProp
(
learning_rate
=
0.001
)
optimizer
.
_multi_precision
=
mp
if
use_amp
:
optimizer
=
paddle
.
static
.
amp
.
decorate
(
optimizer
,
init_loss_scaling
=
128.0
,
use_dynamic_loss_scaling
=
True
,
use_pure_fp16
=
True
,
use_fp16_guard
=
False
,
)
with
paddle
.
static
.
program_guard
(
train_program
,
startup_program
):
if
use_amp
:
data
=
paddle
.
static
.
data
(
shape
=
[
2
,
2
],
name
=
'X'
,
dtype
=
'float16'
)
else
:
data
=
paddle
.
static
.
data
(
shape
=
[
2
,
2
],
name
=
'X'
,
dtype
=
'float32'
)
hidden
=
paddle
.
static
.
nn
.
fc
(
x
=
data
,
size
=
10
)
loss
=
paddle
.
mean
(
hidden
)
optimizer
.
minimize
(
loss
)
exe
.
run
(
startup_program
)
if
use_amp
:
optimizer
.
amp_init
(
place
=
'gpu'
,
scope
=
paddle
.
static
.
global_scope
())
x
=
np
.
random
.
random
(
size
=
(
2
,
2
)).
astype
(
'float16'
)
else
:
x
=
np
.
random
.
random
(
size
=
(
2
,
2
)).
astype
(
'float32'
)
out
=
[]
for
idx
in
range
(
5
):
(
loss_data
,)
=
exe
.
run
(
train_program
,
feed
=
{
"X"
:
x
},
fetch_list
=
[
loss
.
name
]
)
out
.
append
(
loss_data
)
return
out
def
test_main
(
self
):
if
not
paddle
.
is_compiled_with_cuda
():
return
"Test dygraph mode"
output1_dy
,
params1_dy
=
self
.
dygraph_rmsprop_mp
(
use_amp
=
True
,
mp
=
True
)
output2_dy
,
params2_dy
=
self
.
dygraph_rmsprop_mp
(
use_amp
=
False
,
mp
=
False
)
np
.
testing
.
assert_allclose
(
output1_dy
.
astype
(
'float32'
).
numpy
(),
output2_dy
.
astype
(
'float32'
).
numpy
(),
rtol
=
1e-05
,
atol
=
0.1
,
)
for
idx
in
range
(
len
(
params1_dy
)):
np
.
testing
.
assert_allclose
(
params1_dy
[
idx
].
astype
(
'float32'
).
numpy
(),
params2_dy
[
idx
].
astype
(
'float32'
).
numpy
(),
rtol
=
1e-05
,
atol
=
0.1
,
)
"Test static mode"
output1_st
=
self
.
static_rmsprop_mp
(
use_amp
=
True
,
mp
=
True
)
output2_st
=
self
.
static_rmsprop_mp
(
use_amp
=
False
,
mp
=
False
)
for
idx
in
range
(
len
(
output1_st
)):
np
.
testing
.
assert_allclose
(
output1_st
[
idx
].
astype
(
'float32'
),
output2_st
[
idx
].
astype
(
'float32'
),
rtol
=
1e-05
,
atol
=
0.1
,
)
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
unittest
.
main
()
python/paddle/optimizer/rmsprop.py
浏览文件 @
48060b2e
...
...
@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
import
paddle
from
paddle
import
_C_ops
from
..fluid
import
framework
from
..fluid
import
core
,
framework
,
unique_name
from
..fluid.framework
import
in_dygraph_mode
from
..fluid.layer_helper
import
LayerHelper
from
.optimizer
import
Optimizer
__all__
=
[]
...
...
@@ -184,6 +188,8 @@ class RMSProp(Optimizer):
self
.
_epsilon
=
epsilon
self
.
_momentum
=
momentum
self
.
_centered
=
centered
self
.
_multi_precision
=
False
self
.
_master_weights
=
{}
self
.
_default_dict
=
{
'rho'
:
rho
,
'epsilon'
:
epsilon
,
...
...
@@ -191,6 +197,62 @@ class RMSProp(Optimizer):
'centered'
:
centered
,
}
def
_create_master_weight
(
self
,
param
):
if
param
.
name
in
self
.
_master_weights
:
var
=
self
.
_master_weights
[
param
.
name
]
else
:
assert
isinstance
(
self
.
helper
,
LayerHelper
)
var_name
=
param
.
name
+
"_fp32_master"
var_name
=
unique_name
.
generate
(
var_name
)
var
=
paddle
.
static
.
create_global_var
(
name
=
var_name
,
shape
=
param
.
shape
,
value
=
0
,
dtype
=
'float32'
,
persistable
=
True
,
)
block
=
self
.
helper
.
startup_program
.
global_block
()
block
.
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
[
param
]},
outputs
=
{
"Out"
:
[
var
]},
attrs
=
{
"in_dtype"
:
param
.
dtype
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
,
},
)
self
.
_master_weights
[
param
.
name
]
=
var
return
var
def
_get_accumulator
(
self
,
name
,
param
):
"""Utility function to fetch an accumulator for a parameter
Args:
name: name of the accumulator
param: parameter variable for which accumulator is to be fetched
Returns:
accumulator variable for the parameter
"""
if
self
.
_name
is
not
None
:
name
=
self
.
_name
+
"_"
+
name
find_master
=
(
self
.
_multi_precision
and
param
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
)
target_param
=
(
self
.
_master_weights
[
param
.
name
]
if
find_master
else
param
)
target_name
=
target_param
.
name
if
(
name
not
in
self
.
_accumulators
or
target_name
not
in
self
.
_accumulators
[
name
]
):
raise
Exception
(
"Accumulator {} does not exist for parameter {}"
.
format
(
name
,
target_name
)
)
return
self
.
_accumulators
[
name
][
target_name
]
def
_create_accumulators
(
self
,
block
,
parameters
):
if
not
isinstance
(
block
,
framework
.
Block
):
raise
TypeError
(
"block is not instance of framework.Block."
)
...
...
@@ -199,6 +261,20 @@ class RMSProp(Optimizer):
parameters
=
parameters
.
get
(
'params'
)
for
p
in
parameters
:
if
self
.
_multi_precision
and
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
master_p
=
self
.
_create_master_weight
(
p
)
self
.
_add_accumulator
(
self
.
_momentum_acc_str
,
master_p
)
self
.
_add_accumulator
(
self
.
_mean_square_acc_str
,
master_p
)
self
.
_add_accumulator
(
self
.
_mean_grad_acc_str
,
master_p
)
continue
if
(
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
and
not
self
.
_multi_precision
):
warnings
.
warn
(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Lars optimizer."
)
self
.
_add_accumulator
(
self
.
_momentum_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_mean_square_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_mean_grad_acc_str
,
p
)
...
...
@@ -219,6 +295,15 @@ class RMSProp(Optimizer):
mean_grad_acc
=
self
.
_get_accumulator
(
self
.
_mean_grad_acc_str
,
param_and_grad
[
0
]
)
find_master
=
(
self
.
_multi_precision
and
param_and_grad
[
0
].
dtype
==
core
.
VarDesc
.
VarType
.
FP16
)
master_weight
=
(
self
.
_master_weights
[
param_and_grad
[
0
].
name
]
if
find_master
else
None
)
if
in_dygraph_mode
():
_C_ops
.
rmsprop_
(
...
...
@@ -228,29 +313,38 @@ class RMSProp(Optimizer):
momentum_acc
,
self
.
_create_param_lr
(
param_and_grad
),
mean_grad_acc
,
master_weight
,
self
.
_epsilon
,
self
.
_rho
,
self
.
_momentum
,
self
.
_centered
,
find_master
,
)
return
None
else
:
inputs
=
{
"Param"
:
param_and_grad
[
0
],
"Grad"
:
param_and_grad
[
1
],
"Moment"
:
momentum_acc
,
"MeanSquare"
:
mean_square_acc
,
"MeanGrad"
:
mean_grad_acc
,
"LearningRate"
:
self
.
_create_param_lr
(
param_and_grad
),
}
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
],
"MomentOut"
:
momentum_acc
,
"MeanSquareOut"
:
mean_square_acc
,
"MeanGradOut"
:
mean_grad_acc
,
}
if
find_master
:
inputs
[
"MasterParam"
]
=
master_weight
outputs
[
"MasterParamOut"
]
=
master_weight
rmsprop_op
=
block
.
append_op
(
type
=
self
.
type
,
inputs
=
{
"Param"
:
param_and_grad
[
0
],
"Grad"
:
param_and_grad
[
1
],
"Moment"
:
momentum_acc
,
"MeanSquare"
:
mean_square_acc
,
"MeanGrad"
:
mean_grad_acc
,
"LearningRate"
:
self
.
_create_param_lr
(
param_and_grad
),
},
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
],
"MomentOut"
:
momentum_acc
,
"MeanSquareOut"
:
mean_square_acc
,
"MeanGradOut"
:
mean_grad_acc
,
},
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
{
"epsilon"
:
self
.
_epsilon
,
"decay"
:
self
.
_rho
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录