Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
683f152a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
683f152a
编写于
4月 29, 2022
作者:
A
Aurelius84
提交者:
GitHub
4月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[OP]Fix adamw not registered into AllKernels (#42391)
上级
e66d91b3
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
209 addition
and
165 deletion
+209
-165
paddle/fluid/operators/optimizers/adam_op.cc
paddle/fluid/operators/optimizers/adam_op.cc
+2
-165
paddle/fluid/operators/optimizers/adam_op.h
paddle/fluid/operators/optimizers/adam_op.h
+149
-0
paddle/fluid/operators/optimizers/adamw_op.cc
paddle/fluid/operators/optimizers/adamw_op.cc
+58
-0
未找到文件。
paddle/fluid/operators/optimizers/adam_op.cc
浏览文件 @
683f152a
...
@@ -12,168 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,168 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/
framework/op_version_registry
.h"
#include "paddle/fluid/
operators/optimizers/adam_op
.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_
version_
registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/multiary.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
AdamOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Param"
);
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
framework
::
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
{
if
(
var_name
==
"Beta1Pow"
||
var_name
==
"Beta2Pow"
||
var_name
==
"SkipUpdate"
)
{
return
expected_kernel_type
;
}
else
{
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
}
};
class
AdamOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Param"
,
"(Tensor) Input parameter"
);
AddInput
(
"Grad"
,
"(Tensor) Input gradient"
);
AddInput
(
"LearningRate"
,
"(Tensor) Learning rate"
);
AddInput
(
"Moment1"
,
"(Tensor) Input first moment"
);
AddInput
(
"Moment2"
,
"(Tensor) Input second moment"
);
AddInput
(
"Beta1Pow"
,
"(Tensor) Input beta1 power accumulator"
);
AddInput
(
"Beta2Pow"
,
"(Tensor) Input beta2 power accumulator"
);
AddInput
(
"Beta1Tensor"
,
"(Tensor<float32>, optional) If provided, Adam will use this "
"as beta1, this has a higher priority than attr(beta1), the "
"shape of this tensor MUST BE [1]."
)
.
AsDispensable
();
AddInput
(
"Beta2Tensor"
,
"(Tensor<float32>, optional) If provided, Adam will use this "
"as beta2, this has a higher priority than attr(beta2), the "
"shape of this tensor MUST BE [1]."
)
.
AsDispensable
();
AddInput
(
"EpsilonTensor"
,
"(Tensor<float32>, optional) If provided, Adam will use this "
"as epsilon, this has a higher priority than attr(epsilon), the "
"shape of this tensor MUST BE [1]."
)
.
AsDispensable
();
AddInput
(
"MasterParam"
,
"FP32 master weight for AMP."
).
AsDispensable
();
AddInput
(
"SkipUpdate"
,
"(Tensor<bool>, optional), Skip the update or not."
)
.
AsDispensable
();
AddOutput
(
"ParamOut"
,
"(Tensor) Output parameter"
);
AddOutput
(
"Moment1Out"
,
"(Tensor) Output first moment"
);
AddOutput
(
"Moment2Out"
,
"(Tensor) Output second moment"
);
AddOutput
(
"Beta1PowOut"
,
"(Tensor) Output beta1 power accumulator"
);
AddOutput
(
"Beta2PowOut"
,
"(Tensor) Output beta2 power accumulator"
);
AddOutput
(
"MasterParamOut"
,
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam)."
)
.
AsDispensable
();
AddAttr
<
float
>
(
"beta1"
,
"(float, default 0.9) "
"Exponential decay rate for the "
"first moment estimates."
)
.
SetDefault
(
0.9
f
);
AddAttr
<
float
>
(
"beta2"
,
"(float, default 0.999) "
"exponential decay rate for the "
"second moment estimates."
)
.
SetDefault
(
0.999
f
);
AddAttr
<
float
>
(
"epsilon"
,
"(float, default 1.0e-8) "
"Constant for numerical stability"
)
.
SetDefault
(
1.0e-8
f
);
AddAttr
<
bool
>
(
"lazy_mode"
,
"(bool, default false) "
"only update the parameter that has gradient in sparse update"
)
.
SetDefault
(
false
);
AddAttr
<
int64_t
>
(
"min_row_size_to_use_multithread"
,
"(int64_t, default 0) "
"when not zero, if param row size is larger then "
"min_row_size_to_use_multithread and "
"inner_op_parallelism is larger then 0, sparse update "
"will run in multithread mode"
)
.
SetDefault
(
1000
);
AddAttr
<
bool
>
(
"multi_precision"
,
"(bool, default false) "
"Whether to use multi-precision during weight updating."
)
.
SetDefault
(
false
);
// TODO(zhiqiu): We could set Beta1PowOut and Beta2PowOut
// as dispensable since they are not used when use_global_beta_pow is true.
AddAttr
<
bool
>
(
"use_global_beta_pow"
,
"(bool, default false) "
"Whether to use global beta_pow for whole model instead of "
"creating beta_pow for each parameter."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
Adam Optimizer.
This implements the Adam optimizer from Section 2 of the Adam
paper : https://arxiv.org/abs/1412.6980.
Adam is a first-order gradient-based optimization method based on
adaptive estimates of lower-order moments.
Adam updates:
$$
moment\_1\_out = \beta_1 * moment\_1 + (1 - \beta_1) * grad \\
moment\_2_\out = \beta_2 * moment\_2 + (1 - \beta_2) * grad * grad \\
learning\_rate = learning\_rate *
\frac{\sqrt{1 - \beta_{2\_pow}}}{1 - \beta_{1\_pow}} \\
param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon}
$$
)DOC"
);
}
};
class
AdamWOp
:
public
AdamOp
{
using
AdamOp
::
AdamOp
;
};
class
AdamWOpMaker
:
public
AdamOpMaker
{
public:
void
Make
()
{
AdamOpMaker
::
Make
();
AddAttr
<
float
>
(
"lr_ratio"
,
"(float, default 1.0) "
"layerwise learning rate decay"
)
.
SetDefault
(
1.0
f
);
AddAttr
<
float
>
(
"coeff"
,
"(float, default 0.01) "
"coeff of the weight decay"
)
.
SetDefault
(
0.01
f
);
AddAttr
<
bool
>
(
"with_decay"
,
"(bool, default false) "
"whether to do weight decay"
)
.
SetDefault
(
false
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
DECLARE_INFER_SHAPE_FUNCTOR
(
adam
,
AdamInferMetaFunctor
,
DECLARE_INFER_SHAPE_FUNCTOR
(
adam
,
AdamInferMetaFunctor
,
...
@@ -185,14 +30,6 @@ REGISTER_OPERATOR(
...
@@ -185,14 +30,6 @@ REGISTER_OPERATOR(
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
AdamInferMetaFunctor
);
AdamInferMetaFunctor
);
DECLARE_INFER_SHAPE_FUNCTOR
(
adamw
,
AdamwInferMetaFunctor
,
PD_INFER_META
(
phi
::
AdamwInferMeta
));
REGISTER_OPERATOR
(
adamw
,
ops
::
AdamWOp
,
ops
::
AdamWOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
AdamwInferMetaFunctor
);
REGISTER_OP_VERSION
(
adam
)
REGISTER_OP_VERSION
(
adam
)
.
AddCheckpoint
(
.
AddCheckpoint
(
R"ROC(
R"ROC(
...
...
paddle/fluid/operators/optimizers/adam_op.h
0 → 100644
浏览文件 @
683f152a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
AdamOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Param"
);
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
framework
::
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
{
if
(
var_name
==
"Beta1Pow"
||
var_name
==
"Beta2Pow"
||
var_name
==
"SkipUpdate"
)
{
return
expected_kernel_type
;
}
else
{
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
}
};
class
AdamOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Param"
,
"(Tensor) Input parameter"
);
AddInput
(
"Grad"
,
"(Tensor) Input gradient"
);
AddInput
(
"LearningRate"
,
"(Tensor) Learning rate"
);
AddInput
(
"Moment1"
,
"(Tensor) Input first moment"
);
AddInput
(
"Moment2"
,
"(Tensor) Input second moment"
);
AddInput
(
"Beta1Pow"
,
"(Tensor) Input beta1 power accumulator"
);
AddInput
(
"Beta2Pow"
,
"(Tensor) Input beta2 power accumulator"
);
AddInput
(
"Beta1Tensor"
,
"(Tensor<float32>, optional) If provided, Adam will use this "
"as beta1, this has a higher priority than attr(beta1), the "
"shape of this tensor MUST BE [1]."
)
.
AsDispensable
();
AddInput
(
"Beta2Tensor"
,
"(Tensor<float32>, optional) If provided, Adam will use this "
"as beta2, this has a higher priority than attr(beta2), the "
"shape of this tensor MUST BE [1]."
)
.
AsDispensable
();
AddInput
(
"EpsilonTensor"
,
"(Tensor<float32>, optional) If provided, Adam will use this "
"as epsilon, this has a higher priority than attr(epsilon), the "
"shape of this tensor MUST BE [1]."
)
.
AsDispensable
();
AddInput
(
"MasterParam"
,
"FP32 master weight for AMP."
).
AsDispensable
();
AddInput
(
"SkipUpdate"
,
"(Tensor<bool>, optional), Skip the update or not."
)
.
AsDispensable
();
AddOutput
(
"ParamOut"
,
"(Tensor) Output parameter"
);
AddOutput
(
"Moment1Out"
,
"(Tensor) Output first moment"
);
AddOutput
(
"Moment2Out"
,
"(Tensor) Output second moment"
);
AddOutput
(
"Beta1PowOut"
,
"(Tensor) Output beta1 power accumulator"
);
AddOutput
(
"Beta2PowOut"
,
"(Tensor) Output beta2 power accumulator"
);
AddOutput
(
"MasterParamOut"
,
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam)."
)
.
AsDispensable
();
AddAttr
<
float
>
(
"beta1"
,
"(float, default 0.9) "
"Exponential decay rate for the "
"first moment estimates."
)
.
SetDefault
(
0.9
f
);
AddAttr
<
float
>
(
"beta2"
,
"(float, default 0.999) "
"exponential decay rate for the "
"second moment estimates."
)
.
SetDefault
(
0.999
f
);
AddAttr
<
float
>
(
"epsilon"
,
"(float, default 1.0e-8) "
"Constant for numerical stability"
)
.
SetDefault
(
1.0e-8
f
);
AddAttr
<
bool
>
(
"lazy_mode"
,
"(bool, default false) "
"only update the parameter that has gradient in sparse update"
)
.
SetDefault
(
false
);
AddAttr
<
int64_t
>
(
"min_row_size_to_use_multithread"
,
"(int64_t, default 0) "
"when not zero, if param row size is larger then "
"min_row_size_to_use_multithread and "
"inner_op_parallelism is larger then 0, sparse update "
"will run in multithread mode"
)
.
SetDefault
(
1000
);
AddAttr
<
bool
>
(
"multi_precision"
,
"(bool, default false) "
"Whether to use multi-precision during weight updating."
)
.
SetDefault
(
false
);
// TODO(zhiqiu): We could set Beta1PowOut and Beta2PowOut
// as dispensable since they are not used when use_global_beta_pow is true.
AddAttr
<
bool
>
(
"use_global_beta_pow"
,
"(bool, default false) "
"Whether to use global beta_pow for whole model instead of "
"creating beta_pow for each parameter."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
Adam Optimizer.
This implements the Adam optimizer from Section 2 of the Adam
paper : https://arxiv.org/abs/1412.6980.
Adam is a first-order gradient-based optimization method based on
adaptive estimates of lower-order moments.
Adam updates:
$$
moment\_1\_out = \beta_1 * moment\_1 + (1 - \beta_1) * grad \\
moment\_2_\out = \beta_2 * moment\_2 + (1 - \beta_2) * grad * grad \\
learning\_rate = learning\_rate *
\frac{\sqrt{1 - \beta_{2\_pow}}}{1 - \beta_{1\_pow}} \\
param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon}
$$
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/optimizers/adamw_op.cc
0 → 100644
浏览文件 @
683f152a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/adam_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace
paddle
{
namespace
operators
{
class
AdamWOp
:
public
AdamOp
{
using
AdamOp
::
AdamOp
;
};
class
AdamWOpMaker
:
public
AdamOpMaker
{
public:
void
Make
()
{
AdamOpMaker
::
Make
();
AddAttr
<
float
>
(
"lr_ratio"
,
"(float, default 1.0) "
"layerwise learning rate decay"
)
.
SetDefault
(
1.0
f
);
AddAttr
<
float
>
(
"coeff"
,
"(float, default 0.01) "
"coeff of the weight decay"
)
.
SetDefault
(
0.01
f
);
AddAttr
<
bool
>
(
"with_decay"
,
"(bool, default false) "
"whether to do weight decay"
)
.
SetDefault
(
false
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
DECLARE_INFER_SHAPE_FUNCTOR
(
adamw
,
AdamwInferMetaFunctor
,
PD_INFER_META
(
phi
::
AdamwInferMeta
));
REGISTER_OPERATOR
(
adamw
,
ops
::
AdamWOp
,
ops
::
AdamWOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
AdamwInferMetaFunctor
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录