未验证 提交 d15b490a 编写于 作者: Y Yuang Liu 提交者: GitHub

[operator migration] Migrate merged momentum cpu/gpu kernels (#44300)

上级 84b72c5f
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -103,7 +103,3 @@ namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(merged_momentum,
ops::MergedMomentumOp,
ops::MergedMomentumOpMaker);
REGISTER_OP_CPU_KERNEL(merged_momentum,
ops::MergedMomentumOpKernel<phi::CPUContext, float>,
ops::MergedMomentumOpKernel<phi::CPUContext, double>);
......@@ -12,8 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
namespace paddle {
namespace operators {
......
......@@ -12,8 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
namespace paddle {
......
......@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/core/macros.h"
namespace paddle {
namespace operators {
......
......@@ -29,9 +29,3 @@ limitations under the License. */
#define FLT_MAX __FLT_MAX__
#endif // __FLT_MAX__
#endif // PADDLE_WITH_MUSL
#if defined(__NVCC__) || defined(__HIPCC__)
#define PADDLE_RESTRICT __restrict__
#else
#define PADDLE_RESTRICT
#endif
......@@ -53,4 +53,10 @@ namespace phi {
#define PD_CONCATENATE2(arg1, arg2) arg1##arg2
#define PD_EXPAND(x) x
#if defined(__NVCC__) || defined(__HIPCC__)
#define PADDLE_RESTRICT __restrict__
#else
#define PADDLE_RESTRICT
#endif
} // namespace phi
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/merged_momentum_impl.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
merged_momentum,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, plat::float16>,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, float>,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, double>);
PD_REGISTER_KERNEL(merged_momentum,
CPU,
ALL_LAYOUT,
phi::MergedMomentumKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/merged_momentum_impl.h"
PD_REGISTER_KERNEL(merged_momentum,
GPU,
ALL_LAYOUT,
phi::MergedMomentumKernel,
phi::dtype::float16,
float,
double) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -14,19 +14,18 @@
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
#include "paddle/phi/kernels/merged_momentum_kernel.h"
namespace paddle {
namespace operators {
namespace phi {
template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;
using MultiPrecisionType = typename phi::dtype::MPTypeTrait<T>::Type;
template <typename MT, uint32_t kParamNum, bool kHasMasterParams>
struct MergedMomentumMasterParams {
......@@ -84,171 +83,152 @@ struct MergedMomentumKernelParam
}
};
template <typename DeviceContext, typename T>
class MergedMomentumOpKernel : public framework::OpKernel<T> {
using MPType = typename operators::details::MPTypeTrait<T>::Type;
template <typename MT, typename Context, typename MPType, typename T>
void MergedMomentumInnerCompute(
const Context &ctx,
const std::vector<const DenseTensor *> &params,
const std::vector<const DenseTensor *> &grads,
const std::vector<const DenseTensor *> &velocitys,
const std::vector<const DenseTensor *> &lrs,
const paddle::optional<std::vector<const DenseTensor *>> &master_params_opt,
float mu,
bool use_nesterov,
const std::vector<std::string> &regularization_methods,
const std::vector<float> &regularization_coeffs,
float rescale_grad,
const bool multi_precision,
std::vector<DenseTensor *> params_out,
std::vector<DenseTensor *> velocitys_out,
std::vector<DenseTensor *> master_params_out) {
size_t n = params.size();
PADDLE_ENFORCE_EQ(n,
params_out.size(),
phi::errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
params_out.size(),
n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(
params[i],
params_out[i],
phi::errors::InvalidArgument("Input(Param) and Output(ParamOut) "
"must be the same Tensors."));
}
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const bool multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
InnerCompute<MPType>(ctx, multi_precision);
} else {
InnerCompute<T>(ctx, multi_precision);
}
PADDLE_ENFORCE_EQ(
n,
grads.size(),
phi::errors::InvalidArgument(
"The size of Input(Grad) must be equal to Input(Param), but got "
"the size of Input(Grad) is %d, the size of Input(Param) is %d.",
grads.size(),
n));
PADDLE_ENFORCE_EQ(n,
velocitys.size(),
phi::errors::InvalidArgument(
"The size of Input(Velocity) must be equal to "
"Input(Param), but got the size of Input(Velocity) "
"is %d, the size of Input(Param) is %d.",
velocitys.size(),
n));
PADDLE_ENFORCE_EQ(
n,
velocitys_out.size(),
phi::errors::InvalidArgument(
"The size of Output(VelocityOut) must be "
"equal to Input(Param), but got the size of Output(VelocityOut) is "
"%d, the size of Input(Param) is %d.",
velocitys_out.size(),
n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(velocitys[i],
velocitys_out[i],
phi::errors::InvalidArgument(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."));
}
private:
template <typename MT>
void InnerCompute(const framework::ExecutionContext &ctx,
const bool multi_precision) const {
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
size_t n = params.size();
PADDLE_ENFORCE_EQ(n,
params_out.size(),
platform::errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
params_out.size(),
n));
if (multi_precision) {
auto master_params = master_params_opt.get();
PADDLE_ENFORCE_EQ(
n,
master_params.size(),
phi::errors::InvalidArgument(
"The size of Input(MasterParam) must be "
"equal to Input(Param), but got the size of Input(MasterParam) "
"is %d, the size of Input(Param) is %d.",
master_params.size(),
n));
PADDLE_ENFORCE_EQ(
n,
master_params_out.size(),
phi::errors::InvalidArgument(
"The size of Output(MasterParamOut) must be equal to "
"Input(MasterParam), but got the size of Output(MasterParamOut) "
"is %d, the size of Input(Param) is %d.",
master_params_out.size(),
n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(params[i],
params_out[i],
platform::errors::InvalidArgument(
"The size of Input(Param) and Output(ParamOut) "
PADDLE_ENFORCE_EQ(master_params[i],
master_params_out[i],
phi::errors::InvalidArgument(
"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."));
PADDLE_ENFORCE_NOT_NULL(master_params[i],
phi::errors::InvalidArgument(
"Input(MasterParam) must be provided when "
"multi_precision=True."));
}
} else {
master_params_out.clear();
}
auto grads = ctx.MultiInput<framework::Tensor>("Grad");
if (lrs.size() != 1) {
PADDLE_ENFORCE_EQ(
n,
grads.size(),
platform::errors::InvalidArgument(
"The size of Input(Grad) must be equal to Input(Param), but got "
"the size of Input(Grad) is %d, the size of Input(Param) is %d.",
grads.size(),
lrs.size(),
phi::errors::InvalidArgument(
"If the size of Input(LearningRate) is not 1, the size of "
"Input(LearningRate) must be "
"equal to Input(Param), but got the size of Input(LearningRate) "
"is %d, the size of Input(Param) is %d.",
lrs.size(),
n));
auto velocitys = ctx.MultiInput<framework::Tensor>("Velocity");
PADDLE_ENFORCE_EQ(n,
velocitys.size(),
platform::errors::InvalidArgument(
"The size of Input(Velocity) must be equal to "
"Input(Param), but got the size of Input(Velocity) "
"is %d, the size of Input(Param) is %d.",
velocitys.size(),
n));
auto velocitys_out = ctx.MultiOutput<framework::Tensor>("VelocityOut");
}
if (regularization_methods.size() != 0) {
PADDLE_ENFORCE_EQ(
n,
velocitys_out.size(),
platform::errors::InvalidArgument(
"The size of Output(VelocityOut) must be "
"equal to Input(Param), but got the size of Output(VelocityOut) is "
"%d, the size of Input(Param) is %d.",
velocitys_out.size(),
regularization_methods.size(),
phi::errors::InvalidArgument(
"The size of Attr(regularization_method) must be equal "
"to Input(Param), but got the size of "
"Attr(regularization_method) is %d, the size of Input(Param) is "
"%d.",
regularization_methods.size(),
n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(velocitys[i],
velocitys_out[i],
platform::errors::InvalidArgument(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."));
}
auto master_params = ctx.MultiInput<framework::Tensor>("MasterParam");
auto master_params_out =
ctx.MultiOutput<framework::Tensor>("MasterParamOut");
if (multi_precision) {
PADDLE_ENFORCE_EQ(
n,
master_params.size(),
platform::errors::InvalidArgument(
"The size of Input(MasterParam) must be "
"equal to Input(Param), but got the size of Input(MasterParam) "
"is %d, the size of Input(Param) is %d.",
master_params.size(),
n));
PADDLE_ENFORCE_EQ(
n,
master_params_out.size(),
platform::errors::InvalidArgument(
"The size of Output(MasterParamOut) must be equal to "
"Input(MasterParam), but got the size of Output(MasterParamOut) "
"is %d, the size of Input(Param) is %d.",
master_params_out.size(),
n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(master_params[i],
master_params_out[i],
platform::errors::InvalidArgument(
"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."));
PADDLE_ENFORCE_NOT_NULL(master_params[i],
platform::errors::InvalidArgument(
"Input(MasterParam) must be provided when "
"multi_precision=True."));
}
} else {
master_params.clear();
master_params_out.clear();
}
auto mu = ctx.Attr<float>("mu");
auto rescale_grad = ctx.Attr<float>("rescale_grad");
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate");
if (lrs.size() != 1) {
PADDLE_ENFORCE_EQ(
n,
lrs.size(),
platform::errors::InvalidArgument(
"If the size of Input(LearningRate) is not 1, the size of "
"Input(LearningRate) must be "
"equal to Input(Param), but got the size of Input(LearningRate) "
"is %d, the size of Input(Param) is %d.",
lrs.size(),
n));
}
auto use_nesterov = ctx.Attr<bool>("use_nesterov");
auto regularization_methods =
ctx.Attr<std::vector<std::string>>("regularization_method");
auto regularization_coeffs =
ctx.Attr<std::vector<float>>("regularization_coeff");
if (regularization_methods.size() != 0) {
PADDLE_ENFORCE_EQ(
n,
regularization_methods.size(),
platform::errors::InvalidArgument(
"The size of Attr(regularization_method) must be equal "
"to Input(Param), but got the size of "
"Attr(regularization_method) is %d, the size of Input(Param) is "
"%d.",
regularization_methods.size(),
n));
PADDLE_ENFORCE_EQ(
n,
regularization_coeffs.size(),
platform::errors::InvalidArgument(
"The size of Attr(regularization_coeff) must be equal "
"to Input(Param), but got the size of Attr(regularization_coeff) "
"is %d, the size of Input(Param) is %d.",
regularization_coeffs.size(),
n));
}
VLOG(5) << "use_nesterov: " << use_nesterov
<< ", regularization_methods.size(): "
<< regularization_methods.size()
<< ", regularization_coeffs.size(): "
<< regularization_coeffs.size();
PADDLE_ENFORCE_EQ(
n,
regularization_coeffs.size(),
phi::errors::InvalidArgument(
"The size of Attr(regularization_coeff) must be equal "
"to Input(Param), but got the size of Attr(regularization_coeff) "
"is %d, the size of Input(Param) is %d.",
regularization_coeffs.size(),
n));
}
auto &dev_ctx = ctx.template device_context<DeviceContext>();
VLOG(5) << "use_nesterov: " << use_nesterov
<< ", regularization_methods.size(): "
<< regularization_methods.size()
<< ", regularization_coeffs.size(): "
<< regularization_coeffs.size();
if (lrs.size() == 1 && use_nesterov == false &&
regularization_methods.size() == 0) {
if (lrs.size() == 1 && use_nesterov == false &&
regularization_methods.size() == 0) {
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MT, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
......@@ -273,52 +253,51 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
kMultiPrecision ? master_params_out[j + start]->data<MT>() \
: nullptr); \
} \
platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \
phi::funcs::ForRange<Context> for_range(ctx, max_size); \
for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
}
if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
} else {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}
#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
} else {
for (size_t idx = 0; idx < n; idx++) {
phi::RegularizationType regularization_flag =
regularization_methods.size() > 0 &&
regularization_methods[idx] == "l2_decay"
? phi::RegularizationType::kL2DECAY
: phi::RegularizationType::kNONE;
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}
#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
} else {
for (size_t idx = 0; idx < n; idx++) {
phi::RegularizationType regularization_flag =
regularization_methods.size() > 0 &&
regularization_methods[idx] == "l2_decay"
? phi::RegularizationType::kL2DECAY
: phi::RegularizationType::kNONE;
MT regularization_coeff = static_cast<MT>(0.0);
if (regularization_coeffs.size() != 0) {
regularization_coeff = static_cast<MT>(regularization_coeffs[idx]);
}
auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0];
MT regularization_coeff = static_cast<MT>(0.0);
if (regularization_coeffs.size() != 0) {
regularization_coeff = static_cast<MT>(regularization_coeffs[idx]);
}
auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0];
const MT *master_in_data =
multi_precision ? master_params[idx]->data<MT>() : nullptr;
MT *master_out_data =
multi_precision ? master_params_out[idx]->data<MT>() : nullptr;
if (platform::is_cpu_place(ctx.GetPlace())) {
phi::CPUDenseMomentumFunctor<MT> functor;
functor(params[idx],
grads[idx],
velocitys[idx],
lr_temp,
static_cast<MT>(mu),
use_nesterov,
regularization_flag,
regularization_coeff,
params_out[idx],
velocitys_out[idx]);
VLOG(10) << "Launch MergedMomentum cpu kernel.";
} else if (platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext &>(ctx.device_context()),
params[idx]->numel());
const MT *master_in_data =
multi_precision ? master_params_opt.get()[idx]->data<MT>() : nullptr;
MT *master_out_data =
multi_precision ? master_params_out[idx]->data<MT>() : nullptr;
if (paddle::platform::is_cpu_place(ctx.GetPlace())) {
phi::CPUDenseMomentumFunctor<MT> functor;
functor(params[idx],
grads[idx],
velocitys[idx],
lr_temp,
static_cast<MT>(mu),
use_nesterov,
regularization_flag,
regularization_coeff,
params_out[idx],
velocitys_out[idx]);
VLOG(10) << "Launch MergedMomentum cpu kernel.";
} else if (paddle::platform::is_gpu_place(ctx.GetPlace())) {
phi::funcs::ForRange<Context> for_range(
static_cast<const Context &>(ctx), params[idx]->numel());
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
phi::DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), \
......@@ -334,37 +313,88 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
velocitys_out[idx]->data<MT>(), \
master_out_data); \
for_range(functor);
if (use_nesterov) {
if (regularization_flag == phi::RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
phi::UseNesterov, phi::RegularizationType::kL2DECAY);
VLOG(10)
<< "Launch MergedMomentum gpu kernel use_nesterov kL2DECAY.";
} else {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
phi::UseNesterov, phi::RegularizationType::kNONE);
VLOG(10)
<< "Launch MergedMomentum gpu kernel use_nesterov kNONE.";
}
if (use_nesterov) {
if (regularization_flag == phi::RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
phi::UseNesterov, phi::RegularizationType::kL2DECAY);
VLOG(10)
<< "Launch MergedMomentum gpu kernel use_nesterov kL2DECAY.";
} else {
if (regularization_flag == phi::RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
phi::NoNesterov, phi::RegularizationType::kL2DECAY);
VLOG(10)
<< "Launch MergedMomentum gpu kernel no_nesterov kL2DECAY.";
} else {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
phi::NoNesterov, phi::RegularizationType::kNONE);
VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kNONE.";
}
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
phi::UseNesterov, phi::RegularizationType::kNONE);
VLOG(10) << "Launch MergedMomentum gpu kernel use_nesterov kNONE.";
}
} else {
if (regularization_flag == phi::RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
phi::NoNesterov, phi::RegularizationType::kL2DECAY);
VLOG(10)
<< "Launch MergedMomentum gpu kernel no_nesterov kL2DECAY.";
} else {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
phi::NoNesterov, phi::RegularizationType::kNONE);
VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kNONE.";
}
}
}
VLOG(10)
<< "Launch MergedMomentum kernel with multi_lr and regularization.";
}
VLOG(10)
<< "Launch MergedMomentum kernel with multi_lr and regularization.";
}
};
}
template <typename T, typename Context>
void MergedMomentumKernel(
const Context &dev_ctx,
const std::vector<const DenseTensor *> &param,
const std::vector<const DenseTensor *> &grad,
const std::vector<const DenseTensor *> &velocity,
const std::vector<const DenseTensor *> &learning_rate,
const paddle::optional<std::vector<const DenseTensor *>> &master_param,
float mu,
bool use_nesterov,
const std::vector<std::string> &regularization_method,
const std::vector<float> &regularization_coeff,
bool multi_precision,
float rescale_grad,
std::vector<DenseTensor *> param_out,
std::vector<DenseTensor *> velocity_out,
std::vector<DenseTensor *> master_param_out) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
if (multi_precision) {
MergedMomentumInnerCompute<MPType, Context, MPType, T>(
dev_ctx,
param,
grad,
velocity,
learning_rate,
master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
rescale_grad,
multi_precision,
param_out,
velocity_out,
master_param_out);
} else {
MergedMomentumInnerCompute<T, Context, MPType, T>(dev_ctx,
param,
grad,
velocity,
learning_rate,
master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
rescale_grad,
multi_precision,
param_out,
velocity_out,
master_param_out);
}
}
} // namespace operators
} // namespace paddle
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MergedMomentumKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& param,
const std::vector<const DenseTensor*>& grad,
const std::vector<const DenseTensor*>& velocity,
const std::vector<const DenseTensor*>& learning_rate,
const paddle::optional<std::vector<const DenseTensor*>>& master_param,
float mu,
bool use_nesterov,
const std::vector<std::string>& regularization_method,
const std::vector<float>& regularization_coeff,
bool multi_precision,
float rescale_grad,
std::vector<DenseTensor*> param_out,
std::vector<DenseTensor*> velocity_out,
std::vector<DenseTensor*> master_param_out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MergedMomentumOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"merged_momentum",
{"Param", "Grad", "Velocity", "LearningRate", "MasterParam"},
{"mu",
"use_nesterov",
"regularization_method",
"regularization_coeff",
"multi_precision",
"rescale_grad"},
{
"ParamOut",
"VelocityOut",
"MasterParamOut",
});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(merged_momentum,
phi::MergedMomentumOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册