Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
151ec311
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
151ec311
编写于
3月 08, 2023
作者:
N
niuliling123
提交者:
GitHub
3月 08, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add mult_precision param for adamax op (#49705)
上级
8f398dd8
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
717 addition
and
58 deletion
+717
-58
paddle/fluid/operators/optimizers/adamax_op.cc
paddle/fluid/operators/optimizers/adamax_op.cc
+9
-1
paddle/fluid/pybind/eager_generator.h
paddle/fluid/pybind/eager_generator.h
+11
-1
paddle/phi/api/yaml/legacy_ops.yaml
paddle/phi/api/yaml/legacy_ops.yaml
+5
-3
paddle/phi/infermeta/multiary.cc
paddle/phi/infermeta/multiary.cc
+4
-1
paddle/phi/infermeta/multiary.h
paddle/phi/infermeta/multiary.h
+4
-1
paddle/phi/kernels/adamax_kernel.h
paddle/phi/kernels/adamax_kernel.h
+4
-1
paddle/phi/kernels/gpu/adamax_kernel.cu
paddle/phi/kernels/gpu/adamax_kernel.cu
+116
-2
paddle/phi/kernels/impl/adamax_kernel_impl.h
paddle/phi/kernels/impl/adamax_kernel_impl.h
+6
-3
paddle/phi/ops/compat/adamax_sig.cc
paddle/phi/ops/compat/adamax_sig.cc
+49
-0
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+115
-18
python/paddle/fluid/tests/unittests/test_adamax_op.py
python/paddle/fluid/tests/unittests/test_adamax_op.py
+265
-0
python/paddle/optimizer/adamax.py
python/paddle/optimizer/adamax.py
+129
-27
未找到文件。
paddle/fluid/operators/optimizers/adamax_op.cc
浏览文件 @
151ec311
...
@@ -42,12 +42,16 @@ class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -42,12 +42,16 @@ class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor) "
"(Tensor) "
"Input exponentially weighted infinity norm"
);
"Input exponentially weighted infinity norm"
);
AddInput
(
"Beta1Pow"
,
"(Tensor) Input beta1 power accumulator"
);
AddInput
(
"Beta1Pow"
,
"(Tensor) Input beta1 power accumulator"
);
AddInput
(
"MasterParam"
,
"FP32 master weight for AMP."
).
AsDispensable
();
AddOutput
(
"ParamOut"
,
"(Tensor) Output parameter"
);
AddOutput
(
"ParamOut"
,
"(Tensor) Output parameter"
);
AddOutput
(
"MomentOut"
,
"(Tensor) Output first moment"
);
AddOutput
(
"MomentOut"
,
"(Tensor) Output first moment"
);
AddOutput
(
"InfNormOut"
,
AddOutput
(
"InfNormOut"
,
"(Tensor) "
"(Tensor) "
"Output exponentially weighted infinity norm"
);
"Output exponentially weighted infinity norm"
);
AddOutput
(
"MasterParamOut"
,
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam)."
)
.
AsDispensable
();
AddAttr
<
float
>
(
"beta1"
,
AddAttr
<
float
>
(
"beta1"
,
"(float, default 0.9) "
"(float, default 0.9) "
...
@@ -63,6 +67,10 @@ class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -63,6 +67,10 @@ class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
"(float, default 1.0e-8) "
"(float, default 1.0e-8) "
"Constant for numerical stability"
)
"Constant for numerical stability"
)
.
SetDefault
(
1.0e-8
f
);
.
SetDefault
(
1.0e-8
f
);
AddAttr
<
bool
>
(
"multi_precision"
,
"(bool, default false) "
"Whether to use multi-precision during weight updating."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Adamax Optimizer.
Adamax Optimizer.
...
...
paddle/fluid/pybind/eager_generator.h
浏览文件 @
151ec311
...
@@ -193,6 +193,14 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
...
@@ -193,6 +193,14 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"Beta1Pow"
,
"Beta1Pow"
,
"Beta2Pow"
,
"Beta2Pow"
,
"MasterParam"
}},
"MasterParam"
}},
{
"adamax"
,
{
"Param"
,
"Grad"
,
"LearningRate"
,
"Moment"
,
"InfNorm"
,
"Beta1Pow"
,
"MasterParam"
}},
{
"lamb"
,
{
"lamb"
,
{
"Param"
,
{
"Param"
,
"Grad"
,
"Grad"
,
...
@@ -368,6 +376,8 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
...
@@ -368,6 +376,8 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
"Beta1PowOut"
,
"Beta1PowOut"
,
"Beta2PowOut"
,
"Beta2PowOut"
,
"MasterParamOut"
}},
"MasterParamOut"
}},
{
"adamax"
,
{
"ParamOut"
,
"MomentOut"
,
"InfNormOut"
,
"Beta1Pow"
,
"MasterParamOut"
}},
{
"sgd"
,
{
"ParamOut"
,
"MasterParamOut"
}},
{
"sgd"
,
{
"ParamOut"
,
"MasterParamOut"
}},
{
"adagrad"
,
{
"ParamOut"
,
"MomentOut"
,
"MasterParamOut"
}},
{
"adagrad"
,
{
"ParamOut"
,
"MomentOut"
,
"MasterParamOut"
}},
{
"lamb"
,
{
"lamb"
,
...
@@ -413,7 +423,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
...
@@ -413,7 +423,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
"AvgSquaredUpdateOut"
,
"AvgSquaredUpdateOut"
,
"MasterParamOut"
}},
"MasterParamOut"
}},
{
"adagrad"
,
{
"ParamOut"
,
"MomentOut"
,
"MasterParamOut"
}},
{
"adagrad"
,
{
"ParamOut"
,
"MomentOut"
,
"MasterParamOut"
}},
{
"adamax"
,
{
"ParamOut"
,
"MomentOut"
,
"InfNormOut"
}},
{
"adamax"
,
{
"ParamOut"
,
"MomentOut"
,
"InfNormOut"
,
"MasterParamOut"
}},
{
"dpsgd"
,
{
"ParamOut"
}},
{
"dpsgd"
,
{
"ParamOut"
}},
{
"decayed_adagrad"
,
{
"ParamOut"
,
"MomentOut"
}},
{
"decayed_adagrad"
,
{
"ParamOut"
,
"MomentOut"
}},
{
"lars_momentum"
,
{
"ParamOut"
,
"VelocityOut"
}},
{
"lars_momentum"
,
{
"ParamOut"
,
"VelocityOut"
}},
...
...
paddle/phi/api/yaml/legacy_ops.yaml
浏览文件 @
151ec311
...
@@ -55,13 +55,15 @@
...
@@ -55,13 +55,15 @@
inplace
:
(param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_outs)
inplace
:
(param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_outs)
-
op
:
adamax_
-
op
:
adamax_
args
:
(Tensor param, Tensor grad, Tensor learning_rate, Tensor moment, Tensor inf_norm, Tensor beta1_pow,
float beta1, float beta2, float epsil
on)
args
:
(Tensor param, Tensor grad, Tensor learning_rate, Tensor moment, Tensor inf_norm, Tensor beta1_pow,
Tensor master_param, float beta1, float beta2, float epsilon, bool multi_precisi
on)
output
:
Tensor(param_out), Tensor(avg_squared_grad_out), Tensor(avg_squared_update_out)
output
:
Tensor(param_out), Tensor(avg_squared_grad_out), Tensor(avg_squared_update_out)
, Tensor(master_param_outs)
infer_meta
:
infer_meta
:
func
:
AdamaxInferMeta
func
:
AdamaxInferMeta
kernel
:
kernel
:
func
:
adamax
func
:
adamax
inplace
:
(param -> param_out), (moment -> avg_squared_grad_out), (inf_norm -> avg_squared_update_out)
data_type
:
param
optional
:
master_param
inplace
:
(param -> param_out), (moment -> avg_squared_grad_out), (inf_norm -> avg_squared_update_out), (master_param ->master_param_outs)
-
op
:
adamw_
-
op
:
adamw_
args
:
(Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, float lr_ratio, float coeff, bool with_decay, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow)
args
:
(Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, float lr_ratio, float coeff, bool with_decay, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow)
...
...
paddle/phi/infermeta/multiary.cc
浏览文件 @
151ec311
...
@@ -190,12 +190,15 @@ void AdamaxInferMeta(const MetaTensor& param,
...
@@ -190,12 +190,15 @@ void AdamaxInferMeta(const MetaTensor& param,
const
MetaTensor
&
moment
,
const
MetaTensor
&
moment
,
const
MetaTensor
&
inf_norm
,
const
MetaTensor
&
inf_norm
,
const
MetaTensor
&
beta1_pow
,
const
MetaTensor
&
beta1_pow
,
const
MetaTensor
&
master_param
,
float
beta1
,
float
beta1
,
float
beta2
,
float
beta2
,
float
epsilon
,
float
epsilon
,
bool
multi_precision
,
MetaTensor
*
param_out
,
MetaTensor
*
param_out
,
MetaTensor
*
moment_out
,
MetaTensor
*
moment_out
,
MetaTensor
*
inf_norm_out
)
{
MetaTensor
*
inf_norm_out
,
MetaTensor
*
master_param_outs
)
{
auto
lr_dims
=
learning_rate
.
dims
();
auto
lr_dims
=
learning_rate
.
dims
();
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
product
(
lr_dims
),
product
(
lr_dims
),
...
...
paddle/phi/infermeta/multiary.h
浏览文件 @
151ec311
...
@@ -69,12 +69,15 @@ void AdamaxInferMeta(const MetaTensor& param,
...
@@ -69,12 +69,15 @@ void AdamaxInferMeta(const MetaTensor& param,
const
MetaTensor
&
moment
,
const
MetaTensor
&
moment
,
const
MetaTensor
&
inf_norm
,
const
MetaTensor
&
inf_norm
,
const
MetaTensor
&
beta1_pow
,
const
MetaTensor
&
beta1_pow
,
const
MetaTensor
&
master_param
,
float
beta1
,
float
beta1
,
float
beta2
,
float
beta2
,
float
epsilon
,
float
epsilon
,
bool
multi_precision
,
MetaTensor
*
param_out
,
MetaTensor
*
param_out
,
MetaTensor
*
moment_out
,
MetaTensor
*
moment_out
,
MetaTensor
*
inf_norm_out
);
MetaTensor
*
inf_norm_out
,
MetaTensor
*
master_param_outs
);
void
AdamInferMeta
(
const
MetaTensor
&
param
,
void
AdamInferMeta
(
const
MetaTensor
&
param
,
const
MetaTensor
&
grad
,
const
MetaTensor
&
grad
,
...
...
paddle/phi/kernels/adamax_kernel.h
浏览文件 @
151ec311
...
@@ -26,11 +26,14 @@ void AdamaxKernel(const Context& dev_ctx,
...
@@ -26,11 +26,14 @@ void AdamaxKernel(const Context& dev_ctx,
const
DenseTensor
&
moment
,
const
DenseTensor
&
moment
,
const
DenseTensor
&
inf_norm
,
const
DenseTensor
&
inf_norm
,
const
DenseTensor
&
beta1_pow
,
const
DenseTensor
&
beta1_pow
,
const
paddle
::
optional
<
DenseTensor
>&
master_param
,
float
beta1
,
float
beta1
,
float
beta2
,
float
beta2
,
float
epsilon
,
float
epsilon
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
param_out
,
DenseTensor
*
moment_out
,
DenseTensor
*
moment_out
,
DenseTensor
*
inf_norm_out
);
DenseTensor
*
inf_norm_out
,
DenseTensor
*
master_param_outs
);
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/gpu/adamax_kernel.cu
浏览文件 @
151ec311
...
@@ -15,7 +15,121 @@
...
@@ -15,7 +15,121 @@
#include "paddle/phi/kernels/adamax_kernel.h"
#include "paddle/phi/kernels/adamax_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/adamax_kernel_impl.h"
namespace
phi
{
template
<
typename
T
,
typename
MT
>
__global__
void
AdamaxGPUKernel
(
const
T
*
param
,
const
T
*
grad
,
const
MT
*
learning_rate
,
const
MT
*
moment
,
const
MT
*
inf_norm
,
const
MT
*
beta1_pow
,
const
MT
*
master_param
,
MT
d_beta1
,
MT
d_beta2
,
MT
d_epsilon
,
int
num
,
T
*
param_out
,
MT
*
moment_out
,
MT
*
inf_norm_out
,
MT
*
master_param_out
)
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
PD_REGISTER_KERNEL
(
adamax
,
GPU
,
ALL_LAYOUT
,
phi
::
AdamaxKernel
,
float
,
double
)
{}
MT
lr
=
static_cast
<
MT
>
(
learning_rate
[
0
]);
MT
d_pow
=
static_cast
<
MT
>
(
beta1_pow
[
0
]);
MT
one
=
static_cast
<
MT
>
(
1.0
f
);
auto
l_r
=
lr
/
(
one
-
d_pow
);
for
(
int
index
=
idx
;
index
<
num
;
index
+=
gridDim
.
x
*
blockDim
.
x
)
{
// load and cast input to MT
MT
d_param
=
master_param
?
master_param
[
index
]
:
static_cast
<
MT
>
(
param
[
index
]);
MT
d_grad
=
static_cast
<
MT
>
(
grad
[
index
]);
MT
d_moment
=
static_cast
<
MT
>
(
moment
[
index
]);
MT
d_inf
=
static_cast
<
MT
>
(
inf_norm
[
index
]);
// compute
auto
mom_out
=
d_beta1
*
d_moment
+
(
one
-
d_beta1
)
*
d_grad
;
auto
norm_out
=
std
::
max
(
std
::
abs
(
d_grad
),
d_beta2
*
d_inf
+
d_epsilon
);
auto
out_data
=
d_param
-
l_r
*
(
mom_out
/
norm_out
);
// store
param_out
[
index
]
=
static_cast
<
T
>
(
out_data
);
moment_out
[
index
]
=
static_cast
<
T
>
(
mom_out
);
inf_norm_out
[
index
]
=
static_cast
<
T
>
(
norm_out
);
if
(
master_param_out
)
{
master_param_out
[
index
]
=
out_data
;
}
}
}
template
<
typename
T
,
typename
Context
>
void
AdamaxKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
param
,
const
DenseTensor
&
grad
,
const
DenseTensor
&
learning_rate
,
const
DenseTensor
&
moment
,
const
DenseTensor
&
inf_norm
,
const
DenseTensor
&
beta1_pow
,
const
paddle
::
optional
<
DenseTensor
>&
master_param
,
float
beta1
,
float
beta2
,
float
epsilon
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
moment_out
,
DenseTensor
*
inf_norm_out
,
DenseTensor
*
master_param_outs
)
{
using
MPDType
=
typename
phi
::
dtype
::
template
MPTypeTrait
<
T
>
::
Type
;
T
*
param_out_data
=
dev_ctx
.
template
Alloc
<
T
>(
param_out
);
MPDType
*
moment_out_data
=
dev_ctx
.
template
Alloc
<
MPDType
>(
moment_out
);
MPDType
*
inf_norm_out_data
=
dev_ctx
.
template
Alloc
<
MPDType
>(
inf_norm_out
);
const
MPDType
*
master_in_data
=
multi_precision
?
master_param
->
data
<
MPDType
>
()
:
nullptr
;
MPDType
*
master_out_data
=
multi_precision
?
dev_ctx
.
template
Alloc
<
MPDType
>(
master_param_outs
)
:
nullptr
;
PADDLE_ENFORCE_EQ
(
beta1_pow
.
numel
(),
1
,
errors
::
InvalidArgument
(
"beta1 pow's size should be 1, but received "
"value is:%d."
,
beta1_pow
.
numel
()));
MPDType
beta1_
=
static_cast
<
MPDType
>
(
beta1
);
MPDType
beta2_
=
static_cast
<
MPDType
>
(
beta2
);
MPDType
epsilon_
=
static_cast
<
MPDType
>
(
epsilon
);
int
numel
=
param
.
numel
();
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
numel
,
1
);
int
grid
=
config
.
block_per_grid
.
x
;
int
block
=
config
.
thread_per_block
.
x
;
auto
stream
=
dev_ctx
.
stream
();
AdamaxGPUKernel
<
T
,
MPDType
>
<<<
block
,
grid
,
0
,
stream
>>>
(
param
.
data
<
T
>
(),
grad
.
data
<
T
>
(),
learning_rate
.
data
<
MPDType
>
(),
moment
.
data
<
MPDType
>
(),
inf_norm
.
data
<
MPDType
>
(),
beta1_pow
.
data
<
MPDType
>
(),
master_in_data
,
beta1_
,
beta2_
,
epsilon_
,
numel
,
param_out_data
,
moment_out_data
,
inf_norm_out_data
,
master_out_data
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
adamax
,
GPU
,
ALL_LAYOUT
,
phi
::
AdamaxKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/impl/adamax_kernel_impl.h
浏览文件 @
151ec311
...
@@ -28,12 +28,15 @@ void AdamaxKernel(const Context& dev_ctx,
...
@@ -28,12 +28,15 @@ void AdamaxKernel(const Context& dev_ctx,
const
DenseTensor
&
moment
,
const
DenseTensor
&
moment
,
const
DenseTensor
&
inf_norm
,
const
DenseTensor
&
inf_norm
,
const
DenseTensor
&
beta1_pow
,
const
DenseTensor
&
beta1_pow
,
const
paddle
::
optional
<
DenseTensor
>&
master_param
,
float
beta1
,
float
beta1
,
float
beta2
,
float
beta2
,
float
epsilon
,
float
epsilon
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
param_out
,
DenseTensor
*
moment_out
,
DenseTensor
*
moment_out
,
DenseTensor
*
inf_norm_out
)
{
DenseTensor
*
inf_norm_out
,
DenseTensor
*
master_param_outs
)
{
dev_ctx
.
template
Alloc
<
T
>(
param_out
);
dev_ctx
.
template
Alloc
<
T
>(
param_out
);
dev_ctx
.
template
Alloc
<
T
>(
moment_out
);
dev_ctx
.
template
Alloc
<
T
>(
moment_out
);
dev_ctx
.
template
Alloc
<
T
>(
inf_norm_out
);
dev_ctx
.
template
Alloc
<
T
>(
inf_norm_out
);
...
@@ -56,10 +59,10 @@ void AdamaxKernel(const Context& dev_ctx,
...
@@ -56,10 +59,10 @@ void AdamaxKernel(const Context& dev_ctx,
auto
&
place
=
*
dev_ctx
.
eigen_device
();
auto
&
place
=
*
dev_ctx
.
eigen_device
();
eigen_moment_out
.
device
(
place
)
=
eigen_moment_out
.
device
(
place
)
=
beta1_
*
eigen_moment
+
(
1
-
beta1_
)
*
eigen_grad
;
beta1_
*
eigen_moment
+
(
static_cast
<
T
>
(
1
)
-
beta1_
)
*
eigen_grad
;
eigen_inf_norm_out
.
device
(
place
)
=
eigen_inf_norm_out
.
device
(
place
)
=
eigen_grad
.
abs
().
cwiseMax
((
beta2_
*
eigen_inf_norm
)
+
epsilon_
);
eigen_grad
.
abs
().
cwiseMax
((
beta2_
*
eigen_inf_norm
)
+
epsilon_
);
auto
lr_t
=
eigen_lr
/
(
1
-
eigen_beta1_pow
);
auto
lr_t
=
eigen_lr
/
(
static_cast
<
T
>
(
1
)
-
eigen_beta1_pow
);
Eigen
::
DSizes
<
int
,
1
>
m_dsize
(
moment_out
->
numel
());
Eigen
::
DSizes
<
int
,
1
>
m_dsize
(
moment_out
->
numel
());
eigen_param_out
.
device
(
place
)
=
eigen_param_out
.
device
(
place
)
=
eigen_param
-
eigen_param
-
...
...
paddle/phi/ops/compat/adamax_sig.cc
0 → 100644
浏览文件 @
151ec311
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/small_vector.h"
namespace
phi
{
KernelSignature
AdamaxOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
paddle
::
small_vector
<
const
char
*>
in_names
=
{
"Param"
,
"Grad"
,
"LearningRate"
,
"Moment"
,
"InfNorm"
,
"Beta1Pow"
,
"MasterParam"
};
paddle
::
small_vector
<
const
char
*>
out_names
=
{
"ParamOut"
,
"MomentOut"
,
"InfNormOut"
,
"MasterParamOut"
};
paddle
::
small_vector
<
const
char
*>
attr_names
;
attr_names
.
emplace_back
(
"beta1"
);
attr_names
.
emplace_back
(
"beta2"
);
attr_names
.
emplace_back
(
"epsilon"
);
attr_names
.
emplace_back
(
"multi_precision"
);
if
(
ctx
.
IsDenseTensorInput
(
"Grad"
))
{
return
KernelSignature
(
"adamax"
,
std
::
move
(
in_names
),
std
::
move
(
attr_names
),
std
::
move
(
out_names
));
}
else
{
return
KernelSignature
(
"unregistered"
,
{},
{},
{});
}
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
adamax
,
phi
::
AdamaxOpArgumentMapping
);
python/paddle/fluid/optimizer.py
浏览文件 @
151ec311
...
@@ -2761,10 +2761,59 @@ class AdamaxOptimizer(Optimizer):
...
@@ -2761,10 +2761,59 @@ class AdamaxOptimizer(Optimizer):
self
.
_beta1
=
beta1
self
.
_beta1
=
beta1
self
.
_beta2
=
beta2
self
.
_beta2
=
beta2
self
.
_epsilon
=
epsilon
self
.
_epsilon
=
epsilon
self
.
_multi_precision
=
False
self
.
_master_weights
=
{}
def
_create_master_weight
(
self
,
param
):
if
param
.
name
in
self
.
_master_weights
:
var
=
self
.
_master_weights
[
param
.
name
]
else
:
assert
isinstance
(
self
.
helper
,
LayerHelper
)
var_name
=
param
.
name
+
'_fp32_master'
var_name
=
unique_name
.
generate
(
var_name
)
var
=
paddle
.
static
.
create_global_var
(
name
=
var_name
,
shape
=
param
.
shape
,
value
=
0
,
dtype
=
'float32'
,
persistable
=
True
,
)
block
=
self
.
helper
.
startup_program
.
global_block
()
block
.
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
[
param
]},
outputs
=
{
"Out"
:
[
var
]},
attrs
=
{
"in_dtype"
:
param
.
dtype
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
,
},
)
self
.
_master_weights
[
param
.
name
]
=
var
return
var
def
_create_accumulators
(
self
,
block
,
parameters
):
def
_create_accumulators
(
self
,
block
,
parameters
):
# Create accumulator tensors for first moment and infinity norm
# Create accumulator tensors for first moment and infinity norm
for
p
in
parameters
:
for
p
in
parameters
:
if
self
.
_multi_precision
and
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
master_p
=
self
.
_create_master_weight
(
p
)
self
.
_add_accumulator
(
self
.
_moment_acc_str
,
master_p
)
self
.
_add_accumulator
(
self
.
_inf_norm_acc_str
,
master_p
)
self
.
_add_accumulator
(
name
=
self
.
_beta1_pow_acc_str
,
param
=
master_p
,
fill_value
=
self
.
_beta1
,
shape
=
[
1
],
)
continue
if
(
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
and
not
self
.
_multi_precision
):
warnings
.
warn
(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Lars optimizer."
)
self
.
_add_accumulator
(
self
.
_moment_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_moment_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_inf_norm_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_inf_norm_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_add_accumulator
(
...
@@ -2774,6 +2823,34 @@ class AdamaxOptimizer(Optimizer):
...
@@ -2774,6 +2823,34 @@ class AdamaxOptimizer(Optimizer):
shape
=
[
1
],
shape
=
[
1
],
)
)
def
_get_accumulator
(
self
,
name
,
param
):
"""Utility function to fetch an accumulator for a parameter
Args:
name: name of the accumulator
param: parameter variable for which accumulator is to be fetched
Returns:
accumulator variable for the parameter
"""
if
self
.
_name
is
not
None
:
name
=
self
.
_name
+
"_"
+
name
find_master
=
(
self
.
_multi_precision
and
core
.
VarDesc
.
VarType
.
FP16
==
param
.
dtype
)
target_param
=
(
self
.
_master_weights
[
param
.
name
]
if
find_master
else
param
)
target_name
=
target_param
.
name
if
(
name
not
in
self
.
_accumulators
or
target_name
not
in
self
.
_accumulators
[
name
]
):
raise
Exception
(
"Accumulator {} does not exist for parameter {}"
.
format
(
name
,
target_name
)
)
return
self
.
_accumulators
[
name
][
target_name
]
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
assert
isinstance
(
block
,
framework
.
Block
)
assert
isinstance
(
block
,
framework
.
Block
)
...
@@ -2785,6 +2862,15 @@ class AdamaxOptimizer(Optimizer):
...
@@ -2785,6 +2862,15 @@ class AdamaxOptimizer(Optimizer):
self
.
_beta1_pow_acc_str
,
param_and_grad
[
0
]
self
.
_beta1_pow_acc_str
,
param_and_grad
[
0
]
)
)
find_master
=
(
self
.
_multi_precision
and
param_and_grad
[
0
].
dtype
==
core
.
VarDesc
.
VarType
.
FP16
)
master_weight
=
(
self
.
_master_weights
[
param_and_grad
[
0
].
name
]
if
find_master
else
None
)
if
in_dygraph_mode
():
if
in_dygraph_mode
():
_C_ops
.
adamax_
(
_C_ops
.
adamax_
(
param_and_grad
[
0
],
param_and_grad
[
0
],
...
@@ -2793,32 +2879,43 @@ class AdamaxOptimizer(Optimizer):
...
@@ -2793,32 +2879,43 @@ class AdamaxOptimizer(Optimizer):
moment
,
moment
,
inf_norm
,
inf_norm
,
beta1_pow_acc
,
beta1_pow_acc
,
master_weight
,
self
.
_beta1
,
self
.
_beta1
,
self
.
_beta2
,
self
.
_beta2
,
self
.
_epsilon
,
self
.
_epsilon
,
find_master
,
)
)
else
:
else
:
# create the adamax optimize op
# create the adamax optimize op
inputs
=
{
"Param"
:
param_and_grad
[
0
],
"Grad"
:
param_and_grad
[
1
],
"LearningRate"
:
self
.
_create_param_lr
(
param_and_grad
),
"Moment"
:
moment
,
"InfNorm"
:
inf_norm
,
"Beta1Pow"
:
beta1_pow_acc
,
}
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
],
"MomentOut"
:
moment
,
"InfNormOut"
:
inf_norm
,
}
if
find_master
:
inputs
[
"MasterParam"
]
=
master_weight
outputs
[
"MasterParamOut"
]
=
master_weight
attrs
=
{
"beta1"
:
self
.
_beta1
,
"beta2"
:
self
.
_beta2
,
"epsilon"
:
self
.
_epsilon
,
"multi_precision"
:
find_master
,
}
adamax_op
=
block
.
append_op
(
adamax_op
=
block
.
append_op
(
type
=
self
.
type
,
type
=
self
.
type
,
inputs
=
{
inputs
=
inputs
,
"Param"
:
param_and_grad
[
0
],
outputs
=
outputs
,
"Grad"
:
param_and_grad
[
1
],
attrs
=
attrs
,
"LearningRate"
:
self
.
_create_param_lr
(
param_and_grad
),
"Moment"
:
moment
,
"InfNorm"
:
inf_norm
,
"Beta1Pow"
:
beta1_pow_acc
,
},
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
],
"MomentOut"
:
moment
,
"InfNormOut"
:
inf_norm
,
},
attrs
=
{
"beta1"
:
self
.
_beta1
,
"beta2"
:
self
.
_beta2
,
"epsilon"
:
self
.
_epsilon
,
},
stop_gradient
=
True
,
stop_gradient
=
True
,
)
)
...
...
python/paddle/fluid/tests/unittests/test_adamax_op.py
浏览文件 @
151ec311
...
@@ -17,6 +17,8 @@ import unittest
...
@@ -17,6 +17,8 @@ import unittest
import
numpy
as
np
import
numpy
as
np
from
op_test
import
OpTest
from
op_test
import
OpTest
import
paddle
class
TestAdamaxOp1
(
OpTest
):
class
TestAdamaxOp1
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -203,5 +205,268 @@ class TestAdamaxOpV2(unittest.TestCase):
...
@@ -203,5 +205,268 @@ class TestAdamaxOpV2(unittest.TestCase):
)
)
class
TestAdamaxOpMultiPrecison
(
unittest
.
TestCase
):
def
_test_adamax_op_dygraph_place_amp
(
self
,
place
,
use_amp
=
False
):
import
paddle
paddle
.
disable_static
()
paddle
.
seed
(
10
)
paddle
.
set_device
(
place
)
input
=
paddle
.
randn
((
5
,
5
))
model
=
paddle
.
nn
.
Linear
(
5
,
5
)
optimizer
=
paddle
.
optimizer
.
Adamax
(
0.1
,
beta1
=
0.1
,
parameters
=
model
.
parameters
()
)
optimizer
.
_multi_precision
=
use_amp
for
idx
in
range
(
2
):
if
place
==
'gpu'
and
use_amp
:
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
1024
)
if
place
==
'gpu'
and
use_amp
:
with
paddle
.
amp
.
auto_cast
(
level
=
'O2'
):
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
scaled
=
scaler
.
scale
(
loss
)
scaled
.
backward
()
scaler
.
step
(
optimizer
)
optimizer
.
clear_grad
()
else
:
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
loss
.
backward
()
optimizer
.
step
()
optimizer
.
clear_grad
()
paddle
.
enable_static
()
def
_get_places
(
self
):
import
paddle
places
=
[
'cpu'
]
if
paddle
.
is_compiled_with_cuda
():
places
.
append
(
'gpu'
)
return
places
def
test_main
(
self
):
for
place
in
self
.
_get_places
():
use_amp_list
=
[
True
,
False
]
for
use_amp
in
use_amp_list
:
self
.
_test_adamax_op_dygraph_place_amp
(
place
,
use_amp
)
class
TestAdamaxMultiPrecision2_0
(
unittest
.
TestCase
):
def
dygraph_adamax_mp
(
self
,
mp
,
use_amp
):
paddle
.
disable_static
()
paddle
.
seed
(
100
)
paddle
.
set_device
(
'gpu'
)
input
=
paddle
.
randn
((
2
,
2
))
model
=
paddle
.
nn
.
Linear
(
2
,
2
)
optimizer
=
paddle
.
optimizer
.
Adamax
(
0.5
,
parameters
=
model
.
parameters
())
optimizer
.
_multi_precision
=
mp
if
use_amp
:
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
1024
)
for
idx
in
range
(
5
):
if
use_amp
:
with
paddle
.
amp
.
auto_cast
(
level
=
'O2'
):
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
scaled
=
scaler
.
scale
(
loss
)
scaled
.
backward
()
scaler
.
minimize
(
optimizer
,
scaled
)
optimizer
.
clear_grad
()
else
:
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
loss
.
backward
()
optimizer
.
step
()
optimizer
.
clear_grad
()
return
output
,
model
.
parameters
()
def
static_adamax_mp
(
self
,
mp
,
use_amp
):
paddle
.
enable_static
()
paddle
.
seed
(
100
)
np
.
random
.
seed
(
100
)
exe
=
paddle
.
static
.
Executor
(
'gpu'
)
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
optimizer
=
paddle
.
optimizer
.
Adamax
(
0.1
)
optimizer
.
_multi_precision
=
mp
if
use_amp
:
optimizer
=
paddle
.
static
.
amp
.
decorate
(
optimizer
,
init_loss_scaling
=
128.0
,
use_dynamic_loss_scaling
=
True
,
use_pure_fp16
=
True
,
use_fp16_guard
=
False
,
)
with
paddle
.
static
.
program_guard
(
train_program
,
startup_program
):
if
use_amp
:
data
=
paddle
.
static
.
data
(
shape
=
[
2
,
2
],
name
=
'X'
,
dtype
=
'float16'
)
else
:
data
=
paddle
.
static
.
data
(
shape
=
[
2
,
2
],
name
=
'X'
,
dtype
=
'float32'
)
hidden
=
paddle
.
static
.
nn
.
fc
(
x
=
data
,
size
=
10
)
loss
=
paddle
.
mean
(
hidden
)
optimizer
.
minimize
(
loss
)
exe
.
run
(
startup_program
)
if
use_amp
:
optimizer
.
amp_init
(
place
=
'gpu'
,
scope
=
paddle
.
static
.
global_scope
())
x
=
np
.
random
.
random
(
size
=
(
2
,
2
)).
astype
(
'float16'
)
else
:
x
=
np
.
random
.
random
(
size
=
(
2
,
2
)).
astype
(
'float32'
)
out
=
[]
for
idx
in
range
(
5
):
(
loss_data
,)
=
exe
.
run
(
train_program
,
feed
=
{
"X"
:
x
},
fetch_list
=
[
loss
.
name
]
)
out
.
append
(
loss_data
)
return
out
def
test_main
(
self
):
if
not
paddle
.
is_compiled_with_cuda
():
return
"Test dygraph mode"
output1_dy
,
params1_dy
=
self
.
dygraph_adamax_mp
(
use_amp
=
True
,
mp
=
True
)
output2_dy
,
params2_dy
=
self
.
dygraph_adamax_mp
(
use_amp
=
False
,
mp
=
False
)
np
.
testing
.
assert_allclose
(
output1_dy
.
astype
(
'float32'
).
numpy
(),
output2_dy
.
astype
(
'float32'
).
numpy
(),
rtol
=
1e-05
,
atol
=
0.1
,
)
for
idx
in
range
(
len
(
params1_dy
)):
np
.
testing
.
assert_allclose
(
params1_dy
[
idx
].
astype
(
'float32'
).
numpy
(),
params2_dy
[
idx
].
astype
(
'float32'
).
numpy
(),
rtol
=
1e-05
,
atol
=
0.1
,
)
"Test static mode"
output1_st
=
self
.
static_adamax_mp
(
use_amp
=
True
,
mp
=
True
)
output2_st
=
self
.
static_adamax_mp
(
use_amp
=
False
,
mp
=
False
)
for
idx
in
range
(
len
(
output1_st
)):
np
.
testing
.
assert_allclose
(
output1_st
[
idx
].
astype
(
'float32'
),
output2_st
[
idx
].
astype
(
'float32'
),
rtol
=
1e-05
,
atol
=
0.1
,
)
class
TestAdamaxMultiPrecision1_0
(
unittest
.
TestCase
):
def
dygraph_adamax_mp
(
self
,
use_amp
,
mp
):
paddle
.
disable_static
()
paddle
.
seed
(
10
)
paddle
.
set_device
(
'gpu'
)
input
=
paddle
.
randn
((
2
,
2
))
model
=
paddle
.
nn
.
Linear
(
2
,
2
)
optimizer
=
paddle
.
fluid
.
optimizer
.
Adamax
(
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
()
)
optimizer
.
_multi_precision
=
mp
if
use_amp
:
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
1024
)
for
idx
in
range
(
5
):
if
use_amp
:
with
paddle
.
amp
.
auto_cast
(
level
=
'O2'
):
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
scaled
=
scaler
.
scale
(
loss
)
scaled
.
backward
()
scaler
.
minimize
(
optimizer
,
scaled
)
optimizer
.
clear_gradients
()
else
:
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
optimizer
.
minimize
(
loss
)
optimizer
.
clear_gradients
()
return
output
,
model
.
parameters
()
def
static_adamax_mp
(
self
,
use_amp
,
mp
):
paddle
.
enable_static
()
paddle
.
seed
(
100
)
np
.
random
.
seed
(
100
)
exe
=
paddle
.
static
.
Executor
(
'gpu'
)
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
optimizer
=
paddle
.
fluid
.
optimizer
.
Adamax
(
learning_rate
=
0.001
)
optimizer
.
_multi_precision
=
mp
if
use_amp
:
optimizer
=
paddle
.
static
.
amp
.
decorate
(
optimizer
,
init_loss_scaling
=
128.0
,
use_dynamic_loss_scaling
=
True
,
use_pure_fp16
=
True
,
use_fp16_guard
=
False
,
)
with
paddle
.
static
.
program_guard
(
train_program
,
startup_program
):
if
use_amp
:
data
=
paddle
.
static
.
data
(
shape
=
[
2
,
2
],
name
=
'X'
,
dtype
=
'float16'
)
else
:
data
=
paddle
.
static
.
data
(
shape
=
[
2
,
2
],
name
=
'X'
,
dtype
=
'float32'
)
hidden
=
paddle
.
static
.
nn
.
fc
(
x
=
data
,
size
=
10
)
loss
=
paddle
.
mean
(
hidden
)
optimizer
.
minimize
(
loss
)
exe
.
run
(
startup_program
)
if
use_amp
:
optimizer
.
amp_init
(
place
=
'gpu'
,
scope
=
paddle
.
static
.
global_scope
())
x
=
np
.
random
.
random
(
size
=
(
2
,
2
)).
astype
(
'float16'
)
else
:
x
=
np
.
random
.
random
(
size
=
(
2
,
2
)).
astype
(
'float32'
)
out
=
[]
for
idx
in
range
(
5
):
(
loss_data
,)
=
exe
.
run
(
train_program
,
feed
=
{
"X"
:
x
},
fetch_list
=
[
loss
.
name
]
)
out
.
append
(
loss_data
)
return
out
def
test_main
(
self
):
if
not
paddle
.
is_compiled_with_cuda
():
return
"Test dygraph mode"
output1_dy
,
params1_dy
=
self
.
dygraph_adamax_mp
(
use_amp
=
True
,
mp
=
True
)
output2_dy
,
params2_dy
=
self
.
dygraph_adamax_mp
(
use_amp
=
False
,
mp
=
False
)
np
.
testing
.
assert_allclose
(
output1_dy
.
astype
(
'float32'
).
numpy
(),
output2_dy
.
astype
(
'float32'
).
numpy
(),
rtol
=
1e-05
,
atol
=
0.1
,
)
for
idx
in
range
(
len
(
params1_dy
)):
np
.
testing
.
assert_allclose
(
params1_dy
[
idx
].
astype
(
'float32'
).
numpy
(),
params2_dy
[
idx
].
astype
(
'float32'
).
numpy
(),
rtol
=
1e-05
,
atol
=
0.1
,
)
"Test static mode"
output1_st
=
self
.
static_adamax_mp
(
use_amp
=
True
,
mp
=
True
)
output2_st
=
self
.
static_adamax_mp
(
use_amp
=
False
,
mp
=
False
)
for
idx
in
range
(
len
(
output1_st
)):
np
.
testing
.
assert_allclose
(
output1_st
[
idx
].
astype
(
'float32'
),
output2_st
[
idx
].
astype
(
'float32'
),
rtol
=
1e-05
,
atol
=
0.1
,
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/optimizer/adamax.py
浏览文件 @
151ec311
...
@@ -12,11 +12,15 @@
...
@@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
warnings
import
paddle
from
paddle
import
_C_ops
from
paddle
import
_C_ops
from
..fluid
import
framework
from
..fluid
import
core
,
framework
,
unique_name
from
..fluid.dygraph
import
no_grad
from
..fluid.dygraph
import
no_grad
from
..fluid.framework
import
name_scope
from
..fluid.framework
import
name_scope
from
..fluid.layer_helper
import
LayerHelper
from
.optimizer
import
Optimizer
from
.optimizer
import
Optimizer
__all__
=
[]
__all__
=
[]
...
@@ -164,26 +168,104 @@ class Adamax(Optimizer):
...
@@ -164,26 +168,104 @@ class Adamax(Optimizer):
self
.
_beta1
=
beta1
self
.
_beta1
=
beta1
self
.
_beta2
=
beta2
self
.
_beta2
=
beta2
self
.
_epsilon
=
epsilon
self
.
_epsilon
=
epsilon
self
.
_multi_precision
=
False
self
.
_master_weights
=
{}
self
.
_default_dict
=
{
self
.
_default_dict
=
{
'beta1'
:
beta1
,
'beta1'
:
beta1
,
'beta2'
:
beta2
,
'beta2'
:
beta2
,
'epsilon'
:
epsilon
,
'epsilon'
:
epsilon
,
}
}
def
_add_moments_pows
(
self
,
p
):
acc_dtype
=
p
.
dtype
if
self
.
_is_dtype_fp16_or_bf16
(
acc_dtype
):
acc_dtype
=
core
.
VarDesc
.
VarType
.
FP32
self
.
_add_accumulator
(
self
.
_moment_acc_str
,
p
,
dtype
=
acc_dtype
)
self
.
_add_accumulator
(
self
.
_inf_norm_acc_str
,
p
,
dtype
=
acc_dtype
)
self
.
_add_accumulator
(
name
=
self
.
_beta1_pow_acc_str
,
param
=
p
,
fill_value
=
self
.
_beta1
,
shape
=
[
1
],
)
def
_create_master_weight
(
self
,
param
):
if
param
.
name
in
self
.
_master_weights
:
var
=
self
.
_master_weights
[
param
.
name
]
else
:
assert
isinstance
(
self
.
helper
,
LayerHelper
)
var_name
=
param
.
name
+
"_fp32_master"
var_name
=
unique_name
.
generate
(
var_name
)
var
=
paddle
.
static
.
create_global_var
(
name
=
var_name
,
shape
=
param
.
shape
,
value
=
0
,
dtype
=
'float32'
,
persistable
=
True
,
)
block
=
self
.
helper
.
startup_program
.
global_block
()
block
.
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
[
param
]},
outputs
=
{
"Out"
:
[
var
]},
attrs
=
{
"in_dtype"
:
param
.
dtype
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
,
},
)
self
.
_master_weights
[
param
.
name
]
=
var
return
var
def
_create_accumulators
(
self
,
block
,
parameters
):
def
_create_accumulators
(
self
,
block
,
parameters
):
if
isinstance
(
parameters
,
dict
):
if
isinstance
(
parameters
,
dict
):
parameters
=
self
.
_update_param_group
(
parameters
)
parameters
=
self
.
_update_param_group
(
parameters
)
# Create accumulator tensors for first moment and infinity norm
# Create accumulator tensors for first moment and infinity norm
for
p
in
parameters
:
for
p
in
parameters
:
self
.
_add_accumulator
(
self
.
_moment_acc_str
,
p
)
if
self
.
_multi_precision
and
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
self
.
_add_accumulator
(
self
.
_inf_norm_acc_str
,
p
)
master_p
=
self
.
_create_master_weight
(
p
)
self
.
_add_accumulator
(
self
.
_add_moments_pows
(
master_p
)
name
=
self
.
_beta1_pow_acc_str
,
continue
param
=
p
,
if
(
fill_value
=
self
.
_beta1
,
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
shape
=
[
1
],
and
not
self
.
_multi_precision
):
warnings
.
warn
(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Adam optimizer."
)
self
.
_add_moments_pows
(
p
)
def
_get_accumulator
(
self
,
name
,
param
):
"""Utility function to fetch an accumulator for a parameter
Args:
name: name of the accumulator
param: parameter variable for which accumulator is to be fetched
Returns:
accumulator variable for the parameter
"""
if
self
.
_name
is
not
None
:
name
=
self
.
_name
+
"_"
+
name
find_master
=
self
.
_multi_precision
and
self
.
_is_dtype_fp16_or_bf16
(
param
.
dtype
)
target_param
=
(
self
.
_master_weights
[
param
.
name
]
if
find_master
else
param
)
target_name
=
target_param
.
name
if
(
name
not
in
self
.
_accumulators
or
target_name
not
in
self
.
_accumulators
[
name
]
):
raise
Exception
(
"Accumulator {} does not exist for parameter {}"
.
format
(
name
,
target_name
)
)
)
return
self
.
_accumulators
[
name
][
target_name
]
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
assert
isinstance
(
block
,
framework
.
Block
)
assert
isinstance
(
block
,
framework
.
Block
)
...
@@ -194,10 +276,20 @@ class Adamax(Optimizer):
...
@@ -194,10 +276,20 @@ class Adamax(Optimizer):
inf_norm
=
self
.
_get_accumulator
(
inf_norm
=
self
.
_get_accumulator
(
self
.
_inf_norm_acc_str
,
param_and_grad
[
0
]
self
.
_inf_norm_acc_str
,
param_and_grad
[
0
]
)
)
find_master
=
(
self
.
_multi_precision
and
param_and_grad
[
0
].
dtype
==
core
.
VarDesc
.
VarType
.
FP16
)
master_weight
=
(
self
.
_master_weights
[
param_and_grad
[
0
].
name
]
if
find_master
else
None
)
beta1_pow_acc
=
self
.
_get_accumulator
(
beta1_pow_acc
=
self
.
_get_accumulator
(
self
.
_beta1_pow_acc_str
,
param_and_grad
[
0
]
self
.
_beta1_pow_acc_str
,
param_and_grad
[
0
]
)
)
if
framework
.
in_dygraph_mode
():
if
framework
.
in_dygraph_mode
():
_C_ops
.
adamax_
(
_C_ops
.
adamax_
(
param_and_grad
[
0
],
param_and_grad
[
0
],
...
@@ -206,32 +298,42 @@ class Adamax(Optimizer):
...
@@ -206,32 +298,42 @@ class Adamax(Optimizer):
moment
,
moment
,
inf_norm
,
inf_norm
,
beta1_pow_acc
,
beta1_pow_acc
,
master_weight
,
self
.
_beta1
,
self
.
_beta1
,
self
.
_beta2
,
self
.
_beta2
,
self
.
_epsilon
,
self
.
_epsilon
,
find_master
,
)
)
else
:
else
:
# create the adamax optimize op
# create the adamax optimize op
inputs
=
{
"Param"
:
param_and_grad
[
0
],
"Grad"
:
param_and_grad
[
1
],
"LearningRate"
:
self
.
_create_param_lr
(
param_and_grad
),
"Moment"
:
moment
,
"InfNorm"
:
inf_norm
,
"Beta1Pow"
:
beta1_pow_acc
,
}
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
],
"MomentOut"
:
moment
,
"InfNormOut"
:
inf_norm
,
}
if
find_master
:
inputs
[
"MasterParam"
]
=
master_weight
outputs
[
"MasterParamOut"
]
=
master_weight
attrs
=
{
"beta1"
:
self
.
_beta1
,
"beta2"
:
self
.
_beta2
,
"epsilon"
:
self
.
_epsilon
,
"multi_precision"
:
find_master
,
}
adamax_op
=
block
.
append_op
(
adamax_op
=
block
.
append_op
(
type
=
self
.
type
,
type
=
self
.
type
,
inputs
=
{
inputs
=
inputs
,
"Param"
:
param_and_grad
[
0
],
outputs
=
outputs
,
"Grad"
:
param_and_grad
[
1
],
attrs
=
attrs
,
"LearningRate"
:
self
.
_create_param_lr
(
param_and_grad
),
"Moment"
:
moment
,
"InfNorm"
:
inf_norm
,
"Beta1Pow"
:
beta1_pow_acc
,
},
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
],
"MomentOut"
:
moment
,
"InfNormOut"
:
inf_norm
,
},
attrs
=
{
"beta1"
:
self
.
_beta1
,
"beta2"
:
self
.
_beta2
,
"epsilon"
:
self
.
_epsilon
,
},
stop_gradient
=
True
,
stop_gradient
=
True
,
)
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录