Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d15b490a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d15b490a
编写于
7月 14, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
7月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[operator migration] Migrate merged momentum cpu/gpu kernels (#44300)
上级
84b72c5f
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
538 addition
and
24 deletion
+538
-24
paddle/fluid/operators/optimizers/merged_momentum_op.cc
paddle/fluid/operators/optimizers/merged_momentum_op.cc
+1
-5
paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc
paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc
+7
-1
paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc
paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc
+6
-1
paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h
...d/operators/optimizers/pow2_decay_with_linear_warmup_op.h
+1
-1
paddle/fluid/platform/macros.h
paddle/fluid/platform/macros.h
+0
-6
paddle/phi/core/macros.h
paddle/phi/core/macros.h
+6
-0
paddle/phi/kernels/cpu/merged_momentum_kernel.cc
paddle/phi/kernels/cpu/merged_momentum_kernel.cc
+10
-10
paddle/phi/kernels/gpu/merged_momentum_kernel.cu
paddle/phi/kernels/gpu/merged_momentum_kernel.cu
+25
-0
paddle/phi/kernels/impl/merged_momentum_impl.h
paddle/phi/kernels/impl/merged_momentum_impl.h
+400
-0
paddle/phi/kernels/merged_momentum_kernel.h
paddle/phi/kernels/merged_momentum_kernel.h
+42
-0
paddle/phi/ops/compat/merged_momentum_sig.cc
paddle/phi/ops/compat/merged_momentum_sig.cc
+40
-0
未找到文件。
paddle/fluid/operators/optimizers/merged_momentum_op.cc
浏览文件 @
d15b490a
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/
operators/optimizers/merged_momentum_op
.h"
#include "paddle/fluid/
framework/op_registry
.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -103,7 +103,3 @@ namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT
(
merged_momentum
,
ops
::
MergedMomentumOp
,
ops
::
MergedMomentumOpMaker
);
REGISTER_OP_CPU_KERNEL
(
merged_momentum
,
ops
::
MergedMomentumOpKernel
<
phi
::
CPUContext
,
float
>
,
ops
::
MergedMomentumOpKernel
<
phi
::
CPUContext
,
double
>
);
paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc
浏览文件 @
d15b490a
...
...
@@ -12,8 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc
浏览文件 @
d15b490a
...
...
@@ -12,8 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
namespace
paddle
{
...
...
paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h
浏览文件 @
d15b490a
...
...
@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/
fluid/platform
/macros.h"
#include "paddle/
phi/core
/macros.h"
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/platform/macros.h
浏览文件 @
d15b490a
...
...
@@ -29,9 +29,3 @@ limitations under the License. */
#define FLT_MAX __FLT_MAX__
#endif // __FLT_MAX__
#endif // PADDLE_WITH_MUSL
#if defined(__NVCC__) || defined(__HIPCC__)
#define PADDLE_RESTRICT __restrict__
#else
#define PADDLE_RESTRICT
#endif
paddle/phi/core/macros.h
浏览文件 @
d15b490a
...
...
@@ -53,4 +53,10 @@ namespace phi {
#define PD_CONCATENATE2(arg1, arg2) arg1##arg2
#define PD_EXPAND(x) x
#if defined(__NVCC__) || defined(__HIPCC__)
#define PADDLE_RESTRICT __restrict__
#else
#define PADDLE_RESTRICT
#endif
}
// namespace phi
paddle/
fluid/operators/optimizers/merged_momentum_op.cu
→
paddle/
phi/kernels/cpu/merged_momentum_kernel.cc
浏览文件 @
d15b490a
// Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/merged_momentum_impl.h"
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
merged_momentum
,
ops
::
MergedMomentumOpKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
MergedMomentumOpKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
MergedMomentumOpKernel
<
plat
::
CUDADeviceContext
,
double
>
);
PD_REGISTER_KERNEL
(
merged_momentum
,
CPU
,
ALL_LAYOUT
,
phi
::
MergedMomentumKernel
,
float
,
double
)
{}
paddle/phi/kernels/gpu/merged_momentum_kernel.cu
0 → 100644
浏览文件 @
d15b490a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/merged_momentum_impl.h"
PD_REGISTER_KERNEL
(
merged_momentum
,
GPU
,
ALL_LAYOUT
,
phi
::
MergedMomentumKernel
,
phi
::
dtype
::
float16
,
float
,
double
)
{}
paddle/
fluid/operators/optimizers/merged_momentum_op
.h
→
paddle/
phi/kernels/impl/merged_momentum_impl
.h
浏览文件 @
d15b490a
// Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -14,19 +14,18 @@
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
#include "paddle/phi/kernels/merged_momentum_kernel.h"
namespace
paddle
{
namespace
operators
{
namespace
phi
{
template
<
typename
T
>
using
MultiPrecisionType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
using
MultiPrecisionType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
template
<
typename
MT
,
uint32_t
kParamNum
,
bool
kHasMasterParams
>
struct
MergedMomentumMasterParams
{
...
...
@@ -84,68 +83,62 @@ struct MergedMomentumKernelParam
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
MergedMomentumOpKernel
:
public
framework
::
OpKernel
<
T
>
{
using
MPType
=
typename
operators
::
details
::
MPTypeTrait
<
T
>::
Type
;
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
bool
multi_precision
=
ctx
.
Attr
<
bool
>
(
"multi_precision"
);
if
(
multi_precision
)
{
InnerCompute
<
MPType
>
(
ctx
,
multi_precision
);
}
else
{
InnerCompute
<
T
>
(
ctx
,
multi_precision
);
}
}
private:
template
<
typename
MT
>
void
InnerCompute
(
const
framework
::
ExecutionContext
&
ctx
,
const
bool
multi_precision
)
const
{
auto
params
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Param"
);
auto
params_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"ParamOut"
);
template
<
typename
MT
,
typename
Context
,
typename
MPType
,
typename
T
>
void
MergedMomentumInnerCompute
(
const
Context
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
params
,
const
std
::
vector
<
const
DenseTensor
*>
&
grads
,
const
std
::
vector
<
const
DenseTensor
*>
&
velocitys
,
const
std
::
vector
<
const
DenseTensor
*>
&
lrs
,
const
paddle
::
optional
<
std
::
vector
<
const
DenseTensor
*>>
&
master_params_opt
,
float
mu
,
bool
use_nesterov
,
const
std
::
vector
<
std
::
string
>
&
regularization_methods
,
const
std
::
vector
<
float
>
&
regularization_coeffs
,
float
rescale_grad
,
const
bool
multi_precision
,
std
::
vector
<
DenseTensor
*>
params_out
,
std
::
vector
<
DenseTensor
*>
velocitys_out
,
std
::
vector
<
DenseTensor
*>
master_params_out
)
{
size_t
n
=
params
.
size
();
PADDLE_ENFORCE_EQ
(
n
,
params_out
.
size
(),
platform
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d."
,
params_out
.
size
(),
n
));
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
params
[
i
],
PADDLE_ENFORCE_EQ
(
params
[
i
],
params_out
[
i
],
platform
::
errors
::
InvalidArgument
(
"The size of Input(Param) and Output(ParamOut) "
phi
::
errors
::
InvalidArgument
(
"Input(Param) and Output(ParamOut) "
"must be the same Tensors."
));
}
auto
grads
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Grad"
);
PADDLE_ENFORCE_EQ
(
n
,
grads
.
size
(),
platform
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"The size of Input(Grad) must be equal to Input(Param), but got "
"the size of Input(Grad) is %d, the size of Input(Param) is %d."
,
grads
.
size
(),
n
));
auto
velocitys
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Velocity"
);
PADDLE_ENFORCE_EQ
(
n
,
velocitys
.
size
(),
platform
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"The size of Input(Velocity) must be equal to "
"Input(Param), but got the size of Input(Velocity) "
"is %d, the size of Input(Param) is %d."
,
velocitys
.
size
(),
n
));
auto
velocitys_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"VelocityOut"
);
PADDLE_ENFORCE_EQ
(
n
,
velocitys_out
.
size
(),
platform
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"The size of Output(VelocityOut) must be "
"equal to Input(Param), but got the size of Output(VelocityOut) is "
"%d, the size of Input(Param) is %d."
,
...
...
@@ -154,19 +147,17 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
velocitys
[
i
],
velocitys_out
[
i
],
platform
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."
));
}
auto
master_params
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"MasterParam"
);
auto
master_params_out
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"MasterParamOut"
);
if
(
multi_precision
)
{
auto
master_params
=
master_params_opt
.
get
();
PADDLE_ENFORCE_EQ
(
n
,
master_params
.
size
(),
platform
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"The size of Input(MasterParam) must be "
"equal to Input(Param), but got the size of Input(MasterParam) "
"is %d, the size of Input(Param) is %d."
,
...
...
@@ -175,7 +166,7 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ
(
n
,
master_params_out
.
size
(),
platform
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"The size of Output(MasterParamOut) must be equal to "
"Input(MasterParam), but got the size of Output(MasterParamOut) "
"is %d, the size of Input(Param) is %d."
,
...
...
@@ -184,27 +175,23 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
master_params
[
i
],
master_params_out
[
i
],
platform
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."
));
PADDLE_ENFORCE_NOT_NULL
(
master_params
[
i
],
platform
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"Input(MasterParam) must be provided when "
"multi_precision=True."
));
}
}
else
{
master_params
.
clear
();
master_params_out
.
clear
();
}
auto
mu
=
ctx
.
Attr
<
float
>
(
"mu"
);
auto
rescale_grad
=
ctx
.
Attr
<
float
>
(
"rescale_grad"
);
auto
lrs
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"LearningRate"
);
if
(
lrs
.
size
()
!=
1
)
{
PADDLE_ENFORCE_EQ
(
n
,
lrs
.
size
(),
platform
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"If the size of Input(LearningRate) is not 1, the size of "
"Input(LearningRate) must be "
"equal to Input(Param), but got the size of Input(LearningRate) "
...
...
@@ -212,16 +199,11 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
lrs
.
size
(),
n
));
}
auto
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"use_nesterov"
);
auto
regularization_methods
=
ctx
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"regularization_method"
);
auto
regularization_coeffs
=
ctx
.
Attr
<
std
::
vector
<
float
>>
(
"regularization_coeff"
);
if
(
regularization_methods
.
size
()
!=
0
)
{
PADDLE_ENFORCE_EQ
(
n
,
regularization_methods
.
size
(),
platform
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"The size of Attr(regularization_method) must be equal "
"to Input(Param), but got the size of "
"Attr(regularization_method) is %d, the size of Input(Param) is "
...
...
@@ -231,7 +213,7 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ
(
n
,
regularization_coeffs
.
size
(),
platform
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"The size of Attr(regularization_coeff) must be equal "
"to Input(Param), but got the size of Attr(regularization_coeff) "
"is %d, the size of Input(Param) is %d."
,
...
...
@@ -245,8 +227,6 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
<<
", regularization_coeffs.size(): "
<<
regularization_coeffs
.
size
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
if
(
lrs
.
size
()
==
1
&&
use_nesterov
==
false
&&
regularization_methods
.
size
()
==
0
)
{
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
...
...
@@ -273,7 +253,7 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
kMultiPrecision ? master_params_out[j + start]->data<MT>() \
: nullptr); \
} \
p
latform::ForRange<DeviceContext> for_range(dev_ctx, max_size);
\
p
hi::funcs::ForRange<Context> for_range(ctx, max_size);
\
for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
...
...
@@ -299,10 +279,10 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
auto
lr_temp
=
lrs
.
size
()
>
1
?
lrs
[
idx
]
:
lrs
[
0
];
const
MT
*
master_in_data
=
multi_precision
?
master_params
[
idx
]
->
data
<
MT
>
()
:
nullptr
;
multi_precision
?
master_params_opt
.
get
()
[
idx
]
->
data
<
MT
>
()
:
nullptr
;
MT
*
master_out_data
=
multi_precision
?
master_params_out
[
idx
]
->
data
<
MT
>
()
:
nullptr
;
if
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()))
{
if
(
paddle
::
platform
::
is_cpu_place
(
ctx
.
GetPlace
()))
{
phi
::
CPUDenseMomentumFunctor
<
MT
>
functor
;
functor
(
params
[
idx
],
grads
[
idx
],
...
...
@@ -315,10 +295,9 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
params_out
[
idx
],
velocitys_out
[
idx
]);
VLOG
(
10
)
<<
"Launch MergedMomentum cpu kernel."
;
}
else
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
params
[
idx
]
->
numel
());
}
else
if
(
paddle
::
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
phi
::
funcs
::
ForRange
<
Context
>
for_range
(
static_cast
<
const
Context
&>
(
ctx
),
params
[
idx
]
->
numel
());
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
phi::DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), \
...
...
@@ -343,8 +322,7 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
}
else
{
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL
(
phi
::
UseNesterov
,
phi
::
RegularizationType
::
kNONE
);
VLOG
(
10
)
<<
"Launch MergedMomentum gpu kernel use_nesterov kNONE."
;
VLOG
(
10
)
<<
"Launch MergedMomentum gpu kernel use_nesterov kNONE."
;
}
}
else
{
if
(
regularization_flag
==
phi
::
RegularizationType
::
kL2DECAY
)
{
...
...
@@ -363,8 +341,60 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
VLOG
(
10
)
<<
"Launch MergedMomentum kernel with multi_lr and regularization."
;
}
}
template
<
typename
T
,
typename
Context
>
void
MergedMomentumKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>
&
param
,
const
std
::
vector
<
const
DenseTensor
*>
&
grad
,
const
std
::
vector
<
const
DenseTensor
*>
&
velocity
,
const
std
::
vector
<
const
DenseTensor
*>
&
learning_rate
,
const
paddle
::
optional
<
std
::
vector
<
const
DenseTensor
*>>
&
master_param
,
float
mu
,
bool
use_nesterov
,
const
std
::
vector
<
std
::
string
>
&
regularization_method
,
const
std
::
vector
<
float
>
&
regularization_coeff
,
bool
multi_precision
,
float
rescale_grad
,
std
::
vector
<
DenseTensor
*>
param_out
,
std
::
vector
<
DenseTensor
*>
velocity_out
,
std
::
vector
<
DenseTensor
*>
master_param_out
)
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
if
(
multi_precision
)
{
MergedMomentumInnerCompute
<
MPType
,
Context
,
MPType
,
T
>
(
dev_ctx
,
param
,
grad
,
velocity
,
learning_rate
,
master_param
,
mu
,
use_nesterov
,
regularization_method
,
regularization_coeff
,
rescale_grad
,
multi_precision
,
param_out
,
velocity_out
,
master_param_out
);
}
else
{
MergedMomentumInnerCompute
<
T
,
Context
,
MPType
,
T
>
(
dev_ctx
,
param
,
grad
,
velocity
,
learning_rate
,
master_param
,
mu
,
use_nesterov
,
regularization_method
,
regularization_coeff
,
rescale_grad
,
multi_precision
,
param_out
,
velocity_out
,
master_param_out
);
}
}
;
}
}
// namespace operators
}
// namespace paddle
}
// namespace phi
paddle/phi/kernels/merged_momentum_kernel.h
0 → 100644
浏览文件 @
d15b490a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
MergedMomentumKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
param
,
const
std
::
vector
<
const
DenseTensor
*>&
grad
,
const
std
::
vector
<
const
DenseTensor
*>&
velocity
,
const
std
::
vector
<
const
DenseTensor
*>&
learning_rate
,
const
paddle
::
optional
<
std
::
vector
<
const
DenseTensor
*>>&
master_param
,
float
mu
,
bool
use_nesterov
,
const
std
::
vector
<
std
::
string
>&
regularization_method
,
const
std
::
vector
<
float
>&
regularization_coeff
,
bool
multi_precision
,
float
rescale_grad
,
std
::
vector
<
DenseTensor
*>
param_out
,
std
::
vector
<
DenseTensor
*>
velocity_out
,
std
::
vector
<
DenseTensor
*>
master_param_out
);
}
// namespace phi
paddle/phi/ops/compat/merged_momentum_sig.cc
0 → 100644
浏览文件 @
d15b490a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
MergedMomentumOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"merged_momentum"
,
{
"Param"
,
"Grad"
,
"Velocity"
,
"LearningRate"
,
"MasterParam"
},
{
"mu"
,
"use_nesterov"
,
"regularization_method"
,
"regularization_coeff"
,
"multi_precision"
,
"rescale_grad"
},
{
"ParamOut"
,
"VelocityOut"
,
"MasterParamOut"
,
});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
merged_momentum
,
phi
::
MergedMomentumOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录