未验证 提交 d15b490a 编写于 作者: Y Yuang Liu 提交者: GitHub

[operator migration] Migrate merged momentum cpu/gpu kernels (#44300)

上级 84b72c5f
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -103,7 +103,3 @@ namespace plat = paddle::platform; ...@@ -103,7 +103,3 @@ namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(merged_momentum, REGISTER_OP_WITHOUT_GRADIENT(merged_momentum,
ops::MergedMomentumOp, ops::MergedMomentumOp,
ops::MergedMomentumOpMaker); ops::MergedMomentumOpMaker);
REGISTER_OP_CPU_KERNEL(merged_momentum,
ops::MergedMomentumOpKernel<phi::CPUContext, float>,
ops::MergedMomentumOpKernel<phi::CPUContext, double>);
...@@ -12,8 +12,14 @@ ...@@ -12,8 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,8 +12,13 @@ ...@@ -12,8 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" #include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
namespace paddle { namespace paddle {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/phi/core/macros.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -29,9 +29,3 @@ limitations under the License. */ ...@@ -29,9 +29,3 @@ limitations under the License. */
#define FLT_MAX __FLT_MAX__ #define FLT_MAX __FLT_MAX__
#endif // __FLT_MAX__ #endif // __FLT_MAX__
#endif // PADDLE_WITH_MUSL #endif // PADDLE_WITH_MUSL
#if defined(__NVCC__) || defined(__HIPCC__)
#define PADDLE_RESTRICT __restrict__
#else
#define PADDLE_RESTRICT
#endif
...@@ -53,4 +53,10 @@ namespace phi { ...@@ -53,4 +53,10 @@ namespace phi {
#define PD_CONCATENATE2(arg1, arg2) arg1##arg2 #define PD_CONCATENATE2(arg1, arg2) arg1##arg2
#define PD_EXPAND(x) x #define PD_EXPAND(x) x
#if defined(__NVCC__) || defined(__HIPCC__)
#define PADDLE_RESTRICT __restrict__
#else
#define PADDLE_RESTRICT
#endif
} // namespace phi } // namespace phi
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/merged_momentum_impl.h"
namespace ops = paddle::operators; PD_REGISTER_KERNEL(merged_momentum,
namespace plat = paddle::platform; CPU,
ALL_LAYOUT,
REGISTER_OP_CUDA_KERNEL( phi::MergedMomentumKernel,
merged_momentum, float,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, plat::float16>, double) {}
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, float>,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, double>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/merged_momentum_impl.h"
PD_REGISTER_KERNEL(merged_momentum,
GPU,
ALL_LAYOUT,
phi::MergedMomentumKernel,
phi::dtype::float16,
float,
double) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,19 +14,18 @@ ...@@ -14,19 +14,18 @@
#pragma once #pragma once
#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/phi/core/macros.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" #include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
#include "paddle/phi/kernels/merged_momentum_kernel.h"
namespace paddle { namespace phi {
namespace operators {
template <typename T> template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type; using MultiPrecisionType = typename phi::dtype::MPTypeTrait<T>::Type;
template <typename MT, uint32_t kParamNum, bool kHasMasterParams> template <typename MT, uint32_t kParamNum, bool kHasMasterParams>
struct MergedMomentumMasterParams { struct MergedMomentumMasterParams {
...@@ -84,171 +83,152 @@ struct MergedMomentumKernelParam ...@@ -84,171 +83,152 @@ struct MergedMomentumKernelParam
} }
}; };
template <typename DeviceContext, typename T> template <typename MT, typename Context, typename MPType, typename T>
class MergedMomentumOpKernel : public framework::OpKernel<T> { void MergedMomentumInnerCompute(
using MPType = typename operators::details::MPTypeTrait<T>::Type; const Context &ctx,
const std::vector<const DenseTensor *> &params,
const std::vector<const DenseTensor *> &grads,
const std::vector<const DenseTensor *> &velocitys,
const std::vector<const DenseTensor *> &lrs,
const paddle::optional<std::vector<const DenseTensor *>> &master_params_opt,
float mu,
bool use_nesterov,
const std::vector<std::string> &regularization_methods,
const std::vector<float> &regularization_coeffs,
float rescale_grad,
const bool multi_precision,
std::vector<DenseTensor *> params_out,
std::vector<DenseTensor *> velocitys_out,
std::vector<DenseTensor *> master_params_out) {
size_t n = params.size();
PADDLE_ENFORCE_EQ(n,
params_out.size(),
phi::errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
params_out.size(),
n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(
params[i],
params_out[i],
phi::errors::InvalidArgument("Input(Param) and Output(ParamOut) "
"must be the same Tensors."));
}
public: PADDLE_ENFORCE_EQ(
void Compute(const framework::ExecutionContext &ctx) const override { n,
const bool multi_precision = ctx.Attr<bool>("multi_precision"); grads.size(),
if (multi_precision) { phi::errors::InvalidArgument(
InnerCompute<MPType>(ctx, multi_precision); "The size of Input(Grad) must be equal to Input(Param), but got "
} else { "the size of Input(Grad) is %d, the size of Input(Param) is %d.",
InnerCompute<T>(ctx, multi_precision); grads.size(),
} n));
PADDLE_ENFORCE_EQ(n,
velocitys.size(),
phi::errors::InvalidArgument(
"The size of Input(Velocity) must be equal to "
"Input(Param), but got the size of Input(Velocity) "
"is %d, the size of Input(Param) is %d.",
velocitys.size(),
n));
PADDLE_ENFORCE_EQ(
n,
velocitys_out.size(),
phi::errors::InvalidArgument(
"The size of Output(VelocityOut) must be "
"equal to Input(Param), but got the size of Output(VelocityOut) is "
"%d, the size of Input(Param) is %d.",
velocitys_out.size(),
n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(velocitys[i],
velocitys_out[i],
phi::errors::InvalidArgument(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."));
} }
private: if (multi_precision) {
template <typename MT> auto master_params = master_params_opt.get();
void InnerCompute(const framework::ExecutionContext &ctx, PADDLE_ENFORCE_EQ(
const bool multi_precision) const { n,
auto params = ctx.MultiInput<framework::Tensor>("Param"); master_params.size(),
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut"); phi::errors::InvalidArgument(
size_t n = params.size(); "The size of Input(MasterParam) must be "
PADDLE_ENFORCE_EQ(n, "equal to Input(Param), but got the size of Input(MasterParam) "
params_out.size(), "is %d, the size of Input(Param) is %d.",
platform::errors::InvalidArgument( master_params.size(),
"The size of Output(ParamOut) must be equal to " n));
"Input(Param), but got the size of Output(ParamOut) " PADDLE_ENFORCE_EQ(
"is %d, the size of Input(Param) is %d.", n,
params_out.size(), master_params_out.size(),
n)); phi::errors::InvalidArgument(
"The size of Output(MasterParamOut) must be equal to "
"Input(MasterParam), but got the size of Output(MasterParamOut) "
"is %d, the size of Input(Param) is %d.",
master_params_out.size(),
n));
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(params[i], PADDLE_ENFORCE_EQ(master_params[i],
params_out[i], master_params_out[i],
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The size of Input(Param) and Output(ParamOut) " "Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors.")); "must be the same Tensors."));
PADDLE_ENFORCE_NOT_NULL(master_params[i],
phi::errors::InvalidArgument(
"Input(MasterParam) must be provided when "
"multi_precision=True."));
} }
} else {
master_params_out.clear();
}
auto grads = ctx.MultiInput<framework::Tensor>("Grad"); if (lrs.size() != 1) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
n, n,
grads.size(), lrs.size(),
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The size of Input(Grad) must be equal to Input(Param), but got " "If the size of Input(LearningRate) is not 1, the size of "
"the size of Input(Grad) is %d, the size of Input(Param) is %d.", "Input(LearningRate) must be "
grads.size(), "equal to Input(Param), but got the size of Input(LearningRate) "
"is %d, the size of Input(Param) is %d.",
lrs.size(),
n)); n));
}
auto velocitys = ctx.MultiInput<framework::Tensor>("Velocity"); if (regularization_methods.size() != 0) {
PADDLE_ENFORCE_EQ(n,
velocitys.size(),
platform::errors::InvalidArgument(
"The size of Input(Velocity) must be equal to "
"Input(Param), but got the size of Input(Velocity) "
"is %d, the size of Input(Param) is %d.",
velocitys.size(),
n));
auto velocitys_out = ctx.MultiOutput<framework::Tensor>("VelocityOut");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
n, n,
velocitys_out.size(), regularization_methods.size(),
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The size of Output(VelocityOut) must be " "The size of Attr(regularization_method) must be equal "
"equal to Input(Param), but got the size of Output(VelocityOut) is " "to Input(Param), but got the size of "
"%d, the size of Input(Param) is %d.", "Attr(regularization_method) is %d, the size of Input(Param) is "
velocitys_out.size(), "%d.",
regularization_methods.size(),
n)); n));
for (size_t i = 0; i < n; ++i) { PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(velocitys[i], n,
velocitys_out[i], regularization_coeffs.size(),
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Input(Velocity) and Output(VelocityOut) must be " "The size of Attr(regularization_coeff) must be equal "
"the same Tensors.")); "to Input(Param), but got the size of Attr(regularization_coeff) "
} "is %d, the size of Input(Param) is %d.",
regularization_coeffs.size(),
auto master_params = ctx.MultiInput<framework::Tensor>("MasterParam"); n));
auto master_params_out = }
ctx.MultiOutput<framework::Tensor>("MasterParamOut");
if (multi_precision) {
PADDLE_ENFORCE_EQ(
n,
master_params.size(),
platform::errors::InvalidArgument(
"The size of Input(MasterParam) must be "
"equal to Input(Param), but got the size of Input(MasterParam) "
"is %d, the size of Input(Param) is %d.",
master_params.size(),
n));
PADDLE_ENFORCE_EQ(
n,
master_params_out.size(),
platform::errors::InvalidArgument(
"The size of Output(MasterParamOut) must be equal to "
"Input(MasterParam), but got the size of Output(MasterParamOut) "
"is %d, the size of Input(Param) is %d.",
master_params_out.size(),
n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(master_params[i],
master_params_out[i],
platform::errors::InvalidArgument(
"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."));
PADDLE_ENFORCE_NOT_NULL(master_params[i],
platform::errors::InvalidArgument(
"Input(MasterParam) must be provided when "
"multi_precision=True."));
}
} else {
master_params.clear();
master_params_out.clear();
}
auto mu = ctx.Attr<float>("mu");
auto rescale_grad = ctx.Attr<float>("rescale_grad");
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate");
if (lrs.size() != 1) {
PADDLE_ENFORCE_EQ(
n,
lrs.size(),
platform::errors::InvalidArgument(
"If the size of Input(LearningRate) is not 1, the size of "
"Input(LearningRate) must be "
"equal to Input(Param), but got the size of Input(LearningRate) "
"is %d, the size of Input(Param) is %d.",
lrs.size(),
n));
}
auto use_nesterov = ctx.Attr<bool>("use_nesterov");
auto regularization_methods =
ctx.Attr<std::vector<std::string>>("regularization_method");
auto regularization_coeffs =
ctx.Attr<std::vector<float>>("regularization_coeff");
if (regularization_methods.size() != 0) {
PADDLE_ENFORCE_EQ(
n,
regularization_methods.size(),
platform::errors::InvalidArgument(
"The size of Attr(regularization_method) must be equal "
"to Input(Param), but got the size of "
"Attr(regularization_method) is %d, the size of Input(Param) is "
"%d.",
regularization_methods.size(),
n));
PADDLE_ENFORCE_EQ(
n,
regularization_coeffs.size(),
platform::errors::InvalidArgument(
"The size of Attr(regularization_coeff) must be equal "
"to Input(Param), but got the size of Attr(regularization_coeff) "
"is %d, the size of Input(Param) is %d.",
regularization_coeffs.size(),
n));
}
VLOG(5) << "use_nesterov: " << use_nesterov
<< ", regularization_methods.size(): "
<< regularization_methods.size()
<< ", regularization_coeffs.size(): "
<< regularization_coeffs.size();
auto &dev_ctx = ctx.template device_context<DeviceContext>(); VLOG(5) << "use_nesterov: " << use_nesterov
<< ", regularization_methods.size(): "
<< regularization_methods.size()
<< ", regularization_coeffs.size(): "
<< regularization_coeffs.size();
if (lrs.size() == 1 && use_nesterov == false && if (lrs.size() == 1 && use_nesterov == false &&
regularization_methods.size() == 0) { regularization_methods.size() == 0) {
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \ #define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MT, kMultiPrecision> kernel_params; \ MergedMomentumKernelParam<T, MT, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \ constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
...@@ -273,52 +253,51 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -273,52 +253,51 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
kMultiPrecision ? master_params_out[j + start]->data<MT>() \ kMultiPrecision ? master_params_out[j + start]->data<MT>() \
: nullptr); \ : nullptr); \
} \ } \
platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \ phi::funcs::ForRange<Context> for_range(ctx, max_size); \
for_range(kernel_params); \ for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \ VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \ << kernel_params.param_num; \
} }
if (multi_precision) { if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true); PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
} else {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}
#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
} else { } else {
for (size_t idx = 0; idx < n; idx++) { PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
phi::RegularizationType regularization_flag = }
regularization_methods.size() > 0 && #undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
regularization_methods[idx] == "l2_decay" } else {
? phi::RegularizationType::kL2DECAY for (size_t idx = 0; idx < n; idx++) {
: phi::RegularizationType::kNONE; phi::RegularizationType regularization_flag =
regularization_methods.size() > 0 &&
regularization_methods[idx] == "l2_decay"
? phi::RegularizationType::kL2DECAY
: phi::RegularizationType::kNONE;
MT regularization_coeff = static_cast<MT>(0.0); MT regularization_coeff = static_cast<MT>(0.0);
if (regularization_coeffs.size() != 0) { if (regularization_coeffs.size() != 0) {
regularization_coeff = static_cast<MT>(regularization_coeffs[idx]); regularization_coeff = static_cast<MT>(regularization_coeffs[idx]);
} }
auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0]; auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0];
const MT *master_in_data = const MT *master_in_data =
multi_precision ? master_params[idx]->data<MT>() : nullptr; multi_precision ? master_params_opt.get()[idx]->data<MT>() : nullptr;
MT *master_out_data = MT *master_out_data =
multi_precision ? master_params_out[idx]->data<MT>() : nullptr; multi_precision ? master_params_out[idx]->data<MT>() : nullptr;
if (platform::is_cpu_place(ctx.GetPlace())) { if (paddle::platform::is_cpu_place(ctx.GetPlace())) {
phi::CPUDenseMomentumFunctor<MT> functor; phi::CPUDenseMomentumFunctor<MT> functor;
functor(params[idx], functor(params[idx],
grads[idx], grads[idx],
velocitys[idx], velocitys[idx],
lr_temp, lr_temp,
static_cast<MT>(mu), static_cast<MT>(mu),
use_nesterov, use_nesterov,
regularization_flag, regularization_flag,
regularization_coeff, regularization_coeff,
params_out[idx], params_out[idx],
velocitys_out[idx]); velocitys_out[idx]);
VLOG(10) << "Launch MergedMomentum cpu kernel."; VLOG(10) << "Launch MergedMomentum cpu kernel.";
} else if (platform::is_gpu_place(ctx.GetPlace())) { } else if (paddle::platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range( phi::funcs::ForRange<Context> for_range(
static_cast<const DeviceContext &>(ctx.device_context()), static_cast<const Context &>(ctx), params[idx]->numel());
params[idx]->numel());
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \ #define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
phi::DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \ phi::DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), \ params[idx]->data<T>(), \
...@@ -334,37 +313,88 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -334,37 +313,88 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
velocitys_out[idx]->data<MT>(), \ velocitys_out[idx]->data<MT>(), \
master_out_data); \ master_out_data); \
for_range(functor); for_range(functor);
if (use_nesterov) { if (use_nesterov) {
if (regularization_flag == phi::RegularizationType::kL2DECAY) { if (regularization_flag == phi::RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
phi::UseNesterov, phi::RegularizationType::kL2DECAY); phi::UseNesterov, phi::RegularizationType::kL2DECAY);
VLOG(10) VLOG(10)
<< "Launch MergedMomentum gpu kernel use_nesterov kL2DECAY."; << "Launch MergedMomentum gpu kernel use_nesterov kL2DECAY.";
} else {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
phi::UseNesterov, phi::RegularizationType::kNONE);
VLOG(10)
<< "Launch MergedMomentum gpu kernel use_nesterov kNONE.";
}
} else { } else {
if (regularization_flag == phi::RegularizationType::kL2DECAY) { PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( phi::UseNesterov, phi::RegularizationType::kNONE);
phi::NoNesterov, phi::RegularizationType::kL2DECAY); VLOG(10) << "Launch MergedMomentum gpu kernel use_nesterov kNONE.";
VLOG(10) }
<< "Launch MergedMomentum gpu kernel no_nesterov kL2DECAY."; } else {
} else { if (regularization_flag == phi::RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
phi::NoNesterov, phi::RegularizationType::kNONE); phi::NoNesterov, phi::RegularizationType::kL2DECAY);
VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kNONE."; VLOG(10)
} << "Launch MergedMomentum gpu kernel no_nesterov kL2DECAY.";
} else {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
phi::NoNesterov, phi::RegularizationType::kNONE);
VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kNONE.";
} }
} }
} }
VLOG(10)
<< "Launch MergedMomentum kernel with multi_lr and regularization.";
} }
VLOG(10)
<< "Launch MergedMomentum kernel with multi_lr and regularization.";
} }
}; }
template <typename T, typename Context>
void MergedMomentumKernel(
const Context &dev_ctx,
const std::vector<const DenseTensor *> &param,
const std::vector<const DenseTensor *> &grad,
const std::vector<const DenseTensor *> &velocity,
const std::vector<const DenseTensor *> &learning_rate,
const paddle::optional<std::vector<const DenseTensor *>> &master_param,
float mu,
bool use_nesterov,
const std::vector<std::string> &regularization_method,
const std::vector<float> &regularization_coeff,
bool multi_precision,
float rescale_grad,
std::vector<DenseTensor *> param_out,
std::vector<DenseTensor *> velocity_out,
std::vector<DenseTensor *> master_param_out) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
if (multi_precision) {
MergedMomentumInnerCompute<MPType, Context, MPType, T>(
dev_ctx,
param,
grad,
velocity,
learning_rate,
master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
rescale_grad,
multi_precision,
param_out,
velocity_out,
master_param_out);
} else {
MergedMomentumInnerCompute<T, Context, MPType, T>(dev_ctx,
param,
grad,
velocity,
learning_rate,
master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
rescale_grad,
multi_precision,
param_out,
velocity_out,
master_param_out);
}
}
} // namespace operators } // namespace phi
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MergedMomentumKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& param,
const std::vector<const DenseTensor*>& grad,
const std::vector<const DenseTensor*>& velocity,
const std::vector<const DenseTensor*>& learning_rate,
const paddle::optional<std::vector<const DenseTensor*>>& master_param,
float mu,
bool use_nesterov,
const std::vector<std::string>& regularization_method,
const std::vector<float>& regularization_coeff,
bool multi_precision,
float rescale_grad,
std::vector<DenseTensor*> param_out,
std::vector<DenseTensor*> velocity_out,
std::vector<DenseTensor*> master_param_out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MergedMomentumOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"merged_momentum",
{"Param", "Grad", "Velocity", "LearningRate", "MasterParam"},
{"mu",
"use_nesterov",
"regularization_method",
"regularization_coeff",
"multi_precision",
"rescale_grad"},
{
"ParamOut",
"VelocityOut",
"MasterParamOut",
});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(merged_momentum,
phi::MergedMomentumOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册