未验证 提交 d15b490a 编写于 作者: Y Yuang Liu 提交者: GitHub

[operator migration] Migrate merged momentum cpu/gpu kernels (#44300)

上级 84b72c5f
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -103,7 +103,3 @@ namespace plat = paddle::platform; ...@@ -103,7 +103,3 @@ namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(merged_momentum, REGISTER_OP_WITHOUT_GRADIENT(merged_momentum,
ops::MergedMomentumOp, ops::MergedMomentumOp,
ops::MergedMomentumOpMaker); ops::MergedMomentumOpMaker);
REGISTER_OP_CPU_KERNEL(merged_momentum,
ops::MergedMomentumOpKernel<phi::CPUContext, float>,
ops::MergedMomentumOpKernel<phi::CPUContext, double>);
...@@ -12,8 +12,14 @@ ...@@ -12,8 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,8 +12,13 @@ ...@@ -12,8 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" #include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
namespace paddle { namespace paddle {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/phi/core/macros.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -29,9 +29,3 @@ limitations under the License. */ ...@@ -29,9 +29,3 @@ limitations under the License. */
#define FLT_MAX __FLT_MAX__ #define FLT_MAX __FLT_MAX__
#endif // __FLT_MAX__ #endif // __FLT_MAX__
#endif // PADDLE_WITH_MUSL #endif // PADDLE_WITH_MUSL
#if defined(__NVCC__) || defined(__HIPCC__)
#define PADDLE_RESTRICT __restrict__
#else
#define PADDLE_RESTRICT
#endif
...@@ -53,4 +53,10 @@ namespace phi { ...@@ -53,4 +53,10 @@ namespace phi {
#define PD_CONCATENATE2(arg1, arg2) arg1##arg2 #define PD_CONCATENATE2(arg1, arg2) arg1##arg2
#define PD_EXPAND(x) x #define PD_EXPAND(x) x
#if defined(__NVCC__) || defined(__HIPCC__)
#define PADDLE_RESTRICT __restrict__
#else
#define PADDLE_RESTRICT
#endif
} // namespace phi } // namespace phi
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/merged_momentum_impl.h"
namespace ops = paddle::operators; PD_REGISTER_KERNEL(merged_momentum,
namespace plat = paddle::platform; CPU,
ALL_LAYOUT,
REGISTER_OP_CUDA_KERNEL( phi::MergedMomentumKernel,
merged_momentum, float,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, plat::float16>, double) {}
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, float>,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, double>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/merged_momentum_impl.h"
PD_REGISTER_KERNEL(merged_momentum,
GPU,
ALL_LAYOUT,
phi::MergedMomentumKernel,
phi::dtype::float16,
float,
double) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,19 +14,18 @@ ...@@ -14,19 +14,18 @@
#pragma once #pragma once
#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/phi/core/macros.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" #include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
#include "paddle/phi/kernels/merged_momentum_kernel.h"
namespace paddle { namespace phi {
namespace operators {
template <typename T> template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type; using MultiPrecisionType = typename phi::dtype::MPTypeTrait<T>::Type;
template <typename MT, uint32_t kParamNum, bool kHasMasterParams> template <typename MT, uint32_t kParamNum, bool kHasMasterParams>
struct MergedMomentumMasterParams { struct MergedMomentumMasterParams {
...@@ -84,68 +83,62 @@ struct MergedMomentumKernelParam ...@@ -84,68 +83,62 @@ struct MergedMomentumKernelParam
} }
}; };
template <typename DeviceContext, typename T> template <typename MT, typename Context, typename MPType, typename T>
class MergedMomentumOpKernel : public framework::OpKernel<T> { void MergedMomentumInnerCompute(
using MPType = typename operators::details::MPTypeTrait<T>::Type; const Context &ctx,
const std::vector<const DenseTensor *> &params,
public: const std::vector<const DenseTensor *> &grads,
void Compute(const framework::ExecutionContext &ctx) const override { const std::vector<const DenseTensor *> &velocitys,
const bool multi_precision = ctx.Attr<bool>("multi_precision"); const std::vector<const DenseTensor *> &lrs,
if (multi_precision) { const paddle::optional<std::vector<const DenseTensor *>> &master_params_opt,
InnerCompute<MPType>(ctx, multi_precision); float mu,
} else { bool use_nesterov,
InnerCompute<T>(ctx, multi_precision); const std::vector<std::string> &regularization_methods,
} const std::vector<float> &regularization_coeffs,
} float rescale_grad,
const bool multi_precision,
private: std::vector<DenseTensor *> params_out,
template <typename MT> std::vector<DenseTensor *> velocitys_out,
void InnerCompute(const framework::ExecutionContext &ctx, std::vector<DenseTensor *> master_params_out) {
const bool multi_precision) const {
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
size_t n = params.size(); size_t n = params.size();
PADDLE_ENFORCE_EQ(n, PADDLE_ENFORCE_EQ(n,
params_out.size(), params_out.size(),
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to " "The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) " "Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.", "is %d, the size of Input(Param) is %d.",
params_out.size(), params_out.size(),
n)); n));
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(params[i], PADDLE_ENFORCE_EQ(
params[i],
params_out[i], params_out[i],
platform::errors::InvalidArgument( phi::errors::InvalidArgument("Input(Param) and Output(ParamOut) "
"The size of Input(Param) and Output(ParamOut) "
"must be the same Tensors.")); "must be the same Tensors."));
} }
auto grads = ctx.MultiInput<framework::Tensor>("Grad");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
n, n,
grads.size(), grads.size(),
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The size of Input(Grad) must be equal to Input(Param), but got " "The size of Input(Grad) must be equal to Input(Param), but got "
"the size of Input(Grad) is %d, the size of Input(Param) is %d.", "the size of Input(Grad) is %d, the size of Input(Param) is %d.",
grads.size(), grads.size(),
n)); n));
auto velocitys = ctx.MultiInput<framework::Tensor>("Velocity");
PADDLE_ENFORCE_EQ(n, PADDLE_ENFORCE_EQ(n,
velocitys.size(), velocitys.size(),
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The size of Input(Velocity) must be equal to " "The size of Input(Velocity) must be equal to "
"Input(Param), but got the size of Input(Velocity) " "Input(Param), but got the size of Input(Velocity) "
"is %d, the size of Input(Param) is %d.", "is %d, the size of Input(Param) is %d.",
velocitys.size(), velocitys.size(),
n)); n));
auto velocitys_out = ctx.MultiOutput<framework::Tensor>("VelocityOut");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
n, n,
velocitys_out.size(), velocitys_out.size(),
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The size of Output(VelocityOut) must be " "The size of Output(VelocityOut) must be "
"equal to Input(Param), but got the size of Output(VelocityOut) is " "equal to Input(Param), but got the size of Output(VelocityOut) is "
"%d, the size of Input(Param) is %d.", "%d, the size of Input(Param) is %d.",
...@@ -154,19 +147,17 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -154,19 +147,17 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(velocitys[i], PADDLE_ENFORCE_EQ(velocitys[i],
velocitys_out[i], velocitys_out[i],
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Input(Velocity) and Output(VelocityOut) must be " "Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors.")); "the same Tensors."));
} }
auto master_params = ctx.MultiInput<framework::Tensor>("MasterParam");
auto master_params_out =
ctx.MultiOutput<framework::Tensor>("MasterParamOut");
if (multi_precision) { if (multi_precision) {
auto master_params = master_params_opt.get();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
n, n,
master_params.size(), master_params.size(),
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The size of Input(MasterParam) must be " "The size of Input(MasterParam) must be "
"equal to Input(Param), but got the size of Input(MasterParam) " "equal to Input(Param), but got the size of Input(MasterParam) "
"is %d, the size of Input(Param) is %d.", "is %d, the size of Input(Param) is %d.",
...@@ -175,7 +166,7 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -175,7 +166,7 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
n, n,
master_params_out.size(), master_params_out.size(),
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The size of Output(MasterParamOut) must be equal to " "The size of Output(MasterParamOut) must be equal to "
"Input(MasterParam), but got the size of Output(MasterParamOut) " "Input(MasterParam), but got the size of Output(MasterParamOut) "
"is %d, the size of Input(Param) is %d.", "is %d, the size of Input(Param) is %d.",
...@@ -184,27 +175,23 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -184,27 +175,23 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(master_params[i], PADDLE_ENFORCE_EQ(master_params[i],
master_params_out[i], master_params_out[i],
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Input(MasterParam) and Output(MasterParamOut) " "Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors.")); "must be the same Tensors."));
PADDLE_ENFORCE_NOT_NULL(master_params[i], PADDLE_ENFORCE_NOT_NULL(master_params[i],
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Input(MasterParam) must be provided when " "Input(MasterParam) must be provided when "
"multi_precision=True.")); "multi_precision=True."));
} }
} else { } else {
master_params.clear();
master_params_out.clear(); master_params_out.clear();
} }
auto mu = ctx.Attr<float>("mu");
auto rescale_grad = ctx.Attr<float>("rescale_grad");
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate");
if (lrs.size() != 1) { if (lrs.size() != 1) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
n, n,
lrs.size(), lrs.size(),
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"If the size of Input(LearningRate) is not 1, the size of " "If the size of Input(LearningRate) is not 1, the size of "
"Input(LearningRate) must be " "Input(LearningRate) must be "
"equal to Input(Param), but got the size of Input(LearningRate) " "equal to Input(Param), but got the size of Input(LearningRate) "
...@@ -212,16 +199,11 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -212,16 +199,11 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
lrs.size(), lrs.size(),
n)); n));
} }
auto use_nesterov = ctx.Attr<bool>("use_nesterov");
auto regularization_methods =
ctx.Attr<std::vector<std::string>>("regularization_method");
auto regularization_coeffs =
ctx.Attr<std::vector<float>>("regularization_coeff");
if (regularization_methods.size() != 0) { if (regularization_methods.size() != 0) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
n, n,
regularization_methods.size(), regularization_methods.size(),
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The size of Attr(regularization_method) must be equal " "The size of Attr(regularization_method) must be equal "
"to Input(Param), but got the size of " "to Input(Param), but got the size of "
"Attr(regularization_method) is %d, the size of Input(Param) is " "Attr(regularization_method) is %d, the size of Input(Param) is "
...@@ -231,7 +213,7 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -231,7 +213,7 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
n, n,
regularization_coeffs.size(), regularization_coeffs.size(),
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The size of Attr(regularization_coeff) must be equal " "The size of Attr(regularization_coeff) must be equal "
"to Input(Param), but got the size of Attr(regularization_coeff) " "to Input(Param), but got the size of Attr(regularization_coeff) "
"is %d, the size of Input(Param) is %d.", "is %d, the size of Input(Param) is %d.",
...@@ -245,8 +227,6 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -245,8 +227,6 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
<< ", regularization_coeffs.size(): " << ", regularization_coeffs.size(): "
<< regularization_coeffs.size(); << regularization_coeffs.size();
auto &dev_ctx = ctx.template device_context<DeviceContext>();
if (lrs.size() == 1 && use_nesterov == false && if (lrs.size() == 1 && use_nesterov == false &&
regularization_methods.size() == 0) { regularization_methods.size() == 0) {
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \ #define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
...@@ -273,7 +253,7 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -273,7 +253,7 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
kMultiPrecision ? master_params_out[j + start]->data<MT>() \ kMultiPrecision ? master_params_out[j + start]->data<MT>() \
: nullptr); \ : nullptr); \
} \ } \
platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \ phi::funcs::ForRange<Context> for_range(ctx, max_size); \
for_range(kernel_params); \ for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \ VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \ << kernel_params.param_num; \
...@@ -299,10 +279,10 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -299,10 +279,10 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0]; auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0];
const MT *master_in_data = const MT *master_in_data =
multi_precision ? master_params[idx]->data<MT>() : nullptr; multi_precision ? master_params_opt.get()[idx]->data<MT>() : nullptr;
MT *master_out_data = MT *master_out_data =
multi_precision ? master_params_out[idx]->data<MT>() : nullptr; multi_precision ? master_params_out[idx]->data<MT>() : nullptr;
if (platform::is_cpu_place(ctx.GetPlace())) { if (paddle::platform::is_cpu_place(ctx.GetPlace())) {
phi::CPUDenseMomentumFunctor<MT> functor; phi::CPUDenseMomentumFunctor<MT> functor;
functor(params[idx], functor(params[idx],
grads[idx], grads[idx],
...@@ -315,10 +295,9 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -315,10 +295,9 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
params_out[idx], params_out[idx],
velocitys_out[idx]); velocitys_out[idx]);
VLOG(10) << "Launch MergedMomentum cpu kernel."; VLOG(10) << "Launch MergedMomentum cpu kernel.";
} else if (platform::is_gpu_place(ctx.GetPlace())) { } else if (paddle::platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range( phi::funcs::ForRange<Context> for_range(
static_cast<const DeviceContext &>(ctx.device_context()), static_cast<const Context &>(ctx), params[idx]->numel());
params[idx]->numel());
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \ #define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
phi::DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \ phi::DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), \ params[idx]->data<T>(), \
...@@ -343,8 +322,7 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -343,8 +322,7 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
} else { } else {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
phi::UseNesterov, phi::RegularizationType::kNONE); phi::UseNesterov, phi::RegularizationType::kNONE);
VLOG(10) VLOG(10) << "Launch MergedMomentum gpu kernel use_nesterov kNONE.";
<< "Launch MergedMomentum gpu kernel use_nesterov kNONE.";
} }
} else { } else {
if (regularization_flag == phi::RegularizationType::kL2DECAY) { if (regularization_flag == phi::RegularizationType::kL2DECAY) {
...@@ -363,8 +341,60 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -363,8 +341,60 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
VLOG(10) VLOG(10)
<< "Launch MergedMomentum kernel with multi_lr and regularization."; << "Launch MergedMomentum kernel with multi_lr and regularization.";
} }
}
template <typename T, typename Context>
void MergedMomentumKernel(
const Context &dev_ctx,
const std::vector<const DenseTensor *> &param,
const std::vector<const DenseTensor *> &grad,
const std::vector<const DenseTensor *> &velocity,
const std::vector<const DenseTensor *> &learning_rate,
const paddle::optional<std::vector<const DenseTensor *>> &master_param,
float mu,
bool use_nesterov,
const std::vector<std::string> &regularization_method,
const std::vector<float> &regularization_coeff,
bool multi_precision,
float rescale_grad,
std::vector<DenseTensor *> param_out,
std::vector<DenseTensor *> velocity_out,
std::vector<DenseTensor *> master_param_out) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
if (multi_precision) {
MergedMomentumInnerCompute<MPType, Context, MPType, T>(
dev_ctx,
param,
grad,
velocity,
learning_rate,
master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
rescale_grad,
multi_precision,
param_out,
velocity_out,
master_param_out);
} else {
MergedMomentumInnerCompute<T, Context, MPType, T>(dev_ctx,
param,
grad,
velocity,
learning_rate,
master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
rescale_grad,
multi_precision,
param_out,
velocity_out,
master_param_out);
} }
}; }
} // namespace operators } // namespace phi
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MergedMomentumKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& param,
const std::vector<const DenseTensor*>& grad,
const std::vector<const DenseTensor*>& velocity,
const std::vector<const DenseTensor*>& learning_rate,
const paddle::optional<std::vector<const DenseTensor*>>& master_param,
float mu,
bool use_nesterov,
const std::vector<std::string>& regularization_method,
const std::vector<float>& regularization_coeff,
bool multi_precision,
float rescale_grad,
std::vector<DenseTensor*> param_out,
std::vector<DenseTensor*> velocity_out,
std::vector<DenseTensor*> master_param_out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MergedMomentumOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"merged_momentum",
{"Param", "Grad", "Velocity", "LearningRate", "MasterParam"},
{"mu",
"use_nesterov",
"regularization_method",
"regularization_coeff",
"multi_precision",
"rescale_grad"},
{
"ParamOut",
"VelocityOut",
"MasterParamOut",
});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(merged_momentum,
phi::MergedMomentumOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册