Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d55ee95f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d55ee95f
编写于
7月 12, 2022
作者:
Z
zhangbo9674
提交者:
GitHub
7月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Phi] Migrate merged_adam_op into Phi (#44184)
* remov merged_adam_op to phi * refine code
上级
636c6347
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
345 addition
and
367 deletion
+345
-367
paddle/fluid/operators/optimizers/merged_adam_op.cc
paddle/fluid/operators/optimizers/merged_adam_op.cc
+17
-13
paddle/fluid/operators/optimizers/merged_adam_op.cu
paddle/fluid/operators/optimizers/merged_adam_op.cu
+0
-230
paddle/fluid/operators/optimizers/merged_adam_op.h
paddle/fluid/operators/optimizers/merged_adam_op.h
+0
-124
paddle/phi/infermeta/multiary.cc
paddle/phi/infermeta/multiary.cc
+21
-0
paddle/phi/infermeta/multiary.h
paddle/phi/infermeta/multiary.h
+21
-0
paddle/phi/kernels/adam_kernel.h
paddle/phi/kernels/adam_kernel.h
+23
-0
paddle/phi/kernels/cpu/adam_kernel.cc
paddle/phi/kernels/cpu/adam_kernel.cc
+104
-0
paddle/phi/kernels/gpu/adam_kernel.cu
paddle/phi/kernels/gpu/adam_kernel.cu
+112
-0
paddle/phi/ops/compat/merged_adam_sig.cc
paddle/phi/ops/compat/merged_adam_sig.cc
+47
-0
未找到文件。
paddle/fluid/operators/optimizers/merged_adam_op.cc
浏览文件 @
d55ee95f
...
@@ -10,7 +10,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -10,7 +10,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/optimizers/merged_adam_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -21,8 +25,6 @@ class MergedAdamOp : public framework::OperatorWithKernel {
...
@@ -21,8 +25,6 @@ class MergedAdamOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
param_dtype
=
auto
param_dtype
=
...
@@ -128,13 +130,15 @@ $$
...
@@ -128,13 +130,15 @@ $$
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
merged_adam
,
ops
::
MergedAdamOp
,
DECLARE_INFER_SHAPE_FUNCTOR
(
merged_adam
,
ops
::
MergedAdamOpMaker
);
MergedAdamInferMetaFunctor
,
REGISTER_OP_WITHOUT_GRADIENT
(
merged_adamw
,
PD_INFER_META
(
phi
::
MergedAdamInferMeta
));
ops
::
MergedAdamOp
,
ops
::
MergedAdamOpMaker
);
REGISTER_OPERATOR
(
merged_adam
,
REGISTER_OP_CPU_KERNEL
(
merged_adam
,
ops
::
MergedAdamOp
,
ops
::
MergedAdamOpKernel
<
phi
::
CPUContext
,
float
>
,
ops
::
MergedAdamOpMaker
,
ops
::
MergedAdamOpKernel
<
phi
::
CPUContext
,
double
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
MergedAdamInferMetaFunctor
);
paddle/fluid/operators/optimizers/merged_adam_op.cu
已删除
100644 → 0
浏览文件 @
636c6347
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/merged_adam_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
typename
MT
>
__global__
void
AdamKernelREG
(
MT
beta1
,
MT
beta2
,
MT
epsilon
,
MT
beta1_pow_
,
MT
beta2_pow_
,
const
MT
*
moment1
,
MT
*
moment1_out
,
const
MT
*
moment2
,
MT
*
moment2_out
,
const
MT
*
lr_
,
const
T
*
grad
,
const
T
*
param
,
T
*
param_out
,
const
MT
*
master_param
,
MT
*
master_param_out
,
int
ndim
)
{
MT
lr
=
*
lr_
;
MT
beta1_pow
=
beta1_pow_
;
MT
beta2_pow
=
beta2_pow_
;
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(;
id
<
ndim
;
id
+=
gridDim
.
x
*
blockDim
.
x
)
{
MT
p
=
master_param
?
master_param
[
id
]
:
static_cast
<
MT
>
(
param
[
id
]);
MT
g
=
static_cast
<
MT
>
(
grad
[
id
]);
MT
mom1
=
static_cast
<
MT
>
(
moment1
[
id
]);
MT
mom2
=
static_cast
<
MT
>
(
moment2
[
id
]);
mom1
=
beta1
*
mom1
+
(
static_cast
<
MT
>
(
1.0
)
-
beta1
)
*
g
;
mom2
=
beta2
*
mom2
+
(
static_cast
<
MT
>
(
1.0
)
-
beta2
)
*
g
*
g
;
MT
denom
=
(
sqrt
(
mom2
)
/
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
))
+
epsilon
;
p
+=
(
mom1
/
denom
)
*
(
-
(
lr
/
(
static_cast
<
MT
>
(
1.0
)
-
beta1_pow
)));
moment1_out
[
id
]
=
mom1
;
moment2_out
[
id
]
=
mom2
;
param_out
[
id
]
=
static_cast
<
T
>
(
p
);
if
(
master_param_out
)
{
master_param_out
[
id
]
=
p
;
}
}
}
template
<
typename
T
,
typename
MT
>
__global__
void
AdamKernelMEM
(
MT
beta1
,
MT
beta2
,
MT
epsilon
,
const
MT
*
beta1_pow_
,
const
MT
*
beta2_pow_
,
const
MT
*
moment1
,
MT
*
moment1_out
,
const
MT
*
moment2
,
MT
*
moment2_out
,
const
MT
*
lr_
,
const
T
*
grad
,
const
T
*
param
,
T
*
param_out
,
const
MT
*
master_param
,
MT
*
master_param_out
,
int
ndim
)
{
MT
lr
=
*
lr_
;
MT
beta1_pow
=
*
beta1_pow_
;
MT
beta2_pow
=
*
beta2_pow_
;
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(;
id
<
ndim
;
id
+=
gridDim
.
x
*
blockDim
.
x
)
{
MT
p
=
master_param
?
master_param
[
id
]
:
static_cast
<
MT
>
(
param
[
id
]);
MT
g
=
static_cast
<
MT
>
(
grad
[
id
]);
MT
mom1
=
static_cast
<
MT
>
(
moment1
[
id
]);
MT
mom2
=
static_cast
<
MT
>
(
moment2
[
id
]);
mom1
=
beta1
*
mom1
+
(
static_cast
<
MT
>
(
1.0
)
-
beta1
)
*
g
;
mom2
=
beta2
*
mom2
+
(
static_cast
<
MT
>
(
1.0
)
-
beta2
)
*
g
*
g
;
MT
denom
=
(
sqrt
(
mom2
)
/
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
))
+
epsilon
;
p
+=
(
mom1
/
denom
)
*
(
-
(
lr
/
(
static_cast
<
MT
>
(
1.0
)
-
beta1_pow
)));
moment1_out
[
id
]
=
mom1
;
moment2_out
[
id
]
=
mom2
;
param_out
[
id
]
=
static_cast
<
T
>
(
p
);
if
(
master_param_out
)
{
master_param_out
[
id
]
=
p
;
}
}
}
template
<
typename
T
>
__global__
void
UpdateBetaPow
(
T
beta1
,
T
beta2
,
const
T
*
beta1_pow_
,
const
T
*
beta2_pow_
,
T
*
beta1_pow_out
,
T
*
beta2_pow_out
)
{
*
beta1_pow_out
=
beta1
*
beta1_pow_
[
0
];
*
beta2_pow_out
=
beta2
*
beta2_pow_
[
0
];
}
template
<
typename
T
>
class
MergedAdamOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
MPDType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
auto
param
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Param"
);
auto
grad
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Grad"
);
auto
lr
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
mom1
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Moment1"
);
auto
mom2
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Moment2"
);
auto
beta1_pow
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Beta1Pow"
);
auto
beta2_pow
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Beta2Pow"
);
auto
param_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
mom1_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"Moment1Out"
);
auto
mom2_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"Moment2Out"
);
auto
beta1_pow_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"Beta1PowOut"
);
auto
beta2_pow_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"Beta2PowOut"
);
MPDType
beta1
=
static_cast
<
MPDType
>
(
ctx
.
Attr
<
float
>
(
"beta1"
));
MPDType
beta2
=
static_cast
<
MPDType
>
(
ctx
.
Attr
<
float
>
(
"beta2"
));
MPDType
epsilon
=
static_cast
<
MPDType
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
bool
use_global_beta_pow
=
ctx
.
Attr
<
bool
>
(
"use_global_beta_pow"
);
VLOG
(
4
)
<<
"use_global_beta_pow:"
<<
use_global_beta_pow
;
const
bool
multi_precision
=
ctx
.
Attr
<
bool
>
(
"multi_precision"
);
auto
master_param
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"MasterParam"
);
auto
master_param_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"MasterParamOut"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
size_t
param_num
=
param
.
size
();
for
(
size_t
idx
=
0
;
idx
<
param_num
;
idx
++
)
{
const
MPDType
*
master_in_data
=
multi_precision
?
master_param
[
idx
]
->
data
<
MPDType
>
()
:
nullptr
;
MPDType
*
master_out_data
=
multi_precision
?
master_param_out
[
idx
]
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
())
:
nullptr
;
// update param and moment
int
threads
=
512
;
int
blocks
=
(
param
[
idx
]
->
numel
()
+
threads
-
1
)
/
threads
;
if
(
beta1_pow
[
idx
]
->
place
()
==
platform
::
CPUPlace
()
&&
beta2_pow
[
idx
]
->
place
()
==
platform
::
CPUPlace
())
{
// Compute with betapow in REG
AdamKernelREG
<
T
,
MPDType
><<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
beta1
,
beta2
,
epsilon
,
*
beta1_pow
[
idx
]
->
data
<
MPDType
>
(),
*
beta2_pow
[
idx
]
->
data
<
MPDType
>
(),
mom1
[
idx
]
->
data
<
MPDType
>
(),
mom1_out
[
idx
]
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
mom2
[
idx
]
->
data
<
MPDType
>
(),
mom2_out
[
idx
]
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
lr
[
idx
]
->
data
<
MPDType
>
(),
grad
[
idx
]
->
data
<
T
>
(),
param
[
idx
]
->
data
<
T
>
(),
param_out
[
idx
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
master_in_data
,
master_out_data
,
param
[
idx
]
->
numel
());
if
(
!
use_global_beta_pow
)
{
// Cpu update
beta1_pow_out
[
idx
]
->
mutable_data
<
MPDType
>
(
platform
::
CPUPlace
())[
0
]
=
beta1
*
beta1_pow
[
idx
]
->
data
<
MPDType
>
()[
0
];
beta2_pow_out
[
idx
]
->
mutable_data
<
MPDType
>
(
platform
::
CPUPlace
())[
0
]
=
beta2
*
beta2_pow
[
idx
]
->
data
<
MPDType
>
()[
0
];
}
}
else
{
AdamKernelMEM
<
T
,
MPDType
><<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
beta1
,
beta2
,
epsilon
,
beta1_pow
[
idx
]
->
data
<
MPDType
>
(),
beta2_pow
[
idx
]
->
data
<
MPDType
>
(),
mom1
[
idx
]
->
data
<
MPDType
>
(),
mom1_out
[
idx
]
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
mom2
[
idx
]
->
data
<
MPDType
>
(),
mom2_out
[
idx
]
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
lr
[
idx
]
->
data
<
MPDType
>
(),
grad
[
idx
]
->
data
<
T
>
(),
param
[
idx
]
->
data
<
T
>
(),
param_out
[
idx
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
master_in_data
,
master_out_data
,
param
[
idx
]
->
numel
());
if
(
!
use_global_beta_pow
)
{
// Update with gpu
UpdateBetaPow
<
MPDType
><<<
1
,
32
,
0
,
dev_ctx
.
stream
()
>>>
(
beta1
,
beta2
,
beta1_pow
[
idx
]
->
data
<
MPDType
>
(),
beta2_pow
[
idx
]
->
data
<
MPDType
>
(),
beta1_pow_out
[
idx
]
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()),
beta2_pow_out
[
idx
]
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
()));
}
}
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
merged_adam
,
ops
::
MergedAdamOpCUDAKernel
<
float
>
,
ops
::
MergedAdamOpCUDAKernel
<
double
>
,
ops
::
MergedAdamOpCUDAKernel
<
plat
::
float16
>
);
paddle/fluid/operators/optimizers/merged_adam_op.h
已删除
100644 → 0
浏览文件 @
636c6347
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/phi/kernels/funcs/adam_functors.h"
namespace
paddle
{
namespace
operators
{
namespace
scatter
=
paddle
::
operators
::
math
::
scatter
;
template
<
typename
DeviceContext
,
typename
T
>
class
MergedAdamOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
param
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Param"
);
size_t
n
=
param
.
size
();
auto
grad
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Grad"
);
PADDLE_ENFORCE_EQ
(
n
,
grad
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of Input(Grad) must be equal to "
"Input(Param), but got the size of Input(Grad) "
"is %d, the size of Input(Param) is %d."
,
grad
.
size
(),
n
));
auto
lr
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"LearningRate"
);
PADDLE_ENFORCE_EQ
(
n
,
lr
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of Input(LearningRate) must be equal to "
"Input(Param), but got the size of Input(LearningRate) "
"is %d, the size of Input(Param) is %d."
,
lr
.
size
(),
n
));
auto
mom1
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Moment1"
);
PADDLE_ENFORCE_EQ
(
n
,
mom1
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of Input(Moment1) must be equal to "
"Input(Param), but got the size of Input(Moment1) "
"is %d, the size of Input(Param) is %d."
,
mom1
.
size
(),
n
));
auto
mom2
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Moment2"
);
PADDLE_ENFORCE_EQ
(
n
,
mom2
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of Input(Moment2) must be equal to "
"Input(Param), but got the size of Input(Moment2) "
"is %d, the size of Input(Param) is %d."
,
mom2
.
size
(),
n
));
auto
beta1_pow
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Beta1Pow"
);
PADDLE_ENFORCE_EQ
(
n
,
beta1_pow
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of Input(Beta1Pow) must be equal to "
"Input(Param), but got the size of Input(Beta1Pow) "
"is %d, the size of Input(Param) is %d."
,
beta1_pow
.
size
(),
n
));
auto
beta2_pow
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Beta2Pow"
);
PADDLE_ENFORCE_EQ
(
n
,
beta2_pow
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of Input(Beta2Pow) must be equal to "
"Input(Param), but got the size of Input(Beta2Pow) "
"is %d, the size of Input(Param) is %d."
,
beta2_pow
.
size
(),
n
));
auto
param_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
mom1_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"Moment1Out"
);
auto
mom2_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"Moment2Out"
);
auto
beta1_pow_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"Beta1PowOut"
);
auto
beta2_pow_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"Beta2PowOut"
);
T
beta1
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"beta1"
));
T
beta2
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"beta2"
));
T
epsilon
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
bool
use_global_beta_pow
=
ctx
.
Attr
<
bool
>
(
"use_global_beta_pow"
);
VLOG
(
4
)
<<
"use_global_beta_pow:"
<<
use_global_beta_pow
;
size_t
param_num
=
param
.
size
();
for
(
size_t
idx
=
0
;
idx
<
param_num
;
idx
++
)
{
phi
::
funcs
::
AdamFunctor
<
T
,
phi
::
funcs
::
CPUAdam
>
functor
(
beta1
,
beta2
,
epsilon
,
beta1_pow
[
idx
]
->
data
<
T
>
(),
beta2_pow
[
idx
]
->
data
<
T
>
(),
mom1
[
idx
]
->
data
<
T
>
(),
mom1_out
[
idx
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
mom2
[
idx
]
->
data
<
T
>
(),
mom2_out
[
idx
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
lr
[
idx
]
->
data
<
T
>
(),
grad
[
idx
]
->
data
<
T
>
(),
param
[
idx
]
->
data
<
T
>
(),
param_out
[
idx
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
functor
(
param
[
idx
]
->
numel
());
if
(
!
use_global_beta_pow
)
{
beta1_pow_out
[
idx
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())[
0
]
=
beta1
*
beta1_pow
[
idx
]
->
data
<
T
>
()[
0
];
beta2_pow_out
[
idx
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())[
0
]
=
beta2
*
beta2_pow
[
idx
]
->
data
<
T
>
()[
0
];
}
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/phi/infermeta/multiary.cc
浏览文件 @
d55ee95f
...
@@ -1528,6 +1528,27 @@ void LogspaceInferMeta(const MetaTensor& start,
...
@@ -1528,6 +1528,27 @@ void LogspaceInferMeta(const MetaTensor& start,
out
->
set_dtype
(
start
.
dtype
());
out
->
set_dtype
(
start
.
dtype
());
}
}
void
MergedAdamInferMeta
(
const
std
::
vector
<
const
MetaTensor
*>&
param
,
const
std
::
vector
<
const
MetaTensor
*>&
grad
,
const
std
::
vector
<
const
MetaTensor
*>&
learning_rate
,
const
std
::
vector
<
const
MetaTensor
*>&
moment1
,
const
std
::
vector
<
const
MetaTensor
*>&
moment2
,
const
std
::
vector
<
const
MetaTensor
*>&
beta1_pow
,
const
std
::
vector
<
const
MetaTensor
*>&
beta2_pow
,
const
paddle
::
optional
<
std
::
vector
<
const
MetaTensor
*>>&
master_param
,
const
Scalar
&
beta1
,
const
Scalar
&
beta2
,
const
Scalar
&
epsilon
,
bool
multi_precision
,
bool
use_global_beta_pow
,
std
::
vector
<
MetaTensor
*>
param_out
,
std
::
vector
<
MetaTensor
*>
moment1_out
,
std
::
vector
<
MetaTensor
*>
moment2_out
,
std
::
vector
<
MetaTensor
*>
beta1_pow_out
,
std
::
vector
<
MetaTensor
*>
beta2_pow_out
,
std
::
vector
<
MetaTensor
*>
master_param_out
)
{}
void
MeshgridInferMeta
(
const
std
::
vector
<
const
MetaTensor
*>&
inputs
,
void
MeshgridInferMeta
(
const
std
::
vector
<
const
MetaTensor
*>&
inputs
,
std
::
vector
<
MetaTensor
*>
outputs
)
{
std
::
vector
<
MetaTensor
*>
outputs
)
{
const
size_t
inputs_num
=
inputs
.
size
();
const
size_t
inputs_num
=
inputs
.
size
();
...
...
paddle/phi/infermeta/multiary.h
浏览文件 @
d55ee95f
...
@@ -234,6 +234,27 @@ void LogspaceInferMeta(const MetaTensor& start,
...
@@ -234,6 +234,27 @@ void LogspaceInferMeta(const MetaTensor& start,
const
MetaTensor
&
base
,
const
MetaTensor
&
base
,
MetaTensor
*
out
);
MetaTensor
*
out
);
void
MergedAdamInferMeta
(
const
std
::
vector
<
const
MetaTensor
*>&
param
,
const
std
::
vector
<
const
MetaTensor
*>&
grad
,
const
std
::
vector
<
const
MetaTensor
*>&
learning_rate
,
const
std
::
vector
<
const
MetaTensor
*>&
moment1
,
const
std
::
vector
<
const
MetaTensor
*>&
moment2
,
const
std
::
vector
<
const
MetaTensor
*>&
beta1_pow
,
const
std
::
vector
<
const
MetaTensor
*>&
beta2_pow
,
const
paddle
::
optional
<
std
::
vector
<
const
MetaTensor
*>>&
master_param
,
const
Scalar
&
beta1
,
const
Scalar
&
beta2
,
const
Scalar
&
epsilon
,
bool
multi_precision
,
bool
use_global_beta_pow
,
std
::
vector
<
MetaTensor
*>
param_out
,
std
::
vector
<
MetaTensor
*>
moment1_out
,
std
::
vector
<
MetaTensor
*>
moment2_out
,
std
::
vector
<
MetaTensor
*>
beta1_pow_out
,
std
::
vector
<
MetaTensor
*>
beta2_pow_out
,
std
::
vector
<
MetaTensor
*>
master_param_out
);
void
MeshgridInferMeta
(
const
std
::
vector
<
const
MetaTensor
*>&
inputs
,
void
MeshgridInferMeta
(
const
std
::
vector
<
const
MetaTensor
*>&
inputs
,
std
::
vector
<
MetaTensor
*>
outputs
);
std
::
vector
<
MetaTensor
*>
outputs
);
...
...
paddle/phi/kernels/adam_kernel.h
浏览文件 @
d55ee95f
...
@@ -44,4 +44,27 @@ void AdamDenseKernel(const Context& dev_ctx,
...
@@ -44,4 +44,27 @@ void AdamDenseKernel(const Context& dev_ctx,
DenseTensor
*
beta2_pow_out
,
DenseTensor
*
beta2_pow_out
,
DenseTensor
*
master_param_outs
);
DenseTensor
*
master_param_outs
);
template
<
typename
T
,
typename
Context
>
void
MergedAdamKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
param
,
const
std
::
vector
<
const
DenseTensor
*>&
grad
,
const
std
::
vector
<
const
DenseTensor
*>&
learning_rate
,
const
std
::
vector
<
const
DenseTensor
*>&
moment1
,
const
std
::
vector
<
const
DenseTensor
*>&
moment2
,
const
std
::
vector
<
const
DenseTensor
*>&
beta1_pow
,
const
std
::
vector
<
const
DenseTensor
*>&
beta2_pow
,
const
paddle
::
optional
<
std
::
vector
<
const
DenseTensor
*>>&
master_param
,
const
Scalar
&
beta1
,
const
Scalar
&
beta2
,
const
Scalar
&
epsilon
,
bool
multi_precision
,
bool
use_global_beta_pow
,
std
::
vector
<
DenseTensor
*>
param_out
,
std
::
vector
<
DenseTensor
*>
moment1_out
,
std
::
vector
<
DenseTensor
*>
moment2_out
,
std
::
vector
<
DenseTensor
*>
beta1_pow_out
,
std
::
vector
<
DenseTensor
*>
beta2_pow_out
,
std
::
vector
<
DenseTensor
*>
master_param_out
);
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/cpu/adam_kernel.cc
浏览文件 @
d55ee95f
...
@@ -167,7 +167,111 @@ void AdamDenseKernel(const Context& dev_ctx,
...
@@ -167,7 +167,111 @@ void AdamDenseKernel(const Context& dev_ctx,
}
}
}
}
template
<
typename
T
,
typename
Context
>
void
MergedAdamKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
param
,
const
std
::
vector
<
const
DenseTensor
*>&
grad
,
const
std
::
vector
<
const
DenseTensor
*>&
learning_rate
,
const
std
::
vector
<
const
DenseTensor
*>&
moment1
,
const
std
::
vector
<
const
DenseTensor
*>&
moment2
,
const
std
::
vector
<
const
DenseTensor
*>&
beta1_pow
,
const
std
::
vector
<
const
DenseTensor
*>&
beta2_pow
,
const
paddle
::
optional
<
std
::
vector
<
const
DenseTensor
*>>&
master_param
,
const
Scalar
&
beta1
,
const
Scalar
&
beta2
,
const
Scalar
&
epsilon
,
bool
multi_precision
,
bool
use_global_beta_pow
,
std
::
vector
<
DenseTensor
*>
param_out
,
std
::
vector
<
DenseTensor
*>
moment1_out
,
std
::
vector
<
DenseTensor
*>
moment2_out
,
std
::
vector
<
DenseTensor
*>
beta1_pow_out
,
std
::
vector
<
DenseTensor
*>
beta2_pow_out
,
std
::
vector
<
DenseTensor
*>
master_param_out
)
{
size_t
param_num
=
param
.
size
();
PADDLE_ENFORCE_EQ
(
param_num
,
grad
.
size
(),
errors
::
InvalidArgument
(
"The size of Input(grad) must be equal to "
"Input(param), but got the size of Input(grad) "
"is %d, the size of Input(param) is %d."
,
grad
.
size
(),
param_num
));
PADDLE_ENFORCE_EQ
(
param_num
,
learning_rate
.
size
(),
errors
::
InvalidArgument
(
"The size of Input(learning_rate) must be equal to "
"Input(param), but got the size of Input(learning_rate) "
"is %d, the size of Input(param) is %d."
,
learning_rate
.
size
(),
param_num
));
PADDLE_ENFORCE_EQ
(
param_num
,
moment1
.
size
(),
errors
::
InvalidArgument
(
"The size of Input(moment1) must be equal to "
"Input(param), but got the size of Input(moment1) "
"is %d, the size of Input(param) is %d."
,
moment1
.
size
(),
param_num
));
PADDLE_ENFORCE_EQ
(
param_num
,
moment2
.
size
(),
errors
::
InvalidArgument
(
"The size of Input(moment2) must be equal to "
"Input(param), but got the size of Input(moment2) "
"is %d, the size of Input(param) is %d."
,
moment2
.
size
(),
param_num
));
PADDLE_ENFORCE_EQ
(
param_num
,
beta1_pow
.
size
(),
errors
::
InvalidArgument
(
"The size of Input(beta1_pow) must be equal to "
"Input(param), but got the size of Input(beta1_pow) "
"is %d, the size of Input(param) is %d."
,
beta1_pow
.
size
(),
param_num
));
PADDLE_ENFORCE_EQ
(
param_num
,
beta2_pow
.
size
(),
errors
::
InvalidArgument
(
"The size of Input(beta2_pow) must be equal to "
"Input(param), but got the size of Input(beta2_pow) "
"is %d, the size of Input(param) is %d."
,
beta2_pow
.
size
(),
param_num
));
T
beta1_
=
beta1
.
to
<
T
>
();
T
beta2_
=
beta2
.
to
<
T
>
();
T
epsilon_
=
epsilon
.
to
<
T
>
();
for
(
size_t
idx
=
0
;
idx
<
param_num
;
idx
++
)
{
phi
::
funcs
::
AdamFunctor
<
T
,
phi
::
funcs
::
CPUAdam
>
functor
(
beta1_
,
beta2_
,
epsilon_
,
beta1_pow
[
idx
]
->
data
<
T
>
(),
beta2_pow
[
idx
]
->
data
<
T
>
(),
moment1
[
idx
]
->
data
<
T
>
(),
dev_ctx
.
template
Alloc
<
T
>(
moment1_out
[
idx
]),
moment2
[
idx
]
->
data
<
T
>
(),
dev_ctx
.
template
Alloc
<
T
>(
moment2_out
[
idx
]),
learning_rate
[
idx
]
->
data
<
T
>
(),
grad
[
idx
]
->
data
<
T
>
(),
param
[
idx
]
->
data
<
T
>
(),
dev_ctx
.
template
Alloc
<
T
>(
param_out
[
idx
]));
functor
(
param
[
idx
]
->
numel
());
if
(
!
use_global_beta_pow
)
{
dev_ctx
.
template
Alloc
<
T
>(
beta1_pow_out
[
idx
])[
0
]
=
beta1_
*
beta1_pow
[
idx
]
->
data
<
T
>
()[
0
];
dev_ctx
.
template
Alloc
<
T
>(
beta2_pow_out
[
idx
])[
0
]
=
beta2_
*
beta2_pow
[
idx
]
->
data
<
T
>
()[
0
];
}
}
}
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
adam
,
CPU
,
ALL_LAYOUT
,
phi
::
AdamDenseKernel
,
float
,
double
)
{
PD_REGISTER_KERNEL
(
adam
,
CPU
,
ALL_LAYOUT
,
phi
::
AdamDenseKernel
,
float
,
double
)
{
}
}
PD_REGISTER_KERNEL
(
merged_adam
,
CPU
,
ALL_LAYOUT
,
phi
::
MergedAdamKernel
,
float
,
double
)
{}
paddle/phi/kernels/gpu/adam_kernel.cu
浏览文件 @
d55ee95f
...
@@ -265,6 +265,106 @@ void AdamDenseKernel(const Context& dev_ctx,
...
@@ -265,6 +265,106 @@ void AdamDenseKernel(const Context& dev_ctx,
}
}
}
}
template
<
typename
T
,
typename
Context
>
void
MergedAdamKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
param
,
const
std
::
vector
<
const
DenseTensor
*>&
grad
,
const
std
::
vector
<
const
DenseTensor
*>&
learning_rate
,
const
std
::
vector
<
const
DenseTensor
*>&
moment1
,
const
std
::
vector
<
const
DenseTensor
*>&
moment2
,
const
std
::
vector
<
const
DenseTensor
*>&
beta1_pow
,
const
std
::
vector
<
const
DenseTensor
*>&
beta2_pow
,
const
paddle
::
optional
<
std
::
vector
<
const
DenseTensor
*>>&
master_param
,
const
Scalar
&
beta1
,
const
Scalar
&
beta2
,
const
Scalar
&
epsilon
,
bool
multi_precision
,
bool
use_global_beta_pow
,
std
::
vector
<
DenseTensor
*>
param_out
,
std
::
vector
<
DenseTensor
*>
moment1_out
,
std
::
vector
<
DenseTensor
*>
moment2_out
,
std
::
vector
<
DenseTensor
*>
beta1_pow_out
,
std
::
vector
<
DenseTensor
*>
beta2_pow_out
,
std
::
vector
<
DenseTensor
*>
master_param_out
)
{
using
MPDType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
VLOG
(
4
)
<<
"use_global_beta_pow:"
<<
use_global_beta_pow
;
MPDType
beta1_
=
beta1
.
to
<
MPDType
>
();
MPDType
beta2_
=
beta2
.
to
<
MPDType
>
();
MPDType
epsilon_
=
epsilon
.
to
<
MPDType
>
();
size_t
param_num
=
param
.
size
();
for
(
size_t
idx
=
0
;
idx
<
param_num
;
idx
++
)
{
const
MPDType
*
master_in_data
=
multi_precision
?
master_param
.
get
()[
idx
]
->
data
<
MPDType
>
()
:
nullptr
;
MPDType
*
master_out_data
=
multi_precision
?
dev_ctx
.
template
Alloc
<
MPDType
>(
master_param_out
[
idx
])
:
nullptr
;
// update param and moment
int
threads
=
512
;
int
blocks
=
(
param
[
idx
]
->
numel
()
+
threads
-
1
)
/
threads
;
if
(
beta1_pow
[
idx
]
->
place
()
==
CPUPlace
()
&&
beta2_pow
[
idx
]
->
place
()
==
CPUPlace
())
{
// Compute with betapow in REG
AdamKernelREG
<
T
,
MPDType
><<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
beta1_
,
beta2_
,
epsilon_
,
*
beta1_pow
[
idx
]
->
data
<
MPDType
>
(),
*
beta2_pow
[
idx
]
->
data
<
MPDType
>
(),
moment1
[
idx
]
->
data
<
MPDType
>
(),
dev_ctx
.
template
Alloc
<
MPDType
>(
moment1_out
[
idx
]),
moment2
[
idx
]
->
data
<
MPDType
>
(),
dev_ctx
.
template
Alloc
<
MPDType
>(
moment2_out
[
idx
]),
learning_rate
[
idx
]
->
data
<
MPDType
>
(),
grad
[
idx
]
->
data
<
T
>
(),
param
[
idx
]
->
data
<
T
>
(),
dev_ctx
.
template
Alloc
<
T
>(
param_out
[
idx
]),
master_in_data
,
master_out_data
,
param
[
idx
]
->
numel
());
if
(
!
use_global_beta_pow
)
{
// Cpu update
dev_ctx
.
template
HostAlloc
<
MPDType
>(
beta1_pow_out
[
idx
])[
0
]
=
beta1_
*
beta1_pow
[
idx
]
->
data
<
MPDType
>
()[
0
];
dev_ctx
.
template
HostAlloc
<
MPDType
>(
beta2_pow_out
[
idx
])[
0
]
=
beta2_
*
beta2_pow
[
idx
]
->
data
<
MPDType
>
()[
0
];
}
}
else
{
AdamKernelMEM
<
T
,
MPDType
><<<
blocks
,
threads
,
0
,
dev_ctx
.
stream
()
>>>
(
beta1_
,
beta2_
,
epsilon_
,
beta1_pow
[
idx
]
->
data
<
MPDType
>
(),
beta2_pow
[
idx
]
->
data
<
MPDType
>
(),
moment1
[
idx
]
->
data
<
MPDType
>
(),
dev_ctx
.
template
Alloc
<
MPDType
>(
moment1_out
[
idx
]),
moment2
[
idx
]
->
data
<
MPDType
>
(),
dev_ctx
.
template
Alloc
<
MPDType
>(
moment2_out
[
idx
]),
learning_rate
[
idx
]
->
data
<
MPDType
>
(),
grad
[
idx
]
->
data
<
T
>
(),
param
[
idx
]
->
data
<
T
>
(),
dev_ctx
.
template
Alloc
<
T
>(
param_out
[
idx
]),
master_in_data
,
master_out_data
,
param
[
idx
]
->
numel
());
if
(
!
use_global_beta_pow
)
{
// Update with gpu
UpdateBetaPow
<
MPDType
><<<
1
,
32
,
0
,
dev_ctx
.
stream
()
>>>
(
beta1_
,
beta2_
,
beta1_pow
[
idx
]
->
data
<
MPDType
>
(),
beta2_pow
[
idx
]
->
data
<
MPDType
>
(),
dev_ctx
.
template
Alloc
<
MPDType
>(
beta1_pow_out
[
idx
]),
dev_ctx
.
template
Alloc
<
MPDType
>(
beta2_pow_out
[
idx
]));
}
}
}
}
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
adam
,
PD_REGISTER_KERNEL
(
adam
,
...
@@ -279,3 +379,15 @@ PD_REGISTER_KERNEL(adam,
...
@@ -279,3 +379,15 @@ PD_REGISTER_KERNEL(adam,
kernel
->
InputAt
(
6
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
6
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
8
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
8
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
}
}
PD_REGISTER_KERNEL
(
merged_adam
,
GPU
,
ALL_LAYOUT
,
phi
::
MergedAdamKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{
// Skip beta1_pow, beta2_pow data transform
kernel
->
InputAt
(
5
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
kernel
->
InputAt
(
6
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
}
paddle/phi/ops/compat/merged_adam_sig.cc
0 → 100644
浏览文件 @
d55ee95f
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/small_vector.h"
namespace
phi
{
KernelSignature
MergedAdamOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
paddle
::
small_vector
<
const
char
*>
in_names
=
{
"Param"
,
"Grad"
,
"LearningRate"
,
"Moment1"
,
"Moment2"
,
"Beta1Pow"
,
"Beta2Pow"
,
"MasterParam"
};
paddle
::
small_vector
<
const
char
*>
out_names
=
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
,
"MasterParamOut"
};
paddle
::
small_vector
<
const
char
*>
attr_names
=
{
"beta1"
,
"beta2"
,
"epsilon"
,
"multi_precision"
,
"use_global_beta_pow"
};
return
KernelSignature
(
"merged_adam"
,
std
::
move
(
in_names
),
std
::
move
(
attr_names
),
std
::
move
(
out_names
));
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
merged_adam
,
phi
::
MergedAdamOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录